oxigdal_streaming/state/
backend.rs1use crate::error::{Result, StreamingError};
4use async_trait::async_trait;
5use std::collections::HashMap;
6#[cfg(feature = "rocksdb-backend")]
7use std::path::PathBuf;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11#[async_trait]
13pub trait StateBackend: Send + Sync {
14 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
16
17 async fn put(&self, key: &[u8], value: &[u8]) -> Result<()>;
19
20 async fn delete(&self, key: &[u8]) -> Result<()>;
22
23 async fn contains(&self, key: &[u8]) -> Result<bool>;
25
26 async fn clear(&self) -> Result<()>;
28
29 async fn snapshot(&self) -> Result<Vec<u8>>;
31
32 async fn restore(&self, snapshot: &[u8]) -> Result<()>;
34
35 async fn keys(&self) -> Result<Vec<Vec<u8>>>;
37
38 fn name(&self) -> &str;
40}
41
42pub struct MemoryStateBackend {
44 state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
45}
46
47impl MemoryStateBackend {
48 pub fn new() -> Self {
50 Self {
51 state: Arc::new(RwLock::new(HashMap::new())),
52 }
53 }
54}
55
56impl Default for MemoryStateBackend {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62#[async_trait]
63impl StateBackend for MemoryStateBackend {
64 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
65 Ok(self.state.read().await.get(key).cloned())
66 }
67
68 async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
69 self.state
70 .write()
71 .await
72 .insert(key.to_vec(), value.to_vec());
73 Ok(())
74 }
75
76 async fn delete(&self, key: &[u8]) -> Result<()> {
77 self.state.write().await.remove(key);
78 Ok(())
79 }
80
81 async fn contains(&self, key: &[u8]) -> Result<bool> {
82 Ok(self.state.read().await.contains_key(key))
83 }
84
85 async fn clear(&self) -> Result<()> {
86 self.state.write().await.clear();
87 Ok(())
88 }
89
90 async fn snapshot(&self) -> Result<Vec<u8>> {
91 let state = self.state.read().await;
92 oxicode::encode_to_vec(&*state)
94 .map_err(|e| StreamingError::SerializationError(e.to_string()))
95 }
96
97 async fn restore(&self, snapshot: &[u8]) -> Result<()> {
98 let (restored, _): (HashMap<Vec<u8>, Vec<u8>>, _) = oxicode::decode_from_slice(snapshot)
99 .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
100 *self.state.write().await = restored;
101 Ok(())
102 }
103
104 async fn keys(&self) -> Result<Vec<Vec<u8>>> {
105 Ok(self.state.read().await.keys().cloned().collect())
106 }
107
108 fn name(&self) -> &str {
109 "MemoryStateBackend"
110 }
111}
112
113#[cfg(feature = "rocksdb-backend")]
115pub struct RocksDBStateBackend {
116 db: Arc<rocksdb::DB>,
117 path: PathBuf,
118}
119
120#[cfg(feature = "rocksdb-backend")]
121impl RocksDBStateBackend {
122 pub fn new(path: PathBuf) -> Result<Self> {
124 let mut opts = rocksdb::Options::default();
125 opts.create_if_missing(true);
126
127 let db = rocksdb::DB::open(&opts, &path)?;
128
129 Ok(Self {
130 db: Arc::new(db),
131 path,
132 })
133 }
134
135 pub fn path(&self) -> &PathBuf {
137 &self.path
138 }
139}
140
141#[cfg(feature = "rocksdb-backend")]
142#[async_trait]
143impl StateBackend for RocksDBStateBackend {
144 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
145 Ok(self.db.get(key)?)
146 }
147
148 async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
149 self.db.put(key, value)?;
150 Ok(())
151 }
152
153 async fn delete(&self, key: &[u8]) -> Result<()> {
154 self.db.delete(key)?;
155 Ok(())
156 }
157
158 async fn contains(&self, key: &[u8]) -> Result<bool> {
159 Ok(self.db.get(key)?.is_some())
160 }
161
162 async fn clear(&self) -> Result<()> {
163 let keys: Vec<Vec<u8>> = self
164 .db
165 .iterator(rocksdb::IteratorMode::Start)
166 .map(|item| {
167 let (key, _) = item.map_err(|e| StreamingError::StateError(e.to_string()))?;
168 Ok(key.to_vec())
169 })
170 .collect::<Result<Vec<_>>>()?;
171
172 for key in keys {
173 self.db.delete(&key)?;
174 }
175
176 Ok(())
177 }
178
179 async fn snapshot(&self) -> Result<Vec<u8>> {
180 let snapshot = self.db.snapshot();
181 let mut data = Vec::new();
182
183 for item in snapshot.iterator(rocksdb::IteratorMode::Start) {
184 let (key, value) = item?;
185 let entry = (key.to_vec(), value.to_vec());
186 let serialized = oxicode::encode_to_vec(&entry)
188 .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
189 data.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
190 data.extend_from_slice(&serialized);
191 }
192
193 Ok(data)
194 }
195
196 async fn restore(&self, snapshot: &[u8]) -> Result<()> {
197 self.clear().await?;
198
199 let mut offset = 0;
200 while offset < snapshot.len() {
201 if offset + 4 > snapshot.len() {
202 break;
203 }
204
205 let len = u32::from_le_bytes([
206 snapshot[offset],
207 snapshot[offset + 1],
208 snapshot[offset + 2],
209 snapshot[offset + 3],
210 ]) as usize;
211
212 offset += 4;
213
214 if offset + len > snapshot.len() {
215 break;
216 }
217
218 let entry_data = &snapshot[offset..offset + len];
219 let ((key, value), _): ((Vec<u8>, Vec<u8>), _) = oxicode::decode_from_slice(entry_data)
220 .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
221 self.db.put(&key, &value)?;
222
223 offset += len;
224 }
225
226 Ok(())
227 }
228
229 async fn keys(&self) -> Result<Vec<Vec<u8>>> {
230 let keys: Vec<Vec<u8>> = self
231 .db
232 .iterator(rocksdb::IteratorMode::Start)
233 .map(|item| {
234 let (key, _) = item?;
235 Ok(key.to_vec())
236 })
237 .collect::<std::result::Result<Vec<_>, rocksdb::Error>>()?;
238
239 Ok(keys)
240 }
241
242 fn name(&self) -> &str {
243 "RocksDBStateBackend"
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[tokio::test]
252 async fn test_memory_backend() -> Result<()> {
253 let backend = MemoryStateBackend::new();
254
255 backend.put(b"key1", b"value1").await?;
256 let value = backend.get(b"key1").await?;
257 assert_eq!(value, Some(b"value1".to_vec()));
258
259 assert!(backend.contains(b"key1").await?);
260 assert!(!backend.contains(b"key2").await?);
261
262 backend.delete(b"key1").await?;
263 assert!(!backend.contains(b"key1").await?);
264
265 Ok(())
266 }
267
268 #[tokio::test]
269 async fn test_memory_backend_snapshot() -> Result<()> {
270 let backend = MemoryStateBackend::new();
271
272 backend.put(b"key1", b"value1").await?;
273 backend.put(b"key2", b"value2").await?;
274
275 let snapshot = backend.snapshot().await?;
276
277 let backend2 = MemoryStateBackend::new();
278 backend2.restore(&snapshot).await?;
279
280 assert_eq!(backend2.get(b"key1").await?, Some(b"value1".to_vec()));
281 assert_eq!(backend2.get(b"key2").await?, Some(b"value2".to_vec()));
282
283 Ok(())
284 }
285
286 #[cfg(feature = "rocksdb-backend")]
287 #[tokio::test]
288 async fn test_rocksdb_backend() -> Result<()> {
289 let temp_dir = tempfile::tempdir()
290 .map_err(|e| StreamingError::StateError(format!("Failed to create temp dir: {}", e)))?;
291 let backend = RocksDBStateBackend::new(temp_dir.path().to_path_buf())?;
292
293 backend.put(b"key1", b"value1").await?;
294 let value = backend.get(b"key1").await?;
295 assert_eq!(value, Some(b"value1".to_vec()));
296
297 assert!(backend.contains(b"key1").await?);
298
299 backend.delete(b"key1").await?;
300 assert!(!backend.contains(b"key1").await?);
301
302 Ok(())
303 }
304}