use crate::{Result, VoirsError};
use chrono::{Datelike, Timelike};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmingConfig {
pub strategy: WarmingStrategy,
pub pattern_window_size: usize,
pub min_frequency_threshold: f64,
pub enable_time_patterns: bool,
pub enable_text_similarity: bool,
pub max_warm_entries: usize,
pub warming_interval: Duration,
pub similarity_threshold: f64,
}
impl Default for WarmingConfig {
fn default() -> Self {
Self {
strategy: WarmingStrategy::Predictive,
pattern_window_size: 1000,
min_frequency_threshold: 0.1,
enable_time_patterns: true,
enable_text_similarity: true,
max_warm_entries: 50,
warming_interval: Duration::from_secs(60),
similarity_threshold: 0.7,
}
}
}
impl WarmingConfig {
pub fn with_strategy(mut self, strategy: WarmingStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_pattern_window(mut self, size: usize) -> Self {
self.pattern_window_size = size;
self
}
pub fn enable_time_patterns(mut self, enable: bool) -> Self {
self.enable_time_patterns = enable;
self
}
pub fn with_similarity_threshold(mut self, threshold: f64) -> Self {
self.similarity_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn enable_text_similarity(mut self, enable: bool) -> Self {
self.enable_text_similarity = enable;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WarmingStrategy {
FrequencyBased,
RecencyBased,
Predictive,
TimeBased,
Hybrid,
}
#[derive(Debug, Clone)]
struct AccessPattern {
voice_id: String,
text: String,
timestamp: Instant,
hour_of_day: u8,
day_of_week: u8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FrequencyStats {
pub access_count: u64,
pub last_access: String,
pub avg_interval_secs: f64,
pub frequency_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatternAnalysis {
pub frequent_voices: Vec<(String, u64)>,
pub common_patterns: Vec<(String, u64)>,
pub hourly_patterns: HashMap<u8, u64>,
pub daily_patterns: HashMap<u8, u64>,
pub predictions: Vec<WarmingPrediction>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmingPrediction {
pub voice_id: String,
pub text_pattern: String,
pub confidence: f64,
pub reason: String,
}
#[derive(Debug)]
struct WarmerState {
config: WarmingConfig,
access_history: Vec<AccessPattern>,
frequency_map: HashMap<String, u64>,
last_warming: Option<Instant>,
warming_count: u64,
}
#[derive(Debug, Clone)]
pub struct CacheWarmer {
state: Arc<RwLock<WarmerState>>,
}
impl CacheWarmer {
pub fn new(config: WarmingConfig) -> Self {
Self {
state: Arc::new(RwLock::new(WarmerState {
config,
access_history: Vec::new(),
frequency_map: HashMap::new(),
last_warming: None,
warming_count: 0,
})),
}
}
pub async fn record_access(&self, voice_id: &str, text: &str) -> Result<()> {
let mut state = self.state.write().await;
let now = chrono::Utc::now();
let pattern = AccessPattern {
voice_id: voice_id.to_string(),
text: text.to_string(),
timestamp: Instant::now(),
hour_of_day: now.hour() as u8,
day_of_week: now.weekday().num_days_from_monday() as u8,
};
state.access_history.push(pattern);
let key = format!("{}:{}", voice_id, text);
*state.frequency_map.entry(key).or_insert(0) += 1;
let window_size = state.config.pattern_window_size;
let current_len = state.access_history.len();
if current_len > window_size {
state.access_history.drain(0..current_len - window_size);
}
Ok(())
}
pub async fn analyze_patterns(&self) -> Result<PatternAnalysis> {
let state = self.state.read().await;
if state.access_history.is_empty() {
return Ok(PatternAnalysis {
frequent_voices: Vec::new(),
common_patterns: Vec::new(),
hourly_patterns: HashMap::new(),
daily_patterns: HashMap::new(),
predictions: Vec::new(),
});
}
let mut voice_counts: HashMap<String, u64> = HashMap::new();
for pattern in &state.access_history {
*voice_counts.entry(pattern.voice_id.clone()).or_insert(0) += 1;
}
let mut frequent_voices: Vec<_> = voice_counts.into_iter().collect();
frequent_voices.sort_by_key(|b| std::cmp::Reverse(b.1));
frequent_voices.truncate(10);
let mut text_counts: HashMap<String, u64> = HashMap::new();
for pattern in &state.access_history {
*text_counts.entry(pattern.text.clone()).or_insert(0) += 1;
}
let mut common_patterns: Vec<_> = text_counts.into_iter().collect();
common_patterns.sort_by_key(|b| std::cmp::Reverse(b.1));
common_patterns.truncate(10);
let mut hourly_patterns: HashMap<u8, u64> = HashMap::new();
let mut daily_patterns: HashMap<u8, u64> = HashMap::new();
if state.config.enable_time_patterns {
for pattern in &state.access_history {
*hourly_patterns.entry(pattern.hour_of_day).or_insert(0) += 1;
*daily_patterns.entry(pattern.day_of_week).or_insert(0) += 1;
}
}
let predictions = self.generate_predictions(&state).await?;
Ok(PatternAnalysis {
frequent_voices,
common_patterns,
hourly_patterns,
daily_patterns,
predictions,
})
}
pub async fn warm_cache(&self) -> Result<WarmingStats> {
let mut state = self.state.write().await;
if let Some(last) = state.last_warming {
if last.elapsed() < state.config.warming_interval {
return Ok(WarmingStats {
entries_warmed: 0,
predictions_used: 0,
warming_time: Duration::from_secs(0),
});
}
}
let start = Instant::now();
drop(state);
let analysis = self.analyze_patterns().await?;
let mut state = self.state.write().await;
let mut entries_warmed = 0;
let max_entries = state.config.max_warm_entries;
for prediction in &analysis.predictions {
if entries_warmed >= max_entries {
break;
}
if prediction.confidence >= state.config.min_frequency_threshold {
tracing::debug!(
"Warming cache entry: {} (confidence: {:.2})",
prediction.text_pattern,
prediction.confidence
);
entries_warmed += 1;
}
}
state.last_warming = Some(Instant::now());
state.warming_count += 1;
let warming_time = start.elapsed();
Ok(WarmingStats {
entries_warmed,
predictions_used: analysis.predictions.len(),
warming_time,
})
}
pub async fn get_stats(&self) -> Result<WarmingStatistics> {
let state = self.state.read().await;
Ok(WarmingStatistics {
total_accesses: state.access_history.len(),
unique_patterns: state.frequency_map.len(),
warming_cycles: state.warming_count,
last_warming: state.last_warming.map(|t| t.elapsed()),
})
}
pub async fn reset(&self) -> Result<()> {
let mut state = self.state.write().await;
state.access_history.clear();
state.frequency_map.clear();
state.last_warming = None;
state.warming_count = 0;
Ok(())
}
async fn generate_predictions(&self, state: &WarmerState) -> Result<Vec<WarmingPrediction>> {
let mut predictions = Vec::new();
match state.config.strategy {
WarmingStrategy::FrequencyBased => {
self.frequency_based_predictions(state, &mut predictions)?;
}
WarmingStrategy::RecencyBased => {
self.recency_based_predictions(state, &mut predictions)?;
}
WarmingStrategy::Predictive => {
self.predictive_predictions(state, &mut predictions)?;
}
WarmingStrategy::TimeBased => {
self.time_based_predictions(state, &mut predictions)?;
}
WarmingStrategy::Hybrid => {
self.frequency_based_predictions(state, &mut predictions)?;
self.time_based_predictions(state, &mut predictions)?;
}
}
predictions.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Less)
});
predictions.truncate(state.config.max_warm_entries);
Ok(predictions)
}
fn frequency_based_predictions(
&self,
state: &WarmerState,
predictions: &mut Vec<WarmingPrediction>,
) -> Result<()> {
let total_accesses = state.access_history.len() as f64;
for (key, count) in &state.frequency_map {
let frequency = *count as f64 / total_accesses;
if frequency >= state.config.min_frequency_threshold {
if let Some((voice_id, text)) = key.split_once(':') {
predictions.push(WarmingPrediction {
voice_id: voice_id.to_string(),
text_pattern: text.to_string(),
confidence: frequency,
reason: format!("Frequent access ({} times)", count),
});
}
}
}
Ok(())
}
fn recency_based_predictions(
&self,
state: &WarmerState,
predictions: &mut Vec<WarmingPrediction>,
) -> Result<()> {
let recent_count = 10.min(state.access_history.len());
let recent = &state.access_history[state.access_history.len() - recent_count..];
for pattern in recent {
predictions.push(WarmingPrediction {
voice_id: pattern.voice_id.clone(),
text_pattern: pattern.text.clone(),
confidence: 0.8,
reason: "Recent access".to_string(),
});
}
Ok(())
}
fn predictive_predictions(
&self,
state: &WarmerState,
predictions: &mut Vec<WarmingPrediction>,
) -> Result<()> {
self.frequency_based_predictions(state, predictions)?;
if state.config.enable_text_similarity && state.access_history.len() >= 2 {
let recent = &state.access_history[state.access_history.len() - 1];
for pattern in &state.access_history[..state.access_history.len() - 1] {
let similarity = self.text_similarity(&recent.text, &pattern.text);
if similarity >= state.config.similarity_threshold {
predictions.push(WarmingPrediction {
voice_id: pattern.voice_id.clone(),
text_pattern: pattern.text.clone(),
confidence: similarity,
reason: format!(
"Similar to recent access ({:.0}% match)",
similarity * 100.0
),
});
}
}
}
Ok(())
}
fn time_based_predictions(
&self,
state: &WarmerState,
predictions: &mut Vec<WarmingPrediction>,
) -> Result<()> {
if !state.config.enable_time_patterns {
return Ok(());
}
let now = chrono::Utc::now();
let current_hour = now.hour() as u8;
for pattern in &state.access_history {
if pattern.hour_of_day == current_hour {
predictions.push(WarmingPrediction {
voice_id: pattern.voice_id.clone(),
text_pattern: pattern.text.clone(),
confidence: 0.7,
reason: format!("Time-based pattern (hour {})", current_hour),
});
}
}
Ok(())
}
fn text_similarity(&self, text1: &str, text2: &str) -> f64 {
let words1: std::collections::HashSet<_> = text1.split_whitespace().collect();
let words2: std::collections::HashSet<_> = text2.split_whitespace().collect();
if words1.is_empty() && words2.is_empty() {
return 1.0;
}
let intersection = words1.intersection(&words2).count();
let union = words1.union(&words2).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmingStats {
pub entries_warmed: usize,
pub predictions_used: usize,
pub warming_time: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmingStatistics {
pub total_accesses: usize,
pub unique_patterns: usize,
pub warming_cycles: u64,
pub last_warming: Option<Duration>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cache_warmer_creation() {
let warmer = CacheWarmer::new(WarmingConfig::default());
let stats = warmer.get_stats().await.unwrap();
assert_eq!(stats.total_accesses, 0);
}
#[tokio::test]
async fn test_record_access() {
let warmer = CacheWarmer::new(WarmingConfig::default());
warmer.record_access("voice1", "Hello world").await.unwrap();
warmer.record_access("voice1", "Hello world").await.unwrap();
warmer.record_access("voice2", "Goodbye").await.unwrap();
let stats = warmer.get_stats().await.unwrap();
assert_eq!(stats.total_accesses, 3);
assert_eq!(stats.unique_patterns, 2);
}
#[tokio::test]
async fn test_pattern_analysis() {
let warmer = CacheWarmer::new(WarmingConfig::default());
for _ in 0..5 {
warmer.record_access("voice1", "Hello").await.unwrap();
}
for _ in 0..3 {
warmer.record_access("voice2", "World").await.unwrap();
}
let analysis = warmer.analyze_patterns().await.unwrap();
assert_eq!(analysis.frequent_voices.len(), 2);
assert_eq!(analysis.common_patterns.len(), 2);
assert_eq!(analysis.frequent_voices[0].0, "voice1");
assert_eq!(analysis.frequent_voices[0].1, 5);
}
#[tokio::test]
async fn test_frequency_based_predictions() {
let config = WarmingConfig::default()
.with_strategy(WarmingStrategy::FrequencyBased)
.with_pattern_window(100);
let warmer = CacheWarmer::new(config);
for _ in 0..10 {
warmer
.record_access("voice1", "Frequent text")
.await
.unwrap();
}
warmer.record_access("voice2", "Rare text").await.unwrap();
let analysis = warmer.analyze_patterns().await.unwrap();
assert!(!analysis.predictions.is_empty());
assert_eq!(analysis.predictions[0].text_pattern, "Frequent text");
}
#[tokio::test]
async fn test_text_similarity() {
let warmer = CacheWarmer::new(WarmingConfig::default());
let sim1 = warmer.text_similarity("hello world", "hello world");
assert_eq!(sim1, 1.0);
let sim2 = warmer.text_similarity("hello world", "world hello");
assert_eq!(sim2, 1.0);
let sim3 = warmer.text_similarity("hello", "goodbye");
assert_eq!(sim3, 0.0);
let sim4 = warmer.text_similarity("hello world", "hello earth");
assert!(sim4 > 0.0 && sim4 < 1.0);
}
#[tokio::test]
async fn test_warming_interval() {
let config = WarmingConfig::default().with_strategy(WarmingStrategy::FrequencyBased);
let warmer = CacheWarmer::new(config);
warmer.record_access("voice1", "Test").await.unwrap();
let stats1 = warmer.warm_cache().await.unwrap();
assert!(stats1.warming_time > Duration::from_secs(0));
let stats2 = warmer.warm_cache().await.unwrap();
assert_eq!(stats2.entries_warmed, 0);
}
#[tokio::test]
async fn test_warmer_reset() {
let warmer = CacheWarmer::new(WarmingConfig::default());
warmer.record_access("voice1", "Test").await.unwrap();
warmer.reset().await.unwrap();
let stats = warmer.get_stats().await.unwrap();
assert_eq!(stats.total_accesses, 0);
assert_eq!(stats.unique_patterns, 0);
}
}