mod in_memory;
mod vertex_ai_memory_bank;
mod vertex_ai_rag;
pub use in_memory::InMemoryMemoryService;
pub use vertex_ai_memory_bank::{VertexAiMemoryBankConfig, VertexAiMemoryBankService};
pub use vertex_ai_rag::{VertexAiRagMemoryConfig, VertexAiRagMemoryService};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub key: String,
pub value: serde_json::Value,
pub created_at: u64,
pub updated_at: u64,
}
impl MemoryEntry {
pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
let now = now_secs();
Self {
key: key.into(),
value,
created_at: now,
updated_at: now,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum MemoryError {
#[error("Memory key not found: {0}")]
NotFound(String),
#[error("Storage error: {0}")]
Storage(String),
}
#[async_trait]
pub trait MemoryService: Send + Sync {
async fn store(&self, session_id: &str, entry: MemoryEntry) -> Result<(), MemoryError>;
async fn get(&self, session_id: &str, key: &str) -> Result<Option<MemoryEntry>, MemoryError>;
async fn list(&self, session_id: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
async fn search(&self, session_id: &str, query: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
async fn delete(&self, session_id: &str, key: &str) -> Result<(), MemoryError>;
async fn clear(&self, session_id: &str) -> Result<(), MemoryError>;
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_entry_new() {
let entry = MemoryEntry::new("topic", serde_json::json!("Rust"));
assert_eq!(entry.key, "topic");
assert_eq!(entry.value, serde_json::json!("Rust"));
assert!(entry.created_at > 0);
}
#[test]
fn memory_service_is_object_safe() {
fn _assert(_: &dyn MemoryService) {}
}
#[tokio::test]
async fn store_and_get() {
let svc = InMemoryMemoryService::new();
let entry = MemoryEntry::new("topic", serde_json::json!("AI"));
svc.store("s1", entry).await.unwrap();
let fetched = svc.get("s1", "topic").await.unwrap();
assert!(fetched.is_some());
assert_eq!(fetched.unwrap().value, serde_json::json!("AI"));
}
#[tokio::test]
async fn get_nonexistent_returns_none() {
let svc = InMemoryMemoryService::new();
let fetched = svc.get("s1", "missing").await.unwrap();
assert!(fetched.is_none());
}
#[tokio::test]
async fn list_entries() {
let svc = InMemoryMemoryService::new();
svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
.await
.unwrap();
svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
.await
.unwrap();
svc.store("s2", MemoryEntry::new("c", serde_json::json!(3)))
.await
.unwrap();
let entries = svc.list("s1").await.unwrap();
assert_eq!(entries.len(), 2);
}
#[tokio::test]
async fn search_entries() {
let svc = InMemoryMemoryService::new();
svc.store(
"s1",
MemoryEntry::new("rust_topic", serde_json::json!("Rust programming")),
)
.await
.unwrap();
svc.store(
"s1",
MemoryEntry::new("python_topic", serde_json::json!("Python scripting")),
)
.await
.unwrap();
let results = svc.search("s1", "rust").await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "rust_topic");
}
#[tokio::test]
async fn delete_entry() {
let svc = InMemoryMemoryService::new();
svc.store("s1", MemoryEntry::new("k", serde_json::json!(1)))
.await
.unwrap();
svc.delete("s1", "k").await.unwrap();
let fetched = svc.get("s1", "k").await.unwrap();
assert!(fetched.is_none());
}
#[tokio::test]
async fn clear_session() {
let svc = InMemoryMemoryService::new();
svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
.await
.unwrap();
svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
.await
.unwrap();
svc.clear("s1").await.unwrap();
let entries = svc.list("s1").await.unwrap();
assert!(entries.is_empty());
}
}