ceylon_memory/
in_memory.rs1use async_trait::async_trait;
2use ceylon_core::error::Result;
3use ceylon_core::memory::{Memory, MemoryEntry, MemoryQuery};
4use dashmap::DashMap;
5use std::sync::Arc;
6use tracing::trace;
7
8pub struct InMemoryBackend {
10 store: Arc<DashMap<String, MemoryEntry>>,
11 max_entries: Option<usize>,
12 default_ttl_seconds: Option<i64>,
13}
14
15impl InMemoryBackend {
16 pub fn new() -> Self {
18 Self {
19 store: Arc::new(DashMap::new()),
20 max_entries: None,
21 default_ttl_seconds: None,
22 }
23 }
24
25 pub fn with_max_entries(mut self, max: usize) -> Self {
27 self.max_entries = Some(max);
28 self
29 }
30
31 pub fn with_ttl_seconds(mut self, seconds: i64) -> Self {
33 self.default_ttl_seconds = Some(seconds);
34 self
35 }
36
37 fn is_expired(&self, entry: &MemoryEntry) -> bool {
39 entry.is_expired()
40 }
41
42 fn cleanup_expired(&self) {
44 self.store.retain(|_, entry| !self.is_expired(entry));
45 }
46
47 fn evict_oldest(&self) {
49 if let Some(oldest) = self
50 .store
51 .iter()
52 .min_by_key(|e| e.value().created_at)
53 .map(|e| e.key().clone())
54 {
55 self.store.remove(&oldest);
56 }
57 }
58}
59
60impl Default for InMemoryBackend {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66#[async_trait]
67impl Memory for InMemoryBackend {
68 async fn store(&self, mut entry: MemoryEntry) -> Result<String> {
69 if entry.expires_at.is_none() {
71 if let Some(ttl) = self.default_ttl_seconds {
72 entry = entry.with_ttl_seconds(ttl);
73 }
74 }
75
76 self.cleanup_expired();
78
79 if let Some(max) = self.max_entries {
81 while self.store.len() >= max {
82 self.evict_oldest();
83 }
84 }
85
86 let id = entry.id.clone();
87 self.store.insert(id.clone(), entry);
88 trace!(id = %id, "stored memory entry");
89 Ok(id)
90 }
91
92 async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
93 if let Some(entry) = self.store.get(id) {
94 if self.is_expired(&entry) {
95 drop(entry); self.store.remove(id);
98 trace!(id = %id, "memory miss (expired)");
99 Ok(None)
100 } else {
101 trace!(id = %id, "memory hit");
102 Ok(Some(entry.clone()))
103 }
104 } else {
105 trace!(id = %id, "memory miss");
106 Ok(None)
107 }
108 }
109
110 async fn search(&self, query: MemoryQuery) -> Result<Vec<MemoryEntry>> {
111 self.cleanup_expired();
113
114 let mut results: Vec<MemoryEntry> = self
115 .store
116 .iter()
117 .map(|e| e.value().clone())
118 .filter(|entry| {
119 let matches_filters = query
121 .filters
122 .iter()
123 .all(|(key, value)| entry.metadata.get(key) == Some(value));
124
125 let matches_query = if let Some(ref q) = query.semantic_query {
127 entry.content.to_lowercase().contains(&q.to_lowercase())
128 } else {
129 true
130 };
131
132 matches_filters && matches_query
133 })
134 .collect();
135
136 results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
138
139 if let Some(limit) = query.limit {
141 results.truncate(limit);
142 }
143
144 Ok(results)
145 }
146
147 async fn delete(&self, id: &str) -> Result<bool> {
148 Ok(self.store.remove(id).is_some())
149 }
150
151 async fn clear(&self) -> Result<()> {
152 self.store.clear();
153 Ok(())
154 }
155
156 async fn count(&self) -> Result<usize> {
157 self.cleanup_expired();
158 Ok(self.store.len())
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use serde_json::json;
166
167 #[tokio::test]
168 async fn test_basic_storage() {
169 let backend = InMemoryBackend::new();
170
171 let entry = MemoryEntry::new("test content");
172 let id = backend.store(entry.clone()).await.unwrap();
173
174 let retrieved = backend.get(&id).await.unwrap();
175 assert!(retrieved.is_some());
176 assert_eq!(retrieved.unwrap().content, "test content");
177 }
178
179 #[tokio::test]
180 async fn test_ttl_expiration() {
181 let backend = InMemoryBackend::new();
182
183 let entry = MemoryEntry::new("expires soon").with_ttl_seconds(1);
185 let id = backend.store(entry).await.unwrap();
186
187 assert!(backend.get(&id).await.unwrap().is_some());
189
190 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
192
193 assert!(backend.get(&id).await.unwrap().is_none());
195 }
196
197 #[tokio::test]
198 async fn test_metadata_search() {
199 let backend = InMemoryBackend::new();
200
201 let entry1 = MemoryEntry::new("user message")
202 .with_metadata("type", json!("user"))
203 .with_metadata("user_id", json!("123"));
204
205 let entry2 = MemoryEntry::new("system message").with_metadata("type", json!("system"));
206
207 backend.store(entry1).await.unwrap();
208 backend.store(entry2).await.unwrap();
209
210 let query = MemoryQuery::new().with_filter("type", json!("user"));
212 let results = backend.search(query).await.unwrap();
213
214 assert_eq!(results.len(), 1);
215 assert_eq!(results[0].content, "user message");
216 }
217
218 #[tokio::test]
219 async fn test_max_entries() {
220 let backend = InMemoryBackend::new().with_max_entries(2);
221
222 backend.store(MemoryEntry::new("first")).await.unwrap();
223 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
224
225 backend.store(MemoryEntry::new("second")).await.unwrap();
226 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
227
228 backend.store(MemoryEntry::new("third")).await.unwrap();
229
230 assert_eq!(backend.count().await.unwrap(), 2);
232
233 let results = backend.search(MemoryQuery::new()).await.unwrap();
235 assert!(!results.iter().any(|e| e.content == "first"));
236 }
237
238 #[tokio::test]
239 async fn test_delete() {
240 let backend = InMemoryBackend::new();
241
242 let entry = MemoryEntry::new("to delete");
243 let id = backend.store(entry).await.unwrap();
244
245 assert!(backend.delete(&id).await.unwrap());
246 assert!(backend.get(&id).await.unwrap().is_none());
247 assert!(!backend.delete(&id).await.unwrap()); }
249
250 #[tokio::test]
251 async fn test_clear() {
252 let backend = InMemoryBackend::new();
253
254 backend.store(MemoryEntry::new("entry 1")).await.unwrap();
255 backend.store(MemoryEntry::new("entry 2")).await.unwrap();
256
257 assert_eq!(backend.count().await.unwrap(), 2);
258
259 backend.clear().await.unwrap();
260 assert_eq!(backend.count().await.unwrap(), 0);
261 }
262}