use std::collections::HashMap;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::error::PersistenceError;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct StoreItem {
pub value: Value,
pub key: String,
pub namespace: Vec<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl StoreItem {
pub fn new(value: Value, key: String, namespace: Vec<String>) -> Self {
let now = Utc::now();
Self {
value,
key,
namespace,
created_at: now,
updated_at: now,
}
}
}
#[async_trait]
pub trait Store: Send + Sync {
async fn put(
&self,
namespace: &[&str],
key: &str,
value: Value,
) -> Result<(), PersistenceError>;
async fn get(
&self,
namespace: &[&str],
key: &str,
) -> Result<Option<StoreItem>, PersistenceError>;
async fn search(
&self,
namespace: &[&str],
query: Option<&str>,
limit: Option<usize>,
) -> Result<Vec<StoreItem>, PersistenceError>;
async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), PersistenceError>;
fn supports_semantic_search(&self) -> bool {
false
}
fn embedding_dims(&self) -> Option<usize> {
None
}
}
pub struct InMemoryStore {
data: tokio::sync::RwLock<HashMap<String, StoreItem>>,
}
impl InMemoryStore {
pub fn new() -> Self {
Self {
data: tokio::sync::RwLock::new(HashMap::new()),
}
}
fn make_key(namespace: &[&str], key: &str) -> String {
if namespace.is_empty() {
key.to_string()
} else {
format!("{}:{}", namespace.join(":"), key)
}
}
}
impl Default for InMemoryStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Store for InMemoryStore {
async fn put(
&self,
namespace: &[&str],
key: &str,
value: Value,
) -> Result<(), PersistenceError> {
let store_key = Self::make_key(namespace, key);
let namespace_vec: Vec<String> = namespace.iter().map(|s| s.to_string()).collect();
let item = StoreItem::new(value, key.to_string(), namespace_vec);
let mut data = self.data.write().await;
data.insert(store_key, item);
Ok(())
}
async fn get(
&self,
namespace: &[&str],
key: &str,
) -> Result<Option<StoreItem>, PersistenceError> {
let store_key = Self::make_key(namespace, key);
let data = self.data.read().await;
Ok(data.get(&store_key).cloned())
}
async fn search(
&self,
namespace: &[&str],
query: Option<&str>,
limit: Option<usize>,
) -> Result<Vec<StoreItem>, PersistenceError> {
let namespace_prefix = if namespace.is_empty() {
String::new()
} else {
format!("{}:", namespace.join(":"))
};
let data = self.data.read().await;
let mut results: Vec<StoreItem> = data
.values()
.filter(|item| {
if !namespace_prefix.is_empty() {
let item_prefix = format!("{}:", item.namespace.join(":"));
if item_prefix != namespace_prefix {
return false;
}
} else if !item.namespace.is_empty() {
return false;
}
if let Some(q) = query {
let value_str = serde_json::to_string(&item.value).unwrap_or_default();
value_str.to_lowercase().contains(&q.to_lowercase())
} else {
true
}
})
.cloned()
.collect();
results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
if let Some(limit) = limit {
results.truncate(limit);
}
Ok(results)
}
async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), PersistenceError> {
let store_key = Self::make_key(namespace, key);
let mut data = self.data.write().await;
data.remove(&store_key);
Ok(())
}
}
pub type StoreBox = std::sync::Arc<dyn Store>;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_store() {
let store = InMemoryStore::new();
store
.put(
&["user-1", "memories"],
"memory-1",
serde_json::json!({"food_preference": "I like pizza"}),
)
.await
.unwrap();
let item = store
.get(&["user-1", "memories"], "memory-1")
.await
.unwrap();
assert!(item.is_some());
assert_eq!(
item.unwrap().value,
serde_json::json!({"food_preference": "I like pizza"})
);
let results = store
.search(&["user-1", "memories"], Some("pizza"), None)
.await
.unwrap();
assert_eq!(results.len(), 1);
store
.delete(&["user-1", "memories"], "memory-1")
.await
.unwrap();
let item = store
.get(&["user-1", "memories"], "memory-1")
.await
.unwrap();
assert!(item.is_none());
}
}