use crate::models::{Area, Project, Task};
use anyhow::Result;
use chrono::{DateTime, Utc};
use moka::future::Cache;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryCacheConfig {
pub max_queries: u64,
pub ttl: Duration,
pub tti: Duration,
pub enable_compression: bool,
pub max_result_size: usize,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self {
max_queries: 1000,
ttl: Duration::from_secs(1800), tti: Duration::from_secs(300), enable_compression: true,
max_result_size: 1024 * 1024, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedQueryResult<T> {
pub data: T,
pub executed_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub execution_time_ms: u64,
pub params_hash: String,
pub dependencies: Vec<QueryDependency>,
pub result_size: usize,
pub compressed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct QueryDependency {
pub table: String,
pub entity_id: Option<Uuid>,
pub invalidating_operations: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct QueryCacheStats {
pub total_queries: u64,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
pub total_size_bytes: u64,
pub average_execution_time_ms: f64,
pub compressed_queries: u64,
pub uncompressed_queries: u64,
}
impl QueryCacheStats {
pub fn calculate_hit_rate(&mut self) {
let total = self.hits + self.misses;
self.hit_rate = if total > 0 {
#[allow(clippy::cast_precision_loss)]
{
self.hits as f64 / total as f64
}
} else {
0.0
};
}
}
pub struct QueryCache {
tasks_cache: Cache<String, CachedQueryResult<Vec<Task>>>,
projects_cache: Cache<String, CachedQueryResult<Vec<Project>>>,
areas_cache: Cache<String, CachedQueryResult<Vec<Area>>>,
search_cache: Cache<String, CachedQueryResult<Vec<Task>>>,
stats: Arc<RwLock<QueryCacheStats>>,
config: QueryCacheConfig,
}
impl QueryCache {
#[must_use]
pub fn new(config: QueryCacheConfig) -> Self {
let tasks_cache = Cache::builder()
.max_capacity(config.max_queries)
.time_to_live(config.ttl)
.time_to_idle(config.tti)
.build();
let projects_cache = Cache::builder()
.max_capacity(config.max_queries)
.time_to_live(config.ttl)
.time_to_idle(config.tti)
.build();
let areas_cache = Cache::builder()
.max_capacity(config.max_queries)
.time_to_live(config.ttl)
.time_to_idle(config.tti)
.build();
let search_cache = Cache::builder()
.max_capacity(config.max_queries)
.time_to_live(config.ttl)
.time_to_idle(config.tti)
.build();
Self {
tasks_cache,
projects_cache,
areas_cache,
search_cache,
stats: Arc::new(RwLock::new(QueryCacheStats::default())),
config,
}
}
#[must_use]
pub fn new_default() -> Self {
Self::new(QueryCacheConfig::default())
}
pub async fn cache_tasks_query<F, Fut>(
&self,
query_key: &str,
params_hash: &str,
fetcher: F,
) -> Result<Vec<Task>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<Vec<Task>>>,
{
if let Some(cached) = self.tasks_cache.get(query_key).await {
if !cached.is_expired() && cached.params_hash == params_hash {
self.record_hit();
debug!("Query cache hit for tasks: {}", query_key);
return Ok(cached.data);
}
}
let start_time = std::time::Instant::now();
let data = fetcher().await?;
#[allow(clippy::cast_possible_truncation)]
let execution_time = start_time.elapsed().as_millis() as u64;
let result_size = Self::calculate_result_size(&data);
if result_size > self.config.max_result_size {
warn!("Query result too large to cache: {} bytes", result_size);
self.record_miss();
return Ok(data);
}
let dependencies = Self::create_task_dependencies(&data);
let cached_result = CachedQueryResult {
data: data.clone(),
executed_at: Utc::now(),
expires_at: Utc::now()
+ chrono::Duration::from_std(self.config.ttl).unwrap_or_default(),
execution_time_ms: execution_time,
params_hash: params_hash.to_string(),
dependencies,
result_size,
compressed: self.config.enable_compression,
};
self.tasks_cache
.insert(query_key.to_string(), cached_result)
.await;
self.update_stats(result_size, execution_time, false);
self.record_miss();
debug!(
"Cached tasks query: {} ({}ms, {} bytes)",
query_key, execution_time, result_size
);
Ok(data)
}
pub async fn cache_projects_query<F, Fut>(
&self,
query_key: &str,
params_hash: &str,
fetcher: F,
) -> Result<Vec<Project>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<Vec<Project>>>,
{
if let Some(cached) = self.projects_cache.get(query_key).await {
if !cached.is_expired() && cached.params_hash == params_hash {
self.record_hit();
debug!("Query cache hit for projects: {}", query_key);
return Ok(cached.data);
}
}
let start_time = std::time::Instant::now();
let data = fetcher().await?;
#[allow(clippy::cast_possible_truncation)]
let execution_time = start_time.elapsed().as_millis() as u64;
let result_size = Self::calculate_result_size(&data);
if result_size > self.config.max_result_size {
warn!("Query result too large to cache: {} bytes", result_size);
self.record_miss();
return Ok(data);
}
let dependencies = Self::create_project_dependencies(&data);
let cached_result = CachedQueryResult {
data: data.clone(),
executed_at: Utc::now(),
expires_at: Utc::now()
+ chrono::Duration::from_std(self.config.ttl).unwrap_or_default(),
execution_time_ms: execution_time,
params_hash: params_hash.to_string(),
dependencies,
result_size,
compressed: self.config.enable_compression,
};
self.projects_cache
.insert(query_key.to_string(), cached_result)
.await;
self.update_stats(result_size, execution_time, false);
self.record_miss();
debug!(
"Cached projects query: {} ({}ms, {} bytes)",
query_key, execution_time, result_size
);
Ok(data)
}
pub async fn cache_areas_query<F, Fut>(
&self,
query_key: &str,
params_hash: &str,
fetcher: F,
) -> Result<Vec<Area>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<Vec<Area>>>,
{
if let Some(cached) = self.areas_cache.get(query_key).await {
if !cached.is_expired() && cached.params_hash == params_hash {
self.record_hit();
debug!("Query cache hit for areas: {}", query_key);
return Ok(cached.data);
}
}
let start_time = std::time::Instant::now();
let data = fetcher().await?;
#[allow(clippy::cast_possible_truncation)]
let execution_time = start_time.elapsed().as_millis() as u64;
let result_size = Self::calculate_result_size(&data);
if result_size > self.config.max_result_size {
warn!("Query result too large to cache: {} bytes", result_size);
self.record_miss();
return Ok(data);
}
let dependencies = Self::create_area_dependencies(&data);
let cached_result = CachedQueryResult {
data: data.clone(),
executed_at: Utc::now(),
expires_at: Utc::now()
+ chrono::Duration::from_std(self.config.ttl).unwrap_or_default(),
execution_time_ms: execution_time,
params_hash: params_hash.to_string(),
dependencies,
result_size,
compressed: self.config.enable_compression,
};
self.areas_cache
.insert(query_key.to_string(), cached_result)
.await;
self.update_stats(result_size, execution_time, false);
self.record_miss();
debug!(
"Cached areas query: {} ({}ms, {} bytes)",
query_key, execution_time, result_size
);
Ok(data)
}
pub async fn cache_search_query<F, Fut>(
&self,
query_key: &str,
params_hash: &str,
fetcher: F,
) -> Result<Vec<Task>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<Vec<Task>>>,
{
if let Some(cached) = self.search_cache.get(query_key).await {
if !cached.is_expired() && cached.params_hash == params_hash {
self.record_hit();
debug!("Query cache hit for search: {}", query_key);
return Ok(cached.data);
}
}
let start_time = std::time::Instant::now();
let data = fetcher().await?;
#[allow(clippy::cast_possible_truncation)]
let execution_time = start_time.elapsed().as_millis() as u64;
let result_size = Self::calculate_result_size(&data);
if result_size > self.config.max_result_size {
warn!("Query result too large to cache: {} bytes", result_size);
self.record_miss();
return Ok(data);
}
let dependencies = Self::create_task_dependencies(&data);
let cached_result = CachedQueryResult {
data: data.clone(),
executed_at: Utc::now(),
expires_at: Utc::now()
+ chrono::Duration::from_std(self.config.ttl).unwrap_or_default(),
execution_time_ms: execution_time,
params_hash: params_hash.to_string(),
dependencies,
result_size,
compressed: self.config.enable_compression,
};
self.search_cache
.insert(query_key.to_string(), cached_result)
.await;
self.update_stats(result_size, execution_time, false);
self.record_miss();
debug!(
"Cached search query: {} ({}ms, {} bytes)",
query_key, execution_time, result_size
);
Ok(data)
}
pub fn invalidate_by_entity(&self, entity_type: &str, entity_id: Option<&Uuid>) {
self.tasks_cache.invalidate_all();
self.projects_cache.invalidate_all();
self.areas_cache.invalidate_all();
self.search_cache.invalidate_all();
debug!(
"Invalidated all query caches due to entity change: {} {:?}",
entity_type, entity_id
);
}
pub fn invalidate_by_operation(&self, operation: &str) {
match operation {
"task_created" | "task_updated" | "task_deleted" | "task_completed" => {
self.tasks_cache.invalidate_all();
self.search_cache.invalidate_all();
}
"project_created" | "project_updated" | "project_deleted" => {
self.projects_cache.invalidate_all();
self.tasks_cache.invalidate_all(); }
"area_created" | "area_updated" | "area_deleted" => {
self.areas_cache.invalidate_all();
self.projects_cache.invalidate_all(); self.tasks_cache.invalidate_all(); }
_ => {
self.invalidate_all();
}
}
debug!("Invalidated query caches due to operation: {}", operation);
}
pub fn invalidate_all(&self) {
self.tasks_cache.invalidate_all();
self.projects_cache.invalidate_all();
self.areas_cache.invalidate_all();
self.search_cache.invalidate_all();
}
#[must_use]
pub fn get_stats(&self) -> QueryCacheStats {
let mut stats = self.stats.read().clone();
stats.calculate_hit_rate();
stats
}
fn calculate_result_size<T>(data: &T) -> usize
where
T: Serialize,
{
serde_json::to_vec(data).map_or(0, |bytes| bytes.len())
}
fn create_task_dependencies(tasks: &[Task]) -> Vec<QueryDependency> {
let mut dependencies = Vec::new();
dependencies.push(QueryDependency {
table: "TMTask".to_string(),
entity_id: None,
invalidating_operations: vec![
"task_created".to_string(),
"task_updated".to_string(),
"task_deleted".to_string(),
"task_completed".to_string(),
],
});
for task in tasks {
dependencies.push(QueryDependency {
table: "TMTask".to_string(),
entity_id: Some(task.uuid),
invalidating_operations: vec![
"task_updated".to_string(),
"task_deleted".to_string(),
"task_completed".to_string(),
],
});
if let Some(project_uuid) = task.project_uuid {
dependencies.push(QueryDependency {
table: "TMProject".to_string(),
entity_id: Some(project_uuid),
invalidating_operations: vec![
"project_updated".to_string(),
"project_deleted".to_string(),
],
});
}
if let Some(area_uuid) = task.area_uuid {
dependencies.push(QueryDependency {
table: "TMArea".to_string(),
entity_id: Some(area_uuid),
invalidating_operations: vec![
"area_updated".to_string(),
"area_deleted".to_string(),
],
});
}
}
dependencies
}
fn create_project_dependencies(projects: &[Project]) -> Vec<QueryDependency> {
let mut dependencies = Vec::new();
dependencies.push(QueryDependency {
table: "TMProject".to_string(),
entity_id: None,
invalidating_operations: vec![
"project_created".to_string(),
"project_updated".to_string(),
"project_deleted".to_string(),
],
});
for project in projects {
dependencies.push(QueryDependency {
table: "TMProject".to_string(),
entity_id: Some(project.uuid),
invalidating_operations: vec![
"project_updated".to_string(),
"project_deleted".to_string(),
],
});
if let Some(area_uuid) = project.area_uuid {
dependencies.push(QueryDependency {
table: "TMArea".to_string(),
entity_id: Some(area_uuid),
invalidating_operations: vec![
"area_updated".to_string(),
"area_deleted".to_string(),
],
});
}
}
dependencies
}
fn create_area_dependencies(areas: &[Area]) -> Vec<QueryDependency> {
let mut dependencies = Vec::new();
dependencies.push(QueryDependency {
table: "TMArea".to_string(),
entity_id: None,
invalidating_operations: vec![
"area_created".to_string(),
"area_updated".to_string(),
"area_deleted".to_string(),
],
});
for area in areas {
dependencies.push(QueryDependency {
table: "TMArea".to_string(),
entity_id: Some(area.uuid),
invalidating_operations: vec![
"area_updated".to_string(),
"area_deleted".to_string(),
],
});
}
dependencies
}
fn record_hit(&self) {
let mut stats = self.stats.write();
stats.hits += 1;
}
fn record_miss(&self) {
let mut stats = self.stats.write();
stats.misses += 1;
}
#[allow(clippy::cast_precision_loss)]
fn update_stats(&self, result_size: usize, execution_time_ms: u64, compressed: bool) {
let mut stats = self.stats.write();
stats.total_queries += 1;
stats.total_size_bytes += result_size as u64;
let total_queries = stats.total_queries as f64;
let current_avg = stats.average_execution_time_ms;
stats.average_execution_time_ms =
(current_avg * (total_queries - 1.0) + execution_time_ms as f64) / total_queries;
if compressed {
stats.compressed_queries += 1;
} else {
stats.uncompressed_queries += 1;
}
}
}
impl<T> CachedQueryResult<T> {
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::TaskStatus;
use crate::test_utils::create_mock_tasks;
#[tokio::test]
async fn test_query_cache_basic_operations() {
let cache = QueryCache::new_default();
let tasks = create_mock_tasks();
let query_key = "test_tasks_query";
let params_hash = "test_params_hash";
let result = cache
.cache_tasks_query(query_key, params_hash, || async { Ok(tasks.clone()) })
.await
.unwrap();
assert_eq!(result.len(), tasks.len());
let cached_result = cache
.cache_tasks_query(query_key, params_hash, || async {
panic!("Should not execute fetcher on cache hit");
})
.await
.unwrap();
assert_eq!(cached_result.len(), tasks.len());
let different_params = "different_params_hash";
let _ = cache
.cache_tasks_query(query_key, different_params, || async {
Ok(create_mock_tasks())
})
.await
.unwrap();
let stats = cache.get_stats();
assert!(stats.hits >= 1);
assert!(stats.misses >= 1);
}
#[tokio::test]
async fn test_query_cache_invalidation() {
let cache = QueryCache::new_default();
let tasks = create_mock_tasks();
cache
.cache_tasks_query("test_query", "params", || async { Ok(tasks.clone()) })
.await
.unwrap();
cache.invalidate_by_operation("task_updated");
let _ = cache
.cache_tasks_query("test_query", "params", || async { Ok(create_mock_tasks()) })
.await
.unwrap();
let stats = cache.get_stats();
assert!(stats.misses >= 2);
}
#[tokio::test]
async fn test_query_cache_dependencies() {
let _cache = QueryCache::new_default();
let tasks = create_mock_tasks();
let dependencies = QueryCache::create_task_dependencies(&tasks);
assert!(!dependencies.is_empty());
assert!(dependencies.iter().any(|dep| dep.table == "TMTask"));
}
#[tokio::test]
async fn test_query_cache_projects_query() {
let cache = QueryCache::new_default();
let projects = vec![Project {
uuid: Uuid::new_v4(),
title: "Project 1".to_string(),
area_uuid: Some(Uuid::new_v4()),
created: Utc::now(),
modified: Utc::now(),
status: TaskStatus::Incomplete,
notes: Some("Notes".to_string()),
deadline: None,
start_date: None,
tags: vec![],
tasks: vec![],
}];
let query_key = "test_projects_query";
let params_hash = "test_params";
let result = cache
.cache_projects_query(query_key, params_hash, || async { Ok(projects.clone()) })
.await
.unwrap();
assert_eq!(result.len(), projects.len());
let cached_result = cache
.cache_projects_query(query_key, params_hash, || async {
panic!("Should not execute fetcher on cache hit");
})
.await
.unwrap();
assert_eq!(cached_result.len(), projects.len());
}
#[tokio::test]
async fn test_query_cache_config_default() {
let config = QueryCacheConfig::default();
assert_eq!(config.max_queries, 1000);
assert_eq!(config.ttl, Duration::from_secs(1800));
assert_eq!(config.tti, Duration::from_secs(300));
assert!(config.enable_compression);
assert_eq!(config.max_result_size, 1024 * 1024);
}
#[tokio::test]
async fn test_cached_query_result_creation() {
let tasks = create_mock_tasks();
let now = Utc::now();
let expires_at = now + chrono::Duration::seconds(1800);
let dependency = QueryDependency {
table: "TMTask".to_string(),
entity_id: None,
invalidating_operations: vec![
"INSERT".to_string(),
"UPDATE".to_string(),
"DELETE".to_string(),
],
};
let result = CachedQueryResult {
data: tasks.clone(),
executed_at: now,
expires_at,
execution_time_ms: 100,
params_hash: "test_hash".to_string(),
result_size: 1024,
dependencies: vec![dependency.clone()],
compressed: false,
};
assert_eq!(result.data.len(), tasks.len());
assert_eq!(result.execution_time_ms, 100);
assert_eq!(result.result_size, 1024);
assert_eq!(result.params_hash, "test_hash");
assert_eq!(result.dependencies, vec![dependency]);
assert!(!result.compressed);
}
#[tokio::test]
async fn test_query_cache_areas_query() {
let cache = QueryCache::new_default();
let areas = vec![Area {
uuid: Uuid::new_v4(),
title: "Area 1".to_string(),
created: Utc::now(),
modified: Utc::now(),
notes: Some("Notes".to_string()),
tags: vec![],
projects: vec![],
}];
let query_key = "test_areas_query";
let params_hash = "test_params";
let result = cache
.cache_areas_query(query_key, params_hash, || async { Ok(areas.clone()) })
.await
.unwrap();
assert_eq!(result.len(), areas.len());
let cached_result = cache
.cache_areas_query(query_key, params_hash, || async {
panic!("Should not execute fetcher on cache hit");
})
.await
.unwrap();
assert_eq!(cached_result.len(), areas.len());
}
#[tokio::test]
async fn test_query_cache_expiration() {
let config = QueryCacheConfig {
max_queries: 100,
ttl: Duration::from_millis(10), tti: Duration::from_millis(5),
enable_compression: false,
max_result_size: 1024,
};
let cache = QueryCache::new(config);
let tasks = create_mock_tasks();
let query_key = "test_expiration";
let params_hash = "test_params";
let _result = cache
.cache_tasks_query(query_key, params_hash, || async { Ok(tasks.clone()) })
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
let mut fetcher_called = false;
let _expired_result = cache
.cache_tasks_query(query_key, params_hash, || async {
fetcher_called = true;
Ok(tasks.clone())
})
.await
.unwrap();
assert!(fetcher_called);
}
#[tokio::test]
async fn test_query_cache_size_limit() {
let config = QueryCacheConfig {
max_queries: 2, ttl: Duration::from_secs(300),
tti: Duration::from_secs(60),
enable_compression: false,
max_result_size: 1024,
};
let cache = QueryCache::new(config);
let tasks = create_mock_tasks();
let _result1 = cache
.cache_tasks_query("key1", "params1", || async { Ok(tasks.clone()) })
.await
.unwrap();
let _result2 = cache
.cache_tasks_query("key2", "params2", || async { Ok(tasks.clone()) })
.await
.unwrap();
let _result3 = cache
.cache_tasks_query("key3", "params3", || async { Ok(tasks.clone()) })
.await
.unwrap();
let stats = cache.get_stats();
assert!(stats.total_queries <= 10); }
#[tokio::test]
async fn test_query_cache_concurrent_access() {
let cache = Arc::new(QueryCache::new_default());
let tasks = create_mock_tasks();
let mut handles = vec![];
for i in 0..10 {
let cache_clone = cache.clone();
let tasks_clone = tasks.clone();
let handle = tokio::spawn(async move {
let key = format!("concurrent_key_{i}");
let params = format!("params_{i}");
let result = cache_clone
.cache_tasks_query(&key, ¶ms, || async { Ok(tasks_clone.clone()) })
.await
.unwrap();
assert!(!result.is_empty());
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
#[tokio::test]
async fn test_query_cache_error_handling() {
let cache = QueryCache::new_default();
let query_key = "error_test";
let params_hash = "test_params";
let result = cache
.cache_tasks_query(query_key, params_hash, || async {
Err(anyhow::anyhow!("Test error"))
})
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_query_cache_compression() {
let config = QueryCacheConfig {
max_queries: 100,
ttl: Duration::from_secs(300),
tti: Duration::from_secs(60),
enable_compression: true,
max_result_size: 1024 * 1024,
};
let cache = QueryCache::new(config);
let tasks = create_mock_tasks();
let query_key = "compression_test";
let params_hash = "test_params";
let result = cache
.cache_tasks_query(query_key, params_hash, || async { Ok(tasks.clone()) })
.await
.unwrap();
assert_eq!(result.len(), tasks.len());
let cached_result = cache
.cache_tasks_query(query_key, params_hash, || async {
panic!("Should not execute fetcher on cache hit");
})
.await
.unwrap();
assert_eq!(cached_result.len(), tasks.len());
}
}