use crate::traits::{CerebroError, KVStore, Result};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Clone, Default)]
pub struct MemoryKVStore {
data: Arc<RwLock<HashMap<String, String>>>,
}
impl MemoryKVStore {
pub fn new() -> Self {
Self {
data: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl KVStore for MemoryKVStore {
async fn set(&self, key: &str, value: &str) -> Result<()> {
let mut store = self
.data
.write()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
store.insert(key.to_string(), value.to_string());
Ok(())
}
async fn get(&self, key: &str) -> Result<Option<String>> {
let store = self
.data
.read()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
Ok(store.get(key).cloned())
}
}
#[derive(Clone, Default)]
pub struct EventSourcedKVStore {
timeline: Arc<RwLock<Vec<(String, String)>>>,
}
impl EventSourcedKVStore {
pub fn new() -> Self {
Self {
timeline: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn get_state_at_step(&self, step: usize) -> Result<HashMap<String, String>> {
let timeline = self
.timeline
.read()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
let mut state = HashMap::new();
for (k, v) in timeline.iter().take(step) {
state.insert(k.clone(), v.clone());
}
Ok(state)
}
}
#[async_trait]
impl KVStore for EventSourcedKVStore {
async fn set(&self, key: &str, value: &str) -> Result<()> {
let mut timeline = self
.timeline
.write()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
timeline.push((key.to_string(), value.to_string()));
Ok(())
}
async fn get(&self, key: &str) -> Result<Option<String>> {
let timeline = self
.timeline
.read()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
for (k, v) in timeline.iter().rev() {
if k == key {
return Ok(Some(v.clone()));
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_set_and_get() {
let store = MemoryKVStore::new();
store.set("session_id", "12345").await.unwrap();
let val = store.get("session_id").await.unwrap();
assert_eq!(val, Some("12345".to_string()));
let empty = store.get("missing").await.unwrap();
assert_eq!(empty, None);
}
}