1use std::collections::HashMap;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use chrono::Utc;
13use lru::LruCache;
14use serde::{Deserialize, Serialize};
15use tokio::sync::RwLock;
16
17#[cfg(feature = "persistence")]
18use sled;
19
20use crate::context::{Context, ContextDomain, ContextId, ContextQuery};
21use crate::error::{ContextError, Result};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct StorageConfig {
26 pub memory_cache_size: usize,
28 pub persist_path: Option<PathBuf>,
30 pub auto_cleanup: bool,
32 pub cleanup_interval_secs: u64,
34 pub enable_persistence: bool,
36}
37
38impl Default for StorageConfig {
39 fn default() -> Self {
40 Self {
41 memory_cache_size: 10_000,
42 persist_path: None,
43 auto_cleanup: true,
44 cleanup_interval_secs: 3600,
45 enable_persistence: true,
46 }
47 }
48}
49
50impl StorageConfig {
51 pub fn memory_only(cache_size: usize) -> Self {
53 Self {
54 memory_cache_size: cache_size,
55 persist_path: None,
56 auto_cleanup: true,
57 cleanup_interval_secs: 3600,
58 enable_persistence: false,
59 }
60 }
61
62 pub fn with_persistence(cache_size: usize, path: impl Into<PathBuf>) -> Self {
64 Self {
65 memory_cache_size: cache_size,
66 persist_path: Some(path.into()),
67 auto_cleanup: true,
68 cleanup_interval_secs: 3600,
69 enable_persistence: true,
70 }
71 }
72}
73
74pub struct ContextStore {
76 memory_cache: Arc<RwLock<LruCache<ContextId, Context>>>,
78 #[cfg(feature = "persistence")]
80 disk_store: Option<sled::Db>,
81 domain_index: Arc<RwLock<HashMap<ContextDomain, Vec<ContextId>>>>,
83 tag_index: Arc<RwLock<HashMap<String, Vec<ContextId>>>>,
85 config: StorageConfig,
87}
88
89impl ContextStore {
90 pub fn new(config: StorageConfig) -> Result<Self> {
92 let memory_cache = Arc::new(RwLock::new(LruCache::new(
93 std::num::NonZeroUsize::new(config.memory_cache_size)
94 .ok_or_else(|| ContextError::Config("Cache size must be > 0".into()))?,
95 )));
96
97 #[cfg(feature = "persistence")]
98 let disk_store = if config.enable_persistence {
99 let path = config
100 .persist_path
101 .clone()
102 .unwrap_or_else(|| PathBuf::from("./data/context_store"));
103
104 if let Some(parent) = path.parent() {
106 std::fs::create_dir_all(parent)?;
107 }
108
109 Some(sled::open(&path)?)
110 } else {
111 None
112 };
113
114 #[cfg(not(feature = "persistence"))]
115 let _disk_store = ();
116
117 Ok(Self {
118 memory_cache,
119 #[cfg(feature = "persistence")]
120 disk_store,
121 domain_index: Arc::new(RwLock::new(HashMap::new())),
122 tag_index: Arc::new(RwLock::new(HashMap::new())),
123 config,
124 })
125 }
126
127 pub async fn store(&self, context: Context) -> Result<ContextId> {
129 let id = context.id.clone();
130
131 {
133 let mut domain_idx = self.domain_index.write().await;
134 domain_idx
135 .entry(context.domain.clone())
136 .or_default()
137 .push(id.clone());
138 }
139
140 {
141 let mut tag_idx = self.tag_index.write().await;
142 for tag in &context.metadata.tags {
143 tag_idx.entry(tag.clone()).or_default().push(id.clone());
144 }
145 }
146
147 {
149 let mut cache = self.memory_cache.write().await;
150 cache.put(id.clone(), context.clone());
151 }
152
153 #[cfg(feature = "persistence")]
155 if let Some(ref db) = self.disk_store {
156 let serialized = serde_json::to_vec(&context)?;
157 db.insert(id.as_str().as_bytes(), serialized)?;
158 db.flush_async().await?;
159 }
160
161 Ok(id)
162 }
163
164 pub async fn get(&self, id: &ContextId) -> Result<Option<Context>> {
166 {
168 let mut cache = self.memory_cache.write().await;
169 if let Some(ctx) = cache.get_mut(id) {
170 ctx.mark_accessed();
171 return Ok(Some(ctx.clone()));
172 }
173 }
174
175 #[cfg(feature = "persistence")]
177 if let Some(ref db) = self.disk_store {
178 if let Some(data) = db.get(id.as_str().as_bytes())? {
179 let mut context: Context = serde_json::from_slice(&data)?;
180 context.mark_accessed();
181
182 let mut cache = self.memory_cache.write().await;
184 cache.put(id.clone(), context.clone());
185
186 return Ok(Some(context));
187 }
188 }
189
190 Ok(None)
191 }
192
193 pub async fn delete(&self, id: &ContextId) -> Result<bool> {
195 let mut found = false;
196
197 let context_data = self.get(id).await?;
199
200 {
202 let mut cache = self.memory_cache.write().await;
203 if cache.pop(id).is_some() {
204 found = true;
205 }
206 }
207
208 #[cfg(feature = "persistence")]
210 if let Some(ref db) = self.disk_store {
211 if db.remove(id.as_str().as_bytes())?.is_some() {
212 found = true;
213 }
214 }
215
216 if let Some(ctx) = context_data {
218 {
220 let mut domain_idx = self.domain_index.write().await;
221 if let Some(ids) = domain_idx.get_mut(&ctx.domain) {
222 ids.retain(|stored_id| stored_id != id);
223 if ids.is_empty() {
225 domain_idx.remove(&ctx.domain);
226 }
227 }
228 }
229
230 {
232 let mut tag_idx = self.tag_index.write().await;
233 for tag in &ctx.metadata.tags {
234 if let Some(ids) = tag_idx.get_mut(tag) {
235 ids.retain(|stored_id| stored_id != id);
236 if ids.is_empty() {
238 tag_idx.remove(tag);
239 }
240 }
241 }
242 }
243 }
244
245 Ok(found)
246 }
247
248 pub async fn query(&self, query: &ContextQuery) -> Result<Vec<Context>> {
250 let mut results = Vec::new();
251
252 let candidate_ids = self.get_candidate_ids(query).await;
254
255 for id in candidate_ids {
257 if let Some(ctx) = self.get(&id).await? {
258 if self.matches_query(&ctx, query) {
259 results.push(ctx);
260 }
261
262 if results.len() >= query.limit {
263 break;
264 }
265 }
266 }
267
268 results.sort_by(|a, b| {
270 let importance_cmp = b
271 .metadata
272 .importance
273 .partial_cmp(&a.metadata.importance)
274 .unwrap_or(std::cmp::Ordering::Equal);
275
276 if importance_cmp == std::cmp::Ordering::Equal {
277 b.accessed_at.cmp(&a.accessed_at)
278 } else {
279 importance_cmp
280 }
281 });
282
283 results.truncate(query.limit);
284 Ok(results)
285 }
286
287 pub async fn retrieve_context(
289 &self,
290 query_text: &str,
291 limit: usize,
292 domain_filter: Option<&ContextDomain>,
293 ) -> Result<Vec<Context>> {
294 let _ctx_query = ContextQuery::new().with_limit(limit);
296
297 if let Some(_domain) = domain_filter {
298 }
300
301 let query_lower = query_text.to_lowercase();
304 let mut results = Vec::new();
305
306 let cache = self.memory_cache.read().await;
307 for (_, ctx) in cache.iter() {
308 if ctx.content.to_lowercase().contains(&query_lower) {
309 if let Some(domain) = domain_filter {
310 if &ctx.domain != domain {
311 continue;
312 }
313 }
314 results.push(ctx.clone());
315 if results.len() >= limit {
316 break;
317 }
318 }
319 }
320
321 results.sort_by(|a, b| {
323 b.metadata
324 .importance
325 .partial_cmp(&a.metadata.importance)
326 .unwrap_or(std::cmp::Ordering::Equal)
327 });
328
329 Ok(results)
330 }
331
332 async fn get_candidate_ids(&self, query: &ContextQuery) -> Vec<ContextId> {
334 let mut candidates = Vec::new();
335
336 if let Some(ref domain) = query.domain_filter {
338 let domain_idx = self.domain_index.read().await;
339 if let Some(ids) = domain_idx.get(domain) {
340 candidates.extend(ids.iter().cloned());
341 }
342 }
343
344 if let Some(ref tags) = query.tag_filter {
346 let tag_idx = self.tag_index.read().await;
347 for tag in tags {
348 if let Some(ids) = tag_idx.get(tag) {
349 candidates.extend(ids.iter().cloned());
350 }
351 }
352 }
353
354 if candidates.is_empty() && query.domain_filter.is_none() && query.tag_filter.is_none() {
356 let cache = self.memory_cache.read().await;
357 candidates = cache.iter().map(|(id, _)| id.clone()).collect();
358 }
359
360 candidates.sort();
362 candidates.dedup();
363
364 candidates
365 }
366
367 fn matches_query(&self, ctx: &Context, query: &ContextQuery) -> bool {
369 if ctx.is_expired() {
371 return false;
372 }
373
374 if let Some(ref domain) = query.domain_filter {
376 if &ctx.domain != domain {
377 return false;
378 }
379 }
380
381 if let Some(ref source) = query.source_filter {
383 if &ctx.metadata.source != source {
384 return false;
385 }
386 }
387
388 if let Some(min_importance) = query.min_importance {
390 if ctx.metadata.importance < min_importance {
391 return false;
392 }
393 }
394
395 if let Some(max_age) = query.max_age_seconds {
397 if ctx.age_seconds() > max_age {
398 return false;
399 }
400 }
401
402 if query.verified_only && !ctx.metadata.verified {
404 return false;
405 }
406
407 if let Some(ref text) = query.query {
409 if !ctx.content.to_lowercase().contains(&text.to_lowercase()) {
410 return false;
411 }
412 }
413
414 true
415 }
416
417 pub async fn stats(&self) -> StorageStats {
419 let cache = self.memory_cache.read().await;
420 let memory_count = cache.len();
421
422 #[cfg(feature = "persistence")]
423 let disk_count = self.disk_store.as_ref().map(|db| db.len()).unwrap_or(0);
424
425 #[cfg(not(feature = "persistence"))]
426 let disk_count = 0;
427
428 StorageStats {
429 memory_count,
430 disk_count,
431 cache_capacity: self.config.memory_cache_size,
432 }
433 }
434
435 pub async fn cleanup_expired(&self) -> Result<usize> {
437 let mut removed = 0;
438 let now = Utc::now();
439
440 let expired_ids: Vec<ContextId> = {
442 let cache = self.memory_cache.read().await;
443 cache
444 .iter()
445 .filter(|(_, ctx)| ctx.expires_at.map(|exp| now > exp).unwrap_or(false))
446 .map(|(id, _)| id.clone())
447 .collect()
448 };
449
450 for id in expired_ids {
452 if self.delete(&id).await? {
453 removed += 1;
454 }
455 }
456
457 Ok(removed)
458 }
459}
460
461#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct StorageStats {
464 pub memory_count: usize,
466 pub disk_count: usize,
468 pub cache_capacity: usize,
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[tokio::test]
477 async fn test_store_and_retrieve() {
478 let config = StorageConfig::memory_only(100);
479 let store = ContextStore::new(config).unwrap();
480
481 let ctx = Context::new("Test content", ContextDomain::Code);
482 let id = ctx.id.clone();
483
484 store.store(ctx).await.unwrap();
485
486 let retrieved = store.get(&id).await.unwrap();
487 assert!(retrieved.is_some());
488 assert_eq!(retrieved.unwrap().content, "Test content");
489 }
490
491 #[tokio::test]
492 async fn test_query_by_domain() {
493 let config = StorageConfig::memory_only(100);
494 let store = ContextStore::new(config).unwrap();
495
496 let ctx1 = Context::new("Code content", ContextDomain::Code);
497 let ctx2 = Context::new("Doc content", ContextDomain::Documentation);
498
499 store.store(ctx1).await.unwrap();
500 store.store(ctx2).await.unwrap();
501
502 let query = ContextQuery::new().with_domain(ContextDomain::Code);
503 let results = store.query(&query).await.unwrap();
504
505 assert_eq!(results.len(), 1);
506 assert_eq!(results[0].domain, ContextDomain::Code);
507 }
508}