use super::{FilterBackend, FilterResult};
use async_trait::async_trait;
#[cfg(feature = "caching")]
use reinhardt_utils::cache::Cache;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
pub fn generate_cache_key(query_params: &HashMap<String, String>, sql: &str) -> String {
let mut hasher = Sha256::new();
let mut sorted_params: Vec<_> = query_params.iter().collect();
sorted_params.sort_by_key(|(k, _)| *k);
for (key, value) in sorted_params {
hasher.update(key.as_bytes());
hasher.update(b"=");
hasher.update(value.as_bytes());
hasher.update(b"&");
}
hasher.update(sql.as_bytes());
let result = hasher.finalize();
format!("filter_cache:{}", hex::encode(result))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CachedResult {
sql: String,
}
#[cfg(feature = "caching")]
pub struct CachedFilterBackend<C: Cache> {
cache: Arc<C>,
ttl: Duration,
inner: Option<Arc<dyn FilterBackend>>,
}
#[cfg(feature = "caching")]
impl<C: Cache> CachedFilterBackend<C> {
pub fn new(cache: C, ttl: Duration) -> Self {
Self {
cache: Arc::new(cache),
ttl,
inner: None,
}
}
pub fn with_inner(mut self, inner: Box<dyn FilterBackend>) -> Self {
self.inner = Some(Arc::from(inner));
self
}
pub fn ttl(&self) -> Duration {
self.ttl
}
}
#[cfg(feature = "caching")]
#[async_trait]
impl<C: Cache> FilterBackend for CachedFilterBackend<C> {
async fn filter_queryset(
&self,
query_params: &HashMap<String, String>,
sql: String,
) -> FilterResult<String> {
let cache_key = generate_cache_key(query_params, &sql);
if let Ok(Some(cached)) = self.cache.get::<CachedResult>(&cache_key).await {
return Ok(cached.sql);
}
let result_sql = if let Some(inner) = &self.inner {
inner.filter_queryset(query_params, sql).await?
} else {
sql
};
let cached_result = CachedResult {
sql: result_sql.clone(),
};
if let Err(e) = self
.cache
.set(&cache_key, &cached_result, Some(self.ttl))
.await
{
eprintln!("Failed to cache filter result: {:?}", e);
}
Ok(result_sql)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[cfg(all(test, feature = "caching"))]
mod tests {
use super::*;
use crate::filters::backend::SimpleSearchBackend;
use reinhardt_utils::cache::InMemoryCache;
#[test]
fn test_generate_cache_key_consistent() {
let mut params1 = HashMap::new();
params1.insert("search".to_string(), "rust".to_string());
params1.insert("ordering".to_string(), "-created_at".to_string());
let mut params2 = HashMap::new();
params2.insert("ordering".to_string(), "-created_at".to_string());
params2.insert("search".to_string(), "rust".to_string());
let sql = "SELECT * FROM articles";
let key1 = generate_cache_key(¶ms1, sql);
let key2 = generate_cache_key(¶ms2, sql);
assert_eq!(key1, key2);
}
#[test]
fn test_generate_cache_key_different_params() {
let mut params1 = HashMap::new();
params1.insert("search".to_string(), "rust".to_string());
let mut params2 = HashMap::new();
params2.insert("search".to_string(), "python".to_string());
let sql = "SELECT * FROM articles";
let key1 = generate_cache_key(¶ms1, sql);
let key2 = generate_cache_key(¶ms2, sql);
assert_ne!(key1, key2);
}
#[test]
fn test_generate_cache_key_different_sql() {
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
let key1 = generate_cache_key(¶ms, "SELECT * FROM articles");
let key2 = generate_cache_key(¶ms, "SELECT * FROM users");
assert_ne!(key1, key2);
}
#[tokio::test]
async fn test_cached_filter_backend_simple() {
let cache = InMemoryCache::new();
let backend = CachedFilterBackend::new(cache, Duration::from_secs(300));
let params = HashMap::new();
let sql = "SELECT * FROM users".to_string();
let result1 = backend.filter_queryset(¶ms, sql.clone()).await.unwrap();
let result2 = backend.filter_queryset(¶ms, sql).await.unwrap();
assert_eq!(result1, result2);
}
#[tokio::test]
async fn test_cached_filter_backend_with_inner() {
let cache = InMemoryCache::new();
let search_backend = SimpleSearchBackend::new("search").with_field("title");
let backend = CachedFilterBackend::new(cache, Duration::from_secs(300))
.with_inner(Box::new(search_backend));
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
let sql = "SELECT * FROM articles".to_string();
let result1 = backend.filter_queryset(¶ms, sql.clone()).await.unwrap();
let result2 = backend.filter_queryset(¶ms, sql).await.unwrap();
assert_eq!(result1, result2);
assert!(result1.contains("WHERE"));
assert!(result1.contains("`title` LIKE '%rust%'"));
}
#[tokio::test]
async fn test_cached_filter_backend_ttl() {
let cache = InMemoryCache::new();
let backend = CachedFilterBackend::new(cache, Duration::from_millis(100));
assert_eq!(backend.ttl(), Duration::from_millis(100));
}
#[test]
fn test_cache_stats_hit_rate() {
let stats = CacheStats {
hits: 75,
misses: 25,
entries: 50,
};
assert_eq!(stats.hit_rate(), 0.75);
}
#[test]
fn test_cache_stats_hit_rate_zero() {
let stats = CacheStats {
hits: 0,
misses: 0,
entries: 0,
};
assert_eq!(stats.hit_rate(), 0.0);
}
#[test]
fn test_cache_stats_hit_rate_perfect() {
let stats = CacheStats {
hits: 100,
misses: 0,
entries: 50,
};
assert_eq!(stats.hit_rate(), 1.0);
}
}