use crate::TursoStorage;
use crate::cache::invalidation::{
InvalidationConfig, InvalidationEvent, InvalidationManager, InvalidationRuleBuilder,
InvalidationTarget,
};
use crate::cache::query_cache::{
AdvancedCacheStats, AdvancedQueryCache, AdvancedQueryCacheConfig, QueryKey, TableDependency,
};
use anyhow::Result;
use do_memory_core::{Episode, Pattern};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
pub struct CachedQueryStorage {
storage: Arc<TursoStorage>,
cache: AdvancedQueryCache,
invalidation_manager: Option<InvalidationManager>,
event_tx: Option<mpsc::UnboundedSender<InvalidationEvent>>,
}
impl CachedQueryStorage {
pub fn new(storage: TursoStorage) -> (Self, mpsc::UnboundedReceiver<InvalidationMessage>) {
let cache_config = AdvancedQueryCacheConfig::default();
Self::with_config(storage, cache_config)
}
pub fn with_config(
storage: TursoStorage,
cache_config: AdvancedQueryCacheConfig,
) -> (Self, mpsc::UnboundedReceiver<InvalidationMessage>) {
let (cache, invalidation_rx) = AdvancedQueryCache::new(cache_config);
let cached_storage = Self {
storage: Arc::new(storage),
cache: cache.clone(),
invalidation_manager: None,
event_tx: None,
};
(cached_storage, invalidation_rx)
}
pub fn with_invalidation(
storage: TursoStorage,
cache_config: AdvancedQueryCacheConfig,
invalidation_config: InvalidationConfig,
) -> (
Self,
mpsc::UnboundedReceiver<InvalidationMessage>,
mpsc::UnboundedSender<InvalidationEvent>,
) {
let (cache, invalidation_rx) = AdvancedQueryCache::new(cache_config);
let (manager, event_tx) = InvalidationManager::new(invalidation_config, cache.clone());
Self::setup_default_rules(&manager);
let cached_storage = Self {
storage: Arc::new(storage),
cache,
invalidation_manager: Some(manager),
event_tx: Some(event_tx.clone()),
};
(cached_storage, invalidation_rx, event_tx)
}
fn setup_default_rules(manager: &InvalidationManager) {
manager.add_rule(
InvalidationRuleBuilder::new("%episodes%")
.depends_on(TableDependency::Episodes)
.depends_on(TableDependency::Steps)
.with_priority(10)
.build(),
);
manager.add_rule(
InvalidationRuleBuilder::new("%patterns%")
.depends_on(TableDependency::Patterns)
.with_priority(10)
.build(),
);
manager.add_rule(
InvalidationRuleBuilder::new("%count%")
.depends_on(TableDependency::Episodes)
.depends_on(TableDependency::Patterns)
.depends_on(TableDependency::Steps)
.with_ttl(Duration::from_secs(30))
.with_priority(5)
.build(),
);
manager.add_rule(
InvalidationRuleBuilder::new("%search%")
.depends_on(TableDependency::Episodes)
.depends_on(TableDependency::Patterns)
.depends_on(TableDependency::Embeddings)
.with_ttl(Duration::from_secs(120))
.with_priority(8)
.build(),
);
}
pub async fn query_cached<F, Fut, T>(
&self,
sql: &str,
params: &[&dyn ToString],
fetch_fn: F,
) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
T: serde::Serialize + serde::de::DeserializeOwned,
{
let key = QueryKey::new(sql, params);
if let Some(cached_data) = self.cache.get(&key) {
debug!("Cache hit for query: {}", sql);
return match postcard::from_bytes(&cached_data) {
Ok(result) => Ok(result),
Err(e) => {
warn!("Failed to deserialize cached result: {}", e);
fetch_fn().await
}
};
}
debug!("Cache miss for query: {}", sql);
let result = fetch_fn().await?;
match postcard::to_allocvec(&result) {
Ok(data) => {
let dependencies = TableDependency::from_query(sql);
self.cache.put(key, data, dependencies);
}
Err(e) => {
warn!("Failed to serialize result for caching: {}", e);
}
}
Ok(result)
}
pub async fn query_episodes_cached(
&self,
sql: &str,
params: &[&dyn ToString],
) -> Result<Vec<Episode>> {
self.query_cached(sql, params, || async {
Ok(Vec::new())
})
.await
}
pub async fn query_patterns_cached(
&self,
sql: &str,
params: &[&dyn ToString],
) -> Result<Vec<Pattern>> {
self.query_cached(sql, params, || async {
Ok(Vec::new())
})
.await
}
pub fn invalidate_table(&self, table: TableDependency) {
self.cache.invalidate_by_table(&table);
if let Some(ref tx) = self.event_tx {
let _ = tx.send(InvalidationEvent::TableModified {
table,
operation: crate::cache::invalidation::CrudOperation::Update,
affected_rows: 0,
});
}
}
pub fn invalidate_all(&self) {
self.cache.clear();
if let Some(ref tx) = self.event_tx {
let _ = tx.send(InvalidationEvent::ManualInvalidation {
target: InvalidationTarget::All,
reason: "Manual cache clear".to_string(),
});
}
}
pub fn cache_stats(&self) -> AdvancedCacheStats {
self.cache.stats()
}
pub fn cache_size(&self) -> usize {
self.cache.len()
}
pub fn clear_expired(&self) -> usize {
self.cache.clear_expired()
}
pub fn storage(&self) -> &TursoStorage {
&self.storage
}
pub fn cache(&self) -> &AdvancedQueryCache {
&self.cache
}
pub async fn start_invalidation_manager(self) {
if let Some(manager) = self.invalidation_manager {
info!("Starting invalidation manager");
manager.run().await;
}
}
}
impl Clone for CachedQueryStorage {
fn clone(&self) -> Self {
Self {
storage: Arc::clone(&self.storage),
cache: self.cache.clone(),
invalidation_manager: self.invalidation_manager.clone(),
event_tx: self.event_tx.clone(),
}
}
}
pub use crate::cache::query_cache::InvalidationMessage;
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::QueryType;
#[test]
fn test_cached_query_storage_creation() {
}
#[test]
fn test_query_key_creation() {
let sql = "SELECT * FROM episodes WHERE domain = ?";
let key = QueryKey::new(sql, &[&"test_domain"]);
assert_eq!(key.query_type, QueryType::Episode);
}
#[test]
fn test_table_dependency_detection() {
let sql = "SELECT e.*, s.* FROM episodes e JOIN steps s ON e.episode_id = s.episode_id";
let deps = TableDependency::from_query(sql);
assert!(deps.contains(&TableDependency::Episodes));
assert!(deps.contains(&TableDependency::Steps));
}
}