oxigdal_streaming/state/
backend.rs1use crate::error::{Result, StreamingError};
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[async_trait]
12pub trait StateBackend: Send + Sync {
13 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
15
16 async fn put(&self, key: &[u8], value: &[u8]) -> Result<()>;
18
19 async fn delete(&self, key: &[u8]) -> Result<()>;
21
22 async fn contains(&self, key: &[u8]) -> Result<bool>;
24
25 async fn clear(&self) -> Result<()>;
27
28 async fn snapshot(&self) -> Result<Vec<u8>>;
30
31 async fn restore(&self, snapshot: &[u8]) -> Result<()>;
33
34 async fn keys(&self) -> Result<Vec<Vec<u8>>>;
36
37 fn name(&self) -> &str;
39}
40
41pub struct MemoryStateBackend {
43 state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
44}
45
46impl MemoryStateBackend {
47 pub fn new() -> Self {
49 Self {
50 state: Arc::new(RwLock::new(HashMap::new())),
51 }
52 }
53}
54
55impl Default for MemoryStateBackend {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61#[async_trait]
62impl StateBackend for MemoryStateBackend {
63 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
64 Ok(self.state.read().await.get(key).cloned())
65 }
66
67 async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
68 self.state
69 .write()
70 .await
71 .insert(key.to_vec(), value.to_vec());
72 Ok(())
73 }
74
75 async fn delete(&self, key: &[u8]) -> Result<()> {
76 self.state.write().await.remove(key);
77 Ok(())
78 }
79
80 async fn contains(&self, key: &[u8]) -> Result<bool> {
81 Ok(self.state.read().await.contains_key(key))
82 }
83
84 async fn clear(&self) -> Result<()> {
85 self.state.write().await.clear();
86 Ok(())
87 }
88
89 async fn snapshot(&self) -> Result<Vec<u8>> {
90 let state = self.state.read().await;
91 oxicode::encode_to_vec(&*state)
93 .map_err(|e| StreamingError::SerializationError(e.to_string()))
94 }
95
96 async fn restore(&self, snapshot: &[u8]) -> Result<()> {
97 let (restored, _): (HashMap<Vec<u8>, Vec<u8>>, _) = oxicode::decode_from_slice(snapshot)
98 .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
99 *self.state.write().await = restored;
100 Ok(())
101 }
102
103 async fn keys(&self) -> Result<Vec<Vec<u8>>> {
104 Ok(self.state.read().await.keys().cloned().collect())
105 }
106
107 fn name(&self) -> &str {
108 "MemoryStateBackend"
109 }
110}
111
112#[cfg(feature = "rocksdb-backend")]
114pub struct RocksDBStateBackend {
115 db: Arc<rocksdb::DB>,
116 path: PathBuf,
117}
118
119#[cfg(feature = "rocksdb-backend")]
120impl RocksDBStateBackend {
121 pub fn new(path: PathBuf) -> Result<Self> {
123 let mut opts = rocksdb::Options::default();
124 opts.create_if_missing(true);
125
126 let db = rocksdb::DB::open(&opts, &path)?;
127
128 Ok(Self {
129 db: Arc::new(db),
130 path,
131 })
132 }
133
134 pub fn path(&self) -> &PathBuf {
136 &self.path
137 }
138}
139
140#[cfg(feature = "rocksdb-backend")]
141#[async_trait]
142impl StateBackend for RocksDBStateBackend {
143 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
144 Ok(self.db.get(key)?)
145 }
146
147 async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
148 self.db.put(key, value)?;
149 Ok(())
150 }
151
152 async fn delete(&self, key: &[u8]) -> Result<()> {
153 self.db.delete(key)?;
154 Ok(())
155 }
156
157 async fn contains(&self, key: &[u8]) -> Result<bool> {
158 Ok(self.db.get(key)?.is_some())
159 }
160
161 async fn clear(&self) -> Result<()> {
162 let keys: Vec<Vec<u8>> = self
163 .db
164 .iterator(rocksdb::IteratorMode::Start)
165 .map(|item| {
166 let (key, _) = item.map_err(|e| StreamingError::StateError(e.to_string()))?;
167 Ok(key.to_vec())
168 })
169 .collect::<Result<Vec<_>>>()?;
170
171 for key in keys {
172 self.db.delete(&key)?;
173 }
174
175 Ok(())
176 }
177
178 async fn snapshot(&self) -> Result<Vec<u8>> {
179 let snapshot = self.db.snapshot();
180 let mut data = Vec::new();
181
182 for item in snapshot.iterator(rocksdb::IteratorMode::Start) {
183 let (key, value) = item?;
184 let entry = (key.to_vec(), value.to_vec());
185 let serialized = oxicode::encode_to_vec(&entry)
187 .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
188 data.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
189 data.extend_from_slice(&serialized);
190 }
191
192 Ok(data)
193 }
194
195 async fn restore(&self, snapshot: &[u8]) -> Result<()> {
196 self.clear().await?;
197
198 let mut offset = 0;
199 while offset < snapshot.len() {
200 if offset + 4 > snapshot.len() {
201 break;
202 }
203
204 let len = u32::from_le_bytes([
205 snapshot[offset],
206 snapshot[offset + 1],
207 snapshot[offset + 2],
208 snapshot[offset + 3],
209 ]) as usize;
210
211 offset += 4;
212
213 if offset + len > snapshot.len() {
214 break;
215 }
216
217 let entry_data = &snapshot[offset..offset + len];
218 let ((key, value), _): ((Vec<u8>, Vec<u8>), _) = oxicode::decode_from_slice(entry_data)
219 .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
220 self.db.put(&key, &value)?;
221
222 offset += len;
223 }
224
225 Ok(())
226 }
227
228 async fn keys(&self) -> Result<Vec<Vec<u8>>> {
229 let keys: Vec<Vec<u8>> = self
230 .db
231 .iterator(rocksdb::IteratorMode::Start)
232 .map(|item| {
233 let (key, _) = item?;
234 Ok(key.to_vec())
235 })
236 .collect::<std::result::Result<Vec<_>, rocksdb::Error>>()?;
237
238 Ok(keys)
239 }
240
241 fn name(&self) -> &str {
242 "RocksDBStateBackend"
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[tokio::test]
251 async fn test_memory_backend() -> Result<()> {
252 let backend = MemoryStateBackend::new();
253
254 backend.put(b"key1", b"value1").await?;
255 let value = backend.get(b"key1").await?;
256 assert_eq!(value, Some(b"value1".to_vec()));
257
258 assert!(backend.contains(b"key1").await?);
259 assert!(!backend.contains(b"key2").await?);
260
261 backend.delete(b"key1").await?;
262 assert!(!backend.contains(b"key1").await?);
263
264 Ok(())
265 }
266
267 #[tokio::test]
268 async fn test_memory_backend_snapshot() -> Result<()> {
269 let backend = MemoryStateBackend::new();
270
271 backend.put(b"key1", b"value1").await?;
272 backend.put(b"key2", b"value2").await?;
273
274 let snapshot = backend.snapshot().await?;
275
276 let backend2 = MemoryStateBackend::new();
277 backend2.restore(&snapshot).await?;
278
279 assert_eq!(backend2.get(b"key1").await?, Some(b"value1".to_vec()));
280 assert_eq!(backend2.get(b"key2").await?, Some(b"value2".to_vec()));
281
282 Ok(())
283 }
284
285 #[cfg(feature = "rocksdb-backend")]
286 #[tokio::test]
287 async fn test_rocksdb_backend() -> Result<()> {
288 let temp_dir = tempfile::tempdir()
289 .map_err(|e| StreamingError::StateError(format!("Failed to create temp dir: {}", e)))?;
290 let backend = RocksDBStateBackend::new(temp_dir.path().to_path_buf())?;
291
292 backend.put(b"key1", b"value1").await?;
293 let value = backend.get(b"key1").await?;
294 assert_eq!(value, Some(b"value1".to_vec()));
295
296 assert!(backend.contains(b"key1").await?);
297
298 backend.delete(b"key1").await?;
299 assert!(!backend.contains(b"key1").await?);
300
301 Ok(())
302 }
303}