use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use crate::error::StorageError;
use crate::key::StateKey;
#[async_trait]
pub trait StateStorage: Send + Sync + 'static {
async fn get_state(&self, key: StateKey) -> Result<Option<String>, StorageError>;
async fn set_state(&self, key: StateKey, state: String) -> Result<(), StorageError>;
async fn clear_state(&self, key: StateKey) -> Result<(), StorageError>;
async fn get_data(
&self,
key: StateKey,
field: &str,
) -> Result<Option<serde_json::Value>, StorageError>;
async fn set_data(
&self,
key: StateKey,
field: &str,
value: serde_json::Value,
) -> Result<(), StorageError>;
async fn get_all_data(
&self,
key: StateKey,
) -> Result<HashMap<String, serde_json::Value>, StorageError>;
async fn clear_data(&self, key: StateKey) -> Result<(), StorageError>;
async fn clear_all(&self, key: StateKey) -> Result<(), StorageError>;
}
#[derive(Clone, Default)]
pub struct MemoryStorage {
entries: Arc<DashMap<StateKey, StorageEntry>>,
}
#[derive(Clone, Default)]
struct StorageEntry {
state: Option<String>,
data: HashMap<String, serde_json::Value>,
}
impl MemoryStorage {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[async_trait]
impl StateStorage for MemoryStorage {
async fn get_state(&self, key: StateKey) -> Result<Option<String>, StorageError> {
Ok(self.entries.get(&key).and_then(|e| e.state.clone()))
}
async fn set_state(&self, key: StateKey, state: String) -> Result<(), StorageError> {
self.entries.entry(key).or_default().state = Some(state);
Ok(())
}
async fn clear_state(&self, key: StateKey) -> Result<(), StorageError> {
if let Some(mut entry) = self.entries.get_mut(&key) {
entry.state = None;
if entry.data.is_empty() {
drop(entry);
self.entries.remove(&key);
}
}
Ok(())
}
async fn get_data(
&self,
key: StateKey,
field: &str,
) -> Result<Option<serde_json::Value>, StorageError> {
Ok(self
.entries
.get(&key)
.and_then(|e| e.data.get(field).cloned()))
}
async fn set_data(
&self,
key: StateKey,
field: &str,
value: serde_json::Value,
) -> Result<(), StorageError> {
self.entries
.entry(key)
.or_default()
.data
.insert(field.to_string(), value);
Ok(())
}
async fn get_all_data(
&self,
key: StateKey,
) -> Result<HashMap<String, serde_json::Value>, StorageError> {
Ok(self
.entries
.get(&key)
.map(|e| e.data.clone())
.unwrap_or_default())
}
async fn clear_data(&self, key: StateKey) -> Result<(), StorageError> {
if let Some(mut entry) = self.entries.get_mut(&key) {
entry.data.clear();
if entry.state.is_none() {
drop(entry);
self.entries.remove(&key);
}
}
Ok(())
}
async fn clear_all(&self, key: StateKey) -> Result<(), StorageError> {
self.entries.remove(&key);
Ok(())
}
}