1mod in_memory;
7mod vertex_ai_memory_bank;
8mod vertex_ai_rag;
9
10pub use in_memory::InMemoryMemoryService;
11pub use vertex_ai_memory_bank::{VertexAiMemoryBankConfig, VertexAiMemoryBankService};
12pub use vertex_ai_rag::{VertexAiRagMemoryConfig, VertexAiRagMemoryService};
13
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct MemoryEntry {
20 pub key: String,
22 pub value: serde_json::Value,
24 pub created_at: u64,
26 pub updated_at: u64,
28}
29
30impl MemoryEntry {
31 pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
33 let now = now_secs();
34 Self {
35 key: key.into(),
36 value,
37 created_at: now,
38 updated_at: now,
39 }
40 }
41}
42
43#[derive(Debug, thiserror::Error)]
45pub enum MemoryError {
46 #[error("Memory key not found: {0}")]
48 NotFound(String),
49 #[error("Storage error: {0}")]
51 Storage(String),
52}
53
54#[async_trait]
58pub trait MemoryService: Send + Sync {
59 async fn store(&self, session_id: &str, entry: MemoryEntry) -> Result<(), MemoryError>;
61
62 async fn get(&self, session_id: &str, key: &str) -> Result<Option<MemoryEntry>, MemoryError>;
64
65 async fn list(&self, session_id: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
67
68 async fn search(&self, session_id: &str, query: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
70
71 async fn delete(&self, session_id: &str, key: &str) -> Result<(), MemoryError>;
73
74 async fn clear(&self, session_id: &str) -> Result<(), MemoryError>;
76}
77
78fn now_secs() -> u64 {
79 std::time::SystemTime::now()
80 .duration_since(std::time::UNIX_EPOCH)
81 .unwrap_or_default()
82 .as_secs()
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn memory_entry_new() {
91 let entry = MemoryEntry::new("topic", serde_json::json!("Rust"));
92 assert_eq!(entry.key, "topic");
93 assert_eq!(entry.value, serde_json::json!("Rust"));
94 assert!(entry.created_at > 0);
95 }
96
97 #[test]
98 fn memory_service_is_object_safe() {
99 fn _assert(_: &dyn MemoryService) {}
100 }
101
102 #[tokio::test]
103 async fn store_and_get() {
104 let svc = InMemoryMemoryService::new();
105 let entry = MemoryEntry::new("topic", serde_json::json!("AI"));
106 svc.store("s1", entry).await.unwrap();
107
108 let fetched = svc.get("s1", "topic").await.unwrap();
109 assert!(fetched.is_some());
110 assert_eq!(fetched.unwrap().value, serde_json::json!("AI"));
111 }
112
113 #[tokio::test]
114 async fn get_nonexistent_returns_none() {
115 let svc = InMemoryMemoryService::new();
116 let fetched = svc.get("s1", "missing").await.unwrap();
117 assert!(fetched.is_none());
118 }
119
120 #[tokio::test]
121 async fn list_entries() {
122 let svc = InMemoryMemoryService::new();
123 svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
124 .await
125 .unwrap();
126 svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
127 .await
128 .unwrap();
129 svc.store("s2", MemoryEntry::new("c", serde_json::json!(3)))
130 .await
131 .unwrap();
132
133 let entries = svc.list("s1").await.unwrap();
134 assert_eq!(entries.len(), 2);
135 }
136
137 #[tokio::test]
138 async fn search_entries() {
139 let svc = InMemoryMemoryService::new();
140 svc.store(
141 "s1",
142 MemoryEntry::new("rust_topic", serde_json::json!("Rust programming")),
143 )
144 .await
145 .unwrap();
146 svc.store(
147 "s1",
148 MemoryEntry::new("python_topic", serde_json::json!("Python scripting")),
149 )
150 .await
151 .unwrap();
152
153 let results = svc.search("s1", "rust").await.unwrap();
154 assert_eq!(results.len(), 1);
155 assert_eq!(results[0].key, "rust_topic");
156 }
157
158 #[tokio::test]
159 async fn delete_entry() {
160 let svc = InMemoryMemoryService::new();
161 svc.store("s1", MemoryEntry::new("k", serde_json::json!(1)))
162 .await
163 .unwrap();
164 svc.delete("s1", "k").await.unwrap();
165 let fetched = svc.get("s1", "k").await.unwrap();
166 assert!(fetched.is_none());
167 }
168
169 #[tokio::test]
170 async fn clear_session() {
171 let svc = InMemoryMemoryService::new();
172 svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
173 .await
174 .unwrap();
175 svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
176 .await
177 .unwrap();
178 svc.clear("s1").await.unwrap();
179 let entries = svc.list("s1").await.unwrap();
180 assert!(entries.is_empty());
181 }
182}