use super::query_cache::{
AdvancedCacheStats, AdvancedQueryCache, InvalidationMessage, QueryType, TableDependency,
};
#[path = "invalidation_types.rs"]
mod types;
use parking_lot::RwLock;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio::time::interval;
use tracing::{debug, info, trace, warn};
pub use types::{
CrudOperation, InvalidationConfig, InvalidationEvent, InvalidationMetrics, InvalidationRule,
InvalidationRuleBuilder, InvalidationStrategy, InvalidationTarget, SchemaChangeType, utils,
};
pub struct InvalidationManager {
config: InvalidationConfig,
cache: AdvancedQueryCache,
rules: Arc<RwLock<Vec<InvalidationRule>>>,
event_rx: mpsc::UnboundedReceiver<InvalidationEvent>,
event_tx: mpsc::UnboundedSender<InvalidationEvent>,
invalidation_rx: mpsc::UnboundedReceiver<InvalidationMessage>,
metrics: Arc<RwLock<InvalidationMetrics>>,
pending: Arc<RwLock<VecDeque<InvalidationEvent>>>,
last_batch_time: Arc<RwLock<Instant>>,
}
impl InvalidationManager {
pub fn new(
config: InvalidationConfig,
cache: AdvancedQueryCache,
) -> (Self, mpsc::UnboundedSender<InvalidationEvent>) {
let (event_tx, event_rx) = mpsc::unbounded_channel();
let invalidation_rx = {
let (_tx, rx) = mpsc::unbounded_channel();
rx
};
let manager = Self {
config,
cache,
rules: Arc::new(RwLock::new(Vec::new())),
event_rx,
event_tx: event_tx.clone(),
invalidation_rx,
metrics: Arc::new(RwLock::new(InvalidationMetrics::default())),
pending: Arc::new(RwLock::new(VecDeque::new())),
last_batch_time: Arc::new(RwLock::new(Instant::now())),
};
(manager, event_tx)
}
pub fn default(cache: AdvancedQueryCache) -> (Self, mpsc::UnboundedSender<InvalidationEvent>) {
Self::new(InvalidationConfig::default(), cache)
}
pub fn add_rule(&self, rule: InvalidationRule) {
let mut rules = self.rules.write();
if rules.len() >= self.config.max_rules {
warn!("Max rules reached, removing lowest priority rule");
rules.sort_by_key(|a| a.priority);
rules.remove(0);
}
rules.push(rule);
rules.sort_by_key(|b| std::cmp::Reverse(b.priority));
debug!("Added invalidation rule, total rules: {}", rules.len());
}
pub fn clear_rules(&self) {
self.rules.write().clear();
debug!("Cleared all invalidation rules");
}
pub fn get_matching_rules(&self, sql: &str) -> Vec<InvalidationRule> {
self.rules
.read()
.iter()
.filter(|rule| rule.matches(sql))
.cloned()
.collect()
}
pub fn handle_event(&self, event: InvalidationEvent) {
let start = Instant::now();
match &event {
InvalidationEvent::TableModified {
table,
operation,
affected_rows,
} => {
debug!(
"Handling table modification: {:?} {:?}, {} rows affected",
table, operation, affected_rows
);
self.cache.invalidate_by_table(table);
let mut metrics = self.metrics.write();
metrics.total_invalidations += 1;
*metrics.by_table.entry(table.clone()).or_default() += 1;
*metrics.by_operation.entry(*operation).or_default() += 1;
metrics.entries_invalidated += *affected_rows;
}
InvalidationEvent::BatchCompleted {
tables,
operation_count,
} => {
debug!(
"Handling batch completion: {} tables, {} operations",
tables.len(),
operation_count
);
for table in tables {
self.cache.invalidate_by_table(table);
}
self.metrics.write().batch_count += 1;
self.metrics.write().entries_invalidated += *operation_count;
}
InvalidationEvent::SchemaChanged { table, change_type } => {
info!("Schema changed on {:?}: {:?}", table, change_type);
self.cache.invalidate_by_table(table);
}
InvalidationEvent::ManualInvalidation { target, reason } => {
info!("Manual invalidation: {} - {:?}", reason, target);
self.handle_manual_invalidation(target);
}
}
let elapsed = start.elapsed().as_micros() as u64;
let mut metrics = self.metrics.write();
let total = metrics.total_invalidations;
metrics.avg_invalidation_time_us =
(metrics.avg_invalidation_time_us * (total - 1) + elapsed) / total.max(1);
}
fn handle_manual_invalidation(&self, target: &InvalidationTarget) {
match target {
InvalidationTarget::All => {
self.cache.clear();
}
InvalidationTarget::Table(table) => {
self.cache.invalidate_by_table(table);
}
InvalidationTarget::Query(key) => {
self.cache.invalidate_key(key);
}
InvalidationTarget::Pattern(pattern) => {
self.invalidate_by_pattern(pattern);
}
InvalidationTarget::Type(query_type) => {
self.invalidate_by_type(*query_type);
}
}
}
fn invalidate_by_pattern(&self, pattern: &str) {
warn!(
"Pattern-based invalidation not fully implemented, clearing all: {}",
pattern
);
self.cache.clear();
}
fn invalidate_by_type(&self, query_type: QueryType) {
debug!("Type-based invalidation requested for: {:?}", query_type);
}
pub fn queue_event(&self, event: InvalidationEvent) {
self.pending.write().push_back(event);
let should_process = {
let pending = self.pending.read();
pending.len() >= self.config.batch_size
|| self.last_batch_time.read().elapsed() > Duration::from_secs(5)
};
if should_process {
self.process_batch();
}
}
pub fn process_batch(&self) {
let mut pending = self.pending.write();
if pending.is_empty() {
return;
}
debug!("Processing {} pending invalidation events", pending.len());
let mut by_table: HashMap<TableDependency, Vec<InvalidationEvent>> = HashMap::new();
for event in pending.drain(..) {
if let Some(table) = Self::get_event_table(&event) {
by_table.entry(table).or_default().push(event);
}
}
for (table, events) in by_table {
let total_rows: u64 = events
.iter()
.filter_map(|e| match e {
InvalidationEvent::TableModified { affected_rows, .. } => Some(*affected_rows),
_ => None,
})
.sum();
self.cache.invalidate_by_table(&table);
let mut metrics = self.metrics.write();
metrics.total_invalidations += 1;
*metrics.by_table.entry(table.clone()).or_default() += 1;
*metrics
.by_operation
.entry(CrudOperation::Update)
.or_default() += 1;
metrics.entries_invalidated += total_rows;
}
*self.last_batch_time.write() = Instant::now();
}
fn get_event_table(event: &InvalidationEvent) -> Option<TableDependency> {
match event {
InvalidationEvent::TableModified { table, .. } => Some(table.clone()),
InvalidationEvent::SchemaChanged { table, .. } => Some(table.clone()),
_ => None,
}
}
pub fn metrics(&self) -> InvalidationMetrics {
self.metrics.read().clone()
}
pub fn clear_metrics(&self) {
*self.metrics.write() = InvalidationMetrics::default();
}
pub fn cache_stats(&self) -> AdvancedCacheStats {
self.cache.stats()
}
pub fn clear_expired(&self) -> usize {
self.cache.clear_expired()
}
pub async fn run(mut self) {
info!(
"Starting invalidation manager with {:?} strategy",
self.config.strategy
);
let mut cleanup_interval = interval(self.config.cleanup_interval);
loop {
tokio::select! {
Some(event) = self.event_rx.recv() => {
if self.config.enable_event_listening {
self.handle_event(event);
}
}
Some(message) = self.invalidation_rx.recv() => {
self.handle_invalidation_message(message);
}
_ = cleanup_interval.tick() => {
if self.config.enable_background_cleanup {
self.perform_cleanup();
}
}
else => {
info!("Invalidation manager shutting down");
break;
}
}
}
}
fn handle_invalidation_message(&self, message: InvalidationMessage) {
match message {
InvalidationMessage::TableChanged(table) => {
self.cache.invalidate_by_table(&table);
}
InvalidationMessage::InvalidateKey(key) => {
self.cache.invalidate_key(&key);
}
InvalidationMessage::InvalidateAll => {
self.cache.clear();
}
InvalidationMessage::Shutdown => {
info!("Received shutdown signal");
}
}
}
fn perform_cleanup(&self) {
self.process_batch();
let cleared = self.cache.clear_expired();
if cleared > 0 {
debug!("Cleared {} expired cache entries during cleanup", cleared);
}
let metrics = self.metrics();
if metrics.total_invalidations > 0 {
trace!(
"Invalidation metrics: {} total, {} entries invalidated",
metrics.total_invalidations, metrics.entries_invalidated
);
}
}
}
impl Clone for InvalidationManager {
fn clone(&self) -> Self {
let (event_tx, event_rx) = mpsc::unbounded_channel();
let invalidation_rx = {
let (_tx, rx) = mpsc::unbounded_channel();
rx
};
Self {
config: self.config.clone(),
cache: self.cache.clone(),
rules: Arc::clone(&self.rules),
event_rx,
event_tx,
invalidation_rx,
metrics: Arc::clone(&self.metrics),
pending: Arc::clone(&self.pending),
last_batch_time: Arc::clone(&self.last_batch_time),
}
}
}
#[cfg(test)]
#[path = "invalidation_tests.rs"]
mod tests;