use anyhow::Result;
use rusqlite::{params, OptionalExtension};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tracing::debug;
use super::types::{ContextBundle, ExpandOptions};
use crate::db::SqliteStore;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub enabled: bool,
pub ttl_seconds: i32,
pub max_entries: i32,
pub evict_batch_size: i32,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: true,
ttl_seconds: 3600, max_entries: 1000,
evict_batch_size: 100,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub chunk_id: i64,
pub options_hash: String,
}
impl CacheKey {
pub fn new(chunk_id: i64, options: &ExpandOptions) -> Self {
Self {
chunk_id,
options_hash: hash_options(options),
}
}
}
pub fn hash_options(options: &ExpandOptions) -> String {
let json = serde_json::to_string(options).expect("ExpandOptions should serialize");
let mut hasher = Sha256::new();
hasher.update(json.as_bytes());
let result = hasher.finalize();
format!("{:x}", result)
}
#[derive(Debug, Default)]
pub struct CacheStats {
pub hits: AtomicU64,
pub misses: AtomicU64,
pub puts: AtomicU64,
pub invalidations: AtomicU64,
pub ttl_evictions: AtomicU64,
pub lru_evictions: AtomicU64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
(hits as f64 / total as f64) * 100.0
}
}
pub fn total_operations(&self) -> u64 {
self.hits.load(Ordering::Relaxed) + self.misses.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.puts.store(0, Ordering::Relaxed);
self.invalidations.store(0, Ordering::Relaxed);
self.ttl_evictions.store(0, Ordering::Relaxed);
self.lru_evictions.store(0, Ordering::Relaxed);
}
}
pub struct ContextCache {
store: Arc<SqliteStore>,
config: CacheConfig,
stats: Arc<CacheStats>,
}
impl ContextCache {
pub fn new(store: Arc<SqliteStore>, config: CacheConfig) -> Self {
Self {
store,
config,
stats: Arc::new(CacheStats::default()),
}
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
pub fn stats(&self) -> Arc<CacheStats> {
Arc::clone(&self.stats)
}
pub async fn get(
&self,
chunk_id: i64,
options: &ExpandOptions,
) -> Result<Option<ContextBundle>> {
if !self.config.enabled {
return Ok(None);
}
let key = CacheKey::new(chunk_id, options);
let cache_key = format!("{}:{}", key.chunk_id, key.options_hash);
let cache_key_clone = cache_key.clone();
let result = self
.store
.run(move |conn| {
let bundle_json: Option<String> = conn
.query_row(
"SELECT bundle_json FROM context_cache
WHERE cache_key = ?1 AND expires_at > datetime('now')",
params![cache_key],
|row| row.get(0),
)
.optional()?;
Ok(bundle_json)
})
.await?;
match result {
Some(json) => {
let cache_key_for_update = cache_key_clone.clone();
self.store.run(move |conn| {
conn.execute(
"UPDATE context_cache SET accessed_at = datetime('now') WHERE cache_key = ?1",
params![cache_key_for_update],
)?;
Ok(())
}).await?;
let bundle: ContextBundle = serde_json::from_str(&json)?;
self.stats.hits.fetch_add(1, Ordering::Relaxed);
debug!("Cache hit for chunk_id={}", chunk_id);
Ok(Some(bundle))
}
None => {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
debug!("Cache miss for chunk_id={}", chunk_id);
Ok(None)
}
}
}
pub async fn put(
&self,
chunk_id: i64,
options: &ExpandOptions,
bundle: &ContextBundle,
) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
self.evict_lru_if_needed().await?;
let key = CacheKey::new(chunk_id, options);
let cache_key = format!("{}:{}", key.chunk_id, key.options_hash);
let bundle_json = serde_json::to_string(bundle)?;
let ttl_seconds = self.config.ttl_seconds;
self.store.run(move |conn| {
conn.execute(
"INSERT OR REPLACE INTO context_cache
(cache_key, bundle_json, created_at, expires_at, accessed_at)
VALUES (?1, ?2, datetime('now'), datetime('now', '+' || ?3 || ' seconds'), datetime('now'))",
params![cache_key, bundle_json, ttl_seconds],
)?;
Ok(())
}).await?;
self.stats.puts.fetch_add(1, Ordering::Relaxed);
debug!("Cached bundle for chunk_id={}", chunk_id);
Ok(())
}
pub async fn invalidate(&self, chunk_id: i64) -> Result<u64> {
if !self.config.enabled {
return Ok(0);
}
let prefix = format!("{}:", chunk_id);
let count = self
.store
.run(move |conn| {
let deleted = conn.execute(
"DELETE FROM context_cache WHERE cache_key LIKE ?1 || '%'",
params![prefix],
)?;
Ok(deleted as u64)
})
.await?;
if count > 0 {
self.stats.invalidations.fetch_add(count, Ordering::Relaxed);
debug!(
"Invalidated {} cache entries for chunk_id={}",
count, chunk_id
);
}
Ok(count)
}
pub async fn invalidate_many(&self, chunk_ids: &[i64]) -> Result<u64> {
if !self.config.enabled || chunk_ids.is_empty() {
return Ok(0);
}
let prefixes: Vec<String> = chunk_ids.iter().map(|id| format!("{}:", id)).collect();
let count = self
.store
.run(move |conn| {
let mut total_deleted = 0u64;
for prefix in prefixes {
let deleted = conn.execute(
"DELETE FROM context_cache WHERE cache_key LIKE ?1 || '%'",
params![prefix],
)?;
total_deleted += deleted as u64;
}
Ok(total_deleted)
})
.await?;
if count > 0 {
self.stats.invalidations.fetch_add(count, Ordering::Relaxed);
debug!(
"Invalidated {} cache entries for {} chunks",
count,
chunk_ids.len()
);
}
Ok(count)
}
pub async fn clear(&self) -> Result<u64> {
let count = self
.store
.run(|conn| {
let deleted = conn.execute("DELETE FROM context_cache", [])?;
Ok(deleted as u64)
})
.await?;
self.stats.reset();
debug!("Cleared {} cache entries", count);
Ok(count)
}
pub async fn evict_expired(&self) -> Result<u64> {
if !self.config.enabled {
return Ok(0);
}
let count = self
.store
.run(|conn| {
let deleted = conn.execute(
"DELETE FROM context_cache WHERE expires_at < datetime('now')",
[],
)?;
Ok(deleted as u64)
})
.await?;
if count > 0 {
self.stats.ttl_evictions.fetch_add(count, Ordering::Relaxed);
debug!("Evicted {} expired cache entries", count);
}
Ok(count)
}
async fn evict_lru_if_needed(&self) -> Result<u64> {
if !self.config.enabled {
return Ok(0);
}
let max_entries = self.config.max_entries;
let evict_batch_size = self.config.evict_batch_size;
let count = self
.store
.run(move |conn| {
let current_count: i32 =
conn.query_row("SELECT COUNT(*) FROM context_cache", [], |row| row.get(0))?;
if current_count <= max_entries {
return Ok(0u64);
}
let to_evict = std::cmp::min(
evict_batch_size,
current_count - max_entries + evict_batch_size, );
let deleted = conn.execute(
"DELETE FROM context_cache WHERE cache_key IN (
SELECT cache_key FROM context_cache
ORDER BY accessed_at ASC
LIMIT ?1
)",
params![to_evict],
)?;
Ok(deleted as u64)
})
.await?;
if count > 0 {
self.stats.lru_evictions.fetch_add(count, Ordering::Relaxed);
debug!("LRU evicted {} cache entries", count);
}
Ok(count)
}
pub async fn get_db_stats(&self) -> Result<DbCacheStats> {
self.store.run(|conn| {
let (total_entries, total_size_bytes): (i64, i64) = conn
.query_row(
"SELECT COUNT(*), COALESCE(SUM(LENGTH(bundle_json)), 0) FROM context_cache",
[],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.unwrap_or((0, 0));
let entries_last_hour: i64 = conn
.query_row(
"SELECT COUNT(*) FROM context_cache WHERE created_at > datetime('now', '-1 hour')",
[],
|row| row.get(0),
)
.unwrap_or(0);
let entries_last_day: i64 = conn
.query_row(
"SELECT COUNT(*) FROM context_cache WHERE created_at > datetime('now', '-1 day')",
[],
|row| row.get(0),
)
.unwrap_or(0);
let entries_last_week: i64 = conn
.query_row(
"SELECT COUNT(*) FROM context_cache WHERE created_at > datetime('now', '-7 day')",
[],
|row| row.get(0),
)
.unwrap_or(0);
Ok(DbCacheStats {
total_entries,
total_size_bytes,
avg_access_count: 0.0, max_access_count: 0, entries_last_hour,
entries_last_day,
entries_last_week,
})
}).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbCacheStats {
pub total_entries: i64,
pub total_size_bytes: i64,
pub avg_access_count: f64,
pub max_access_count: i32,
pub entries_last_hour: i64,
pub entries_last_day: i64,
pub entries_last_week: i64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::traits::StoreMigration;
#[test]
fn test_hash_options_deterministic() {
let options1 = ExpandOptions::with_common();
let options2 = ExpandOptions::with_common();
assert_eq!(hash_options(&options1), hash_options(&options2));
}
#[test]
fn test_hash_options_different() {
let options1 = ExpandOptions::with_common();
let options2 = ExpandOptions::with_all();
assert_ne!(hash_options(&options1), hash_options(&options2));
}
#[test]
fn test_cache_key_creation() {
let options = ExpandOptions::primary_only();
let key = CacheKey::new(123, &options);
assert_eq!(key.chunk_id, 123);
assert!(!key.options_hash.is_empty());
}
#[test]
fn test_cache_stats_hit_rate() {
let stats = CacheStats::default();
assert_eq!(stats.hit_rate(), 0.0);
stats.hits.store(60, Ordering::Relaxed);
stats.misses.store(40, Ordering::Relaxed);
assert_eq!(stats.hit_rate(), 60.0);
stats.hits.store(80, Ordering::Relaxed);
stats.misses.store(20, Ordering::Relaxed);
assert_eq!(stats.hit_rate(), 80.0);
}
#[test]
fn test_cache_stats_total_operations() {
let stats = CacheStats::default();
assert_eq!(stats.total_operations(), 0);
stats.hits.store(100, Ordering::Relaxed);
stats.misses.store(50, Ordering::Relaxed);
assert_eq!(stats.total_operations(), 150);
}
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert!(config.enabled);
assert_eq!(config.ttl_seconds, 3600);
assert_eq!(config.max_entries, 1000);
assert_eq!(config.evict_batch_size, 100);
}
#[test]
fn test_cache_stats_reset() {
let stats = CacheStats::default();
stats.hits.store(100, Ordering::Relaxed);
stats.misses.store(50, Ordering::Relaxed);
stats.puts.store(75, Ordering::Relaxed);
stats.reset();
assert_eq!(stats.hits.load(Ordering::Relaxed), 0);
assert_eq!(stats.misses.load(Ordering::Relaxed), 0);
assert_eq!(stats.puts.load(Ordering::Relaxed), 0);
}
use crate::context::types::{ContextItem, LineRange};
use std::sync::atomic::AtomicUsize;
static TEST_DB_COUNTER: AtomicUsize = AtomicUsize::new(0);
async fn setup_test_store() -> Arc<crate::db::SqliteStore> {
let counter = TEST_DB_COUNTER.fetch_add(1, Ordering::SeqCst);
let db_name = format!("file:memdb_cache_test_{}?mode=memory&cache=shared", counter);
let store = crate::db::SqliteStore::connect(&db_name)
.await
.expect("Failed to create test store");
store.migrate().await.expect("Failed to run migrations");
Arc::new(store)
}
fn create_test_bundle() -> ContextBundle {
let mut bundle = ContextBundle::new();
bundle.add_item(ContextItem {
relpath: "test.rs".to_string(),
range: LineRange::new(1, 10),
role: "primary".to_string(),
reason: "Test item".to_string(),
content: "fn test() {}".to_string(),
tokens: 5,
});
bundle
}
#[tokio::test]
async fn test_cache_put_and_get() {
let store = setup_test_store().await;
let config = CacheConfig::default();
let cache = ContextCache::new(store, config);
let options = ExpandOptions::with_common();
let bundle = create_test_bundle();
let result = cache.get(123, &options).await.unwrap();
assert!(result.is_none());
assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 1);
cache.put(123, &options, &bundle).await.unwrap();
assert_eq!(cache.stats().puts.load(Ordering::Relaxed), 1);
let result = cache.get(123, &options).await.unwrap();
assert!(result.is_some());
assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 1);
let retrieved = result.unwrap();
assert_eq!(retrieved.items.len(), 1);
assert_eq!(retrieved.items[0].relpath, "test.rs");
}
#[tokio::test]
async fn test_cache_invalidate() {
let store = setup_test_store().await;
let config = CacheConfig::default();
let cache = ContextCache::new(store, config);
let options = ExpandOptions::with_common();
let bundle = create_test_bundle();
cache.put(123, &options, &bundle).await.unwrap();
let result = cache.get(123, &options).await.unwrap();
assert!(result.is_some());
let count = cache.invalidate(123).await.unwrap();
assert_eq!(count, 1);
let result = cache.get(123, &options).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_cache_clear() {
let store = setup_test_store().await;
let config = CacheConfig::default();
let cache = ContextCache::new(store, config);
let options = ExpandOptions::with_common();
let bundle = create_test_bundle();
cache.put(100, &options, &bundle).await.unwrap();
cache.put(200, &options, &bundle).await.unwrap();
cache.put(300, &options, &bundle).await.unwrap();
let count = cache.clear().await.unwrap();
assert_eq!(count, 3);
assert!(cache.get(100, &options).await.unwrap().is_none());
assert!(cache.get(200, &options).await.unwrap().is_none());
assert!(cache.get(300, &options).await.unwrap().is_none());
}
#[tokio::test]
async fn test_cache_db_stats() {
let store = setup_test_store().await;
let config = CacheConfig::default();
let cache = ContextCache::new(store, config);
let options = ExpandOptions::with_common();
let bundle = create_test_bundle();
let stats = cache.get_db_stats().await.unwrap();
assert_eq!(stats.total_entries, 0);
cache.put(100, &options, &bundle).await.unwrap();
cache.put(200, &options, &bundle).await.unwrap();
let stats = cache.get_db_stats().await.unwrap();
assert_eq!(stats.total_entries, 2);
assert!(stats.total_size_bytes > 0);
assert_eq!(stats.entries_last_hour, 2);
}
#[tokio::test]
async fn test_cache_disabled() {
let store = setup_test_store().await;
let mut config = CacheConfig::default();
config.enabled = false;
let cache = ContextCache::new(store, config);
let options = ExpandOptions::with_common();
let bundle = create_test_bundle();
cache.put(123, &options, &bundle).await.unwrap();
let result = cache.get(123, &options).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_cache_different_options() {
let store = setup_test_store().await;
let config = CacheConfig::default();
let cache = ContextCache::new(store, config);
let options1 = ExpandOptions::with_common();
let options2 = ExpandOptions::with_all();
let bundle = create_test_bundle();
cache.put(123, &options1, &bundle).await.unwrap();
let result = cache.get(123, &options1).await.unwrap();
assert!(result.is_some());
let result = cache.get(123, &options2).await.unwrap();
assert!(result.is_none());
}
}