rexis_rag/storage/
in_memory.rs

1//! # In-Memory Storage Implementation
2//!
3//! Fast, thread-safe in-memory storage using HashMap with RwLock.
4
5use super::memory::{Memory, MemoryQuery, MemoryStats, MemoryValue};
6use crate::{RragError, RragResult};
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12/// Configuration for in-memory storage
13#[derive(Debug, Clone)]
14pub struct InMemoryConfig {
15    /// Maximum number of keys allowed
16    pub max_keys: Option<usize>,
17
18    /// Maximum memory usage in bytes
19    pub max_memory_bytes: Option<u64>,
20
21    /// Enable automatic eviction when limits are reached
22    pub enable_eviction: bool,
23}
24
25impl Default for InMemoryConfig {
26    fn default() -> Self {
27        Self {
28            max_keys: Some(100_000),
29            max_memory_bytes: Some(1_000_000_000), // 1GB
30            enable_eviction: false,
31        }
32    }
33}
34
35/// Entry stored in memory with metadata
36#[derive(Debug, Clone)]
37struct MemoryEntry {
38    value: MemoryValue,
39    created_at: chrono::DateTime<chrono::Utc>,
40    accessed_at: chrono::DateTime<chrono::Utc>,
41}
42
43/// In-memory storage implementation
44pub struct InMemoryStorage {
45    /// Internal storage
46    data: Arc<RwLock<HashMap<String, MemoryEntry>>>,
47
48    /// Configuration
49    config: InMemoryConfig,
50}
51
52impl InMemoryStorage {
53    /// Create a new in-memory storage with default configuration
54    pub fn new() -> Self {
55        Self {
56            data: Arc::new(RwLock::new(HashMap::new())),
57            config: InMemoryConfig::default(),
58        }
59    }
60
61    /// Create a new in-memory storage with custom configuration
62    pub fn with_config(config: InMemoryConfig) -> Self {
63        Self {
64            data: Arc::new(RwLock::new(HashMap::new())),
65            config,
66        }
67    }
68
69    /// Check if we're within limits
70    async fn check_limits(&self) -> RragResult<()> {
71        let data = self.data.read().await;
72
73        if let Some(max_keys) = self.config.max_keys {
74            if data.len() >= max_keys {
75                return Err(RragError::storage(
76                    "memory_limit",
77                    std::io::Error::new(
78                        std::io::ErrorKind::OutOfMemory,
79                        format!("Exceeded maximum keys: {}", max_keys),
80                    ),
81                ));
82            }
83        }
84
85        Ok(())
86    }
87
88    /// Check if a key matches the query pattern
89    fn matches_query(&self, key: &str, query: &MemoryQuery) -> bool {
90        // Check key pattern
91        if let Some(pattern) = &query.key_pattern {
92            if !key.starts_with(pattern) {
93                return false;
94            }
95        }
96
97        // Check namespace (keys can be prefixed with namespace::)
98        if let Some(namespace) = &query.namespace {
99            let expected_prefix = format!("{}::", namespace);
100            if !key.starts_with(&expected_prefix) {
101                return false;
102            }
103        }
104
105        true
106    }
107
108    /// Estimate memory usage (rough calculation)
109    fn estimate_memory_usage(&self, data: &HashMap<String, MemoryEntry>) -> u64 {
110        let mut total = 0u64;
111
112        for (key, entry) in data.iter() {
113            // Key size
114            total += key.len() as u64;
115
116            // Value size (rough estimate)
117            total += match &entry.value {
118                MemoryValue::String(s) => s.len() as u64,
119                MemoryValue::Integer(_) => 8,
120                MemoryValue::Float(_) => 8,
121                MemoryValue::Boolean(_) => 1,
122                MemoryValue::Json(j) => j.to_string().len() as u64,
123                MemoryValue::Bytes(b) => b.len() as u64,
124                MemoryValue::List(l) => l.len() as u64 * 64, // Rough estimate
125                MemoryValue::Map(m) => m.len() as u64 * 128, // Rough estimate
126            };
127
128            // Metadata overhead
129            total += 64;
130        }
131
132        total
133    }
134}
135
136impl Default for InMemoryStorage {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142#[async_trait]
143impl Memory for InMemoryStorage {
144    fn backend_name(&self) -> &str {
145        "in_memory"
146    }
147
148    async fn set(&self, key: &str, value: MemoryValue) -> RragResult<()> {
149        self.check_limits().await?;
150
151        let mut data = self.data.write().await;
152        let now = chrono::Utc::now();
153
154        data.insert(
155            key.to_string(),
156            MemoryEntry {
157                value,
158                created_at: now,
159                accessed_at: now,
160            },
161        );
162
163        Ok(())
164    }
165
166    async fn get(&self, key: &str) -> RragResult<Option<MemoryValue>> {
167        let mut data = self.data.write().await;
168
169        if let Some(entry) = data.get_mut(key) {
170            entry.accessed_at = chrono::Utc::now();
171            Ok(Some(entry.value.clone()))
172        } else {
173            Ok(None)
174        }
175    }
176
177    async fn delete(&self, key: &str) -> RragResult<bool> {
178        let mut data = self.data.write().await;
179        Ok(data.remove(key).is_some())
180    }
181
182    async fn exists(&self, key: &str) -> RragResult<bool> {
183        let data = self.data.read().await;
184        Ok(data.contains_key(key))
185    }
186
187    async fn keys(&self, query: &MemoryQuery) -> RragResult<Vec<String>> {
188        let data = self.data.read().await;
189
190        let mut keys: Vec<String> = data
191            .keys()
192            .filter(|key| self.matches_query(key, query))
193            .cloned()
194            .collect();
195
196        // Apply offset
197        if let Some(offset) = query.offset {
198            if offset < keys.len() {
199                keys = keys.into_iter().skip(offset).collect();
200            } else {
201                keys.clear();
202            }
203        }
204
205        // Apply limit
206        if let Some(limit) = query.limit {
207            keys.truncate(limit);
208        }
209
210        Ok(keys)
211    }
212
213    async fn mget(&self, keys: &[String]) -> RragResult<Vec<Option<MemoryValue>>> {
214        let mut data = self.data.write().await;
215        let now = chrono::Utc::now();
216
217        let mut results = Vec::with_capacity(keys.len());
218
219        for key in keys {
220            if let Some(entry) = data.get_mut(key) {
221                entry.accessed_at = now;
222                results.push(Some(entry.value.clone()));
223            } else {
224                results.push(None);
225            }
226        }
227
228        Ok(results)
229    }
230
231    async fn mset(&self, pairs: &[(String, MemoryValue)]) -> RragResult<()> {
232        self.check_limits().await?;
233
234        let mut data = self.data.write().await;
235        let now = chrono::Utc::now();
236
237        for (key, value) in pairs {
238            data.insert(
239                key.clone(),
240                MemoryEntry {
241                    value: value.clone(),
242                    created_at: now,
243                    accessed_at: now,
244                },
245            );
246        }
247
248        Ok(())
249    }
250
251    async fn mdelete(&self, keys: &[String]) -> RragResult<usize> {
252        let mut data = self.data.write().await;
253        let mut deleted = 0;
254
255        for key in keys {
256            if data.remove(key).is_some() {
257                deleted += 1;
258            }
259        }
260
261        Ok(deleted)
262    }
263
264    async fn clear(&self, namespace: Option<&str>) -> RragResult<()> {
265        let mut data = self.data.write().await;
266
267        if let Some(ns) = namespace {
268            let prefix = format!("{}::", ns);
269            data.retain(|key, _| !key.starts_with(&prefix));
270        } else {
271            data.clear();
272        }
273
274        Ok(())
275    }
276
277    async fn count(&self, namespace: Option<&str>) -> RragResult<usize> {
278        let data = self.data.read().await;
279
280        if let Some(ns) = namespace {
281            let prefix = format!("{}::", ns);
282            Ok(data.keys().filter(|key| key.starts_with(&prefix)).count())
283        } else {
284            Ok(data.len())
285        }
286    }
287
288    async fn health_check(&self) -> RragResult<bool> {
289        // Try to read the data
290        let _data = self.data.read().await;
291        Ok(true)
292    }
293
294    async fn stats(&self) -> RragResult<MemoryStats> {
295        let data = self.data.read().await;
296
297        let memory_bytes = self.estimate_memory_usage(&data);
298
299        // Count namespaces (keys with :: separator)
300        let namespace_count = data
301            .keys()
302            .filter_map(|key| key.split_once("::").map(|(ns, _)| ns))
303            .collect::<std::collections::HashSet<_>>()
304            .len();
305
306        let mut extra = std::collections::HashMap::new();
307        extra.insert(
308            "max_keys".to_string(),
309            serde_json::json!(self.config.max_keys),
310        );
311        extra.insert(
312            "max_memory_bytes".to_string(),
313            serde_json::json!(self.config.max_memory_bytes),
314        );
315
316        Ok(MemoryStats {
317            total_keys: data.len(),
318            memory_bytes,
319            backend_type: "in_memory".to_string(),
320            namespace_count,
321            last_updated: chrono::Utc::now(),
322            extra,
323        })
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[tokio::test]
332    async fn test_in_memory_basic() {
333        let storage = InMemoryStorage::new();
334
335        // Test set and get
336        storage
337            .set("test_key", MemoryValue::String("test_value".to_string()))
338            .await
339            .unwrap();
340
341        let value = storage.get("test_key").await.unwrap();
342        assert!(value.is_some());
343        assert_eq!(value.unwrap().as_string().unwrap(), "test_value");
344
345        // Test exists
346        assert!(storage.exists("test_key").await.unwrap());
347        assert!(!storage.exists("nonexistent").await.unwrap());
348
349        // Test delete
350        assert!(storage.delete("test_key").await.unwrap());
351        assert!(!storage.exists("test_key").await.unwrap());
352    }
353
354    #[tokio::test]
355    async fn test_in_memory_bulk_operations() {
356        let storage = InMemoryStorage::new();
357
358        // Test mset
359        let pairs = vec![
360            ("key1".to_string(), MemoryValue::Integer(1)),
361            ("key2".to_string(), MemoryValue::Integer(2)),
362            ("key3".to_string(), MemoryValue::Integer(3)),
363        ];
364        storage.mset(&pairs).await.unwrap();
365
366        // Test mget
367        let keys = vec!["key1".to_string(), "key2".to_string(), "key3".to_string()];
368        let values = storage.mget(&keys).await.unwrap();
369        assert_eq!(values.len(), 3);
370        assert!(values.iter().all(|v| v.is_some()));
371
372        // Test mdelete
373        let deleted = storage.mdelete(&keys).await.unwrap();
374        assert_eq!(deleted, 3);
375    }
376
377    #[tokio::test]
378    async fn test_in_memory_namespace() {
379        let storage = InMemoryStorage::new();
380
381        // Add keys with namespace
382        storage
383            .set("ns1::key1", MemoryValue::String("value1".to_string()))
384            .await
385            .unwrap();
386        storage
387            .set("ns1::key2", MemoryValue::String("value2".to_string()))
388            .await
389            .unwrap();
390        storage
391            .set("ns2::key1", MemoryValue::String("value3".to_string()))
392            .await
393            .unwrap();
394
395        // Count by namespace
396        assert_eq!(storage.count(Some("ns1")).await.unwrap(), 2);
397        assert_eq!(storage.count(Some("ns2")).await.unwrap(), 1);
398
399        // Clear namespace
400        storage.clear(Some("ns1")).await.unwrap();
401        assert_eq!(storage.count(Some("ns1")).await.unwrap(), 0);
402        assert_eq!(storage.count(Some("ns2")).await.unwrap(), 1);
403    }
404}