use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
#[derive(Error, Debug)]
pub enum StateStoreError {
#[error("Key not found: {0}")]
KeyNotFound(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("State store error: {0}")]
Other(String),
}
pub type StateStoreResult<T> = Result<T, StateStoreError>;
#[async_trait]
pub trait StateStoreProvider: Send + Sync {
async fn get(&self, store_id: &str, key: &str) -> StateStoreResult<Option<Vec<u8>>>;
async fn set(&self, store_id: &str, key: &str, value: Vec<u8>) -> StateStoreResult<()>;
async fn delete(&self, store_id: &str, key: &str) -> StateStoreResult<bool>;
async fn contains_key(&self, store_id: &str, key: &str) -> StateStoreResult<bool>;
async fn get_many(
&self,
store_id: &str,
keys: &[&str],
) -> StateStoreResult<HashMap<String, Vec<u8>>>;
async fn set_many(&self, store_id: &str, entries: &[(&str, &[u8])]) -> StateStoreResult<()>;
async fn delete_many(&self, store_id: &str, keys: &[&str]) -> StateStoreResult<usize>;
async fn clear_store(&self, store_id: &str) -> StateStoreResult<usize>;
async fn list_keys(&self, store_id: &str) -> StateStoreResult<Vec<String>>;
async fn store_exists(&self, store_id: &str) -> StateStoreResult<bool>;
async fn key_count(&self, store_id: &str) -> StateStoreResult<usize>;
async fn sync(&self) -> StateStoreResult<()> {
Ok(())
}
}
pub struct MemoryStateStoreProvider {
stores: Arc<RwLock<HashMap<String, HashMap<String, Vec<u8>>>>>,
}
impl Default for MemoryStateStoreProvider {
fn default() -> Self {
Self::new()
}
}
impl MemoryStateStoreProvider {
pub fn new() -> Self {
Self {
stores: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl StateStoreProvider for MemoryStateStoreProvider {
async fn get(&self, store_id: &str, key: &str) -> StateStoreResult<Option<Vec<u8>>> {
let stores = self.stores.read().await;
Ok(stores
.get(store_id)
.and_then(|store| store.get(key).cloned()))
}
async fn set(&self, store_id: &str, key: &str, value: Vec<u8>) -> StateStoreResult<()> {
let mut stores = self.stores.write().await;
stores
.entry(store_id.to_string())
.or_default()
.insert(key.to_string(), value);
Ok(())
}
async fn delete(&self, store_id: &str, key: &str) -> StateStoreResult<bool> {
let mut stores = self.stores.write().await;
if let Some(store) = stores.get_mut(store_id) {
let existed = store.remove(key).is_some();
if store.is_empty() {
stores.remove(store_id);
}
Ok(existed)
} else {
Ok(false)
}
}
async fn contains_key(&self, store_id: &str, key: &str) -> StateStoreResult<bool> {
let stores = self.stores.read().await;
Ok(stores
.get(store_id)
.is_some_and(|store| store.contains_key(key)))
}
async fn get_many(
&self,
store_id: &str,
keys: &[&str],
) -> StateStoreResult<HashMap<String, Vec<u8>>> {
let stores = self.stores.read().await;
let mut result = HashMap::new();
if let Some(store) = stores.get(store_id) {
for key in keys {
if let Some(value) = store.get(*key) {
result.insert((*key).to_string(), value.clone());
}
}
}
Ok(result)
}
async fn set_many(&self, store_id: &str, entries: &[(&str, &[u8])]) -> StateStoreResult<()> {
let mut stores = self.stores.write().await;
let store = stores.entry(store_id.to_string()).or_default();
for (key, value) in entries {
store.insert((*key).to_string(), value.to_vec());
}
Ok(())
}
async fn delete_many(&self, store_id: &str, keys: &[&str]) -> StateStoreResult<usize> {
let mut stores = self.stores.write().await;
let mut count = 0;
if let Some(store) = stores.get_mut(store_id) {
for key in keys {
if store.remove(*key).is_some() {
count += 1;
}
}
if store.is_empty() {
stores.remove(store_id);
}
}
Ok(count)
}
async fn clear_store(&self, store_id: &str) -> StateStoreResult<usize> {
let mut stores = self.stores.write().await;
if let Some(store) = stores.remove(store_id) {
Ok(store.len())
} else {
Ok(0)
}
}
async fn list_keys(&self, store_id: &str) -> StateStoreResult<Vec<String>> {
let stores = self.stores.read().await;
Ok(stores
.get(store_id)
.map(|store| store.keys().cloned().collect())
.unwrap_or_default())
}
async fn store_exists(&self, store_id: &str) -> StateStoreResult<bool> {
let stores = self.stores.read().await;
Ok(stores.get(store_id).is_some_and(|store| !store.is_empty()))
}
async fn key_count(&self, store_id: &str) -> StateStoreResult<usize> {
let stores = self.stores.read().await;
Ok(stores.get(store_id).map(|store| store.len()).unwrap_or(0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_state_store_get_set() {
let provider = MemoryStateStoreProvider::new();
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
let result = provider.get("store1", "key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
let result = provider.get("store1", "nonexistent").await.unwrap();
assert_eq!(result, None);
let result = provider.get("nonexistent", "key1").await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_memory_state_store_delete() {
let provider = MemoryStateStoreProvider::new();
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
let deleted = provider.delete("store1", "key1").await.unwrap();
assert!(deleted);
let result = provider.get("store1", "key1").await.unwrap();
assert_eq!(result, None);
let deleted = provider.delete("store1", "nonexistent").await.unwrap();
assert!(!deleted);
}
#[tokio::test]
async fn test_memory_state_store_contains_key() {
let provider = MemoryStateStoreProvider::new();
assert!(!provider.contains_key("store1", "key1").await.unwrap());
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
assert!(provider.contains_key("store1", "key1").await.unwrap());
assert!(!provider.contains_key("store1", "key2").await.unwrap());
}
#[tokio::test]
async fn test_memory_state_store_get_many() {
let provider = MemoryStateStoreProvider::new();
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
provider
.set("store1", "key2", b"value2".to_vec())
.await
.unwrap();
let result = provider
.get_many("store1", &["key1", "key2", "nonexistent"])
.await
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result.get("key1"), Some(&b"value1".to_vec()));
assert_eq!(result.get("key2"), Some(&b"value2".to_vec()));
}
#[tokio::test]
async fn test_memory_state_store_set_many() {
let provider = MemoryStateStoreProvider::new();
provider
.set_many("store1", &[("key1", b"value1"), ("key2", b"value2")])
.await
.unwrap();
let result = provider.get("store1", "key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
let result = provider.get("store1", "key2").await.unwrap();
assert_eq!(result, Some(b"value2".to_vec()));
}
#[tokio::test]
async fn test_memory_state_store_delete_many() {
let provider = MemoryStateStoreProvider::new();
provider
.set_many(
"store1",
&[("key1", b"value1"), ("key2", b"value2"), ("key3", b"value3")],
)
.await
.unwrap();
let count = provider
.delete_many("store1", &["key1", "key2", "nonexistent"])
.await
.unwrap();
assert_eq!(count, 2);
let result = provider.get("store1", "key1").await.unwrap();
assert_eq!(result, None);
let result = provider.get("store1", "key3").await.unwrap();
assert_eq!(result, Some(b"value3".to_vec()));
}
#[tokio::test]
async fn test_memory_state_store_clear_store() {
let provider = MemoryStateStoreProvider::new();
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
provider
.set("store1", "key2", b"value2".to_vec())
.await
.unwrap();
provider
.set("store2", "key1", b"value1".to_vec())
.await
.unwrap();
let count = provider.clear_store("store1").await.unwrap();
assert_eq!(count, 2);
let result = provider.get("store1", "key1").await.unwrap();
assert_eq!(result, None);
let result = provider.get("store2", "key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
}
#[tokio::test]
async fn test_memory_state_store_list_keys() {
let provider = MemoryStateStoreProvider::new();
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
provider
.set("store1", "key2", b"value2".to_vec())
.await
.unwrap();
let mut keys = provider.list_keys("store1").await.unwrap();
keys.sort();
assert_eq!(keys, vec!["key1", "key2"]);
let keys = provider.list_keys("nonexistent").await.unwrap();
assert!(keys.is_empty());
}
#[tokio::test]
async fn test_memory_state_store_store_exists() {
let provider = MemoryStateStoreProvider::new();
assert!(!provider.store_exists("store1").await.unwrap());
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
assert!(provider.store_exists("store1").await.unwrap());
provider.delete("store1", "key1").await.unwrap();
assert!(!provider.store_exists("store1").await.unwrap());
}
#[tokio::test]
async fn test_memory_state_store_key_count() {
let provider = MemoryStateStoreProvider::new();
assert_eq!(provider.key_count("store1").await.unwrap(), 0);
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
assert_eq!(provider.key_count("store1").await.unwrap(), 1);
provider
.set("store1", "key2", b"value2".to_vec())
.await
.unwrap();
assert_eq!(provider.key_count("store1").await.unwrap(), 2);
}
#[tokio::test]
async fn test_memory_state_store_partitioning() {
let provider = MemoryStateStoreProvider::new();
provider
.set("store1", "key", b"value1".to_vec())
.await
.unwrap();
provider
.set("store2", "key", b"value2".to_vec())
.await
.unwrap();
let result1 = provider.get("store1", "key").await.unwrap();
let result2 = provider.get("store2", "key").await.unwrap();
assert_eq!(result1, Some(b"value1".to_vec()));
assert_eq!(result2, Some(b"value2".to_vec()));
}
#[tokio::test]
async fn test_memory_state_store_sync() {
let provider = MemoryStateStoreProvider::new();
provider.sync().await.unwrap();
provider
.set("store1", "key1", b"value1".to_vec())
.await
.unwrap();
provider.sync().await.unwrap();
}
}