1use super::memory::{Memory, MemoryQuery, MemoryStats, MemoryValue};
6use crate::{RragError, RragResult};
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone)]
14pub struct InMemoryConfig {
15 pub max_keys: Option<usize>,
17
18 pub max_memory_bytes: Option<u64>,
20
21 pub enable_eviction: bool,
23}
24
25impl Default for InMemoryConfig {
26 fn default() -> Self {
27 Self {
28 max_keys: Some(100_000),
29 max_memory_bytes: Some(1_000_000_000), enable_eviction: false,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37struct MemoryEntry {
38 value: MemoryValue,
39 created_at: chrono::DateTime<chrono::Utc>,
40 accessed_at: chrono::DateTime<chrono::Utc>,
41}
42
43pub struct InMemoryStorage {
45 data: Arc<RwLock<HashMap<String, MemoryEntry>>>,
47
48 config: InMemoryConfig,
50}
51
52impl InMemoryStorage {
53 pub fn new() -> Self {
55 Self {
56 data: Arc::new(RwLock::new(HashMap::new())),
57 config: InMemoryConfig::default(),
58 }
59 }
60
61 pub fn with_config(config: InMemoryConfig) -> Self {
63 Self {
64 data: Arc::new(RwLock::new(HashMap::new())),
65 config,
66 }
67 }
68
69 async fn check_limits(&self) -> RragResult<()> {
71 let data = self.data.read().await;
72
73 if let Some(max_keys) = self.config.max_keys {
74 if data.len() >= max_keys {
75 return Err(RragError::storage(
76 "memory_limit",
77 std::io::Error::new(
78 std::io::ErrorKind::OutOfMemory,
79 format!("Exceeded maximum keys: {}", max_keys),
80 ),
81 ));
82 }
83 }
84
85 Ok(())
86 }
87
88 fn matches_query(&self, key: &str, query: &MemoryQuery) -> bool {
90 if let Some(pattern) = &query.key_pattern {
92 if !key.starts_with(pattern) {
93 return false;
94 }
95 }
96
97 if let Some(namespace) = &query.namespace {
99 let expected_prefix = format!("{}::", namespace);
100 if !key.starts_with(&expected_prefix) {
101 return false;
102 }
103 }
104
105 true
106 }
107
108 fn estimate_memory_usage(&self, data: &HashMap<String, MemoryEntry>) -> u64 {
110 let mut total = 0u64;
111
112 for (key, entry) in data.iter() {
113 total += key.len() as u64;
115
116 total += match &entry.value {
118 MemoryValue::String(s) => s.len() as u64,
119 MemoryValue::Integer(_) => 8,
120 MemoryValue::Float(_) => 8,
121 MemoryValue::Boolean(_) => 1,
122 MemoryValue::Json(j) => j.to_string().len() as u64,
123 MemoryValue::Bytes(b) => b.len() as u64,
124 MemoryValue::List(l) => l.len() as u64 * 64, MemoryValue::Map(m) => m.len() as u64 * 128, };
127
128 total += 64;
130 }
131
132 total
133 }
134}
135
136impl Default for InMemoryStorage {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[async_trait]
143impl Memory for InMemoryStorage {
144 fn backend_name(&self) -> &str {
145 "in_memory"
146 }
147
148 async fn set(&self, key: &str, value: MemoryValue) -> RragResult<()> {
149 self.check_limits().await?;
150
151 let mut data = self.data.write().await;
152 let now = chrono::Utc::now();
153
154 data.insert(
155 key.to_string(),
156 MemoryEntry {
157 value,
158 created_at: now,
159 accessed_at: now,
160 },
161 );
162
163 Ok(())
164 }
165
166 async fn get(&self, key: &str) -> RragResult<Option<MemoryValue>> {
167 let mut data = self.data.write().await;
168
169 if let Some(entry) = data.get_mut(key) {
170 entry.accessed_at = chrono::Utc::now();
171 Ok(Some(entry.value.clone()))
172 } else {
173 Ok(None)
174 }
175 }
176
177 async fn delete(&self, key: &str) -> RragResult<bool> {
178 let mut data = self.data.write().await;
179 Ok(data.remove(key).is_some())
180 }
181
182 async fn exists(&self, key: &str) -> RragResult<bool> {
183 let data = self.data.read().await;
184 Ok(data.contains_key(key))
185 }
186
187 async fn keys(&self, query: &MemoryQuery) -> RragResult<Vec<String>> {
188 let data = self.data.read().await;
189
190 let mut keys: Vec<String> = data
191 .keys()
192 .filter(|key| self.matches_query(key, query))
193 .cloned()
194 .collect();
195
196 if let Some(offset) = query.offset {
198 if offset < keys.len() {
199 keys = keys.into_iter().skip(offset).collect();
200 } else {
201 keys.clear();
202 }
203 }
204
205 if let Some(limit) = query.limit {
207 keys.truncate(limit);
208 }
209
210 Ok(keys)
211 }
212
213 async fn mget(&self, keys: &[String]) -> RragResult<Vec<Option<MemoryValue>>> {
214 let mut data = self.data.write().await;
215 let now = chrono::Utc::now();
216
217 let mut results = Vec::with_capacity(keys.len());
218
219 for key in keys {
220 if let Some(entry) = data.get_mut(key) {
221 entry.accessed_at = now;
222 results.push(Some(entry.value.clone()));
223 } else {
224 results.push(None);
225 }
226 }
227
228 Ok(results)
229 }
230
231 async fn mset(&self, pairs: &[(String, MemoryValue)]) -> RragResult<()> {
232 self.check_limits().await?;
233
234 let mut data = self.data.write().await;
235 let now = chrono::Utc::now();
236
237 for (key, value) in pairs {
238 data.insert(
239 key.clone(),
240 MemoryEntry {
241 value: value.clone(),
242 created_at: now,
243 accessed_at: now,
244 },
245 );
246 }
247
248 Ok(())
249 }
250
251 async fn mdelete(&self, keys: &[String]) -> RragResult<usize> {
252 let mut data = self.data.write().await;
253 let mut deleted = 0;
254
255 for key in keys {
256 if data.remove(key).is_some() {
257 deleted += 1;
258 }
259 }
260
261 Ok(deleted)
262 }
263
264 async fn clear(&self, namespace: Option<&str>) -> RragResult<()> {
265 let mut data = self.data.write().await;
266
267 if let Some(ns) = namespace {
268 let prefix = format!("{}::", ns);
269 data.retain(|key, _| !key.starts_with(&prefix));
270 } else {
271 data.clear();
272 }
273
274 Ok(())
275 }
276
277 async fn count(&self, namespace: Option<&str>) -> RragResult<usize> {
278 let data = self.data.read().await;
279
280 if let Some(ns) = namespace {
281 let prefix = format!("{}::", ns);
282 Ok(data.keys().filter(|key| key.starts_with(&prefix)).count())
283 } else {
284 Ok(data.len())
285 }
286 }
287
288 async fn health_check(&self) -> RragResult<bool> {
289 let _data = self.data.read().await;
291 Ok(true)
292 }
293
294 async fn stats(&self) -> RragResult<MemoryStats> {
295 let data = self.data.read().await;
296
297 let memory_bytes = self.estimate_memory_usage(&data);
298
299 let namespace_count = data
301 .keys()
302 .filter_map(|key| key.split_once("::").map(|(ns, _)| ns))
303 .collect::<std::collections::HashSet<_>>()
304 .len();
305
306 let mut extra = std::collections::HashMap::new();
307 extra.insert(
308 "max_keys".to_string(),
309 serde_json::json!(self.config.max_keys),
310 );
311 extra.insert(
312 "max_memory_bytes".to_string(),
313 serde_json::json!(self.config.max_memory_bytes),
314 );
315
316 Ok(MemoryStats {
317 total_keys: data.len(),
318 memory_bytes,
319 backend_type: "in_memory".to_string(),
320 namespace_count,
321 last_updated: chrono::Utc::now(),
322 extra,
323 })
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[tokio::test]
332 async fn test_in_memory_basic() {
333 let storage = InMemoryStorage::new();
334
335 storage
337 .set("test_key", MemoryValue::String("test_value".to_string()))
338 .await
339 .unwrap();
340
341 let value = storage.get("test_key").await.unwrap();
342 assert!(value.is_some());
343 assert_eq!(value.unwrap().as_string().unwrap(), "test_value");
344
345 assert!(storage.exists("test_key").await.unwrap());
347 assert!(!storage.exists("nonexistent").await.unwrap());
348
349 assert!(storage.delete("test_key").await.unwrap());
351 assert!(!storage.exists("test_key").await.unwrap());
352 }
353
354 #[tokio::test]
355 async fn test_in_memory_bulk_operations() {
356 let storage = InMemoryStorage::new();
357
358 let pairs = vec![
360 ("key1".to_string(), MemoryValue::Integer(1)),
361 ("key2".to_string(), MemoryValue::Integer(2)),
362 ("key3".to_string(), MemoryValue::Integer(3)),
363 ];
364 storage.mset(&pairs).await.unwrap();
365
366 let keys = vec!["key1".to_string(), "key2".to_string(), "key3".to_string()];
368 let values = storage.mget(&keys).await.unwrap();
369 assert_eq!(values.len(), 3);
370 assert!(values.iter().all(|v| v.is_some()));
371
372 let deleted = storage.mdelete(&keys).await.unwrap();
374 assert_eq!(deleted, 3);
375 }
376
377 #[tokio::test]
378 async fn test_in_memory_namespace() {
379 let storage = InMemoryStorage::new();
380
381 storage
383 .set("ns1::key1", MemoryValue::String("value1".to_string()))
384 .await
385 .unwrap();
386 storage
387 .set("ns1::key2", MemoryValue::String("value2".to_string()))
388 .await
389 .unwrap();
390 storage
391 .set("ns2::key1", MemoryValue::String("value3".to_string()))
392 .await
393 .unwrap();
394
395 assert_eq!(storage.count(Some("ns1")).await.unwrap(), 2);
397 assert_eq!(storage.count(Some("ns2")).await.unwrap(), 1);
398
399 storage.clear(Some("ns1")).await.unwrap();
401 assert_eq!(storage.count(Some("ns1")).await.unwrap(), 0);
402 assert_eq!(storage.count(Some("ns2")).await.unwrap(), 1);
403 }
404}