Skip to main content

oxigdal_streaming/state/
backend.rs

1//! State backend implementations.
2
3use crate::error::{Result, StreamingError};
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Trait for state backends.
11#[async_trait]
12pub trait StateBackend: Send + Sync {
13    /// Get a value from the state.
14    async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
15
16    /// Put a value into the state.
17    async fn put(&self, key: &[u8], value: &[u8]) -> Result<()>;
18
19    /// Delete a value from the state.
20    async fn delete(&self, key: &[u8]) -> Result<()>;
21
22    /// Check if a key exists.
23    async fn contains(&self, key: &[u8]) -> Result<bool>;
24
25    /// Clear all state.
26    async fn clear(&self) -> Result<()>;
27
28    /// Create a snapshot of the state.
29    async fn snapshot(&self) -> Result<Vec<u8>>;
30
31    /// Restore state from a snapshot.
32    async fn restore(&self, snapshot: &[u8]) -> Result<()>;
33
34    /// Get all keys.
35    async fn keys(&self) -> Result<Vec<Vec<u8>>>;
36
37    /// Get the backend name.
38    fn name(&self) -> &str;
39}
40
41/// In-memory state backend.
42pub struct MemoryStateBackend {
43    state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
44}
45
46impl MemoryStateBackend {
47    /// Create a new memory state backend.
48    pub fn new() -> Self {
49        Self {
50            state: Arc::new(RwLock::new(HashMap::new())),
51        }
52    }
53}
54
55impl Default for MemoryStateBackend {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61#[async_trait]
62impl StateBackend for MemoryStateBackend {
63    async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
64        Ok(self.state.read().await.get(key).cloned())
65    }
66
67    async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
68        self.state
69            .write()
70            .await
71            .insert(key.to_vec(), value.to_vec());
72        Ok(())
73    }
74
75    async fn delete(&self, key: &[u8]) -> Result<()> {
76        self.state.write().await.remove(key);
77        Ok(())
78    }
79
80    async fn contains(&self, key: &[u8]) -> Result<bool> {
81        Ok(self.state.read().await.contains_key(key))
82    }
83
84    async fn clear(&self) -> Result<()> {
85        self.state.write().await.clear();
86        Ok(())
87    }
88
89    async fn snapshot(&self) -> Result<Vec<u8>> {
90        let state = self.state.read().await;
91        // Use oxicode for binary serialization since JSON requires string keys
92        oxicode::encode_to_vec(&*state)
93            .map_err(|e| StreamingError::SerializationError(e.to_string()))
94    }
95
96    async fn restore(&self, snapshot: &[u8]) -> Result<()> {
97        let (restored, _): (HashMap<Vec<u8>, Vec<u8>>, _) = oxicode::decode_from_slice(snapshot)
98            .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
99        *self.state.write().await = restored;
100        Ok(())
101    }
102
103    async fn keys(&self) -> Result<Vec<Vec<u8>>> {
104        Ok(self.state.read().await.keys().cloned().collect())
105    }
106
107    fn name(&self) -> &str {
108        "MemoryStateBackend"
109    }
110}
111
112/// RocksDB state backend.
113#[cfg(feature = "rocksdb-backend")]
114pub struct RocksDBStateBackend {
115    db: Arc<rocksdb::DB>,
116    path: PathBuf,
117}
118
119#[cfg(feature = "rocksdb-backend")]
120impl RocksDBStateBackend {
121    /// Create a new RocksDB state backend.
122    pub fn new(path: PathBuf) -> Result<Self> {
123        let mut opts = rocksdb::Options::default();
124        opts.create_if_missing(true);
125
126        let db = rocksdb::DB::open(&opts, &path)?;
127
128        Ok(Self {
129            db: Arc::new(db),
130            path,
131        })
132    }
133
134    /// Get the database path.
135    pub fn path(&self) -> &PathBuf {
136        &self.path
137    }
138}
139
140#[cfg(feature = "rocksdb-backend")]
141#[async_trait]
142impl StateBackend for RocksDBStateBackend {
143    async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
144        Ok(self.db.get(key)?)
145    }
146
147    async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
148        self.db.put(key, value)?;
149        Ok(())
150    }
151
152    async fn delete(&self, key: &[u8]) -> Result<()> {
153        self.db.delete(key)?;
154        Ok(())
155    }
156
157    async fn contains(&self, key: &[u8]) -> Result<bool> {
158        Ok(self.db.get(key)?.is_some())
159    }
160
161    async fn clear(&self) -> Result<()> {
162        let keys: Vec<Vec<u8>> = self
163            .db
164            .iterator(rocksdb::IteratorMode::Start)
165            .map(|item| {
166                let (key, _) = item.map_err(|e| StreamingError::StateError(e.to_string()))?;
167                Ok(key.to_vec())
168            })
169            .collect::<Result<Vec<_>>>()?;
170
171        for key in keys {
172            self.db.delete(&key)?;
173        }
174
175        Ok(())
176    }
177
178    async fn snapshot(&self) -> Result<Vec<u8>> {
179        let snapshot = self.db.snapshot();
180        let mut data = Vec::new();
181
182        for item in snapshot.iterator(rocksdb::IteratorMode::Start) {
183            let (key, value) = item?;
184            let entry = (key.to_vec(), value.to_vec());
185            // Use oxicode for binary serialization
186            let serialized = oxicode::encode_to_vec(&entry)
187                .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
188            data.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
189            data.extend_from_slice(&serialized);
190        }
191
192        Ok(data)
193    }
194
195    async fn restore(&self, snapshot: &[u8]) -> Result<()> {
196        self.clear().await?;
197
198        let mut offset = 0;
199        while offset < snapshot.len() {
200            if offset + 4 > snapshot.len() {
201                break;
202            }
203
204            let len = u32::from_le_bytes([
205                snapshot[offset],
206                snapshot[offset + 1],
207                snapshot[offset + 2],
208                snapshot[offset + 3],
209            ]) as usize;
210
211            offset += 4;
212
213            if offset + len > snapshot.len() {
214                break;
215            }
216
217            let entry_data = &snapshot[offset..offset + len];
218            let ((key, value), _): ((Vec<u8>, Vec<u8>), _) = oxicode::decode_from_slice(entry_data)
219                .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
220            self.db.put(&key, &value)?;
221
222            offset += len;
223        }
224
225        Ok(())
226    }
227
228    async fn keys(&self) -> Result<Vec<Vec<u8>>> {
229        let keys: Vec<Vec<u8>> = self
230            .db
231            .iterator(rocksdb::IteratorMode::Start)
232            .map(|item| {
233                let (key, _) = item?;
234                Ok(key.to_vec())
235            })
236            .collect::<std::result::Result<Vec<_>, rocksdb::Error>>()?;
237
238        Ok(keys)
239    }
240
241    fn name(&self) -> &str {
242        "RocksDBStateBackend"
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[tokio::test]
251    async fn test_memory_backend() -> Result<()> {
252        let backend = MemoryStateBackend::new();
253
254        backend.put(b"key1", b"value1").await?;
255        let value = backend.get(b"key1").await?;
256        assert_eq!(value, Some(b"value1".to_vec()));
257
258        assert!(backend.contains(b"key1").await?);
259        assert!(!backend.contains(b"key2").await?);
260
261        backend.delete(b"key1").await?;
262        assert!(!backend.contains(b"key1").await?);
263
264        Ok(())
265    }
266
267    #[tokio::test]
268    async fn test_memory_backend_snapshot() -> Result<()> {
269        let backend = MemoryStateBackend::new();
270
271        backend.put(b"key1", b"value1").await?;
272        backend.put(b"key2", b"value2").await?;
273
274        let snapshot = backend.snapshot().await?;
275
276        let backend2 = MemoryStateBackend::new();
277        backend2.restore(&snapshot).await?;
278
279        assert_eq!(backend2.get(b"key1").await?, Some(b"value1".to_vec()));
280        assert_eq!(backend2.get(b"key2").await?, Some(b"value2".to_vec()));
281
282        Ok(())
283    }
284
285    #[cfg(feature = "rocksdb-backend")]
286    #[tokio::test]
287    async fn test_rocksdb_backend() -> Result<()> {
288        let temp_dir = tempfile::tempdir()
289            .map_err(|e| StreamingError::StateError(format!("Failed to create temp dir: {}", e)))?;
290        let backend = RocksDBStateBackend::new(temp_dir.path().to_path_buf())?;
291
292        backend.put(b"key1", b"value1").await?;
293        let value = backend.get(b"key1").await?;
294        assert_eq!(value, Some(b"value1".to_vec()));
295
296        assert!(backend.contains(b"key1").await?);
297
298        backend.delete(b"key1").await?;
299        assert!(!backend.contains(b"key1").await?);
300
301        Ok(())
302    }
303}