aof_core/
memory.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::AofResult;
7
8/// Memory backend trait - pluggable persistence for agent state
9///
10/// Implementations should optimize for read performance and minimize allocations.
11#[async_trait]
12pub trait MemoryBackend: Send + Sync {
13    /// Store memory entry
14    async fn store(&self, key: &str, entry: MemoryEntry) -> AofResult<()>;
15
16    /// Retrieve memory entry
17    async fn retrieve(&self, key: &str) -> AofResult<Option<MemoryEntry>>;
18
19    /// Delete memory entry
20    async fn delete(&self, key: &str) -> AofResult<()>;
21
22    /// List keys (with optional prefix filter)
23    async fn list_keys(&self, prefix: Option<&str>) -> AofResult<Vec<String>>;
24
25    /// Clear all entries
26    async fn clear(&self) -> AofResult<()>;
27
28    /// Search entries by metadata
29    async fn search(&self, query: &MemoryQuery) -> AofResult<Vec<MemoryEntry>> {
30        // Default implementation: filter in-memory
31        let keys = self.list_keys(query.prefix.as_deref()).await?;
32        let mut results = Vec::new();
33
34        for key in keys {
35            if let Some(entry) = self.retrieve(&key).await? {
36                if query.matches(&entry) {
37                    results.push(entry);
38                }
39            }
40        }
41
42        Ok(results)
43    }
44}
45
46/// High-level memory interface for agents
47#[async_trait]
48pub trait Memory: Send + Sync {
49    /// Store a value
50    async fn store(&self, key: &str, value: serde_json::Value) -> AofResult<()>;
51
52    /// Retrieve a value
53    async fn retrieve<T: serde::de::DeserializeOwned>(&self, key: &str) -> AofResult<Option<T>>;
54
55    /// Delete a value
56    async fn delete(&self, key: &str) -> AofResult<()>;
57
58    /// List all keys
59    async fn list_keys(&self) -> AofResult<Vec<String>>;
60
61    /// Clear all memory
62    async fn clear(&self) -> AofResult<()>;
63}
64
65/// Memory entry with metadata
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MemoryEntry {
68    /// Entry key
69    pub key: String,
70
71    /// Entry value (JSON)
72    pub value: serde_json::Value,
73
74    /// Timestamp (Unix epoch ms)
75    pub timestamp: u64,
76
77    /// Entry metadata
78    #[serde(default)]
79    pub metadata: HashMap<String, String>,
80
81    /// TTL (seconds, optional)
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub ttl: Option<u64>,
84}
85
86impl MemoryEntry {
87    /// Create new entry
88    pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
89        Self {
90            key: key.into(),
91            value,
92            timestamp: std::time::SystemTime::now()
93                .duration_since(std::time::UNIX_EPOCH)
94                .unwrap()
95                .as_millis() as u64,
96            metadata: HashMap::new(),
97            ttl: None,
98        }
99    }
100
101    /// Add metadata
102    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
103        self.metadata.insert(key.into(), value.into());
104        self
105    }
106
107    /// Set TTL
108    pub fn with_ttl(mut self, ttl_secs: u64) -> Self {
109        self.ttl = Some(ttl_secs);
110        self
111    }
112
113    /// Check if entry is expired
114    pub fn is_expired(&self) -> bool {
115        if let Some(ttl) = self.ttl {
116            let now = std::time::SystemTime::now()
117                .duration_since(std::time::UNIX_EPOCH)
118                .unwrap()
119                .as_millis() as u64;
120            let expiry = self.timestamp + (ttl * 1000);
121            now > expiry
122        } else {
123            false
124        }
125    }
126}
127
128/// Memory query for searching
129#[derive(Debug, Clone, Default, Serialize, Deserialize)]
130pub struct MemoryQuery {
131    /// Key prefix filter
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub prefix: Option<String>,
134
135    /// Metadata filters (key-value pairs)
136    #[serde(default)]
137    pub metadata: HashMap<String, String>,
138
139    /// Limit results
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub limit: Option<usize>,
142
143    /// Include expired entries
144    #[serde(default)]
145    pub include_expired: bool,
146}
147
148impl MemoryQuery {
149    /// Check if entry matches query
150    pub fn matches(&self, entry: &MemoryEntry) -> bool {
151        // Check expiry
152        if !self.include_expired && entry.is_expired() {
153            return false;
154        }
155
156        // Check metadata filters
157        for (key, value) in &self.metadata {
158            if entry.metadata.get(key) != Some(value) {
159                return false;
160            }
161        }
162
163        true
164    }
165}
166
167/// Reference-counted memory backend
168pub type MemoryBackendRef = Arc<dyn MemoryBackend>;
169
170/// Reference-counted memory
171pub type MemoryRef = Arc<dyn Memory>;
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_memory_entry_new() {
179        let entry = MemoryEntry::new("test-key", serde_json::json!({"value": 42}));
180
181        assert_eq!(entry.key, "test-key");
182        assert_eq!(entry.value, serde_json::json!({"value": 42}));
183        assert!(entry.timestamp > 0);
184        assert!(entry.metadata.is_empty());
185        assert!(entry.ttl.is_none());
186    }
187
188    #[test]
189    fn test_memory_entry_with_metadata() {
190        let entry = MemoryEntry::new("key", serde_json::json!(null))
191            .with_metadata("type", "session")
192            .with_metadata("agent", "test-agent");
193
194        assert_eq!(entry.metadata.get("type"), Some(&"session".to_string()));
195        assert_eq!(entry.metadata.get("agent"), Some(&"test-agent".to_string()));
196    }
197
198    #[test]
199    fn test_memory_entry_with_ttl() {
200        let entry = MemoryEntry::new("key", serde_json::json!(null))
201            .with_ttl(3600);
202
203        assert_eq!(entry.ttl, Some(3600));
204    }
205
206    #[test]
207    fn test_memory_entry_is_expired() {
208        // Entry without TTL should never expire
209        let entry_no_ttl = MemoryEntry::new("key", serde_json::json!(null));
210        assert!(!entry_no_ttl.is_expired());
211
212        // Entry with very long TTL should not be expired
213        let entry_long_ttl = MemoryEntry::new("key", serde_json::json!(null))
214            .with_ttl(3600); // 1 hour
215        assert!(!entry_long_ttl.is_expired());
216
217        // Create an entry with timestamp in the past to test expiry
218        let mut entry_expired = MemoryEntry::new("key", serde_json::json!(null))
219            .with_ttl(1); // 1 second TTL
220        // Set timestamp to 2 seconds ago (in milliseconds)
221        entry_expired.timestamp -= 2000;
222        assert!(entry_expired.is_expired());
223    }
224
225    #[test]
226    fn test_memory_entry_serialization() {
227        let entry = MemoryEntry::new("my-key", serde_json::json!({"data": "test"}))
228            .with_metadata("source", "api")
229            .with_ttl(60);
230
231        let json = serde_json::to_string(&entry).unwrap();
232        let deserialized: MemoryEntry = serde_json::from_str(&json).unwrap();
233
234        assert_eq!(deserialized.key, "my-key");
235        assert_eq!(deserialized.value, serde_json::json!({"data": "test"}));
236        assert_eq!(deserialized.metadata.get("source"), Some(&"api".to_string()));
237        assert_eq!(deserialized.ttl, Some(60));
238    }
239
240    #[test]
241    fn test_memory_query_default() {
242        let query = MemoryQuery::default();
243
244        assert!(query.prefix.is_none());
245        assert!(query.metadata.is_empty());
246        assert!(query.limit.is_none());
247        assert!(!query.include_expired);
248    }
249
250    #[test]
251    fn test_memory_query_matches_basic() {
252        let query = MemoryQuery::default();
253        let entry = MemoryEntry::new("key", serde_json::json!(null));
254
255        assert!(query.matches(&entry));
256    }
257
258    #[test]
259    fn test_memory_query_matches_metadata() {
260        let mut query = MemoryQuery::default();
261        query.metadata.insert("type".to_string(), "session".to_string());
262
263        // Entry without metadata shouldn't match
264        let entry_no_meta = MemoryEntry::new("key", serde_json::json!(null));
265        assert!(!query.matches(&entry_no_meta));
266
267        // Entry with matching metadata should match
268        let entry_match = MemoryEntry::new("key", serde_json::json!(null))
269            .with_metadata("type", "session");
270        assert!(query.matches(&entry_match));
271
272        // Entry with wrong metadata value shouldn't match
273        let entry_wrong = MemoryEntry::new("key", serde_json::json!(null))
274            .with_metadata("type", "permanent");
275        assert!(!query.matches(&entry_wrong));
276    }
277
278    #[test]
279    fn test_memory_query_matches_expired() {
280        // Create an entry with timestamp in the past to test expiry
281        let mut entry_expired = MemoryEntry::new("key", serde_json::json!(null))
282            .with_ttl(1); // 1 second TTL
283        // Set timestamp to 2 seconds ago (in milliseconds) to ensure it's expired
284        entry_expired.timestamp -= 2000;
285
286        // Default query excludes expired entries
287        let query_default = MemoryQuery::default();
288        assert!(!query_default.matches(&entry_expired));
289
290        // Query that includes expired entries
291        let query_include = MemoryQuery {
292            include_expired: true,
293            ..Default::default()
294        };
295        assert!(query_include.matches(&entry_expired));
296    }
297
298    #[test]
299    fn test_memory_query_serialization() {
300        let mut query = MemoryQuery {
301            prefix: Some("agent:".to_string()),
302            metadata: HashMap::new(),
303            limit: Some(100),
304            include_expired: true,
305        };
306        query.metadata.insert("type".to_string(), "context".to_string());
307
308        let json = serde_json::to_string(&query).unwrap();
309        let deserialized: MemoryQuery = serde_json::from_str(&json).unwrap();
310
311        assert_eq!(deserialized.prefix, Some("agent:".to_string()));
312        assert_eq!(deserialized.limit, Some(100));
313        assert!(deserialized.include_expired);
314    }
315}