1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::AofResult;
7
8#[async_trait]
12pub trait MemoryBackend: Send + Sync {
13 async fn store(&self, key: &str, entry: MemoryEntry) -> AofResult<()>;
15
16 async fn retrieve(&self, key: &str) -> AofResult<Option<MemoryEntry>>;
18
19 async fn delete(&self, key: &str) -> AofResult<()>;
21
22 async fn list_keys(&self, prefix: Option<&str>) -> AofResult<Vec<String>>;
24
25 async fn clear(&self) -> AofResult<()>;
27
28 async fn search(&self, query: &MemoryQuery) -> AofResult<Vec<MemoryEntry>> {
30 let keys = self.list_keys(query.prefix.as_deref()).await?;
32 let mut results = Vec::new();
33
34 for key in keys {
35 if let Some(entry) = self.retrieve(&key).await? {
36 if query.matches(&entry) {
37 results.push(entry);
38 }
39 }
40 }
41
42 Ok(results)
43 }
44}
45
46#[async_trait]
48pub trait Memory: Send + Sync {
49 async fn store(&self, key: &str, value: serde_json::Value) -> AofResult<()>;
51
52 async fn retrieve<T: serde::de::DeserializeOwned>(&self, key: &str) -> AofResult<Option<T>>;
54
55 async fn delete(&self, key: &str) -> AofResult<()>;
57
58 async fn list_keys(&self) -> AofResult<Vec<String>>;
60
61 async fn clear(&self) -> AofResult<()>;
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MemoryEntry {
68 pub key: String,
70
71 pub value: serde_json::Value,
73
74 pub timestamp: u64,
76
77 #[serde(default)]
79 pub metadata: HashMap<String, String>,
80
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub ttl: Option<u64>,
84}
85
86impl MemoryEntry {
87 pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
89 Self {
90 key: key.into(),
91 value,
92 timestamp: std::time::SystemTime::now()
93 .duration_since(std::time::UNIX_EPOCH)
94 .unwrap()
95 .as_millis() as u64,
96 metadata: HashMap::new(),
97 ttl: None,
98 }
99 }
100
101 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
103 self.metadata.insert(key.into(), value.into());
104 self
105 }
106
107 pub fn with_ttl(mut self, ttl_secs: u64) -> Self {
109 self.ttl = Some(ttl_secs);
110 self
111 }
112
113 pub fn is_expired(&self) -> bool {
115 if let Some(ttl) = self.ttl {
116 let now = std::time::SystemTime::now()
117 .duration_since(std::time::UNIX_EPOCH)
118 .unwrap()
119 .as_millis() as u64;
120 let expiry = self.timestamp + (ttl * 1000);
121 now > expiry
122 } else {
123 false
124 }
125 }
126}
127
128#[derive(Debug, Clone, Default, Serialize, Deserialize)]
130pub struct MemoryQuery {
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub prefix: Option<String>,
134
135 #[serde(default)]
137 pub metadata: HashMap<String, String>,
138
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub limit: Option<usize>,
142
143 #[serde(default)]
145 pub include_expired: bool,
146}
147
148impl MemoryQuery {
149 pub fn matches(&self, entry: &MemoryEntry) -> bool {
151 if !self.include_expired && entry.is_expired() {
153 return false;
154 }
155
156 for (key, value) in &self.metadata {
158 if entry.metadata.get(key) != Some(value) {
159 return false;
160 }
161 }
162
163 true
164 }
165}
166
167pub type MemoryBackendRef = Arc<dyn MemoryBackend>;
169
170pub type MemoryRef = Arc<dyn Memory>;
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_memory_entry_new() {
179 let entry = MemoryEntry::new("test-key", serde_json::json!({"value": 42}));
180
181 assert_eq!(entry.key, "test-key");
182 assert_eq!(entry.value, serde_json::json!({"value": 42}));
183 assert!(entry.timestamp > 0);
184 assert!(entry.metadata.is_empty());
185 assert!(entry.ttl.is_none());
186 }
187
188 #[test]
189 fn test_memory_entry_with_metadata() {
190 let entry = MemoryEntry::new("key", serde_json::json!(null))
191 .with_metadata("type", "session")
192 .with_metadata("agent", "test-agent");
193
194 assert_eq!(entry.metadata.get("type"), Some(&"session".to_string()));
195 assert_eq!(entry.metadata.get("agent"), Some(&"test-agent".to_string()));
196 }
197
198 #[test]
199 fn test_memory_entry_with_ttl() {
200 let entry = MemoryEntry::new("key", serde_json::json!(null))
201 .with_ttl(3600);
202
203 assert_eq!(entry.ttl, Some(3600));
204 }
205
206 #[test]
207 fn test_memory_entry_is_expired() {
208 let entry_no_ttl = MemoryEntry::new("key", serde_json::json!(null));
210 assert!(!entry_no_ttl.is_expired());
211
212 let entry_long_ttl = MemoryEntry::new("key", serde_json::json!(null))
214 .with_ttl(3600); assert!(!entry_long_ttl.is_expired());
216
217 let mut entry_expired = MemoryEntry::new("key", serde_json::json!(null))
219 .with_ttl(1); entry_expired.timestamp -= 2000;
222 assert!(entry_expired.is_expired());
223 }
224
225 #[test]
226 fn test_memory_entry_serialization() {
227 let entry = MemoryEntry::new("my-key", serde_json::json!({"data": "test"}))
228 .with_metadata("source", "api")
229 .with_ttl(60);
230
231 let json = serde_json::to_string(&entry).unwrap();
232 let deserialized: MemoryEntry = serde_json::from_str(&json).unwrap();
233
234 assert_eq!(deserialized.key, "my-key");
235 assert_eq!(deserialized.value, serde_json::json!({"data": "test"}));
236 assert_eq!(deserialized.metadata.get("source"), Some(&"api".to_string()));
237 assert_eq!(deserialized.ttl, Some(60));
238 }
239
240 #[test]
241 fn test_memory_query_default() {
242 let query = MemoryQuery::default();
243
244 assert!(query.prefix.is_none());
245 assert!(query.metadata.is_empty());
246 assert!(query.limit.is_none());
247 assert!(!query.include_expired);
248 }
249
250 #[test]
251 fn test_memory_query_matches_basic() {
252 let query = MemoryQuery::default();
253 let entry = MemoryEntry::new("key", serde_json::json!(null));
254
255 assert!(query.matches(&entry));
256 }
257
258 #[test]
259 fn test_memory_query_matches_metadata() {
260 let mut query = MemoryQuery::default();
261 query.metadata.insert("type".to_string(), "session".to_string());
262
263 let entry_no_meta = MemoryEntry::new("key", serde_json::json!(null));
265 assert!(!query.matches(&entry_no_meta));
266
267 let entry_match = MemoryEntry::new("key", serde_json::json!(null))
269 .with_metadata("type", "session");
270 assert!(query.matches(&entry_match));
271
272 let entry_wrong = MemoryEntry::new("key", serde_json::json!(null))
274 .with_metadata("type", "permanent");
275 assert!(!query.matches(&entry_wrong));
276 }
277
278 #[test]
279 fn test_memory_query_matches_expired() {
280 let mut entry_expired = MemoryEntry::new("key", serde_json::json!(null))
282 .with_ttl(1); entry_expired.timestamp -= 2000;
285
286 let query_default = MemoryQuery::default();
288 assert!(!query_default.matches(&entry_expired));
289
290 let query_include = MemoryQuery {
292 include_expired: true,
293 ..Default::default()
294 };
295 assert!(query_include.matches(&entry_expired));
296 }
297
298 #[test]
299 fn test_memory_query_serialization() {
300 let mut query = MemoryQuery {
301 prefix: Some("agent:".to_string()),
302 metadata: HashMap::new(),
303 limit: Some(100),
304 include_expired: true,
305 };
306 query.metadata.insert("type".to_string(), "context".to_string());
307
308 let json = serde_json::to_string(&query).unwrap();
309 let deserialized: MemoryQuery = serde_json::from_str(&json).unwrap();
310
311 assert_eq!(deserialized.prefix, Some("agent:".to_string()));
312 assert_eq!(deserialized.limit, Some(100));
313 assert!(deserialized.include_expired);
314 }
315}