ceylon_runtime/memory/backends/
in_memory.rs

1use crate::core::error::Result;
2use crate::core::memory::{Memory, MemoryEntry, MemoryQuery};
3use async_trait::async_trait;
4use dashmap::DashMap;
5use std::sync::Arc;
6
7/// In-memory storage backend with TTL support
8pub struct InMemoryBackend {
9    store: Arc<DashMap<String, MemoryEntry>>,
10    max_entries: Option<usize>,
11    default_ttl_seconds: Option<i64>,
12}
13
14impl InMemoryBackend {
15    /// Create a new in-memory backend
16    pub fn new() -> Self {
17        Self {
18            store: Arc::new(DashMap::new()),
19            max_entries: None,
20            default_ttl_seconds: None,
21        }
22    }
23
24    /// Create with maximum entry limit
25    pub fn with_max_entries(mut self, max: usize) -> Self {
26        self.max_entries = Some(max);
27        self
28    }
29
30    /// Create with default TTL
31    pub fn with_ttl_seconds(mut self, seconds: i64) -> Self {
32        self.default_ttl_seconds = Some(seconds);
33        self
34    }
35
36    /// Check if entry is expired
37    fn is_expired(&self, entry: &MemoryEntry) -> bool {
38        entry.is_expired()
39    }
40
41    /// Remove all expired entries
42    fn cleanup_expired(&self) {
43        self.store.retain(|_, entry| !self.is_expired(entry));
44    }
45
46    /// Evict oldest entry to make room
47    fn evict_oldest(&self) {
48        if let Some(oldest) = self
49            .store
50            .iter()
51            .min_by_key(|e| e.value().created_at)
52            .map(|e| e.key().clone())
53        {
54            self.store.remove(&oldest);
55        }
56    }
57}
58
59impl Default for InMemoryBackend {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65#[async_trait]
66impl Memory for InMemoryBackend {
67    async fn store(&self, mut entry: MemoryEntry) -> Result<String> {
68        // Apply default TTL if not set
69        if entry.expires_at.is_none() {
70            if let Some(ttl) = self.default_ttl_seconds {
71                entry = entry.with_ttl_seconds(ttl);
72            }
73        }
74
75        // Cleanup expired entries first
76        self.cleanup_expired();
77
78        // Enforce max entries limit
79        if let Some(max) = self.max_entries {
80            while self.store.len() >= max {
81                self.evict_oldest();
82            }
83        }
84
85        let id = entry.id.clone();
86        self.store.insert(id.clone(), entry);
87        crate::metrics::metrics().record_memory_write();
88        Ok(id)
89    }
90
91    async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
92        if let Some(entry) = self.store.get(id) {
93            if self.is_expired(&entry) {
94                // Remove expired entry
95                drop(entry); // Release read lock
96                self.store.remove(id);
97                crate::metrics::metrics().record_memory_miss();
98                Ok(None)
99            } else {
100                crate::metrics::metrics().record_memory_hit();
101                Ok(Some(entry.clone()))
102            }
103        } else {
104            crate::metrics::metrics().record_memory_miss();
105            Ok(None)
106        }
107    }
108
109    async fn search(&self, query: MemoryQuery) -> Result<Vec<MemoryEntry>> {
110        // Cleanup expired entries
111        self.cleanup_expired();
112
113        let mut results: Vec<MemoryEntry> = self
114            .store
115            .iter()
116            .map(|e| e.value().clone())
117            .filter(|entry| {
118                // Apply metadata filters - all filters must match
119                let matches_filters = query
120                    .filters
121                    .iter()
122                    .all(|(key, value)| entry.metadata.get(key) == Some(value));
123
124                // Apply semantic query (keyword search) if present
125                let matches_query = if let Some(ref q) = query.semantic_query {
126                    entry.content.to_lowercase().contains(&q.to_lowercase())
127                } else {
128                    true
129                };
130
131                matches_filters && matches_query
132            })
133            .collect();
134
135        // Sort by creation time (newest first) for consistent ordering
136        results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
137
138        // Apply limit
139        if let Some(limit) = query.limit {
140            results.truncate(limit);
141        }
142
143        Ok(results)
144    }
145
146    async fn delete(&self, id: &str) -> Result<bool> {
147        Ok(self.store.remove(id).is_some())
148    }
149
150    async fn clear(&self) -> Result<()> {
151        self.store.clear();
152        Ok(())
153    }
154
155    async fn count(&self) -> Result<usize> {
156        self.cleanup_expired();
157        Ok(self.store.len())
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use serde_json::json;
165
166    #[tokio::test]
167    async fn test_basic_storage() {
168        let backend = InMemoryBackend::new();
169
170        let entry = MemoryEntry::new("test content");
171        let id = backend.store(entry.clone()).await.unwrap();
172
173        let retrieved = backend.get(&id).await.unwrap();
174        assert!(retrieved.is_some());
175        assert_eq!(retrieved.unwrap().content, "test content");
176    }
177
178    #[tokio::test]
179    async fn test_ttl_expiration() {
180        let backend = InMemoryBackend::new();
181
182        // Create entry that expires in 1 second
183        let entry = MemoryEntry::new("expires soon").with_ttl_seconds(1);
184        let id = backend.store(entry).await.unwrap();
185
186        // Should exist immediately
187        assert!(backend.get(&id).await.unwrap().is_some());
188
189        // Wait for expiration
190        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
191
192        // Should be gone
193        assert!(backend.get(&id).await.unwrap().is_none());
194    }
195
196    #[tokio::test]
197    async fn test_metadata_search() {
198        let backend = InMemoryBackend::new();
199
200        let entry1 = MemoryEntry::new("user message")
201            .with_metadata("type", json!("user"))
202            .with_metadata("user_id", json!("123"));
203
204        let entry2 = MemoryEntry::new("system message").with_metadata("type", json!("system"));
205
206        backend.store(entry1).await.unwrap();
207        backend.store(entry2).await.unwrap();
208
209        // Search for user messages
210        let query = MemoryQuery::new().with_filter("type", json!("user"));
211        let results = backend.search(query).await.unwrap();
212
213        assert_eq!(results.len(), 1);
214        assert_eq!(results[0].content, "user message");
215    }
216
217    #[tokio::test]
218    async fn test_max_entries() {
219        let backend = InMemoryBackend::new().with_max_entries(2);
220
221        backend.store(MemoryEntry::new("first")).await.unwrap();
222        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
223
224        backend.store(MemoryEntry::new("second")).await.unwrap();
225        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
226
227        backend.store(MemoryEntry::new("third")).await.unwrap();
228
229        // Should only have 2 entries (oldest evicted)
230        assert_eq!(backend.count().await.unwrap(), 2);
231
232        // First entry should be evicted
233        let results = backend.search(MemoryQuery::new()).await.unwrap();
234        assert!(!results.iter().any(|e| e.content == "first"));
235    }
236
237    #[tokio::test]
238    async fn test_delete() {
239        let backend = InMemoryBackend::new();
240
241        let entry = MemoryEntry::new("to delete");
242        let id = backend.store(entry).await.unwrap();
243
244        assert!(backend.delete(&id).await.unwrap());
245        assert!(backend.get(&id).await.unwrap().is_none());
246        assert!(!backend.delete(&id).await.unwrap()); // Already deleted
247    }
248
249    #[tokio::test]
250    async fn test_clear() {
251        let backend = InMemoryBackend::new();
252
253        backend.store(MemoryEntry::new("entry 1")).await.unwrap();
254        backend.store(MemoryEntry::new("entry 2")).await.unwrap();
255
256        assert_eq!(backend.count().await.unwrap(), 2);
257
258        backend.clear().await.unwrap();
259        assert_eq!(backend.count().await.unwrap(), 0);
260    }
261}