Skip to main content

oxigdal_streaming/state/
backend.rs

1//! State backend implementations.
2
3use crate::error::{Result, StreamingError};
4use async_trait::async_trait;
5use std::collections::HashMap;
6#[cfg(feature = "rocksdb-backend")]
7use std::path::PathBuf;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// Trait for state backends.
12#[async_trait]
13pub trait StateBackend: Send + Sync {
14    /// Get a value from the state.
15    async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
16
17    /// Put a value into the state.
18    async fn put(&self, key: &[u8], value: &[u8]) -> Result<()>;
19
20    /// Delete a value from the state.
21    async fn delete(&self, key: &[u8]) -> Result<()>;
22
23    /// Check if a key exists.
24    async fn contains(&self, key: &[u8]) -> Result<bool>;
25
26    /// Clear all state.
27    async fn clear(&self) -> Result<()>;
28
29    /// Create a snapshot of the state.
30    async fn snapshot(&self) -> Result<Vec<u8>>;
31
32    /// Restore state from a snapshot.
33    async fn restore(&self, snapshot: &[u8]) -> Result<()>;
34
35    /// Get all keys.
36    async fn keys(&self) -> Result<Vec<Vec<u8>>>;
37
38    /// Get the backend name.
39    fn name(&self) -> &str;
40}
41
42/// In-memory state backend.
43pub struct MemoryStateBackend {
44    state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
45}
46
47impl MemoryStateBackend {
48    /// Create a new memory state backend.
49    pub fn new() -> Self {
50        Self {
51            state: Arc::new(RwLock::new(HashMap::new())),
52        }
53    }
54}
55
56impl Default for MemoryStateBackend {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62#[async_trait]
63impl StateBackend for MemoryStateBackend {
64    async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
65        Ok(self.state.read().await.get(key).cloned())
66    }
67
68    async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
69        self.state
70            .write()
71            .await
72            .insert(key.to_vec(), value.to_vec());
73        Ok(())
74    }
75
76    async fn delete(&self, key: &[u8]) -> Result<()> {
77        self.state.write().await.remove(key);
78        Ok(())
79    }
80
81    async fn contains(&self, key: &[u8]) -> Result<bool> {
82        Ok(self.state.read().await.contains_key(key))
83    }
84
85    async fn clear(&self) -> Result<()> {
86        self.state.write().await.clear();
87        Ok(())
88    }
89
90    async fn snapshot(&self) -> Result<Vec<u8>> {
91        let state = self.state.read().await;
92        // Use oxicode for binary serialization since JSON requires string keys
93        oxicode::encode_to_vec(&*state)
94            .map_err(|e| StreamingError::SerializationError(e.to_string()))
95    }
96
97    async fn restore(&self, snapshot: &[u8]) -> Result<()> {
98        let (restored, _): (HashMap<Vec<u8>, Vec<u8>>, _) = oxicode::decode_from_slice(snapshot)
99            .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
100        *self.state.write().await = restored;
101        Ok(())
102    }
103
104    async fn keys(&self) -> Result<Vec<Vec<u8>>> {
105        Ok(self.state.read().await.keys().cloned().collect())
106    }
107
108    fn name(&self) -> &str {
109        "MemoryStateBackend"
110    }
111}
112
113/// RocksDB state backend.
114#[cfg(feature = "rocksdb-backend")]
115pub struct RocksDBStateBackend {
116    db: Arc<rocksdb::DB>,
117    path: PathBuf,
118}
119
120#[cfg(feature = "rocksdb-backend")]
121impl RocksDBStateBackend {
122    /// Create a new RocksDB state backend.
123    pub fn new(path: PathBuf) -> Result<Self> {
124        let mut opts = rocksdb::Options::default();
125        opts.create_if_missing(true);
126
127        let db = rocksdb::DB::open(&opts, &path)?;
128
129        Ok(Self {
130            db: Arc::new(db),
131            path,
132        })
133    }
134
135    /// Get the database path.
136    pub fn path(&self) -> &PathBuf {
137        &self.path
138    }
139}
140
141#[cfg(feature = "rocksdb-backend")]
142#[async_trait]
143impl StateBackend for RocksDBStateBackend {
144    async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
145        Ok(self.db.get(key)?)
146    }
147
148    async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
149        self.db.put(key, value)?;
150        Ok(())
151    }
152
153    async fn delete(&self, key: &[u8]) -> Result<()> {
154        self.db.delete(key)?;
155        Ok(())
156    }
157
158    async fn contains(&self, key: &[u8]) -> Result<bool> {
159        Ok(self.db.get(key)?.is_some())
160    }
161
162    async fn clear(&self) -> Result<()> {
163        let keys: Vec<Vec<u8>> = self
164            .db
165            .iterator(rocksdb::IteratorMode::Start)
166            .map(|item| {
167                let (key, _) = item.map_err(|e| StreamingError::StateError(e.to_string()))?;
168                Ok(key.to_vec())
169            })
170            .collect::<Result<Vec<_>>>()?;
171
172        for key in keys {
173            self.db.delete(&key)?;
174        }
175
176        Ok(())
177    }
178
179    async fn snapshot(&self) -> Result<Vec<u8>> {
180        let snapshot = self.db.snapshot();
181        let mut data = Vec::new();
182
183        for item in snapshot.iterator(rocksdb::IteratorMode::Start) {
184            let (key, value) = item?;
185            let entry = (key.to_vec(), value.to_vec());
186            // Use oxicode for binary serialization
187            let serialized = oxicode::encode_to_vec(&entry)
188                .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
189            data.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
190            data.extend_from_slice(&serialized);
191        }
192
193        Ok(data)
194    }
195
196    async fn restore(&self, snapshot: &[u8]) -> Result<()> {
197        self.clear().await?;
198
199        let mut offset = 0;
200        while offset < snapshot.len() {
201            if offset + 4 > snapshot.len() {
202                break;
203            }
204
205            let len = u32::from_le_bytes([
206                snapshot[offset],
207                snapshot[offset + 1],
208                snapshot[offset + 2],
209                snapshot[offset + 3],
210            ]) as usize;
211
212            offset += 4;
213
214            if offset + len > snapshot.len() {
215                break;
216            }
217
218            let entry_data = &snapshot[offset..offset + len];
219            let ((key, value), _): ((Vec<u8>, Vec<u8>), _) = oxicode::decode_from_slice(entry_data)
220                .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
221            self.db.put(&key, &value)?;
222
223            offset += len;
224        }
225
226        Ok(())
227    }
228
229    async fn keys(&self) -> Result<Vec<Vec<u8>>> {
230        let keys: Vec<Vec<u8>> = self
231            .db
232            .iterator(rocksdb::IteratorMode::Start)
233            .map(|item| {
234                let (key, _) = item?;
235                Ok(key.to_vec())
236            })
237            .collect::<std::result::Result<Vec<_>, rocksdb::Error>>()?;
238
239        Ok(keys)
240    }
241
242    fn name(&self) -> &str {
243        "RocksDBStateBackend"
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[tokio::test]
252    async fn test_memory_backend() -> Result<()> {
253        let backend = MemoryStateBackend::new();
254
255        backend.put(b"key1", b"value1").await?;
256        let value = backend.get(b"key1").await?;
257        assert_eq!(value, Some(b"value1".to_vec()));
258
259        assert!(backend.contains(b"key1").await?);
260        assert!(!backend.contains(b"key2").await?);
261
262        backend.delete(b"key1").await?;
263        assert!(!backend.contains(b"key1").await?);
264
265        Ok(())
266    }
267
268    #[tokio::test]
269    async fn test_memory_backend_snapshot() -> Result<()> {
270        let backend = MemoryStateBackend::new();
271
272        backend.put(b"key1", b"value1").await?;
273        backend.put(b"key2", b"value2").await?;
274
275        let snapshot = backend.snapshot().await?;
276
277        let backend2 = MemoryStateBackend::new();
278        backend2.restore(&snapshot).await?;
279
280        assert_eq!(backend2.get(b"key1").await?, Some(b"value1".to_vec()));
281        assert_eq!(backend2.get(b"key2").await?, Some(b"value2".to_vec()));
282
283        Ok(())
284    }
285
286    #[cfg(feature = "rocksdb-backend")]
287    #[tokio::test]
288    async fn test_rocksdb_backend() -> Result<()> {
289        let temp_dir = tempfile::tempdir()
290            .map_err(|e| StreamingError::StateError(format!("Failed to create temp dir: {}", e)))?;
291        let backend = RocksDBStateBackend::new(temp_dir.path().to_path_buf())?;
292
293        backend.put(b"key1", b"value1").await?;
294        let value = backend.get(b"key1").await?;
295        assert_eq!(value, Some(b"value1".to_vec()));
296
297        assert!(backend.contains(b"key1").await?);
298
299        backend.delete(b"key1").await?;
300        assert!(!backend.contains(b"key1").await?);
301
302        Ok(())
303    }
304}