use super::system::CacheSystem;
use anyhow::Result;
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub enum InvalidationTrigger {
FileChange(String),
ReIndex(i64),
Manual,
Pattern(String),
}
pub struct CacheInvalidator {
cache: Arc<CacheSystem>,
}
impl CacheInvalidator {
pub fn new(cache: Arc<CacheSystem>) -> Self {
Self { cache }
}
pub async fn on_file_changed(&self, file_path: &Path) -> Result<InvalidationStats> {
let path_str = file_path.to_string_lossy();
info!("Invalidating cache for file change: {}", path_str);
let mut stats = InvalidationStats::new();
self.cache.invalidate_parse_tree(&path_str).await;
stats.parse_tree_invalidated += 1;
self.cache.clear_l3().await;
stats.context_invalidated = 1;
self.cache.clear_l1().await;
stats.query_invalidated = 1;
info!(
"File change invalidation complete: parse_tree={}, context={}, query={}",
stats.parse_tree_invalidated, stats.context_invalidated, stats.query_invalidated
);
Ok(stats)
}
pub async fn on_reindex(&self, repo_id: i64) -> Result<InvalidationStats> {
info!("Invalidating all caches for repo re-index: {}", repo_id);
let mut stats = InvalidationStats::new();
self.cache.clear_all().await;
stats.query_invalidated = 1;
stats.embedding_invalidated = 1;
stats.context_invalidated = 1;
stats.parse_tree_invalidated = 1;
info!("Re-index invalidation complete: all caches cleared");
Ok(stats)
}
pub async fn on_manual(&self) -> Result<InvalidationStats> {
info!("Manual cache invalidation requested");
let mut stats = InvalidationStats::new();
self.cache.clear_all().await;
stats.query_invalidated = 1;
stats.embedding_invalidated = 1;
stats.context_invalidated = 1;
stats.parse_tree_invalidated = 1;
info!("Manual invalidation complete: all caches cleared");
Ok(stats)
}
pub async fn on_pattern(&self, pattern: &str) -> Result<InvalidationStats> {
info!("Pattern-based cache invalidation: {}", pattern);
let mut stats = InvalidationStats::new();
self.cache.clear_l1().await;
stats.query_invalidated = 1;
info!("Pattern invalidation complete: query_cache cleared");
Ok(stats)
}
pub async fn invalidate_layers(&self, layers: &[CacheLayer]) -> Result<InvalidationStats> {
let mut stats = InvalidationStats::new();
for layer in layers {
match layer {
CacheLayer::L1Query => {
self.cache.clear_l1().await;
stats.query_invalidated = 1;
debug!("Invalidated L1 query cache");
}
CacheLayer::L2Embedding => {
self.cache.clear_l2().await;
stats.embedding_invalidated = 1;
debug!("Invalidated L2 embedding cache");
}
CacheLayer::L3Context => {
self.cache.clear_l3().await;
stats.context_invalidated = 1;
debug!("Invalidated L3 context cache");
}
CacheLayer::ParseTree => {
self.cache.clear_parse_tree().await;
stats.parse_tree_invalidated = 1;
debug!("Invalidated parse tree cache");
}
CacheLayer::All => {
self.cache.clear_all().await;
stats.query_invalidated = 1;
stats.embedding_invalidated = 1;
stats.context_invalidated = 1;
stats.parse_tree_invalidated = 1;
debug!("Invalidated all cache layers");
}
}
}
info!("Layer invalidation complete: {:?}", layers);
Ok(stats)
}
pub fn cache(&self) -> &Arc<CacheSystem> {
&self.cache
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheLayer {
L1Query,
L2Embedding,
L3Context,
ParseTree,
All,
}
impl CacheLayer {
#[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"l1" | "query" | "l1_query" => Some(Self::L1Query),
"l2" | "embedding" | "l2_embedding" => Some(Self::L2Embedding),
"l3" | "context" | "l3_context" => Some(Self::L3Context),
"parse" | "parse_tree" | "parsetree" => Some(Self::ParseTree),
"all" => Some(Self::All),
_ => None,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::L1Query => "L1 Query Cache",
Self::L2Embedding => "L2 Embedding Cache",
Self::L3Context => "L3 Context Cache",
Self::ParseTree => "Parse Tree Cache",
Self::All => "All Caches",
}
}
}
#[derive(Debug, Default, Clone)]
pub struct InvalidationStats {
pub query_invalidated: usize,
pub embedding_invalidated: usize,
pub context_invalidated: usize,
pub parse_tree_invalidated: usize,
}
impl InvalidationStats {
pub fn new() -> Self {
Self::default()
}
pub fn total_invalidated(&self) -> usize {
self.query_invalidated
+ self.embedding_invalidated
+ self.context_invalidated
+ self.parse_tree_invalidated
}
pub fn has_invalidations(&self) -> bool {
self.total_invalidated() > 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::CacheConfig;
#[tokio::test]
async fn test_cache_invalidator_creation() {
let cache = Arc::new(CacheSystem::new(CacheConfig::default()));
let invalidator = CacheInvalidator::new(Arc::clone(&cache));
assert!(Arc::ptr_eq(invalidator.cache(), &cache));
}
#[tokio::test]
async fn test_on_file_changed() {
let cache = Arc::new(CacheSystem::new(CacheConfig::default()));
let invalidator = CacheInvalidator::new(Arc::clone(&cache));
let stats = invalidator
.on_file_changed(Path::new("test.rs"))
.await
.unwrap();
assert!(stats.has_invalidations());
assert!(stats.total_invalidated() > 0);
}
#[tokio::test]
async fn test_on_reindex() {
let cache = Arc::new(CacheSystem::new(CacheConfig::default()));
let invalidator = CacheInvalidator::new(Arc::clone(&cache));
let stats = invalidator.on_reindex(123).await.unwrap();
assert_eq!(stats.query_invalidated, 1);
assert_eq!(stats.embedding_invalidated, 1);
assert_eq!(stats.context_invalidated, 1);
assert_eq!(stats.parse_tree_invalidated, 1);
assert_eq!(stats.total_invalidated(), 4);
}
#[tokio::test]
async fn test_on_manual() {
let cache = Arc::new(CacheSystem::new(CacheConfig::default()));
let invalidator = CacheInvalidator::new(Arc::clone(&cache));
let stats = invalidator.on_manual().await.unwrap();
assert_eq!(stats.total_invalidated(), 4);
}
#[tokio::test]
async fn test_on_pattern() {
let cache = Arc::new(CacheSystem::new(CacheConfig::default()));
let invalidator = CacheInvalidator::new(Arc::clone(&cache));
let stats = invalidator.on_pattern("search_term").await.unwrap();
assert!(stats.query_invalidated > 0);
}
#[tokio::test]
async fn test_invalidate_specific_layers() {
let cache = Arc::new(CacheSystem::new(CacheConfig::default()));
let invalidator = CacheInvalidator::new(Arc::clone(&cache));
let layers = vec![CacheLayer::L1Query, CacheLayer::L2Embedding];
let stats = invalidator.invalidate_layers(&layers).await.unwrap();
assert_eq!(stats.query_invalidated, 1);
assert_eq!(stats.embedding_invalidated, 1);
assert_eq!(stats.context_invalidated, 0);
assert_eq!(stats.parse_tree_invalidated, 0);
}
#[tokio::test]
async fn test_invalidate_all_layers() {
let cache = Arc::new(CacheSystem::new(CacheConfig::default()));
let invalidator = CacheInvalidator::new(Arc::clone(&cache));
let layers = vec![CacheLayer::All];
let stats = invalidator.invalidate_layers(&layers).await.unwrap();
assert_eq!(stats.total_invalidated(), 4);
}
#[test]
fn test_cache_layer_from_str() {
assert_eq!(CacheLayer::from_str("l1"), Some(CacheLayer::L1Query));
assert_eq!(CacheLayer::from_str("query"), Some(CacheLayer::L1Query));
assert_eq!(CacheLayer::from_str("l2"), Some(CacheLayer::L2Embedding));
assert_eq!(
CacheLayer::from_str("embedding"),
Some(CacheLayer::L2Embedding)
);
assert_eq!(CacheLayer::from_str("l3"), Some(CacheLayer::L3Context));
assert_eq!(CacheLayer::from_str("context"), Some(CacheLayer::L3Context));
assert_eq!(CacheLayer::from_str("parse"), Some(CacheLayer::ParseTree));
assert_eq!(CacheLayer::from_str("all"), Some(CacheLayer::All));
assert_eq!(CacheLayer::from_str("invalid"), None);
}
#[test]
fn test_cache_layer_name() {
assert_eq!(CacheLayer::L1Query.name(), "L1 Query Cache");
assert_eq!(CacheLayer::L2Embedding.name(), "L2 Embedding Cache");
assert_eq!(CacheLayer::L3Context.name(), "L3 Context Cache");
assert_eq!(CacheLayer::ParseTree.name(), "Parse Tree Cache");
assert_eq!(CacheLayer::All.name(), "All Caches");
}
#[test]
fn test_invalidation_stats() {
let mut stats = InvalidationStats::new();
assert_eq!(stats.total_invalidated(), 0);
assert!(!stats.has_invalidations());
stats.query_invalidated = 5;
stats.context_invalidated = 3;
assert_eq!(stats.total_invalidated(), 8);
assert!(stats.has_invalidations());
}
}