use crate::error::Result;
use crate::value::Value;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub trait StateStore: Send + Sync {
fn get(&self, node_id: &str) -> Result<Option<Arc<Value>>>;
fn set(&self, node_id: &str, state: Value) -> Result<()>;
fn remove(&self, node_id: &str) -> Result<()>;
fn clear(&self) -> Result<()>;
fn keys(&self) -> Result<Vec<String>>;
}
#[derive(Default)]
pub struct MemoryStateStore {
inner: Mutex<HashMap<String, Arc<Value>>>,
}
impl MemoryStateStore {
pub fn new() -> Self {
Self::default()
}
}
impl StateStore for MemoryStateStore {
fn get(&self, node_id: &str) -> Result<Option<Arc<Value>>> {
let guard = self.inner.lock().expect("MemoryStateStore poisoned");
Ok(guard.get(node_id).cloned())
}
fn set(&self, node_id: &str, state: Value) -> Result<()> {
let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
guard.insert(node_id.to_string(), Arc::new(state));
Ok(())
}
fn remove(&self, node_id: &str) -> Result<()> {
let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
guard.remove(node_id);
Ok(())
}
fn clear(&self) -> Result<()> {
let mut guard = self.inner.lock().expect("MemoryStateStore poisoned");
guard.clear();
Ok(())
}
fn keys(&self) -> Result<Vec<String>> {
let guard = self.inner.lock().expect("MemoryStateStore poisoned");
Ok(guard.keys().cloned().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_store_roundtrip() {
let store = MemoryStateStore::new();
assert!(store.get("a").unwrap().is_none());
store
.set("a", Value::Json(serde_json::json!({"mean": 5.0})))
.unwrap();
let state = store.get("a").unwrap().unwrap();
assert_eq!(state.as_json().unwrap()["mean"], 5.0);
let s1 = store.get("a").unwrap().unwrap();
let s2 = store.get("a").unwrap().unwrap();
assert!(Arc::ptr_eq(&s1, &s2));
}
#[test]
fn memory_store_remove_and_clear() {
let store = MemoryStateStore::new();
store.set("a", Value::Empty).unwrap();
store.set("b", Value::Empty).unwrap();
assert_eq!(store.keys().unwrap().len(), 2);
store.remove("a").unwrap();
assert!(store.get("a").unwrap().is_none());
assert!(store.get("b").unwrap().is_some());
store.clear().unwrap();
assert!(store.keys().unwrap().is_empty());
}
#[test]
fn memory_store_overwrites() {
let store = MemoryStateStore::new();
store
.set("a", Value::Json(serde_json::json!({"v": 1})))
.unwrap();
store
.set("a", Value::Json(serde_json::json!({"v": 2})))
.unwrap();
let state = store.get("a").unwrap().unwrap();
assert_eq!(state.as_json().unwrap()["v"], 2);
}
}