1use std::sync::Arc;
2
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use pulse_core::{KvState, Result, SnapshotId};
6
7#[derive(Default)]
8struct Inner {
9 map: std::collections::HashMap<Vec<u8>, Vec<u8>>,
10 snapshots: std::collections::HashMap<SnapshotId, std::collections::HashMap<Vec<u8>, Vec<u8>>>,
11}
12
13#[derive(Clone)]
14pub struct InMemoryState(Arc<Mutex<Inner>>);
16
17impl Default for InMemoryState {
18 fn default() -> Self {
19 Self(Arc::new(Mutex::new(Inner::default())))
20 }
21}
22
23#[async_trait]
24impl KvState for InMemoryState {
25 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
26 Ok(self.0.lock().map.get(key).cloned())
27 }
28 async fn put(&self, key: &[u8], value: Vec<u8>) -> Result<()> {
29 self.0.lock().map.insert(key.to_vec(), value);
30 Ok(())
31 }
32 async fn delete(&self, key: &[u8]) -> Result<()> {
33 self.0.lock().map.remove(key);
34 Ok(())
35 }
36 async fn iter_prefix(&self, prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
37 let guard = self.0.lock();
38 let mut out = Vec::new();
39 if let Some(p) = prefix {
40 for (k, v) in guard.map.iter() {
41 if k.starts_with(p) {
42 out.push((k.clone(), v.clone()));
43 }
44 }
45 } else {
46 out.extend(guard.map.iter().map(|(k, v)| (k.clone(), v.clone())));
47 }
48 Ok(out)
49 }
50 async fn snapshot(&self) -> Result<SnapshotId> {
51 use std::time::{SystemTime, UNIX_EPOCH};
52 let mut guard = self.0.lock();
53 let ts = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis();
54 let id: SnapshotId = format!("mem-{}", ts);
55 let current = guard.map.clone();
56 guard.snapshots.insert(id.clone(), current);
57 Ok(id)
58 }
59 async fn restore(&self, snapshot: SnapshotId) -> Result<()> {
60 let mut guard = self.0.lock();
61 if let Some(m) = guard.snapshots.get(&snapshot) {
62 guard.map = m.clone();
63 }
64 Ok(())
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71
72 #[tokio::test]
73 async fn in_memory_state_put_get_delete_iter_snapshot_restore() {
74 let state = InMemoryState::default();
75 let key = b"k1";
76 assert!(state.get(key).await.unwrap().is_none());
77 state.put(key, b"v1".to_vec()).await.unwrap();
78 assert_eq!(state.get(key).await.unwrap().unwrap(), b"v1".to_vec());
79 let all = state.iter_prefix(None).await.unwrap();
81 assert_eq!(all.len(), 1);
82 let snap = state.snapshot().await.unwrap();
84 state.put(b"k2", b"v2".to_vec()).await.unwrap();
85 state.restore(snap).await.unwrap();
87 assert!(state.get(b"k2").await.unwrap().is_none());
88 assert_eq!(state.get(b"k1").await.unwrap().unwrap(), b"v1".to_vec());
89 state.delete(key).await.unwrap();
91 assert!(state.get(key).await.unwrap().is_none());
92 }
93}