use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use crate::AofResult;
#[async_trait]
pub trait MemoryBackend: Send + Sync {
async fn store(&self, key: &str, entry: MemoryEntry) -> AofResult<()>;
async fn retrieve(&self, key: &str) -> AofResult<Option<MemoryEntry>>;
async fn delete(&self, key: &str) -> AofResult<()>;
async fn list_keys(&self, prefix: Option<&str>) -> AofResult<Vec<String>>;
async fn clear(&self) -> AofResult<()>;
async fn search(&self, query: &MemoryQuery) -> AofResult<Vec<MemoryEntry>> {
let keys = self.list_keys(query.prefix.as_deref()).await?;
let mut results = Vec::new();
for key in keys {
if let Some(entry) = self.retrieve(&key).await? {
if query.matches(&entry) {
results.push(entry);
}
}
}
Ok(results)
}
}
#[async_trait]
pub trait Memory: Send + Sync {
async fn store(&self, key: &str, value: serde_json::Value) -> AofResult<()>;
async fn retrieve<T: serde::de::DeserializeOwned>(&self, key: &str) -> AofResult<Option<T>>;
async fn delete(&self, key: &str) -> AofResult<()>;
async fn list_keys(&self) -> AofResult<Vec<String>>;
async fn clear(&self) -> AofResult<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub key: String,
pub value: serde_json::Value,
pub timestamp: u64,
#[serde(default)]
pub metadata: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<u64>,
}
impl MemoryEntry {
pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
Self {
key: key.into(),
value,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
metadata: HashMap::new(),
ttl: None,
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn with_ttl(mut self, ttl_secs: u64) -> Self {
self.ttl = Some(ttl_secs);
self
}
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.ttl {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let expiry = self.timestamp + (ttl * 1000);
now > expiry
} else {
false
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MemoryQuery {
#[serde(skip_serializing_if = "Option::is_none")]
pub prefix: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<usize>,
#[serde(default)]
pub include_expired: bool,
}
impl MemoryQuery {
pub fn matches(&self, entry: &MemoryEntry) -> bool {
if !self.include_expired && entry.is_expired() {
return false;
}
for (key, value) in &self.metadata {
if entry.metadata.get(key) != Some(value) {
return false;
}
}
true
}
}
pub type MemoryBackendRef = Arc<dyn MemoryBackend>;
pub type MemoryRef = Arc<dyn Memory>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_entry_new() {
let entry = MemoryEntry::new("test-key", serde_json::json!({"value": 42}));
assert_eq!(entry.key, "test-key");
assert_eq!(entry.value, serde_json::json!({"value": 42}));
assert!(entry.timestamp > 0);
assert!(entry.metadata.is_empty());
assert!(entry.ttl.is_none());
}
#[test]
fn test_memory_entry_with_metadata() {
let entry = MemoryEntry::new("key", serde_json::json!(null))
.with_metadata("type", "session")
.with_metadata("agent", "test-agent");
assert_eq!(entry.metadata.get("type"), Some(&"session".to_string()));
assert_eq!(entry.metadata.get("agent"), Some(&"test-agent".to_string()));
}
#[test]
fn test_memory_entry_with_ttl() {
let entry = MemoryEntry::new("key", serde_json::json!(null))
.with_ttl(3600);
assert_eq!(entry.ttl, Some(3600));
}
#[test]
fn test_memory_entry_is_expired() {
let entry_no_ttl = MemoryEntry::new("key", serde_json::json!(null));
assert!(!entry_no_ttl.is_expired());
let entry_long_ttl = MemoryEntry::new("key", serde_json::json!(null))
.with_ttl(3600); assert!(!entry_long_ttl.is_expired());
let mut entry_expired = MemoryEntry::new("key", serde_json::json!(null))
.with_ttl(1); entry_expired.timestamp -= 2000;
assert!(entry_expired.is_expired());
}
#[test]
fn test_memory_entry_serialization() {
let entry = MemoryEntry::new("my-key", serde_json::json!({"data": "test"}))
.with_metadata("source", "api")
.with_ttl(60);
let json = serde_json::to_string(&entry).unwrap();
let deserialized: MemoryEntry = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.key, "my-key");
assert_eq!(deserialized.value, serde_json::json!({"data": "test"}));
assert_eq!(deserialized.metadata.get("source"), Some(&"api".to_string()));
assert_eq!(deserialized.ttl, Some(60));
}
#[test]
fn test_memory_query_default() {
let query = MemoryQuery::default();
assert!(query.prefix.is_none());
assert!(query.metadata.is_empty());
assert!(query.limit.is_none());
assert!(!query.include_expired);
}
#[test]
fn test_memory_query_matches_basic() {
let query = MemoryQuery::default();
let entry = MemoryEntry::new("key", serde_json::json!(null));
assert!(query.matches(&entry));
}
#[test]
fn test_memory_query_matches_metadata() {
let mut query = MemoryQuery::default();
query.metadata.insert("type".to_string(), "session".to_string());
let entry_no_meta = MemoryEntry::new("key", serde_json::json!(null));
assert!(!query.matches(&entry_no_meta));
let entry_match = MemoryEntry::new("key", serde_json::json!(null))
.with_metadata("type", "session");
assert!(query.matches(&entry_match));
let entry_wrong = MemoryEntry::new("key", serde_json::json!(null))
.with_metadata("type", "permanent");
assert!(!query.matches(&entry_wrong));
}
#[test]
fn test_memory_query_matches_expired() {
let mut entry_expired = MemoryEntry::new("key", serde_json::json!(null))
.with_ttl(1); entry_expired.timestamp -= 2000;
let query_default = MemoryQuery::default();
assert!(!query_default.matches(&entry_expired));
let query_include = MemoryQuery {
include_expired: true,
..Default::default()
};
assert!(query_include.matches(&entry_expired));
}
#[test]
fn test_memory_query_serialization() {
let mut query = MemoryQuery {
prefix: Some("agent:".to_string()),
metadata: HashMap::new(),
limit: Some(100),
include_expired: true,
};
query.metadata.insert("type".to_string(), "context".to_string());
let json = serde_json::to_string(&query).unwrap();
let deserialized: MemoryQuery = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.prefix, Some("agent:".to_string()));
assert_eq!(deserialized.limit, Some(100));
assert!(deserialized.include_expired);
}
}