ceylon_memory/
in_memory.rs

1use async_trait::async_trait;
2use ceylon_core::error::Result;
3use ceylon_core::memory::{Memory, MemoryEntry, MemoryQuery};
4use dashmap::DashMap;
5use std::sync::Arc;
6use tracing::trace;
7
8/// In-memory storage backend with TTL support
9pub struct InMemoryBackend {
10    store: Arc<DashMap<String, MemoryEntry>>,
11    max_entries: Option<usize>,
12    default_ttl_seconds: Option<i64>,
13}
14
15impl InMemoryBackend {
16    /// Create a new in-memory backend
17    pub fn new() -> Self {
18        Self {
19            store: Arc::new(DashMap::new()),
20            max_entries: None,
21            default_ttl_seconds: None,
22        }
23    }
24
25    /// Create with maximum entry limit
26    pub fn with_max_entries(mut self, max: usize) -> Self {
27        self.max_entries = Some(max);
28        self
29    }
30
31    /// Create with default TTL
32    pub fn with_ttl_seconds(mut self, seconds: i64) -> Self {
33        self.default_ttl_seconds = Some(seconds);
34        self
35    }
36
37    /// Check if entry is expired
38    fn is_expired(&self, entry: &MemoryEntry) -> bool {
39        entry.is_expired()
40    }
41
42    /// Remove all expired entries
43    fn cleanup_expired(&self) {
44        self.store.retain(|_, entry| !self.is_expired(entry));
45    }
46
47    /// Evict oldest entry to make room
48    fn evict_oldest(&self) {
49        if let Some(oldest) = self
50            .store
51            .iter()
52            .min_by_key(|e| e.value().created_at)
53            .map(|e| e.key().clone())
54        {
55            self.store.remove(&oldest);
56        }
57    }
58}
59
60impl Default for InMemoryBackend {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66#[async_trait]
67impl Memory for InMemoryBackend {
68    async fn store(&self, mut entry: MemoryEntry) -> Result<String> {
69        // Apply default TTL if not set
70        if entry.expires_at.is_none() {
71            if let Some(ttl) = self.default_ttl_seconds {
72                entry = entry.with_ttl_seconds(ttl);
73            }
74        }
75
76        // Cleanup expired entries first
77        self.cleanup_expired();
78
79        // Enforce max entries limit
80        if let Some(max) = self.max_entries {
81            while self.store.len() >= max {
82                self.evict_oldest();
83            }
84        }
85
86        let id = entry.id.clone();
87        self.store.insert(id.clone(), entry);
88        trace!(id = %id, "stored memory entry");
89        Ok(id)
90    }
91
92    async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
93        if let Some(entry) = self.store.get(id) {
94            if self.is_expired(&entry) {
95                // Remove expired entry
96                drop(entry); // Release read lock
97                self.store.remove(id);
98                trace!(id = %id, "memory miss (expired)");
99                Ok(None)
100            } else {
101                trace!(id = %id, "memory hit");
102                Ok(Some(entry.clone()))
103            }
104        } else {
105            trace!(id = %id, "memory miss");
106            Ok(None)
107        }
108    }
109
110    async fn search(&self, query: MemoryQuery) -> Result<Vec<MemoryEntry>> {
111        // Cleanup expired entries
112        self.cleanup_expired();
113
114        let mut results: Vec<MemoryEntry> = self
115            .store
116            .iter()
117            .map(|e| e.value().clone())
118            .filter(|entry| {
119                // Apply metadata filters - all filters must match
120                let matches_filters = query
121                    .filters
122                    .iter()
123                    .all(|(key, value)| entry.metadata.get(key) == Some(value));
124
125                // Apply semantic query (keyword search) if present
126                let matches_query = if let Some(ref q) = query.semantic_query {
127                    entry.content.to_lowercase().contains(&q.to_lowercase())
128                } else {
129                    true
130                };
131
132                matches_filters && matches_query
133            })
134            .collect();
135
136        // Sort by creation time (newest first) for consistent ordering
137        results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
138
139        // Apply limit
140        if let Some(limit) = query.limit {
141            results.truncate(limit);
142        }
143
144        Ok(results)
145    }
146
147    async fn delete(&self, id: &str) -> Result<bool> {
148        Ok(self.store.remove(id).is_some())
149    }
150
151    async fn clear(&self) -> Result<()> {
152        self.store.clear();
153        Ok(())
154    }
155
156    async fn count(&self) -> Result<usize> {
157        self.cleanup_expired();
158        Ok(self.store.len())
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use serde_json::json;
166
167    #[tokio::test]
168    async fn test_basic_storage() {
169        let backend = InMemoryBackend::new();
170
171        let entry = MemoryEntry::new("test content");
172        let id = backend.store(entry.clone()).await.unwrap();
173
174        let retrieved = backend.get(&id).await.unwrap();
175        assert!(retrieved.is_some());
176        assert_eq!(retrieved.unwrap().content, "test content");
177    }
178
179    #[tokio::test]
180    async fn test_ttl_expiration() {
181        let backend = InMemoryBackend::new();
182
183        // Create entry that expires in 1 second
184        let entry = MemoryEntry::new("expires soon").with_ttl_seconds(1);
185        let id = backend.store(entry).await.unwrap();
186
187        // Should exist immediately
188        assert!(backend.get(&id).await.unwrap().is_some());
189
190        // Wait for expiration
191        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
192
193        // Should be gone
194        assert!(backend.get(&id).await.unwrap().is_none());
195    }
196
197    #[tokio::test]
198    async fn test_metadata_search() {
199        let backend = InMemoryBackend::new();
200
201        let entry1 = MemoryEntry::new("user message")
202            .with_metadata("type", json!("user"))
203            .with_metadata("user_id", json!("123"));
204
205        let entry2 = MemoryEntry::new("system message").with_metadata("type", json!("system"));
206
207        backend.store(entry1).await.unwrap();
208        backend.store(entry2).await.unwrap();
209
210        // Search for user messages
211        let query = MemoryQuery::new().with_filter("type", json!("user"));
212        let results = backend.search(query).await.unwrap();
213
214        assert_eq!(results.len(), 1);
215        assert_eq!(results[0].content, "user message");
216    }
217
218    #[tokio::test]
219    async fn test_max_entries() {
220        let backend = InMemoryBackend::new().with_max_entries(2);
221
222        backend.store(MemoryEntry::new("first")).await.unwrap();
223        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
224
225        backend.store(MemoryEntry::new("second")).await.unwrap();
226        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
227
228        backend.store(MemoryEntry::new("third")).await.unwrap();
229
230        // Should only have 2 entries (oldest evicted)
231        assert_eq!(backend.count().await.unwrap(), 2);
232
233        // First entry should be evicted
234        let results = backend.search(MemoryQuery::new()).await.unwrap();
235        assert!(!results.iter().any(|e| e.content == "first"));
236    }
237
238    #[tokio::test]
239    async fn test_delete() {
240        let backend = InMemoryBackend::new();
241
242        let entry = MemoryEntry::new("to delete");
243        let id = backend.store(entry).await.unwrap();
244
245        assert!(backend.delete(&id).await.unwrap());
246        assert!(backend.get(&id).await.unwrap().is_none());
247        assert!(!backend.delete(&id).await.unwrap()); // Already deleted
248    }
249
250    #[tokio::test]
251    async fn test_clear() {
252        let backend = InMemoryBackend::new();
253
254        backend.store(MemoryEntry::new("entry 1")).await.unwrap();
255        backend.store(MemoryEntry::new("entry 2")).await.unwrap();
256
257        assert_eq!(backend.count().await.unwrap(), 2);
258
259        backend.clear().await.unwrap();
260        assert_eq!(backend.count().await.unwrap(), 0);
261    }
262}