use crate::db::schema::CozoDb;
use serde::{de::DeserializeOwned, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
#[derive(Clone)]
struct CacheEntry {
value_json: String,
created_at: i64,
ttl_seconds: i64,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
if self.ttl_seconds == 0 {
return true;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
(now - self.created_at) > self.ttl_seconds
}
}
#[derive(Clone)]
pub struct PersistentCache {
db: Arc<CozoDb>,
memory: Arc<RwLock<HashMap<String, CacheEntry>>>,
default_ttl: u64,
}
impl PersistentCache {
pub fn new(db: Arc<CozoDb>, default_ttl: u64) -> Self {
Self {
db,
memory: Arc::new(RwLock::new(HashMap::new())),
default_ttl,
}
}
pub fn with_ttl(db: Arc<CozoDb>, ttl_secs: u64) -> Self {
Self::new(db, ttl_secs)
}
pub async fn get<V: DeserializeOwned>(&self, key: &str) -> Option<V> {
if let Some(entry) = self.memory.read().await.get(key) {
if !entry.is_expired() {
return serde_json::from_str(&entry.value_json).ok();
}
}
if let Some(value_json) = self.load_from_db(key).await {
if let Ok(v) = serde_json::from_str::<V>(&value_json) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
self.memory.write().await.insert(
key.to_string(),
CacheEntry {
value_json: value_json.clone(),
created_at: now,
ttl_seconds: self.default_ttl as i64,
},
);
return Some(v);
}
}
None
}
pub async fn insert<K: Serialize, V: Serialize>(&self, key: String, value: V) {
let value_json = serde_json::to_string(&value).unwrap_or_default();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
self.memory.write().await.insert(
key.clone(),
CacheEntry {
value_json: value_json.clone(),
created_at: now,
ttl_seconds: self.default_ttl as i64,
},
);
self.save_to_db(&key, &value_json, now).await.ok();
}
pub async fn invalidate(&self, key: &str) {
self.memory.write().await.remove(key);
self.delete_from_db(key).await.ok();
}
pub async fn invalidate_prefix(&self, prefix: &str) {
let prefix_owned = prefix.to_string();
let keys: Vec<String> = self
.memory
.read()
.await
.keys()
.filter(|k| k.starts_with(&prefix_owned))
.cloned()
.collect();
for key in keys {
self.invalidate(&key).await;
}
}
async fn load_from_db(&self, key: &str) -> Option<String> {
let query = r#"
?[value_json, created_at, ttl_seconds] :=
*query_cache[cache_key = $key, value_json, created_at, ttl_seconds]
"#;
let mut params = std::collections::BTreeMap::new();
params.insert(
"key".to_string(),
serde_json::Value::String(key.to_string()),
);
let result = self.db.run_script(query, params).ok()?;
let row = result.rows.first()?;
let created_at = row.get(1)?.as_i64()?;
let ttl_seconds = row.get(2)?.as_i64()?;
if ttl_seconds > 0 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
if (now - created_at) > ttl_seconds {
self.delete_from_db(key).await.ok();
return None;
}
}
row.get(0)?.as_str().map(String::from)
}
async fn save_to_db(
&self,
key: &str,
value_json: &str,
created_at: i64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let query = r#"
?[cache_key, value_json, created_at, ttl_seconds, tool_name, project_path, metadata]
<- [[ $key, $value_json, $created_at, $ttl_seconds, "unknown", "default", "{}" ]]
:put query_cache { cache_key, value_json, created_at, ttl_seconds, tool_name, project_path, metadata }
"#;
let mut params = std::collections::BTreeMap::new();
params.insert(
"key".to_string(),
serde_json::Value::String(key.to_string()),
);
params.insert(
"value_json".to_string(),
serde_json::Value::String(value_json.to_string()),
);
params.insert(
"created_at".to_string(),
serde_json::Value::Number(created_at.into()),
);
params.insert(
"ttl_seconds".to_string(),
serde_json::Value::Number((self.default_ttl as i64).into()),
);
self.db.run_script(query, params)?;
Ok(())
}
async fn delete_from_db(&self, key: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let query = r#":delete query_cache where cache_key = $key"#;
let mut params = std::collections::BTreeMap::new();
params.insert(
"key".to_string(),
serde_json::Value::String(key.to_string()),
);
self.db.run_script(query, params)?;
Ok(())
}
pub async fn len(&self) -> usize {
self.memory.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
static TEST_DB_COUNTER: AtomicU64 = AtomicU64::new(0);
fn create_test_db() -> CozoDb {
let counter = TEST_DB_COUNTER.fetch_add(1, Ordering::SeqCst);
let temp_dir = std::env::temp_dir();
let db_path = temp_dir.join(format!("leankg_test_persistent_cache_{}.db", counter));
let db = crate::db::schema::init_db(&db_path).unwrap();
drop(db);
std::fs::remove_file(&db_path).ok();
crate::db::schema::init_db(&db_path).unwrap()
}
#[tokio::test]
async fn test_persistent_cache_basic() {
let db = Arc::new(create_test_db());
let cache = PersistentCache::new(db, 300);
cache
.insert::<String, Vec<String>>("test_key".to_string(), vec!["value1".to_string(), "value2".to_string()])
.await;
let result: Option<Vec<String>> = cache.get("test_key").await;
assert!(result.is_some());
let values = result.unwrap();
assert_eq!(values.len(), 2);
assert_eq!(values[0], "value1");
}
#[tokio::test]
async fn test_persistent_cache_expired() {
let db = Arc::new(create_test_db());
let cache = PersistentCache::new(db, 0);
cache
.insert::<String, Vec<String>>("expired_key".to_string(), vec!["value".to_string()])
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
let result: Option<Vec<String>> = cache.get("expired_key").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_persistent_cache_invalidate_prefix() {
let db = Arc::new(create_test_db());
let cache = PersistentCache::new(db, 300);
cache
.insert::<String, Vec<String>>("deps:src/main.rs".to_string(), vec!["lib.rs".to_string()])
.await;
cache
.insert::<String, Vec<String>>("deps:src/lib.rs".to_string(), vec!["mod.rs".to_string()])
.await;
cache
.insert::<String, String>("orch:context:src/main.rs".to_string(), "content".to_string())
.await;
cache.invalidate_prefix("deps:src/").await;
let result1: Option<Vec<String>> = cache.get("deps:src/main.rs").await;
assert!(result1.is_none());
let result2: Option<Vec<String>> = cache.get("deps:src/lib.rs").await;
assert!(result2.is_none());
let result3: Option<String> = cache.get("orch:context:src/main.rs").await;
assert!(result3.is_some());
}
#[tokio::test]
async fn test_persistent_cache_invalidate() {
let db = Arc::new(create_test_db());
let cache = PersistentCache::new(db, 300);
cache
.insert::<String, Vec<String>>("key1".to_string(), vec!["value1".to_string()])
.await;
cache.invalidate("key1").await;
let result: Option<Vec<String>> = cache.get("key1").await;
assert!(result.is_none());
}
}