ceylon_runtime/memory/backends/
in_memory.rs1use crate::core::error::Result;
2use crate::core::memory::{Memory, MemoryEntry, MemoryQuery};
3use async_trait::async_trait;
4use dashmap::DashMap;
5use std::sync::Arc;
6
7pub struct InMemoryBackend {
9 store: Arc<DashMap<String, MemoryEntry>>,
10 max_entries: Option<usize>,
11 default_ttl_seconds: Option<i64>,
12}
13
14impl InMemoryBackend {
15 pub fn new() -> Self {
17 Self {
18 store: Arc::new(DashMap::new()),
19 max_entries: None,
20 default_ttl_seconds: None,
21 }
22 }
23
24 pub fn with_max_entries(mut self, max: usize) -> Self {
26 self.max_entries = Some(max);
27 self
28 }
29
30 pub fn with_ttl_seconds(mut self, seconds: i64) -> Self {
32 self.default_ttl_seconds = Some(seconds);
33 self
34 }
35
36 fn is_expired(&self, entry: &MemoryEntry) -> bool {
38 entry.is_expired()
39 }
40
41 fn cleanup_expired(&self) {
43 self.store.retain(|_, entry| !self.is_expired(entry));
44 }
45
46 fn evict_oldest(&self) {
48 if let Some(oldest) = self
49 .store
50 .iter()
51 .min_by_key(|e| e.value().created_at)
52 .map(|e| e.key().clone())
53 {
54 self.store.remove(&oldest);
55 }
56 }
57}
58
59impl Default for InMemoryBackend {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65#[async_trait]
66impl Memory for InMemoryBackend {
67 async fn store(&self, mut entry: MemoryEntry) -> Result<String> {
68 if entry.expires_at.is_none() {
70 if let Some(ttl) = self.default_ttl_seconds {
71 entry = entry.with_ttl_seconds(ttl);
72 }
73 }
74
75 self.cleanup_expired();
77
78 if let Some(max) = self.max_entries {
80 while self.store.len() >= max {
81 self.evict_oldest();
82 }
83 }
84
85 let id = entry.id.clone();
86 self.store.insert(id.clone(), entry);
87 crate::metrics::metrics().record_memory_write();
88 Ok(id)
89 }
90
91 async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
92 if let Some(entry) = self.store.get(id) {
93 if self.is_expired(&entry) {
94 drop(entry); self.store.remove(id);
97 crate::metrics::metrics().record_memory_miss();
98 Ok(None)
99 } else {
100 crate::metrics::metrics().record_memory_hit();
101 Ok(Some(entry.clone()))
102 }
103 } else {
104 crate::metrics::metrics().record_memory_miss();
105 Ok(None)
106 }
107 }
108
109 async fn search(&self, query: MemoryQuery) -> Result<Vec<MemoryEntry>> {
110 self.cleanup_expired();
112
113 let mut results: Vec<MemoryEntry> = self
114 .store
115 .iter()
116 .map(|e| e.value().clone())
117 .filter(|entry| {
118 let matches_filters = query
120 .filters
121 .iter()
122 .all(|(key, value)| entry.metadata.get(key) == Some(value));
123
124 let matches_query = if let Some(ref q) = query.semantic_query {
126 entry.content.to_lowercase().contains(&q.to_lowercase())
127 } else {
128 true
129 };
130
131 matches_filters && matches_query
132 })
133 .collect();
134
135 results.sort_by(|a, b| b.created_at.cmp(&a.created_at));
137
138 if let Some(limit) = query.limit {
140 results.truncate(limit);
141 }
142
143 Ok(results)
144 }
145
146 async fn delete(&self, id: &str) -> Result<bool> {
147 Ok(self.store.remove(id).is_some())
148 }
149
150 async fn clear(&self) -> Result<()> {
151 self.store.clear();
152 Ok(())
153 }
154
155 async fn count(&self) -> Result<usize> {
156 self.cleanup_expired();
157 Ok(self.store.len())
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use serde_json::json;
165
166 #[tokio::test]
167 async fn test_basic_storage() {
168 let backend = InMemoryBackend::new();
169
170 let entry = MemoryEntry::new("test content");
171 let id = backend.store(entry.clone()).await.unwrap();
172
173 let retrieved = backend.get(&id).await.unwrap();
174 assert!(retrieved.is_some());
175 assert_eq!(retrieved.unwrap().content, "test content");
176 }
177
178 #[tokio::test]
179 async fn test_ttl_expiration() {
180 let backend = InMemoryBackend::new();
181
182 let entry = MemoryEntry::new("expires soon").with_ttl_seconds(1);
184 let id = backend.store(entry).await.unwrap();
185
186 assert!(backend.get(&id).await.unwrap().is_some());
188
189 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
191
192 assert!(backend.get(&id).await.unwrap().is_none());
194 }
195
196 #[tokio::test]
197 async fn test_metadata_search() {
198 let backend = InMemoryBackend::new();
199
200 let entry1 = MemoryEntry::new("user message")
201 .with_metadata("type", json!("user"))
202 .with_metadata("user_id", json!("123"));
203
204 let entry2 = MemoryEntry::new("system message").with_metadata("type", json!("system"));
205
206 backend.store(entry1).await.unwrap();
207 backend.store(entry2).await.unwrap();
208
209 let query = MemoryQuery::new().with_filter("type", json!("user"));
211 let results = backend.search(query).await.unwrap();
212
213 assert_eq!(results.len(), 1);
214 assert_eq!(results[0].content, "user message");
215 }
216
217 #[tokio::test]
218 async fn test_max_entries() {
219 let backend = InMemoryBackend::new().with_max_entries(2);
220
221 backend.store(MemoryEntry::new("first")).await.unwrap();
222 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
223
224 backend.store(MemoryEntry::new("second")).await.unwrap();
225 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
226
227 backend.store(MemoryEntry::new("third")).await.unwrap();
228
229 assert_eq!(backend.count().await.unwrap(), 2);
231
232 let results = backend.search(MemoryQuery::new()).await.unwrap();
234 assert!(!results.iter().any(|e| e.content == "first"));
235 }
236
237 #[tokio::test]
238 async fn test_delete() {
239 let backend = InMemoryBackend::new();
240
241 let entry = MemoryEntry::new("to delete");
242 let id = backend.store(entry).await.unwrap();
243
244 assert!(backend.delete(&id).await.unwrap());
245 assert!(backend.get(&id).await.unwrap().is_none());
246 assert!(!backend.delete(&id).await.unwrap()); }
248
249 #[tokio::test]
250 async fn test_clear() {
251 let backend = InMemoryBackend::new();
252
253 backend.store(MemoryEntry::new("entry 1")).await.unwrap();
254 backend.store(MemoryEntry::new("entry 2")).await.unwrap();
255
256 assert_eq!(backend.count().await.unwrap(), 2);
257
258 backend.clear().await.unwrap();
259 assert_eq!(backend.count().await.unwrap(), 0);
260 }
261}