ceylon_memory/
in_memory.rs

1use async_trait::async_trait;
2use ceylon_core::error::Result;
3use ceylon_core::memory::{Memory, MemoryEntry, MemoryQuery};
4use dashmap::DashMap;
5use std::sync::Arc;
6use tracing::trace;
7
8/// In-memory storage backend with TTL support
9pub struct InMemoryBackend {
10    store: Arc<DashMap<String, MemoryEntry>>,
11    max_entries: Option<usize>,
12    default_ttl_seconds: Option<i64>,
13}
14
15impl InMemoryBackend {
16    /// Create a new in-memory backend
17    pub fn new() -> Self {
18        Self {
19            store: Arc::new(DashMap::new()),
20            max_entries: None,
21            default_ttl_seconds: None,
22        }
23    }
24
25    /// Create with maximum entry limit
26    pub fn with_max_entries(mut self, max: usize) -> Self {
27        self.max_entries = Some(max);
28        self
29    }
30
31    /// Create with default TTL
32    pub fn with_ttl_seconds(mut self, seconds: i64) -> Self {
33        self.default_ttl_seconds = Some(seconds);
34        self
35    }
36
37    /// Check if entry is expired
38    fn is_expired(&self, entry: &MemoryEntry) -> bool {
39        entry.is_expired()
40    }
41
42    /// Remove all expired entries
43    fn cleanup_expired(&self) {
44        self.store.retain(|_, entry| !self.is_expired(entry));
45    }
46
47    /// Evict oldest entry to make room
48    fn evict_oldest(&self) {
49        if let Some(oldest) = self
50            .store
51            .iter()
52            .min_by_key(|e| e.value().created_at)
53            .map(|e| e.key().clone())
54        {
55            self.store.remove(&oldest);
56        }
57    }
58}
59
60impl Default for InMemoryBackend {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66#[async_trait]
67impl Memory for InMemoryBackend {
68    async fn store(&self, mut entry: MemoryEntry) -> Result<String> {
69        // Apply default TTL if not set
70        if entry.expires_at.is_none() {
71            if let Some(ttl) = self.default_ttl_seconds {
72                entry = entry.with_ttl_seconds(ttl);
73            }
74        }
75
76        // Cleanup expired entries first
77        self.cleanup_expired();
78
79        // Enforce max entries limit
80        if let Some(max) = self.max_entries {
81            while self.store.len() >= max {
82                self.evict_oldest();
83            }
84        }
85
86        let id = entry.id.clone();
87        self.store.insert(id.clone(), entry);
88        trace!(id = %id, "stored memory entry");
89        Ok(id)
90    }
91
92    async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
93        if let Some(entry) = self.store.get(id) {
94            if self.is_expired(&entry) {
95                // Remove expired entry
96                drop(entry); // Release read lock
97                self.store.remove(id);
98                trace!(id = %id, "memory miss (expired)");
99                Ok(None)
100            } else {
101                trace!(id = %id, "memory hit");
102                Ok(Some(entry.clone()))
103            }
104        } else {
105            trace!(id = %id, "memory miss");
106            Ok(None)
107        }
108    }
109
110    async fn search(&self, query: MemoryQuery) -> Result<Vec<MemoryEntry>> {
111        // Cleanup expired entries
112        self.cleanup_expired();
113
114        let mut results: Vec<MemoryEntry> = self
115            .store
116            .iter()
117            .map(|e| e.value().clone())
118            .filter(|entry| {
119                // Apply metadata filters - all filters must match
120                let matches_filters = query
121                    .filters
122                    .iter()
123                    .all(|(key, value)| entry.metadata.get(key) == Some(value));
124
125                // Apply semantic query (keyword search) if present
126                // This performs keyword-based search by splitting the query into words
127                // and matching if any significant keyword is found in the content
128                let matches_query = if let Some(ref q) = query.semantic_query {
129                    let content_lower = entry.content.to_lowercase();
130                    let query_lower = q.to_lowercase();
131
132                    // Stop words to filter out (common words that don't add meaning)
133                    let stop_words: std::collections::HashSet<&str> = [
134                        "a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
135                        "have", "has", "had", "do", "does", "did", "will", "would", "could",
136                        "should", "may", "might", "can", "what", "which", "who", "whom", "this",
137                        "that", "these", "those", "am", "of", "for", "with", "at", "by", "from",
138                        "to", "in", "on", "how", "does", "it", "its", "and", "or", "but", "not",
139                        "each", "other", "some", "any", "no", "all", "most", "more",
140                    ]
141                    .into_iter()
142                    .collect();
143
144                    // Extract keywords from query (words with 2+ chars, not stop words)
145                    let keywords: Vec<&str> = query_lower
146                        .split(|c: char| !c.is_alphanumeric())
147                        .filter(|word| word.len() >= 2 && !stop_words.contains(word))
148                        .collect();
149
150                    // Match if any keyword is found in the content
151                    keywords
152                        .iter()
153                        .any(|keyword| content_lower.contains(keyword))
154                } else {
155                    true
156                };
157
158                matches_filters && matches_query
159            })
160            .collect();
161
162        // Sort by creation time (newest first) for consistent ordering
163        results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
164
165        // Apply limit
166        if let Some(limit) = query.limit {
167            results.truncate(limit);
168        }
169
170        Ok(results)
171    }
172
173    async fn delete(&self, id: &str) -> Result<bool> {
174        Ok(self.store.remove(id).is_some())
175    }
176
177    async fn clear(&self) -> Result<()> {
178        self.store.clear();
179        Ok(())
180    }
181
182    async fn count(&self) -> Result<usize> {
183        self.cleanup_expired();
184        Ok(self.store.len())
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use serde_json::json;
192
193    #[tokio::test]
194    async fn test_basic_storage() {
195        let backend = InMemoryBackend::new();
196
197        let entry = MemoryEntry::new("test content");
198        let id = backend.store(entry.clone()).await.unwrap();
199
200        let retrieved = backend.get(&id).await.unwrap();
201        assert!(retrieved.is_some());
202        assert_eq!(retrieved.unwrap().content, "test content");
203    }
204
205    #[tokio::test]
206    async fn test_ttl_expiration() {
207        let backend = InMemoryBackend::new();
208
209        // Create entry that expires in 1 second
210        let entry = MemoryEntry::new("expires soon").with_ttl_seconds(1);
211        let id = backend.store(entry).await.unwrap();
212
213        // Should exist immediately
214        assert!(backend.get(&id).await.unwrap().is_some());
215
216        // Wait for expiration
217        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
218
219        // Should be gone
220        assert!(backend.get(&id).await.unwrap().is_none());
221    }
222
223    #[tokio::test]
224    async fn test_metadata_search() {
225        let backend = InMemoryBackend::new();
226
227        let entry1 = MemoryEntry::new("user message")
228            .with_metadata("type", json!("user"))
229            .with_metadata("user_id", json!("123"));
230
231        let entry2 = MemoryEntry::new("system message").with_metadata("type", json!("system"));
232
233        backend.store(entry1).await.unwrap();
234        backend.store(entry2).await.unwrap();
235
236        // Search for user messages
237        let query = MemoryQuery::new().with_filter("type", json!("user"));
238        let results = backend.search(query).await.unwrap();
239
240        assert_eq!(results.len(), 1);
241        assert_eq!(results[0].content, "user message");
242    }
243
244    #[tokio::test]
245    async fn test_max_entries() {
246        let backend = InMemoryBackend::new().with_max_entries(2);
247
248        backend.store(MemoryEntry::new("first")).await.unwrap();
249        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
250
251        backend.store(MemoryEntry::new("second")).await.unwrap();
252        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
253
254        backend.store(MemoryEntry::new("third")).await.unwrap();
255
256        // Should only have 2 entries (oldest evicted)
257        assert_eq!(backend.count().await.unwrap(), 2);
258
259        // First entry should be evicted
260        let results = backend.search(MemoryQuery::new()).await.unwrap();
261        assert!(!results.iter().any(|e| e.content == "first"));
262    }
263
264    #[tokio::test]
265    async fn test_delete() {
266        let backend = InMemoryBackend::new();
267
268        let entry = MemoryEntry::new("to delete");
269        let id = backend.store(entry).await.unwrap();
270
271        assert!(backend.delete(&id).await.unwrap());
272        assert!(backend.get(&id).await.unwrap().is_none());
273        assert!(!backend.delete(&id).await.unwrap()); // Already deleted
274    }
275
276    #[tokio::test]
277    async fn test_clear() {
278        let backend = InMemoryBackend::new();
279
280        backend.store(MemoryEntry::new("entry 1")).await.unwrap();
281        backend.store(MemoryEntry::new("entry 2")).await.unwrap();
282
283        assert_eq!(backend.count().await.unwrap(), 2);
284
285        backend.clear().await.unwrap();
286        assert_eq!(backend.count().await.unwrap(), 0);
287    }
288}