use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmartTtlConfig {
pub default_ttl: Duration,
pub user_profile_ttl: Duration,
pub static_content_ttl: Duration,
pub real_time_data_ttl: Duration,
pub aggregated_data_ttl: Duration,
pub list_query_ttl: Duration,
pub item_query_ttl: Duration,
pub auto_detect_volatility: bool,
pub min_observations: usize,
pub max_adjustment_factor: f64,
pub custom_patterns: HashMap<String, Duration>,
pub respect_cache_hints: bool,
}
impl Default for SmartTtlConfig {
fn default() -> Self {
Self {
default_ttl: Duration::from_secs(300), user_profile_ttl: Duration::from_secs(900), static_content_ttl: Duration::from_secs(86400), real_time_data_ttl: Duration::from_secs(5), aggregated_data_ttl: Duration::from_secs(1800), list_query_ttl: Duration::from_secs(600), item_query_ttl: Duration::from_secs(300), auto_detect_volatility: true,
min_observations: 10,
max_adjustment_factor: 2.0,
custom_patterns: HashMap::new(),
respect_cache_hints: true,
}
}
}
#[derive(Debug)]
pub struct SmartTtlManager {
config: SmartTtlConfig,
query_stats: Arc<RwLock<HashMap<String, QueryVolatilityStats>>>,
mutation_tracker: Arc<RwLock<MutationTracker>>,
}
#[derive(Debug, Clone)]
struct QueryVolatilityStats {
hit_count: u64,
change_count: u64,
last_result_hash: u64,
volatility_score: f64,
recommended_ttl: Duration,
last_update: Instant,
}
#[derive(Debug, Clone, Default)]
struct MutationTracker {
mutation_to_queries: HashMap<String, Vec<String>>,
mutation_frequency: HashMap<String, MutationStats>,
}
#[derive(Debug, Clone)]
struct MutationStats {
count: u64,
last_mutation: Instant,
#[allow(dead_code)] avg_interval: Duration,
}
impl SmartTtlManager {
pub fn new(config: SmartTtlConfig) -> Self {
Self {
config,
query_stats: Arc::new(RwLock::new(HashMap::new())),
mutation_tracker: Arc::new(RwLock::new(MutationTracker::default())),
}
}
pub async fn calculate_ttl(
&self,
query: &str,
query_type: &str,
cache_hint: Option<Duration>,
) -> TtlResult {
let start = Instant::now();
if self.config.respect_cache_hints {
if let Some(hint) = cache_hint {
return TtlResult {
ttl: hint,
strategy: TtlStrategy::CacheHint,
confidence: 1.0,
calculation_time: start.elapsed(),
};
}
}
for (pattern, ttl) in &self.config.custom_patterns {
if query.contains(pattern) {
return TtlResult {
ttl: *ttl,
strategy: TtlStrategy::CustomPattern(pattern.clone()),
confidence: 1.0,
calculation_time: start.elapsed(),
};
}
}
let base_ttl = self.detect_query_type_ttl(query, query_type);
if self.config.auto_detect_volatility {
if let Some(adjusted_ttl) = self.get_volatility_adjusted_ttl(query).await {
return TtlResult {
ttl: adjusted_ttl.ttl,
strategy: TtlStrategy::VolatilityBased {
base_ttl,
volatility_score: adjusted_ttl.volatility_score,
},
confidence: adjusted_ttl.confidence,
calculation_time: start.elapsed(),
};
}
}
TtlResult {
ttl: base_ttl,
strategy: TtlStrategy::QueryType(query_type.to_string()),
confidence: 0.8,
calculation_time: start.elapsed(),
}
}
fn detect_query_type_ttl(&self, query: &str, query_type: &str) -> Duration {
let query_lower = query.to_lowercase();
if query_lower.contains("live")
|| query_lower.contains("current")
|| query_lower.contains("realtime")
{
return self.config.real_time_data_ttl;
}
if query_lower.contains("categories")
|| query_lower.contains("tags")
|| query_lower.contains("settings")
|| query_lower.contains("config")
|| query_lower.contains("metadata")
{
return self.config.static_content_ttl;
}
if query_lower.contains("count")
|| query_lower.contains("sum")
|| query_lower.contains("aggregate")
|| query_lower.contains("statistics")
|| query_lower.contains("analytics")
{
return self.config.aggregated_data_ttl;
}
if query_lower.contains("list")
|| query_lower.contains("page")
|| query_lower.contains("offset")
|| query_lower.contains("limit")
|| query_type.ends_with('s')
{
return self.config.list_query_ttl;
}
if query_lower.contains("profile")
|| query_lower.contains("user ")
|| query_lower.contains("user{")
|| query_lower.contains("user(")
|| query_lower.contains("account")
|| query_lower.contains("me {")
|| query_lower.contains("me{")
{
return self.config.user_profile_ttl;
}
if query_lower.contains("byid")
|| query_lower.contains("get")
|| query_lower.contains("find")
{
return self.config.item_query_ttl;
}
self.config.default_ttl
}
async fn get_volatility_adjusted_ttl(&self, query: &str) -> Option<VolatilityAdjustedTtl> {
let stats = self.query_stats.read().await;
let query_pattern = self.extract_query_pattern(query);
if let Some(volatility_stats) = stats.get(&query_pattern) {
if volatility_stats.hit_count >= self.config.min_observations as u64 {
return Some(VolatilityAdjustedTtl {
ttl: volatility_stats.recommended_ttl,
volatility_score: volatility_stats.volatility_score,
confidence: self.calculate_confidence(volatility_stats.hit_count),
});
}
}
None
}
pub async fn record_query_result(&self, query: &str, result_hash: u64) {
if !self.config.auto_detect_volatility {
return;
}
let query_pattern = self.extract_query_pattern(query);
let mut stats = self.query_stats.write().await;
let volatility_stats = stats
.entry(query_pattern.clone())
.or_insert(QueryVolatilityStats {
hit_count: 0,
change_count: 0,
last_result_hash: result_hash,
volatility_score: 0.0,
recommended_ttl: self.config.default_ttl,
last_update: Instant::now(),
});
volatility_stats.hit_count += 1;
if volatility_stats.last_result_hash != result_hash {
volatility_stats.change_count += 1;
volatility_stats.last_result_hash = result_hash;
}
volatility_stats.volatility_score =
volatility_stats.change_count as f64 / volatility_stats.hit_count as f64;
volatility_stats.recommended_ttl =
self.calculate_recommended_ttl(&query_pattern, volatility_stats.volatility_score);
volatility_stats.last_update = Instant::now();
}
pub async fn record_mutation(&self, mutation_type: &str, affected_queries: Vec<String>) {
let mut tracker = self.mutation_tracker.write().await;
let mutation_stats = tracker
.mutation_frequency
.entry(mutation_type.to_string())
.or_insert(MutationStats {
count: 0,
last_mutation: Instant::now(),
avg_interval: Duration::from_secs(3600),
});
mutation_stats.count += 1;
mutation_stats.last_mutation = Instant::now();
tracker
.mutation_to_queries
.entry(mutation_type.to_string())
.or_insert_with(Vec::new)
.extend(affected_queries);
}
fn calculate_recommended_ttl(&self, query_pattern: &str, volatility_score: f64) -> Duration {
let base_ttl = self.detect_query_type_ttl(query_pattern, "query");
let adjustment_factor = if volatility_score > 0.7 {
0.5 } else if volatility_score > 0.3 {
0.75
} else if volatility_score < 0.1 {
self.config.max_adjustment_factor } else {
1.5
};
let adjusted_secs = (base_ttl.as_secs() as f64 * adjustment_factor) as u64;
Duration::from_secs(adjusted_secs.max(1)) }
fn extract_query_pattern(&self, query: &str) -> String {
query
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| {
if let Some(pos) = line.find('(') {
&line[..pos]
} else {
line
}
})
.collect::<Vec<_>>()
.join("\n")
}
fn calculate_confidence(&self, observations: u64) -> f64 {
let min = self.config.min_observations as f64;
let confidence = (observations as f64 - min) / (min * 10.0);
confidence.clamp(0.5, 1.0)
}
pub async fn get_analytics(&self) -> SmartTtlAnalytics {
let stats = self.query_stats.read().await;
if stats.is_empty() {
return SmartTtlAnalytics::default();
}
let total_queries = stats.len();
let avg_volatility: f64 =
stats.values().map(|s| s.volatility_score).sum::<f64>() / total_queries as f64;
let avg_ttl_secs: u64 = stats
.values()
.map(|s| s.recommended_ttl.as_secs())
.sum::<u64>()
/ total_queries as u64;
let highly_volatile = stats.values().filter(|s| s.volatility_score > 0.7).count();
let stable_queries = stats.values().filter(|s| s.volatility_score < 0.1).count();
SmartTtlAnalytics {
total_queries,
avg_volatility_score: avg_volatility,
avg_recommended_ttl: Duration::from_secs(avg_ttl_secs),
highly_volatile_queries: highly_volatile,
stable_queries,
}
}
pub async fn cleanup_old_stats(&self, max_age: Duration) {
let mut stats = self.query_stats.write().await;
let now = Instant::now();
stats.retain(|_, stat| now.duration_since(stat.last_update) <= max_age);
}
}
#[derive(Debug, Clone)]
pub struct TtlResult {
pub ttl: Duration,
pub strategy: TtlStrategy,
pub confidence: f64,
pub calculation_time: Duration,
}
#[derive(Debug, Clone)]
pub enum TtlStrategy {
CacheHint,
CustomPattern(String),
QueryType(String),
VolatilityBased {
base_ttl: Duration,
volatility_score: f64,
},
}
#[derive(Debug, Clone)]
struct VolatilityAdjustedTtl {
ttl: Duration,
volatility_score: f64,
confidence: f64,
}
#[derive(Debug, Clone, Default)]
pub struct SmartTtlAnalytics {
pub total_queries: usize,
pub avg_volatility_score: f64,
pub avg_recommended_ttl: Duration,
pub highly_volatile_queries: usize,
pub stable_queries: usize,
}
pub fn parse_cache_hint(schema_metadata: &str) -> Option<Duration> {
if let Some(start) = schema_metadata.find("maxAge:") {
let remaining = &schema_metadata[start + 7..];
let trimmed_start = remaining.trim_start();
if let Some(end) = trimmed_start.find(|c: char| !c.is_numeric()) {
if let Ok(seconds) = trimmed_start[..end].parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_detect_query_types() {
let config = SmartTtlConfig::default();
let manager = SmartTtlManager::new(config.clone());
let res = manager
.calculate_ttl("query { settings { theme } }", "settings", None)
.await;
assert_eq!(res.ttl, config.static_content_ttl);
let res = manager
.calculate_ttl(
"query { stock(symbol: \"AAPL\") { realtime_price } }",
"stock",
None,
)
.await;
assert_eq!(res.ttl, config.real_time_data_ttl);
let res = manager
.calculate_ttl("query { me { name } }", "me", None)
.await;
assert_eq!(res.ttl, config.user_profile_ttl);
let res = manager
.calculate_ttl("query { analytics { daily_count } }", "analytics", None)
.await;
assert_eq!(res.ttl, config.aggregated_data_ttl);
let res = manager
.calculate_ttl("query { users(limit: 10) { name } }", "users", None)
.await;
assert_eq!(res.ttl, config.list_query_ttl);
let res = manager
.calculate_ttl("query { product_by_id(id: 1) { name } }", "product", None)
.await;
assert_eq!(res.ttl, config.item_query_ttl);
let res = manager
.calculate_ttl("query { generic_data { field } }", "generic", None)
.await;
assert_eq!(res.ttl, config.default_ttl);
}
#[tokio::test]
async fn test_volatility_learning_stable() {
let manager = SmartTtlManager::new(SmartTtlConfig::default());
let query = "query { product(id: 1) { price } }";
for _ in 0..15 {
manager.record_query_result(query, 12345).await;
}
let result = manager.calculate_ttl(query, "product", None).await;
if let TtlStrategy::VolatilityBased {
base_ttl: _,
volatility_score,
} = result.strategy
{
assert!(volatility_score < 0.1);
assert!(result.ttl >= Duration::from_secs(600));
} else {
panic!(
"Expected VolatilityBased strategy, got {:?}",
result.strategy
);
}
}
#[tokio::test]
async fn test_volatility_learning_volatile() {
let manager = SmartTtlManager::new(SmartTtlConfig::default());
let query = "query { random_quote { text } }";
for i in 0..15 {
manager.record_query_result(query, i as u64).await;
}
let result = manager.calculate_ttl(query, "random_quote", None).await;
if let TtlStrategy::VolatilityBased {
base_ttl: _,
volatility_score,
} = result.strategy
{
assert!(volatility_score > 0.9);
assert_eq!(result.ttl, Duration::from_secs(150));
} else {
panic!(
"Expected VolatilityBased strategy, got {:?}",
result.strategy
);
}
}
#[tokio::test]
async fn test_cache_hint_priority() {
let config = SmartTtlConfig::default();
let manager = SmartTtlManager::new(config);
let cache_hint = Some(Duration::from_secs(42));
let result = manager
.calculate_ttl("query { volatile }", "test", cache_hint)
.await;
assert_eq!(result.ttl, Duration::from_secs(42));
assert!(matches!(result.strategy, TtlStrategy::CacheHint));
}
#[tokio::test]
async fn test_custom_patterns() {
let mut config = SmartTtlConfig::default();
config
.custom_patterns
.insert("expensive_report".to_string(), Duration::from_secs(3600));
let manager = SmartTtlManager::new(config);
let query = "query { get_expensive_report { data } }";
let result = manager.calculate_ttl(query, "report", None).await;
assert_eq!(result.ttl, Duration::from_secs(3600));
if let TtlStrategy::CustomPattern(p) = result.strategy {
assert_eq!(p, "expensive_report");
} else {
panic!("Expected CustomPattern strategy");
}
}
#[tokio::test]
async fn test_parse_cache_hint() {
assert_eq!(
parse_cache_hint("@cacheControl(maxAge: 60)"),
Some(Duration::from_secs(60))
);
assert_eq!(
parse_cache_hint("type Query @cacheControl(maxAge: 300) {"),
Some(Duration::from_secs(300))
);
assert_eq!(parse_cache_hint("no hint here"), None);
assert_eq!(parse_cache_hint("@cacheControl(invalid)"), None);
}
#[tokio::test]
async fn test_cleanup_old_stats() {
let manager = SmartTtlManager::new(SmartTtlConfig::default());
let query = "query { old }";
manager.record_query_result(query, 1).await;
{
let stats = manager.query_stats.read().await;
assert_eq!(stats.len(), 1);
}
tokio::time::sleep(Duration::from_millis(10)).await;
manager.cleanup_old_stats(Duration::from_millis(0)).await;
{
let stats = manager.query_stats.read().await;
assert_eq!(stats.len(), 0);
}
}
#[tokio::test]
async fn test_analytics_aggregation() {
let manager = SmartTtlManager::new(SmartTtlConfig::default());
for _ in 0..15 {
manager.record_query_result("query { stable }", 1).await;
}
for i in 0..15 {
manager.record_query_result("query { volatile }", i).await;
}
let analytics = manager.get_analytics().await;
assert_eq!(analytics.total_queries, 2);
assert_eq!(analytics.stable_queries, 1);
assert_eq!(analytics.highly_volatile_queries, 1);
assert!(analytics.avg_volatility_score > 0.0 && analytics.avg_volatility_score < 1.0);
}
}