use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use chrono::Utc;
use lru::LruCache;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
#[cfg(feature = "persistence")]
use sled;
use crate::context::{Context, ContextDomain, ContextId, ContextQuery};
use crate::error::{ContextError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
pub memory_cache_size: usize,
pub persist_path: Option<PathBuf>,
pub auto_cleanup: bool,
pub cleanup_interval_secs: u64,
pub enable_persistence: bool,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
memory_cache_size: 10_000,
persist_path: None,
auto_cleanup: true,
cleanup_interval_secs: 3600,
enable_persistence: true,
}
}
}
impl StorageConfig {
pub fn memory_only(cache_size: usize) -> Self {
Self {
memory_cache_size: cache_size,
persist_path: None,
auto_cleanup: true,
cleanup_interval_secs: 3600,
enable_persistence: false,
}
}
pub fn with_persistence(cache_size: usize, path: impl Into<PathBuf>) -> Self {
Self {
memory_cache_size: cache_size,
persist_path: Some(path.into()),
auto_cleanup: true,
cleanup_interval_secs: 3600,
enable_persistence: true,
}
}
}
pub struct ContextStore {
memory_cache: Arc<RwLock<LruCache<ContextId, Context>>>,
#[cfg(feature = "persistence")]
disk_store: Option<sled::Db>,
domain_index: Arc<RwLock<HashMap<ContextDomain, Vec<ContextId>>>>,
tag_index: Arc<RwLock<HashMap<String, Vec<ContextId>>>>,
config: StorageConfig,
}
impl ContextStore {
pub fn new(config: StorageConfig) -> Result<Self> {
let memory_cache = Arc::new(RwLock::new(LruCache::new(
std::num::NonZeroUsize::new(config.memory_cache_size)
.ok_or_else(|| ContextError::Config("Cache size must be > 0".into()))?,
)));
#[cfg(feature = "persistence")]
let disk_store = if config.enable_persistence {
let path = config
.persist_path
.clone()
.unwrap_or_else(|| PathBuf::from("./data/context_store"));
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
Some(sled::open(&path)?)
} else {
None
};
#[cfg(not(feature = "persistence"))]
let _disk_store = ();
Ok(Self {
memory_cache,
#[cfg(feature = "persistence")]
disk_store,
domain_index: Arc::new(RwLock::new(HashMap::new())),
tag_index: Arc::new(RwLock::new(HashMap::new())),
config,
})
}
pub async fn store(&self, context: Context) -> Result<ContextId> {
let id = context.id.clone();
{
let mut domain_idx = self.domain_index.write().await;
domain_idx
.entry(context.domain.clone())
.or_default()
.push(id.clone());
}
{
let mut tag_idx = self.tag_index.write().await;
for tag in &context.metadata.tags {
tag_idx.entry(tag.clone()).or_default().push(id.clone());
}
}
{
let mut cache = self.memory_cache.write().await;
cache.put(id.clone(), context.clone());
}
#[cfg(feature = "persistence")]
if let Some(ref db) = self.disk_store {
let serialized = serde_json::to_vec(&context)?;
db.insert(id.as_str().as_bytes(), serialized)?;
db.flush_async().await?;
}
Ok(id)
}
pub async fn get(&self, id: &ContextId) -> Result<Option<Context>> {
{
let mut cache = self.memory_cache.write().await;
if let Some(ctx) = cache.get_mut(id) {
ctx.mark_accessed();
return Ok(Some(ctx.clone()));
}
}
#[cfg(feature = "persistence")]
if let Some(ref db) = self.disk_store {
if let Some(data) = db.get(id.as_str().as_bytes())? {
let mut context: Context = serde_json::from_slice(&data)?;
context.mark_accessed();
let mut cache = self.memory_cache.write().await;
cache.put(id.clone(), context.clone());
return Ok(Some(context));
}
}
Ok(None)
}
pub async fn delete(&self, id: &ContextId) -> Result<bool> {
let mut found = false;
let context_data = self.get(id).await?;
{
let mut cache = self.memory_cache.write().await;
if cache.pop(id).is_some() {
found = true;
}
}
#[cfg(feature = "persistence")]
if let Some(ref db) = self.disk_store {
if db.remove(id.as_str().as_bytes())?.is_some() {
found = true;
}
}
if let Some(ctx) = context_data {
{
let mut domain_idx = self.domain_index.write().await;
if let Some(ids) = domain_idx.get_mut(&ctx.domain) {
ids.retain(|stored_id| stored_id != id);
if ids.is_empty() {
domain_idx.remove(&ctx.domain);
}
}
}
{
let mut tag_idx = self.tag_index.write().await;
for tag in &ctx.metadata.tags {
if let Some(ids) = tag_idx.get_mut(tag) {
ids.retain(|stored_id| stored_id != id);
if ids.is_empty() {
tag_idx.remove(tag);
}
}
}
}
}
Ok(found)
}
pub async fn query(&self, query: &ContextQuery) -> Result<Vec<Context>> {
let mut results = Vec::new();
let candidate_ids = self.get_candidate_ids(query).await;
for id in candidate_ids {
if let Some(ctx) = self.get(&id).await? {
if self.matches_query(&ctx, query) {
results.push(ctx);
}
if results.len() >= query.limit {
break;
}
}
}
results.sort_by(|a, b| {
let importance_cmp = b
.metadata
.importance
.partial_cmp(&a.metadata.importance)
.unwrap_or(std::cmp::Ordering::Equal);
if importance_cmp == std::cmp::Ordering::Equal {
b.accessed_at.cmp(&a.accessed_at)
} else {
importance_cmp
}
});
results.truncate(query.limit);
Ok(results)
}
pub async fn retrieve_context(
&self,
query_text: &str,
limit: usize,
domain_filter: Option<&ContextDomain>,
) -> Result<Vec<Context>> {
let _ctx_query = ContextQuery::new().with_limit(limit);
if let Some(_domain) = domain_filter {
}
let query_lower = query_text.to_lowercase();
let mut results = Vec::new();
let cache = self.memory_cache.read().await;
for (_, ctx) in cache.iter() {
if ctx.content.to_lowercase().contains(&query_lower) {
if let Some(domain) = domain_filter {
if &ctx.domain != domain {
continue;
}
}
results.push(ctx.clone());
if results.len() >= limit {
break;
}
}
}
results.sort_by(|a, b| {
b.metadata
.importance
.partial_cmp(&a.metadata.importance)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
async fn get_candidate_ids(&self, query: &ContextQuery) -> Vec<ContextId> {
let mut candidates = Vec::new();
if let Some(ref domain) = query.domain_filter {
let domain_idx = self.domain_index.read().await;
if let Some(ids) = domain_idx.get(domain) {
candidates.extend(ids.iter().cloned());
}
}
if let Some(ref tags) = query.tag_filter {
let tag_idx = self.tag_index.read().await;
for tag in tags {
if let Some(ids) = tag_idx.get(tag) {
candidates.extend(ids.iter().cloned());
}
}
}
if candidates.is_empty() && query.domain_filter.is_none() && query.tag_filter.is_none() {
let cache = self.memory_cache.read().await;
candidates = cache.iter().map(|(id, _)| id.clone()).collect();
}
candidates.sort();
candidates.dedup();
candidates
}
fn matches_query(&self, ctx: &Context, query: &ContextQuery) -> bool {
if ctx.is_expired() {
return false;
}
if let Some(ref domain) = query.domain_filter {
if &ctx.domain != domain {
return false;
}
}
if let Some(ref source) = query.source_filter {
if &ctx.metadata.source != source {
return false;
}
}
if let Some(min_importance) = query.min_importance {
if ctx.metadata.importance < min_importance {
return false;
}
}
if let Some(max_age) = query.max_age_seconds {
if ctx.age_seconds() > max_age {
return false;
}
}
if query.verified_only && !ctx.metadata.verified {
return false;
}
if let Some(ref text) = query.query {
if !ctx.content.to_lowercase().contains(&text.to_lowercase()) {
return false;
}
}
true
}
pub async fn stats(&self) -> StorageStats {
let cache = self.memory_cache.read().await;
let memory_count = cache.len();
#[cfg(feature = "persistence")]
let disk_count = self.disk_store.as_ref().map(|db| db.len()).unwrap_or(0);
#[cfg(not(feature = "persistence"))]
let disk_count = 0;
StorageStats {
memory_count,
disk_count,
cache_capacity: self.config.memory_cache_size,
}
}
pub async fn cleanup_expired(&self) -> Result<usize> {
let mut removed = 0;
let now = Utc::now();
let expired_ids: Vec<ContextId> = {
let cache = self.memory_cache.read().await;
cache
.iter()
.filter(|(_, ctx)| ctx.expires_at.map(|exp| now > exp).unwrap_or(false))
.map(|(id, _)| id.clone())
.collect()
};
for id in expired_ids {
if self.delete(&id).await? {
removed += 1;
}
}
Ok(removed)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageStats {
pub memory_count: usize,
pub disk_count: usize,
pub cache_capacity: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_store_and_retrieve() {
let config = StorageConfig::memory_only(100);
let store = ContextStore::new(config).unwrap();
let ctx = Context::new("Test content", ContextDomain::Code);
let id = ctx.id.clone();
store.store(ctx).await.unwrap();
let retrieved = store.get(&id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "Test content");
}
#[tokio::test]
async fn test_query_by_domain() {
let config = StorageConfig::memory_only(100);
let store = ContextStore::new(config).unwrap();
let ctx1 = Context::new("Code content", ContextDomain::Code);
let ctx2 = Context::new("Doc content", ContextDomain::Documentation);
store.store(ctx1).await.unwrap();
store.store(ctx2).await.unwrap();
let query = ContextQuery::new().with_domain(ContextDomain::Code);
let results = store.query(&query).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].domain, ContextDomain::Code);
}
}