1use async_trait::async_trait;
2use enki_core::error::Result;
3use enki_core::memory::{Memory, MemoryEntry, MemoryQuery};
4use dashmap::DashMap;
5use std::sync::Arc;
6use tracing::trace;
7
8pub struct InMemoryBackend {
10 store: Arc<DashMap<String, MemoryEntry>>,
11 max_entries: Option<usize>,
12 default_ttl_seconds: Option<i64>,
13}
14
15impl InMemoryBackend {
16 pub fn new() -> Self {
18 Self {
19 store: Arc::new(DashMap::new()),
20 max_entries: None,
21 default_ttl_seconds: None,
22 }
23 }
24
25 pub fn with_max_entries(mut self, max: usize) -> Self {
27 self.max_entries = Some(max);
28 self
29 }
30
31 pub fn with_ttl_seconds(mut self, seconds: i64) -> Self {
33 self.default_ttl_seconds = Some(seconds);
34 self
35 }
36
37 fn is_expired(&self, entry: &MemoryEntry) -> bool {
39 entry.is_expired()
40 }
41
42 fn cleanup_expired(&self) {
44 self.store.retain(|_, entry| !self.is_expired(entry));
45 }
46
47 fn evict_oldest(&self) {
49 if let Some(oldest) = self
50 .store
51 .iter()
52 .min_by_key(|e| e.value().created_at)
53 .map(|e| e.key().clone())
54 {
55 self.store.remove(&oldest);
56 }
57 }
58}
59
60impl Default for InMemoryBackend {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66#[async_trait]
67impl Memory for InMemoryBackend {
68 async fn store(&self, mut entry: MemoryEntry) -> Result<String> {
69 if entry.expires_at.is_none() {
71 if let Some(ttl) = self.default_ttl_seconds {
72 entry = entry.with_ttl_seconds(ttl);
73 }
74 }
75
76 self.cleanup_expired();
78
79 if let Some(max) = self.max_entries {
81 while self.store.len() >= max {
82 self.evict_oldest();
83 }
84 }
85
86 let id = entry.id.clone();
87 self.store.insert(id.clone(), entry);
88 trace!(id = %id, "stored memory entry");
89 Ok(id)
90 }
91
92 async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
93 if let Some(entry) = self.store.get(id) {
94 if self.is_expired(&entry) {
95 drop(entry); self.store.remove(id);
98 trace!(id = %id, "memory miss (expired)");
99 Ok(None)
100 } else {
101 trace!(id = %id, "memory hit");
102 Ok(Some(entry.clone()))
103 }
104 } else {
105 trace!(id = %id, "memory miss");
106 Ok(None)
107 }
108 }
109
110 async fn search(&self, query: MemoryQuery) -> Result<Vec<MemoryEntry>> {
111 self.cleanup_expired();
113
114 let mut results: Vec<MemoryEntry> = self
115 .store
116 .iter()
117 .map(|e| e.value().clone())
118 .filter(|entry| {
119 let matches_filters = query
121 .filters
122 .iter()
123 .all(|(key, value)| entry.metadata.get(key) == Some(value));
124
125 let matches_query = if let Some(ref q) = query.semantic_query {
129 let content_lower = entry.content.to_lowercase();
130 let query_lower = q.to_lowercase();
131
132 let stop_words: std::collections::HashSet<&str> = [
134 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
135 "have", "has", "had", "do", "does", "did", "will", "would", "could",
136 "should", "may", "might", "can", "what", "which", "who", "whom", "this",
137 "that", "these", "those", "am", "of", "for", "with", "at", "by", "from",
138 "to", "in", "on", "how", "does", "it", "its", "and", "or", "but", "not",
139 "each", "other", "some", "any", "no", "all", "most", "more",
140 ]
141 .into_iter()
142 .collect();
143
144 let keywords: Vec<&str> = query_lower
146 .split(|c: char| !c.is_alphanumeric())
147 .filter(|word| word.len() >= 2 && !stop_words.contains(word))
148 .collect();
149
150 keywords
152 .iter()
153 .any(|keyword| content_lower.contains(keyword))
154 } else {
155 true
156 };
157
158 matches_filters && matches_query
159 })
160 .collect();
161
162 results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
164
165 if let Some(limit) = query.limit {
167 results.truncate(limit);
168 }
169
170 Ok(results)
171 }
172
173 async fn delete(&self, id: &str) -> Result<bool> {
174 Ok(self.store.remove(id).is_some())
175 }
176
177 async fn clear(&self) -> Result<()> {
178 self.store.clear();
179 Ok(())
180 }
181
182 async fn count(&self) -> Result<usize> {
183 self.cleanup_expired();
184 Ok(self.store.len())
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use serde_json::json;
192
193 #[tokio::test]
194 async fn test_basic_storage() {
195 let backend = InMemoryBackend::new();
196
197 let entry = MemoryEntry::new("test content");
198 let id = backend.store(entry.clone()).await.unwrap();
199
200 let retrieved = backend.get(&id).await.unwrap();
201 assert!(retrieved.is_some());
202 assert_eq!(retrieved.unwrap().content, "test content");
203 }
204
205 #[tokio::test]
206 async fn test_ttl_expiration() {
207 let backend = InMemoryBackend::new();
208
209 let entry = MemoryEntry::new("expires soon").with_ttl_seconds(1);
211 let id = backend.store(entry).await.unwrap();
212
213 assert!(backend.get(&id).await.unwrap().is_some());
215
216 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
218
219 assert!(backend.get(&id).await.unwrap().is_none());
221 }
222
223 #[tokio::test]
224 async fn test_metadata_search() {
225 let backend = InMemoryBackend::new();
226
227 let entry1 = MemoryEntry::new("user message")
228 .with_metadata("type", json!("user"))
229 .with_metadata("user_id", json!("123"));
230
231 let entry2 = MemoryEntry::new("system message").with_metadata("type", json!("system"));
232
233 backend.store(entry1).await.unwrap();
234 backend.store(entry2).await.unwrap();
235
236 let query = MemoryQuery::new().with_filter("type", json!("user"));
238 let results = backend.search(query).await.unwrap();
239
240 assert_eq!(results.len(), 1);
241 assert_eq!(results[0].content, "user message");
242 }
243
244 #[tokio::test]
245 async fn test_max_entries() {
246 let backend = InMemoryBackend::new().with_max_entries(2);
247
248 backend.store(MemoryEntry::new("first")).await.unwrap();
249 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
250
251 backend.store(MemoryEntry::new("second")).await.unwrap();
252 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
253
254 backend.store(MemoryEntry::new("third")).await.unwrap();
255
256 assert_eq!(backend.count().await.unwrap(), 2);
258
259 let results = backend.search(MemoryQuery::new()).await.unwrap();
261 assert!(!results.iter().any(|e| e.content == "first"));
262 }
263
264 #[tokio::test]
265 async fn test_delete() {
266 let backend = InMemoryBackend::new();
267
268 let entry = MemoryEntry::new("to delete");
269 let id = backend.store(entry).await.unwrap();
270
271 assert!(backend.delete(&id).await.unwrap());
272 assert!(backend.get(&id).await.unwrap().is_none());
273 assert!(!backend.delete(&id).await.unwrap()); }
275
276 #[tokio::test]
277 async fn test_clear() {
278 let backend = InMemoryBackend::new();
279
280 backend.store(MemoryEntry::new("entry 1")).await.unwrap();
281 backend.store(MemoryEntry::new("entry 2")).await.unwrap();
282
283 assert_eq!(backend.count().await.unwrap(), 2);
284
285 backend.clear().await.unwrap();
286 assert_eq!(backend.count().await.unwrap(), 0);
287 }
288}