1use crate::{Error, Result};
2use async_trait::async_trait;
3use std::time::Duration;
4
5#[async_trait]
7pub trait StateManager: Send + Sync {
8 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
10
11 async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()>;
13
14 async fn delete(&self, key: &str) -> Result<()>;
16
17 async fn exists(&self, key: &str) -> Result<bool>;
19}
20
21pub struct SledStateManager {
23 db: sled::Db,
24}
25
26impl SledStateManager {
27 pub fn new(path: &str) -> Result<Self> {
28 let db = sled::open(path).map_err(|e| Error::Handler(format!("Sled open failed: {}", e)))?;
29 Ok(Self { db })
30 }
31}
32
33#[async_trait]
34impl StateManager for SledStateManager {
35 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
36 let value = self
37 .db
38 .get(key)
39 .map_err(|e| Error::Handler(format!("Sled get failed: {}", e)))?;
40
41 Ok(value.map(|v| v.to_vec()))
42 }
43
44 async fn set(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<()> {
45 self.db
46 .insert(key, value)
47 .map_err(|e| Error::Handler(format!("Sled insert failed: {}", e)))?;
48
49 self.db
50 .flush()
51 .map_err(|e| Error::Handler(format!("Sled flush failed: {}", e)))?;
52
53 Ok(())
55 }
56
57 async fn delete(&self, key: &str) -> Result<()> {
58 self.db
59 .remove(key)
60 .map_err(|e| Error::Handler(format!("Sled remove failed: {}", e)))?;
61 Ok(())
62 }
63
64 async fn exists(&self, key: &str) -> Result<bool> {
65 let exists = self
66 .db
67 .contains_key(key)
68 .map_err(|e| Error::Handler(format!("Sled contains_key failed: {}", e)))?;
69 Ok(exists)
70 }
71}
72
73pub struct MemoryStateManager {
75 store: dashmap::DashMap<String, Vec<u8>>,
76}
77
78impl MemoryStateManager {
79 pub fn new() -> Self {
80 Self {
81 store: dashmap::DashMap::new(),
82 }
83 }
84}
85
86impl Default for MemoryStateManager {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92#[async_trait]
93impl StateManager for MemoryStateManager {
94 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
95 Ok(self.store.get(key).map(|v| v.clone()))
96 }
97
98 async fn set(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<()> {
99 self.store.insert(key.to_string(), value);
100 Ok(())
102 }
103
104 async fn delete(&self, key: &str) -> Result<()> {
105 self.store.remove(key);
106 Ok(())
107 }
108
109 async fn exists(&self, key: &str) -> Result<bool> {
110 Ok(self.store.contains_key(key))
111 }
112}
113
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[tokio::test]
120 async fn test_memory_state_basic() {
121 let state = MemoryStateManager::new();
122
123 state.set("key1", b"value1".to_vec(), None).await.unwrap();
125 let value = state.get("key1").await.unwrap();
126 assert_eq!(value, Some(b"value1".to_vec()));
127
128 assert!(state.exists("key1").await.unwrap());
130 assert!(!state.exists("key2").await.unwrap());
131
132 state.delete("key1").await.unwrap();
134 assert!(!state.exists("key1").await.unwrap());
135 }
136
137 #[tokio::test]
138 async fn test_sled_state_basic() {
139 let temp_dir = tempfile::tempdir().unwrap();
140 let state = SledStateManager::new(temp_dir.path().to_str().unwrap()).unwrap();
141
142 state.set("key1", b"value1".to_vec(), None).await.unwrap();
144 let value = state.get("key1").await.unwrap();
145 assert_eq!(value, Some(b"value1".to_vec()));
146
147 drop(state);
149 let state = SledStateManager::new(temp_dir.path().to_str().unwrap()).unwrap();
150 let value = state.get("key1").await.unwrap();
151 assert_eq!(value, Some(b"value1".to_vec()));
152 }
153}