use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use crate::error::{Error, Result};
#[async_trait]
pub trait StateStore: Send + Sync + 'static {
async fn load(&self, key: &str) -> Result<Option<serde_json::Value>>;
async fn save(&self, key: &str, value: serde_json::Value) -> Result<()>;
async fn flush(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct InMemoryStore {
inner: Arc<Mutex<HashMap<String, serde_json::Value>>>,
}
impl InMemoryStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.inner.lock().expect("InMemoryStore poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[async_trait]
impl StateStore for InMemoryStore {
async fn load(&self, key: &str) -> Result<Option<serde_json::Value>> {
Ok(self
.inner
.lock()
.map_err(|_| Error::Storage("InMemoryStore mutex poisoned".into()))?
.get(key)
.cloned())
}
async fn save(&self, key: &str, value: serde_json::Value) -> Result<()> {
self.inner
.lock()
.map_err(|_| Error::Storage("InMemoryStore mutex poisoned".into()))?
.insert(key.to_string(), value);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn load_missing_key_returns_none() {
let store = InMemoryStore::new();
assert!(store.load("nope").await.unwrap().is_none());
assert!(store.is_empty());
}
#[tokio::test]
async fn save_then_load_roundtrips() {
let store = InMemoryStore::new();
let val = serde_json::json!({"realised": 12.5, "halted": true});
store.save("bot/BTCUSDT", val.clone()).await.unwrap();
assert_eq!(store.load("bot/BTCUSDT").await.unwrap(), Some(val));
assert_eq!(store.len(), 1);
}
#[tokio::test]
async fn save_overwrites_previous() {
let store = InMemoryStore::new();
store.save("k", serde_json::json!(1)).await.unwrap();
store.save("k", serde_json::json!(2)).await.unwrap();
assert_eq!(store.load("k").await.unwrap(), Some(serde_json::json!(2)));
assert_eq!(store.len(), 1);
}
#[tokio::test]
async fn clones_share_state() {
let a = InMemoryStore::new();
let b = a.clone();
a.save("k", serde_json::json!("v")).await.unwrap();
assert_eq!(b.load("k").await.unwrap(), Some(serde_json::json!("v")));
}
#[tokio::test]
async fn flush_default_is_ok() {
let store = InMemoryStore::new();
assert!(store.flush().await.is_ok());
}
}