1use crate::{Error, Result};
2use async_trait::async_trait;
3use std::time::Duration;
4
5#[async_trait]
7pub trait StateManager: Send + Sync {
8 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
10
11 async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()>;
13
14 async fn delete(&self, key: &str) -> Result<()>;
16
17 async fn exists(&self, key: &str) -> Result<bool>;
19}
20
21pub struct SledStateManager {
23 db: sled::Db,
24}
25
26impl SledStateManager {
27 pub fn new(path: &str) -> Result<Self> {
28 let db =
29 sled::open(path).map_err(|e| Error::Handler(format!("Sled open failed: {}", e)))?;
30 Ok(Self { db })
31 }
32}
33
34#[async_trait]
35impl StateManager for SledStateManager {
36 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
37 let value = self
38 .db
39 .get(key)
40 .map_err(|e| Error::Handler(format!("Sled get failed: {}", e)))?;
41
42 Ok(value.map(|v| v.to_vec()))
43 }
44
45 async fn set(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<()> {
46 self.db
47 .insert(key, value)
48 .map_err(|e| Error::Handler(format!("Sled insert failed: {}", e)))?;
49
50 self.db
51 .flush()
52 .map_err(|e| Error::Handler(format!("Sled flush failed: {}", e)))?;
53
54 Ok(())
56 }
57
58 async fn delete(&self, key: &str) -> Result<()> {
59 self.db
60 .remove(key)
61 .map_err(|e| Error::Handler(format!("Sled remove failed: {}", e)))?;
62 Ok(())
63 }
64
65 async fn exists(&self, key: &str) -> Result<bool> {
66 let exists = self
67 .db
68 .contains_key(key)
69 .map_err(|e| Error::Handler(format!("Sled contains_key failed: {}", e)))?;
70 Ok(exists)
71 }
72}
73
74pub struct MemoryStateManager {
76 store: dashmap::DashMap<String, Vec<u8>>,
77}
78
79impl MemoryStateManager {
80 pub fn new() -> Self {
81 Self {
82 store: dashmap::DashMap::new(),
83 }
84 }
85}
86
87impl Default for MemoryStateManager {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[async_trait]
94impl StateManager for MemoryStateManager {
95 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
96 Ok(self.store.get(key).map(|v| v.clone()))
97 }
98
99 async fn set(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<()> {
100 self.store.insert(key.to_string(), value);
101 Ok(())
103 }
104
105 async fn delete(&self, key: &str) -> Result<()> {
106 self.store.remove(key);
107 Ok(())
108 }
109
110 async fn exists(&self, key: &str) -> Result<bool> {
111 Ok(self.store.contains_key(key))
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[tokio::test]
120 async fn test_memory_state_basic() {
121 let state = MemoryStateManager::new();
122
123 state.set("key1", b"value1".to_vec(), None).await.unwrap();
125 let value = state.get("key1").await.unwrap();
126 assert_eq!(value, Some(b"value1".to_vec()));
127
128 assert!(state.exists("key1").await.unwrap());
130 assert!(!state.exists("key2").await.unwrap());
131
132 state.delete("key1").await.unwrap();
134 assert!(!state.exists("key1").await.unwrap());
135 }
136
137 #[tokio::test]
138 async fn test_sled_state_basic() {
139 let temp_dir = tempfile::tempdir().unwrap();
140 let state = SledStateManager::new(temp_dir.path().to_str().unwrap()).unwrap();
141
142 state.set("key1", b"value1".to_vec(), None).await.unwrap();
144 let value = state.get("key1").await.unwrap();
145 assert_eq!(value, Some(b"value1".to_vec()));
146
147 drop(state);
149 let state = SledStateManager::new(temp_dir.path().to_str().unwrap()).unwrap();
150 let value = state.get("key1").await.unwrap();
151 assert_eq!(value, Some(b"value1".to_vec()));
152 }
153}