1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::hash::Hash;
6use std::sync::RwLock;
7
8pub trait StateStore: Send + Sync {
14 type NodeId: Eq + Hash + Clone + Send + Sync;
15 type State: Clone + Send + Sync;
16 type Error: Send + Sync + 'static;
17
18 fn get(&self, id: &Self::NodeId) -> Result<Option<Self::State>, Self::Error>;
20 fn put(&self, id: &Self::NodeId, state: Self::State) -> Result<(), Self::Error>;
22 fn remove(&self, id: &Self::NodeId) -> Result<(), Self::Error>;
24 fn list(&self) -> Result<Vec<(Self::NodeId, Self::State)>, Self::Error>;
26}
27
28pub struct InMemoryStore<Id, State> {
34 data: RwLock<HashMap<Id, State>>,
35}
36
37impl<Id, State> InMemoryStore<Id, State> {
38 pub fn new() -> Self {
39 Self {
40 data: RwLock::new(HashMap::new()),
41 }
42 }
43}
44
45impl<Id, State> Default for InMemoryStore<Id, State> {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl<Id, State> StateStore for InMemoryStore<Id, State>
52where
53 Id: Eq + Hash + Clone + Send + Sync,
54 State: Clone + Send + Sync,
55{
56 type NodeId = Id;
57 type State = State;
58 type Error = Infallible;
59
60 fn get(&self, id: &Id) -> Result<Option<State>, Self::Error> {
61 Ok(self.data.read().unwrap().get(id).cloned())
62 }
63
64 fn put(&self, id: &Id, state: State) -> Result<(), Self::Error> {
65 self.data.write().unwrap().insert(id.clone(), state);
66 Ok(())
67 }
68
69 fn remove(&self, id: &Id) -> Result<(), Self::Error> {
70 self.data.write().unwrap().remove(id);
71 Ok(())
72 }
73
74 fn list(&self) -> Result<Vec<(Id, State)>, Self::Error> {
75 Ok(self
76 .data
77 .read()
78 .unwrap()
79 .iter()
80 .map(|(k, v)| (k.clone(), v.clone()))
81 .collect())
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn basic_crud() {
91 let store: InMemoryStore<String, u64> = InMemoryStore::new();
92
93 assert_eq!(store.get(&"a".into()).unwrap(), None);
94
95 store.put(&"a".into(), 42).unwrap();
96 assert_eq!(store.get(&"a".into()).unwrap(), Some(42));
97
98 store.put(&"a".into(), 99).unwrap();
99 assert_eq!(store.get(&"a".into()).unwrap(), Some(99));
100
101 store.remove(&"a".into()).unwrap();
102 assert_eq!(store.get(&"a".into()).unwrap(), None);
103 }
104
105 #[test]
106 fn list_entries() {
107 let store: InMemoryStore<String, i32> = InMemoryStore::new();
108 store.put(&"x".into(), 1).unwrap();
109 store.put(&"y".into(), 2).unwrap();
110
111 let mut entries = store.list().unwrap();
112 entries.sort_by_key(|(k, _)| k.clone());
113 assert_eq!(entries, vec![("x".into(), 1), ("y".into(), 2)]);
114 }
115}