use std::collections::HashMap;
use std::convert::Infallible;
use std::hash::Hash;
use std::sync::RwLock;
pub trait StateStore: Send + Sync {
type NodeId: Eq + Hash + Clone + Send + Sync;
type State: Clone + Send + Sync;
type Error: Send + Sync + 'static;
fn get(&self, id: &Self::NodeId) -> Result<Option<Self::State>, Self::Error>;
fn put(&self, id: &Self::NodeId, state: Self::State) -> Result<(), Self::Error>;
fn remove(&self, id: &Self::NodeId) -> Result<(), Self::Error>;
fn list(&self) -> Result<Vec<(Self::NodeId, Self::State)>, Self::Error>;
}
pub struct InMemoryStore<Id, State> {
data: RwLock<HashMap<Id, State>>,
}
impl<Id, State> InMemoryStore<Id, State> {
pub fn new() -> Self {
Self {
data: RwLock::new(HashMap::new()),
}
}
}
impl<Id, State> Default for InMemoryStore<Id, State> {
fn default() -> Self {
Self::new()
}
}
impl<Id, State> StateStore for InMemoryStore<Id, State>
where
Id: Eq + Hash + Clone + Send + Sync,
State: Clone + Send + Sync,
{
type NodeId = Id;
type State = State;
type Error = Infallible;
fn get(&self, id: &Id) -> Result<Option<State>, Self::Error> {
Ok(self.data.read().unwrap().get(id).cloned())
}
fn put(&self, id: &Id, state: State) -> Result<(), Self::Error> {
self.data.write().unwrap().insert(id.clone(), state);
Ok(())
}
fn remove(&self, id: &Id) -> Result<(), Self::Error> {
self.data.write().unwrap().remove(id);
Ok(())
}
fn list(&self) -> Result<Vec<(Id, State)>, Self::Error> {
Ok(self
.data
.read()
.unwrap()
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_crud() {
let store: InMemoryStore<String, u64> = InMemoryStore::new();
assert_eq!(store.get(&"a".into()).unwrap(), None);
store.put(&"a".into(), 42).unwrap();
assert_eq!(store.get(&"a".into()).unwrap(), Some(42));
store.put(&"a".into(), 99).unwrap();
assert_eq!(store.get(&"a".into()).unwrap(), Some(99));
store.remove(&"a".into()).unwrap();
assert_eq!(store.get(&"a".into()).unwrap(), None);
}
#[test]
fn list_entries() {
let store: InMemoryStore<String, i32> = InMemoryStore::new();
store.put(&"x".into(), 1).unwrap();
store.put(&"y".into(), 2).unwrap();
let mut entries = store.list().unwrap();
entries.sort_by_key(|(k, _)| k.clone());
assert_eq!(entries, vec![("x".into(), 1), ("y".into(), 2)]);
}
}