context_mcp/
storage.rs

1//! Multi-tier storage for context entries
2//!
3//! Implements a tiered storage system:
4//! 1. In-memory LRU cache for hot data
5//! 2. Sled embedded database for persistence
6//! 3. Optional vector index for similarity search
7
8use std::collections::HashMap;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use chrono::Utc;
13use lru::LruCache;
14use serde::{Deserialize, Serialize};
15use tokio::sync::RwLock;
16
17#[cfg(feature = "persistence")]
18use sled;
19
20use crate::context::{Context, ContextDomain, ContextId, ContextQuery};
21use crate::error::{ContextError, Result};
22
23/// Storage configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct StorageConfig {
26    /// Maximum items in memory cache
27    pub memory_cache_size: usize,
28    /// Path for persistent storage (None for in-memory only)
29    pub persist_path: Option<PathBuf>,
30    /// Enable automatic cleanup of expired contexts
31    pub auto_cleanup: bool,
32    /// Cleanup interval in seconds
33    pub cleanup_interval_secs: u64,
34    /// Enable disk persistence
35    pub enable_persistence: bool,
36}
37
38impl Default for StorageConfig {
39    fn default() -> Self {
40        Self {
41            memory_cache_size: 10_000,
42            persist_path: None,
43            auto_cleanup: true,
44            cleanup_interval_secs: 3600,
45            enable_persistence: true,
46        }
47    }
48}
49
50impl StorageConfig {
51    /// Create config for in-memory only storage
52    pub fn memory_only(cache_size: usize) -> Self {
53        Self {
54            memory_cache_size: cache_size,
55            persist_path: None,
56            auto_cleanup: true,
57            cleanup_interval_secs: 3600,
58            enable_persistence: false,
59        }
60    }
61
62    /// Create config with disk persistence
63    pub fn with_persistence(cache_size: usize, path: impl Into<PathBuf>) -> Self {
64        Self {
65            memory_cache_size: cache_size,
66            persist_path: Some(path.into()),
67            auto_cleanup: true,
68            cleanup_interval_secs: 3600,
69            enable_persistence: true,
70        }
71    }
72}
73
74/// Multi-tier context storage
75pub struct ContextStore {
76    /// In-memory LRU cache
77    memory_cache: Arc<RwLock<LruCache<ContextId, Context>>>,
78    /// Persistent storage (sled)
79    #[cfg(feature = "persistence")]
80    disk_store: Option<sled::Db>,
81    /// Domain index for fast filtering
82    domain_index: Arc<RwLock<HashMap<ContextDomain, Vec<ContextId>>>>,
83    /// Tag index for fast filtering
84    tag_index: Arc<RwLock<HashMap<String, Vec<ContextId>>>>,
85    /// Configuration
86    config: StorageConfig,
87}
88
89impl ContextStore {
90    /// Create a new context store
91    pub fn new(config: StorageConfig) -> Result<Self> {
92        let memory_cache = Arc::new(RwLock::new(LruCache::new(
93            std::num::NonZeroUsize::new(config.memory_cache_size)
94                .ok_or_else(|| ContextError::Config("Cache size must be > 0".into()))?,
95        )));
96
97        #[cfg(feature = "persistence")]
98        let disk_store = if config.enable_persistence {
99            let path = config
100                .persist_path
101                .clone()
102                .unwrap_or_else(|| PathBuf::from("./data/context_store"));
103
104            // Ensure directory exists
105            if let Some(parent) = path.parent() {
106                std::fs::create_dir_all(parent)?;
107            }
108
109            Some(sled::open(&path)?)
110        } else {
111            None
112        };
113
114        #[cfg(not(feature = "persistence"))]
115        let _disk_store = ();
116
117        Ok(Self {
118            memory_cache,
119            #[cfg(feature = "persistence")]
120            disk_store,
121            domain_index: Arc::new(RwLock::new(HashMap::new())),
122            tag_index: Arc::new(RwLock::new(HashMap::new())),
123            config,
124        })
125    }
126
127    /// Store a context entry
128    pub async fn store(&self, context: Context) -> Result<ContextId> {
129        let id = context.id.clone();
130
131        // Update indices
132        {
133            let mut domain_idx = self.domain_index.write().await;
134            domain_idx
135                .entry(context.domain.clone())
136                .or_default()
137                .push(id.clone());
138        }
139
140        {
141            let mut tag_idx = self.tag_index.write().await;
142            for tag in &context.metadata.tags {
143                tag_idx.entry(tag.clone()).or_default().push(id.clone());
144            }
145        }
146
147        // Store in memory cache
148        {
149            let mut cache = self.memory_cache.write().await;
150            cache.put(id.clone(), context.clone());
151        }
152
153        // Persist to disk if enabled
154        #[cfg(feature = "persistence")]
155        if let Some(ref db) = self.disk_store {
156            let serialized = serde_json::to_vec(&context)?;
157            db.insert(id.as_str().as_bytes(), serialized)?;
158            db.flush_async().await?;
159        }
160
161        Ok(id)
162    }
163
164    /// Retrieve a context by ID
165    pub async fn get(&self, id: &ContextId) -> Result<Option<Context>> {
166        // Check memory cache first
167        {
168            let mut cache = self.memory_cache.write().await;
169            if let Some(ctx) = cache.get_mut(id) {
170                ctx.mark_accessed();
171                return Ok(Some(ctx.clone()));
172            }
173        }
174
175        // Check disk storage
176        #[cfg(feature = "persistence")]
177        if let Some(ref db) = self.disk_store {
178            if let Some(data) = db.get(id.as_str().as_bytes())? {
179                let mut context: Context = serde_json::from_slice(&data)?;
180                context.mark_accessed();
181
182                // Promote to memory cache
183                let mut cache = self.memory_cache.write().await;
184                cache.put(id.clone(), context.clone());
185
186                return Ok(Some(context));
187            }
188        }
189
190        Ok(None)
191    }
192
193    /// Delete a context by ID
194    pub async fn delete(&self, id: &ContextId) -> Result<bool> {
195        let mut found = false;
196
197        // First, get the context to extract domain and tags before deletion
198        let context_data = self.get(id).await?;
199
200        // Remove from memory cache
201        {
202            let mut cache = self.memory_cache.write().await;
203            if cache.pop(id).is_some() {
204                found = true;
205            }
206        }
207
208        // Remove from disk
209        #[cfg(feature = "persistence")]
210        if let Some(ref db) = self.disk_store {
211            if db.remove(id.as_str().as_bytes())?.is_some() {
212                found = true;
213            }
214        }
215
216        // Clean up indices if context was found
217        if let Some(ctx) = context_data {
218            // Remove from domain index
219            {
220                let mut domain_idx = self.domain_index.write().await;
221                if let Some(ids) = domain_idx.get_mut(&ctx.domain) {
222                    ids.retain(|stored_id| stored_id != id);
223                    // Remove empty domain entries to prevent unbounded growth
224                    if ids.is_empty() {
225                        domain_idx.remove(&ctx.domain);
226                    }
227                }
228            }
229
230            // Remove from tag index
231            {
232                let mut tag_idx = self.tag_index.write().await;
233                for tag in &ctx.metadata.tags {
234                    if let Some(ids) = tag_idx.get_mut(tag) {
235                        ids.retain(|stored_id| stored_id != id);
236                        // Remove empty tag entries to prevent unbounded growth
237                        if ids.is_empty() {
238                            tag_idx.remove(tag);
239                        }
240                    }
241                }
242            }
243        }
244
245        Ok(found)
246    }
247
248    /// Query contexts based on criteria
249    pub async fn query(&self, query: &ContextQuery) -> Result<Vec<Context>> {
250        let mut results = Vec::new();
251
252        // Get candidate IDs from indices
253        let candidate_ids = self.get_candidate_ids(query).await;
254
255        // Fetch and filter contexts
256        for id in candidate_ids {
257            if let Some(ctx) = self.get(&id).await? {
258                if self.matches_query(&ctx, query) {
259                    results.push(ctx);
260                }
261
262                if results.len() >= query.limit {
263                    break;
264                }
265            }
266        }
267
268        // Sort by importance and recency
269        results.sort_by(|a, b| {
270            let importance_cmp = b
271                .metadata
272                .importance
273                .partial_cmp(&a.metadata.importance)
274                .unwrap_or(std::cmp::Ordering::Equal);
275
276            if importance_cmp == std::cmp::Ordering::Equal {
277                b.accessed_at.cmp(&a.accessed_at)
278            } else {
279                importance_cmp
280            }
281        });
282
283        results.truncate(query.limit);
284        Ok(results)
285    }
286
287    /// Retrieve relevant context for RAG
288    pub async fn retrieve_context(
289        &self,
290        query_text: &str,
291        limit: usize,
292        domain_filter: Option<&ContextDomain>,
293    ) -> Result<Vec<Context>> {
294        // Build query
295        let _ctx_query = ContextQuery::new().with_limit(limit);
296
297        if let Some(_domain) = domain_filter {
298            // ctx_query = ctx_query.with_domain(domain.clone());
299        }
300
301        // For now, simple text matching
302        // TODO: Implement vector similarity when embeddings are available
303        let query_lower = query_text.to_lowercase();
304        let mut results = Vec::new();
305
306        let cache = self.memory_cache.read().await;
307        for (_, ctx) in cache.iter() {
308            if ctx.content.to_lowercase().contains(&query_lower) {
309                if let Some(domain) = domain_filter {
310                    if &ctx.domain != domain {
311                        continue;
312                    }
313                }
314                results.push(ctx.clone());
315                if results.len() >= limit {
316                    break;
317                }
318            }
319        }
320
321        // Sort by importance
322        results.sort_by(|a, b| {
323            b.metadata
324                .importance
325                .partial_cmp(&a.metadata.importance)
326                .unwrap_or(std::cmp::Ordering::Equal)
327        });
328
329        Ok(results)
330    }
331
332    /// Get candidate IDs from indices based on query filters
333    async fn get_candidate_ids(&self, query: &ContextQuery) -> Vec<ContextId> {
334        let mut candidates = Vec::new();
335
336        // If domain filter specified, use domain index
337        if let Some(ref domain) = query.domain_filter {
338            let domain_idx = self.domain_index.read().await;
339            if let Some(ids) = domain_idx.get(domain) {
340                candidates.extend(ids.iter().cloned());
341            }
342        }
343
344        // If tag filter specified, use tag index
345        if let Some(ref tags) = query.tag_filter {
346            let tag_idx = self.tag_index.read().await;
347            for tag in tags {
348                if let Some(ids) = tag_idx.get(tag) {
349                    candidates.extend(ids.iter().cloned());
350                }
351            }
352        }
353
354        // If no filters, get all from cache
355        if candidates.is_empty() && query.domain_filter.is_none() && query.tag_filter.is_none() {
356            let cache = self.memory_cache.read().await;
357            candidates = cache.iter().map(|(id, _)| id.clone()).collect();
358        }
359
360        // Deduplicate
361        candidates.sort();
362        candidates.dedup();
363
364        candidates
365    }
366
367    /// Check if a context matches the query criteria
368    fn matches_query(&self, ctx: &Context, query: &ContextQuery) -> bool {
369        // Check expiration
370        if ctx.is_expired() {
371            return false;
372        }
373
374        // Check domain
375        if let Some(ref domain) = query.domain_filter {
376            if &ctx.domain != domain {
377                return false;
378            }
379        }
380
381        // Check source
382        if let Some(ref source) = query.source_filter {
383            if &ctx.metadata.source != source {
384                return false;
385            }
386        }
387
388        // Check importance
389        if let Some(min_importance) = query.min_importance {
390            if ctx.metadata.importance < min_importance {
391                return false;
392            }
393        }
394
395        // Check age
396        if let Some(max_age) = query.max_age_seconds {
397            if ctx.age_seconds() > max_age {
398                return false;
399            }
400        }
401
402        // Check verified status
403        if query.verified_only && !ctx.metadata.verified {
404            return false;
405        }
406
407        // Check text query (simple contains for now)
408        if let Some(ref text) = query.query {
409            if !ctx.content.to_lowercase().contains(&text.to_lowercase()) {
410                return false;
411            }
412        }
413
414        true
415    }
416
417    /// Get storage statistics
418    pub async fn stats(&self) -> StorageStats {
419        let cache = self.memory_cache.read().await;
420        let memory_count = cache.len();
421
422        #[cfg(feature = "persistence")]
423        let disk_count = self.disk_store.as_ref().map(|db| db.len()).unwrap_or(0);
424
425        #[cfg(not(feature = "persistence"))]
426        let disk_count = 0;
427
428        StorageStats {
429            memory_count,
430            disk_count,
431            cache_capacity: self.config.memory_cache_size,
432        }
433    }
434
435    /// Cleanup expired contexts
436    pub async fn cleanup_expired(&self) -> Result<usize> {
437        let mut removed = 0;
438        let now = Utc::now();
439
440        // Collect expired IDs
441        let expired_ids: Vec<ContextId> = {
442            let cache = self.memory_cache.read().await;
443            cache
444                .iter()
445                .filter(|(_, ctx)| ctx.expires_at.map(|exp| now > exp).unwrap_or(false))
446                .map(|(id, _)| id.clone())
447                .collect()
448        };
449
450        // Remove expired contexts
451        for id in expired_ids {
452            if self.delete(&id).await? {
453                removed += 1;
454            }
455        }
456
457        Ok(removed)
458    }
459}
460
461/// Storage statistics
462#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct StorageStats {
464    /// Number of items in memory cache
465    pub memory_count: usize,
466    /// Number of items on disk
467    pub disk_count: usize,
468    /// Memory cache capacity
469    pub cache_capacity: usize,
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[tokio::test]
477    async fn test_store_and_retrieve() {
478        let config = StorageConfig::memory_only(100);
479        let store = ContextStore::new(config).unwrap();
480
481        let ctx = Context::new("Test content", ContextDomain::Code);
482        let id = ctx.id.clone();
483
484        store.store(ctx).await.unwrap();
485
486        let retrieved = store.get(&id).await.unwrap();
487        assert!(retrieved.is_some());
488        assert_eq!(retrieved.unwrap().content, "Test content");
489    }
490
491    #[tokio::test]
492    async fn test_query_by_domain() {
493        let config = StorageConfig::memory_only(100);
494        let store = ContextStore::new(config).unwrap();
495
496        let ctx1 = Context::new("Code content", ContextDomain::Code);
497        let ctx2 = Context::new("Doc content", ContextDomain::Documentation);
498
499        store.store(ctx1).await.unwrap();
500        store.store(ctx2).await.unwrap();
501
502        let query = ContextQuery::new().with_domain(ContextDomain::Code);
503        let results = store.query(&query).await.unwrap();
504
505        assert_eq!(results.len(), 1);
506        assert_eq!(results[0].domain, ContextDomain::Code);
507    }
508}