use std::fmt;
use std::fs::File;
use std::ops::{Deref, DerefMut};
use atomicwrites::{AllowOverwrite, AtomicFile};
use serde::de::DeserializeOwned;
use serde::Serialize;
use super::{Storage, StorageReadGuard, StorageWriteGuard};
pub struct YamlStorageReadGuard<'a, T: Serialize + DeserializeOwned + 'a> {
storage: &'a YamlStorage<T>,
}
impl<'a, T: Serialize + DeserializeOwned> YamlStorageReadGuard<'a, T> {
fn new(storage: &'a YamlStorage<T>) -> Self {
Self { storage }
}
}
impl<'a, T: Serialize + DeserializeOwned + 'a> Deref for YamlStorageReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
&self.storage.data
}
}
impl<'a, T: 'a + Serialize + DeserializeOwned + fmt::Display> fmt::Display
for YamlStorageReadGuard<'a, T>
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
(**self).fmt(f)
}
}
impl<'a, T: 'a + Serialize + DeserializeOwned> StorageReadGuard<'a, T>
for YamlStorageReadGuard<'a, T>
{
}
pub struct YamlStorageWriteGuard<'a, T: Serialize + DeserializeOwned + 'a> {
storage: &'a mut YamlStorage<T>,
}
impl<'a, T: Serialize + DeserializeOwned> YamlStorageWriteGuard<'a, T> {
fn new(storage: &'a mut YamlStorage<T>) -> Self {
Self { storage }
}
}
impl<'a, T: Serialize + DeserializeOwned> Drop for YamlStorageWriteGuard<'a, T> {
fn drop(&mut self) {
self.storage
.file
.write(|f| serde_yaml::to_writer(f, &self.storage.data))
.expect("File write failed while dropping YamlStorageWriteGuard!");
}
}
impl<'a, T: Serialize + DeserializeOwned + 'a> Deref for YamlStorageWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
&self.storage.data
}
}
impl<'a, T: Serialize + DeserializeOwned + 'a> DerefMut for YamlStorageWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.storage.data
}
}
impl<'a, T: 'a + Serialize + DeserializeOwned + fmt::Display> fmt::Display
for YamlStorageWriteGuard<'a, T>
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
(**self).fmt(f)
}
}
impl<'a, T: 'a + Serialize + DeserializeOwned> StorageWriteGuard<'a, T>
for YamlStorageWriteGuard<'a, T>
{
}
pub struct YamlStorage<T: Serialize + DeserializeOwned> {
data: T,
file: AtomicFile,
}
impl<T: Serialize + DeserializeOwned> YamlStorage<T> {
pub fn new<P: Into<String>, F: Fn() -> T>(path: P, default: F) -> Result<Self, String> {
let path = path.into();
let file = AtomicFile::new(path, AllowOverwrite);
let data = match File::open(file.path()) {
Ok(f) => {
serde_yaml::from_reader(f).map_err(|err| format!("Couldn't read file: {}", err))?
}
Err(_) => {
let data = default();
file.write(|f| serde_yaml::to_writer(f, &data))
.map_err(|err| format!("File write failed: {}", err))?;
data
}
};
Ok(Self { data, file })
}
}
impl<T: fmt::Display + Serialize + DeserializeOwned> fmt::Display for YamlStorage<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
(*self).data.fmt(f)
}
}
impl<T: Serialize + DeserializeOwned> Storage for YamlStorage<T> {
type S = T;
fn read<'a>(&'a self) -> Box<dyn StorageReadGuard<'a, T, Target = T> + 'a> {
Box::new(YamlStorageReadGuard::new(self))
}
fn write<'a>(&'a mut self) -> Box<dyn StorageWriteGuard<'a, T, Target = T> + 'a> {
Box::new(YamlStorageWriteGuard::new(self))
}
}
#[cfg(test)]
mod tests {
use std::io::Write;
use std::path::PathBuf;
use tempdir::TempDir;
use super::*;
use crate::circuit::directory::CircuitDirectory;
use crate::circuit::service::SplinterNode;
use crate::circuit::{AuthorizationType, Circuit, DurabilityType, PersistenceType, RouteType};
fn set_up_mock_state_file(mut temp_dir: PathBuf) -> String {
let mut state = CircuitDirectory::new();
let node = SplinterNode::new("123".into(), vec!["tcp://127.0.0.1:8000".into()]);
state.add_node("123".into(), node);
let circuit = Circuit::builder()
.with_id("alpha".into())
.with_auth(AuthorizationType::Trust)
.with_members(vec!["123".into()])
.with_roster(vec!["abc".into(), "def".into()])
.with_persistence(PersistenceType::Any)
.with_durability(DurabilityType::NoDurability)
.with_routes(RouteType::Any)
.with_circuit_management_type("state_test_app".into())
.build()
.expect("Should have built a correct circuit");
state.add_circuit("alpha".into(), circuit);
let state_string = serde_yaml::to_string(&state).unwrap();
temp_dir.push("circuits.yaml");
let path = temp_dir.to_str().unwrap().to_string();
let mut file = File::create(path.to_string()).unwrap();
file.write_all(state_string.as_bytes()).unwrap();
path
}
fn setup_empty_state_file(mut temp_dir: PathBuf) -> String {
let state = CircuitDirectory::new();
let state_string = serde_yaml::to_string(&state).unwrap();
temp_dir.push("circuits.yaml");
let path = temp_dir.to_str().unwrap().to_string();
let mut file = File::create(path.to_string()).unwrap();
file.write_all(state_string.as_bytes()).unwrap();
path
}
#[test]
fn test_load_empty_state() {
let temp_dir = TempDir::new("test_empty_state").unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
let path = setup_empty_state_file(temp_dir_path);
let storage = YamlStorage::new(path, CircuitDirectory::new).unwrap();
assert!(storage.data.nodes().is_empty());
assert!(storage.data.circuits().is_empty());
}
#[test]
fn test_load_no_state() {
let temp_dir = TempDir::new("test_load_no_state").unwrap();
let mut temp_dir_path = temp_dir.path().to_path_buf();
temp_dir_path.push("circuits.yaml");
let path = temp_dir_path.to_str().unwrap().to_string();
let storage = YamlStorage::new(path, CircuitDirectory::new).unwrap();
assert!(storage.data.nodes().is_empty());
assert!(storage.data.circuits().is_empty());
}
#[test]
fn test_load_state() {
let temp_dir = TempDir::new("test_load_state").unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
let path = set_up_mock_state_file(temp_dir_path);
let storage = YamlStorage::new(path, CircuitDirectory::new).unwrap();
assert_eq!(storage.data.nodes().len(), 1);
assert_eq!(storage.data.circuits().len(), 1);
assert!(storage.data.nodes().contains_key("123"));
assert!(storage.data.circuits().contains_key("alpha"));
assert_eq!(
storage
.data
.nodes()
.get("123")
.unwrap()
.endpoints()
.to_vec(),
vec!["tcp://127.0.0.1:8000".to_string()]
);
assert_eq!(
storage
.data
.circuits()
.get("alpha")
.unwrap()
.roster()
.to_vec(),
vec!["abc".into(), "def".into()]
);
assert_eq!(
storage
.data
.circuits()
.get("alpha")
.unwrap()
.members()
.to_vec(),
vec!["123".to_string()],
);
assert_eq!(
storage
.data
.circuits()
.get("alpha")
.unwrap()
.circuit_management_type(),
"state_test_app"
);
}
#[test]
fn test_write_node_state() {
let temp_dir = TempDir::new("test_write_node").unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
let path = set_up_mock_state_file(temp_dir_path);
{
let mut storage = YamlStorage::new(path.clone(), CircuitDirectory::new).unwrap();
let node = SplinterNode::new("123".into(), vec!["tcp://127.0.0.1:5000".into()]);
storage.write().add_node("777".into(), node);
}
let storage = YamlStorage::new(path, CircuitDirectory::new).unwrap();
assert_eq!(storage.data.nodes().len(), 2);
assert_eq!(storage.data.circuits().len(), 1);
assert!(storage.data.nodes().contains_key("123"));
assert!(storage.data.nodes().contains_key("777"));
assert_eq!(
storage
.data
.nodes()
.get("123")
.unwrap()
.endpoints()
.to_vec(),
vec!["tcp://127.0.0.1:8000".to_string()]
);
assert_eq!(
storage
.data
.nodes()
.get("777")
.unwrap()
.endpoints()
.to_vec(),
vec!["tcp://127.0.0.1:5000".to_string()]
);
}
#[test]
fn test_remove_node_from_state() {
let temp_dir = TempDir::new("test_write_circuit").unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
let path = set_up_mock_state_file(temp_dir_path);
{
let mut storage = YamlStorage::new(path.clone(), CircuitDirectory::new).unwrap();
storage.write().remove_node("123".into());
}
let storage = YamlStorage::new(path.clone(), CircuitDirectory::new).unwrap();
assert_eq!(storage.data.nodes().len(), 0);
assert_eq!(storage.data.circuits().len(), 1);
assert!(!storage.data.nodes().contains_key("123"));
assert!(storage.data.circuits().contains_key("alpha"));
assert_eq!(
storage
.data
.circuits()
.get("alpha")
.unwrap()
.roster()
.to_vec(),
vec!["abc".into(), "def".into()]
);
assert_eq!(
storage
.data
.circuits()
.get("alpha")
.unwrap()
.members()
.to_vec(),
vec!["123".to_string()],
);
}
#[test]
fn test_write_circuit_directory() {
let temp_dir = TempDir::new("test_write_circuit").unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
let path = set_up_mock_state_file(temp_dir_path);
{
let mut storage = YamlStorage::new(path.clone(), CircuitDirectory::new).unwrap();
let circuit = Circuit::builder()
.with_id("alpha".into())
.with_auth(AuthorizationType::Trust)
.with_members(vec!["456".into(), "789".into()])
.with_roster(vec!["qwe".into(), "rty".into(), "uio".into()])
.with_persistence(PersistenceType::Any)
.with_durability(DurabilityType::NoDurability)
.with_routes(RouteType::Any)
.with_circuit_management_type("state_write_test_app".into())
.build()
.expect("Should have built a correct circuit");
storage.write().add_circuit("beta".into(), circuit);
}
let storage = YamlStorage::new(path.clone(), CircuitDirectory::new).unwrap();
assert_eq!(storage.data.circuits().len(), 2);
assert!(storage.data.circuits().contains_key("alpha"));
assert!(storage.data.circuits().contains_key("beta"));
assert_eq!(
storage
.data
.circuits()
.get("alpha")
.unwrap()
.roster()
.to_vec(),
vec!["abc".into(), "def".into()]
);
assert_eq!(
storage
.data
.circuits()
.get("alpha")
.unwrap()
.members()
.to_vec(),
vec!["123".to_string()],
);
assert_eq!(
storage
.data
.circuits()
.get("beta")
.unwrap()
.roster()
.to_vec(),
vec!["qwe".into(), "rty".into(), "uio".into()]
);
assert_eq!(
storage
.data
.circuits()
.get("beta")
.unwrap()
.members()
.to_vec(),
vec!["456".to_string(), "789".to_string()],
);
assert_eq!(
storage
.data
.circuits()
.get("beta")
.unwrap()
.circuit_management_type(),
"state_write_test_app"
);
}
#[test]
fn test_remove_circuit_from_state() {
let temp_dir = TempDir::new("test_write_circuit").unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
let path = set_up_mock_state_file(temp_dir_path);
{
let mut storage = YamlStorage::new(path.clone(), CircuitDirectory::new).unwrap();
storage.write().remove_circuit("alpha".into());
}
let storage = YamlStorage::new(path.clone(), CircuitDirectory::new).unwrap();
assert_eq!(storage.data.nodes().len(), 1);
assert_eq!(storage.data.circuits().len(), 0);
assert!(storage.data.nodes().contains_key("123"));
assert!(!storage.data.circuits().contains_key("alpha"));
assert_eq!(
storage
.data
.nodes()
.get("123")
.unwrap()
.endpoints()
.to_vec(),
vec!["tcp://127.0.0.1:8000".to_string()]
);
}
}