use crate::error::{Result, StreamingError};
use async_trait::async_trait;
use std::collections::HashMap;
#[cfg(feature = "rocksdb-backend")]
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait]
pub trait StateBackend: Send + Sync {
async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
async fn put(&self, key: &[u8], value: &[u8]) -> Result<()>;
async fn delete(&self, key: &[u8]) -> Result<()>;
async fn contains(&self, key: &[u8]) -> Result<bool>;
async fn clear(&self) -> Result<()>;
async fn snapshot(&self) -> Result<Vec<u8>>;
async fn restore(&self, snapshot: &[u8]) -> Result<()>;
async fn keys(&self) -> Result<Vec<Vec<u8>>>;
fn name(&self) -> &str;
}
pub struct MemoryStateBackend {
state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
}
impl MemoryStateBackend {
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemoryStateBackend {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateBackend for MemoryStateBackend {
async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.state.read().await.get(key).cloned())
}
async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
self.state
.write()
.await
.insert(key.to_vec(), value.to_vec());
Ok(())
}
async fn delete(&self, key: &[u8]) -> Result<()> {
self.state.write().await.remove(key);
Ok(())
}
async fn contains(&self, key: &[u8]) -> Result<bool> {
Ok(self.state.read().await.contains_key(key))
}
async fn clear(&self) -> Result<()> {
self.state.write().await.clear();
Ok(())
}
async fn snapshot(&self) -> Result<Vec<u8>> {
let state = self.state.read().await;
oxicode::encode_to_vec(&*state)
.map_err(|e| StreamingError::SerializationError(e.to_string()))
}
async fn restore(&self, snapshot: &[u8]) -> Result<()> {
let (restored, _): (HashMap<Vec<u8>, Vec<u8>>, _) = oxicode::decode_from_slice(snapshot)
.map_err(|e| StreamingError::SerializationError(e.to_string()))?;
*self.state.write().await = restored;
Ok(())
}
async fn keys(&self) -> Result<Vec<Vec<u8>>> {
Ok(self.state.read().await.keys().cloned().collect())
}
fn name(&self) -> &str {
"MemoryStateBackend"
}
}
#[cfg(feature = "rocksdb-backend")]
pub struct RocksDBStateBackend {
db: Arc<rocksdb::DB>,
path: PathBuf,
}
#[cfg(feature = "rocksdb-backend")]
impl RocksDBStateBackend {
pub fn new(path: PathBuf) -> Result<Self> {
let mut opts = rocksdb::Options::default();
opts.create_if_missing(true);
let db = rocksdb::DB::open(&opts, &path)?;
Ok(Self {
db: Arc::new(db),
path,
})
}
pub fn path(&self) -> &PathBuf {
&self.path
}
}
#[cfg(feature = "rocksdb-backend")]
#[async_trait]
impl StateBackend for RocksDBStateBackend {
async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.db.get(key)?)
}
async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
self.db.put(key, value)?;
Ok(())
}
async fn delete(&self, key: &[u8]) -> Result<()> {
self.db.delete(key)?;
Ok(())
}
async fn contains(&self, key: &[u8]) -> Result<bool> {
Ok(self.db.get(key)?.is_some())
}
async fn clear(&self) -> Result<()> {
let keys: Vec<Vec<u8>> = self
.db
.iterator(rocksdb::IteratorMode::Start)
.map(|item| {
let (key, _) = item.map_err(|e| StreamingError::StateError(e.to_string()))?;
Ok(key.to_vec())
})
.collect::<Result<Vec<_>>>()?;
for key in keys {
self.db.delete(&key)?;
}
Ok(())
}
async fn snapshot(&self) -> Result<Vec<u8>> {
let snapshot = self.db.snapshot();
let mut data = Vec::new();
for item in snapshot.iterator(rocksdb::IteratorMode::Start) {
let (key, value) = item?;
let entry = (key.to_vec(), value.to_vec());
let serialized = oxicode::encode_to_vec(&entry)
.map_err(|e| StreamingError::SerializationError(e.to_string()))?;
data.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
data.extend_from_slice(&serialized);
}
Ok(data)
}
async fn restore(&self, snapshot: &[u8]) -> Result<()> {
self.clear().await?;
let mut offset = 0;
while offset < snapshot.len() {
if offset + 4 > snapshot.len() {
break;
}
let len = u32::from_le_bytes([
snapshot[offset],
snapshot[offset + 1],
snapshot[offset + 2],
snapshot[offset + 3],
]) as usize;
offset += 4;
if offset + len > snapshot.len() {
break;
}
let entry_data = &snapshot[offset..offset + len];
let ((key, value), _): ((Vec<u8>, Vec<u8>), _) = oxicode::decode_from_slice(entry_data)
.map_err(|e| StreamingError::SerializationError(e.to_string()))?;
self.db.put(&key, &value)?;
offset += len;
}
Ok(())
}
async fn keys(&self) -> Result<Vec<Vec<u8>>> {
let keys: Vec<Vec<u8>> = self
.db
.iterator(rocksdb::IteratorMode::Start)
.map(|item| {
let (key, _) = item?;
Ok(key.to_vec())
})
.collect::<std::result::Result<Vec<_>, rocksdb::Error>>()?;
Ok(keys)
}
fn name(&self) -> &str {
"RocksDBStateBackend"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_backend() -> Result<()> {
let backend = MemoryStateBackend::new();
backend.put(b"key1", b"value1").await?;
let value = backend.get(b"key1").await?;
assert_eq!(value, Some(b"value1".to_vec()));
assert!(backend.contains(b"key1").await?);
assert!(!backend.contains(b"key2").await?);
backend.delete(b"key1").await?;
assert!(!backend.contains(b"key1").await?);
Ok(())
}
#[tokio::test]
async fn test_memory_backend_snapshot() -> Result<()> {
let backend = MemoryStateBackend::new();
backend.put(b"key1", b"value1").await?;
backend.put(b"key2", b"value2").await?;
let snapshot = backend.snapshot().await?;
let backend2 = MemoryStateBackend::new();
backend2.restore(&snapshot).await?;
assert_eq!(backend2.get(b"key1").await?, Some(b"value1".to_vec()));
assert_eq!(backend2.get(b"key2").await?, Some(b"value2".to_vec()));
Ok(())
}
#[cfg(feature = "rocksdb-backend")]
#[tokio::test]
async fn test_rocksdb_backend() -> Result<()> {
let temp_dir = tempfile::tempdir()
.map_err(|e| StreamingError::StateError(format!("Failed to create temp dir: {}", e)))?;
let backend = RocksDBStateBackend::new(temp_dir.path().to_path_buf())?;
backend.put(b"key1", b"value1").await?;
let value = backend.get(b"key1").await?;
assert_eq!(value, Some(b"value1".to_vec()));
assert!(backend.contains(b"key1").await?);
backend.delete(b"key1").await?;
assert!(!backend.contains(b"key1").await?);
Ok(())
}
}