kaccy_db/
cache_warming.rs

1//! Cache warming strategies for improved performance.
2//!
3//! This module provides:
4//! - Preload hot data on startup
5//! - Background refresh for expiring keys
6//! - Predictive cache loading
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use parking_lot::RwLock;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::time::interval;
15use tracing::{debug, error, info};
16
17use crate::cache::RedisCache;
18use crate::error::Result;
19
20/// Cache warming strategy
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum WarmingStrategy {
23    /// No warming
24    None,
25    /// Preload specific keys on startup
26    Preload,
27    /// Background refresh before expiry
28    BackgroundRefresh,
29    /// Predictive loading based on access patterns
30    Predictive,
31    /// All strategies combined
32    All,
33}
34
35/// Configuration for cache warming
36#[derive(Debug, Clone)]
37pub struct CacheWarmingConfig {
38    /// Warming strategy to use
39    pub strategy: WarmingStrategy,
40    /// Interval for background refresh (seconds)
41    pub refresh_interval_secs: u64,
42    /// Refresh keys when they have less than this TTL remaining (seconds)
43    pub refresh_threshold_secs: u64,
44    /// Maximum number of keys to warm at once
45    pub batch_size: usize,
46    /// Enable predictive loading
47    pub enable_prediction: bool,
48}
49
50impl Default for CacheWarmingConfig {
51    fn default() -> Self {
52        Self {
53            strategy: WarmingStrategy::All,
54            refresh_interval_secs: 60,
55            refresh_threshold_secs: 300, // Refresh when < 5 minutes remaining
56            batch_size: 100,
57            enable_prediction: true,
58        }
59    }
60}
61
62/// Trait for data sources that can provide warm data
63#[async_trait]
64pub trait CacheDataSource: Send + Sync {
65    /// Load hot data keys that should be preloaded
66    async fn get_hot_keys(&self) -> Result<Vec<String>>;
67
68    /// Load data for a specific key
69    async fn load_data(&self, key: &str) -> Result<Option<(String, u64)>>;
70}
71
72/// Cache warmer that manages cache warming strategies
73pub struct CacheWarmer {
74    cache: Arc<RedisCache>,
75    config: CacheWarmingConfig,
76    access_tracker: Arc<RwLock<AccessTracker>>,
77}
78
79impl CacheWarmer {
80    /// Create a new cache warmer
81    pub fn new(cache: Arc<RedisCache>, config: CacheWarmingConfig) -> Self {
82        Self {
83            cache,
84            config,
85            access_tracker: Arc::new(RwLock::new(AccessTracker::new())),
86        }
87    }
88
89    /// Preload hot keys on startup
90    pub async fn preload<T: CacheDataSource>(&self, source: Arc<T>) -> Result<usize> {
91        if !matches!(
92            self.config.strategy,
93            WarmingStrategy::Preload | WarmingStrategy::All
94        ) {
95            return Ok(0);
96        }
97
98        info!("Starting cache preload");
99
100        let hot_keys = source.get_hot_keys().await?;
101        let mut loaded_count = 0;
102
103        for chunk in hot_keys.chunks(self.config.batch_size) {
104            for key in chunk {
105                match source.load_data(key).await {
106                    Ok(Some((value, ttl))) => {
107                        if let Err(e) = self.cache.set(key, &value, ttl).await {
108                            error!(key = %key, error = %e, "Failed to preload key");
109                        } else {
110                            loaded_count += 1;
111                            debug!(key = %key, "Preloaded key");
112                        }
113                    }
114                    Ok(None) => {
115                        debug!(key = %key, "No data for key");
116                    }
117                    Err(e) => {
118                        error!(key = %key, error = %e, "Failed to load data");
119                    }
120                }
121            }
122        }
123
124        info!(count = loaded_count, "Cache preload completed");
125        Ok(loaded_count)
126    }
127
128    /// Start background refresh task
129    pub async fn start_background_refresh<T: CacheDataSource + 'static>(
130        self: Arc<Self>,
131        source: Arc<T>,
132    ) {
133        if !matches!(
134            self.config.strategy,
135            WarmingStrategy::BackgroundRefresh | WarmingStrategy::All
136        ) {
137            return;
138        }
139
140        info!(
141            interval_secs = self.config.refresh_interval_secs,
142            "Starting background cache refresh"
143        );
144
145        let mut refresh_interval = interval(Duration::from_secs(self.config.refresh_interval_secs));
146
147        tokio::spawn(async move {
148            loop {
149                refresh_interval.tick().await;
150
151                debug!("Running background cache refresh");
152
153                match source.get_hot_keys().await {
154                    Ok(keys) => {
155                        for chunk in keys.chunks(self.config.batch_size) {
156                            for key in chunk {
157                                // Check if key needs refresh
158                                match self.cache.ttl(key).await {
159                                    Ok(ttl)
160                                        if ttl > 0
161                                            && ttl < self.config.refresh_threshold_secs as i64 =>
162                                    {
163                                        // Key is expiring soon, refresh it
164                                        match source.load_data(key).await {
165                                            Ok(Some((value, new_ttl))) => {
166                                                if let Err(e) =
167                                                    self.cache.set(key, &value, new_ttl).await
168                                                {
169                                                    error!(key = %key, error = %e, "Failed to refresh key");
170                                                } else {
171                                                    debug!(key = %key, ttl = ttl, "Refreshed expiring key");
172                                                }
173                                            }
174                                            Ok(None) => {}
175                                            Err(e) => {
176                                                error!(key = %key, error = %e, "Failed to load data for refresh");
177                                            }
178                                        }
179                                    }
180                                    Ok(_) => {
181                                        // Key doesn't need refresh yet
182                                    }
183                                    Err(e) => {
184                                        error!(key = %key, error = %e, "Failed to get TTL");
185                                    }
186                                }
187                            }
188                        }
189                    }
190                    Err(e) => {
191                        error!(error = %e, "Failed to get hot keys for refresh");
192                    }
193                }
194            }
195        });
196    }
197
198    /// Track key access for predictive loading
199    pub fn track_access(&self, key: &str) {
200        if !self.config.enable_prediction {
201            return;
202        }
203
204        self.access_tracker.write().record_access(key);
205    }
206
207    /// Get predicted keys that should be loaded
208    pub fn get_predicted_keys(&self, limit: usize) -> Vec<String> {
209        if !self.config.enable_prediction {
210            return Vec::new();
211        }
212
213        self.access_tracker.read().get_top_keys(limit)
214    }
215
216    /// Predictively load keys based on access patterns
217    pub async fn predictive_load<T: CacheDataSource>(
218        &self,
219        source: Arc<T>,
220        limit: usize,
221    ) -> Result<usize> {
222        if !matches!(
223            self.config.strategy,
224            WarmingStrategy::Predictive | WarmingStrategy::All
225        ) {
226            return Ok(0);
227        }
228
229        let predicted_keys = self.get_predicted_keys(limit);
230        let mut loaded_count = 0;
231
232        for key in predicted_keys {
233            // Check if key is already cached
234            if let Ok(true) = self.cache.exists(&key).await {
235                continue;
236            }
237
238            // Load and cache the key
239            match source.load_data(&key).await {
240                Ok(Some((value, ttl))) => {
241                    if let Err(e) = self.cache.set(&key, &value, ttl).await {
242                        error!(key = %key, error = %e, "Failed to predictively load key");
243                    } else {
244                        loaded_count += 1;
245                        debug!(key = %key, "Predictively loaded key");
246                    }
247                }
248                Ok(None) => {}
249                Err(e) => {
250                    error!(key = %key, error = %e, "Failed to load data for prediction");
251                }
252            }
253        }
254
255        if loaded_count > 0 {
256            info!(count = loaded_count, "Predictive cache loading completed");
257        }
258
259        Ok(loaded_count)
260    }
261}
262
263/// Access tracker for predictive loading
264#[derive(Debug, Clone)]
265struct AccessTracker {
266    /// Access counts by key
267    counts: HashMap<String, AccessCount>,
268    /// Total accesses tracked
269    total_accesses: u64,
270}
271
272#[derive(Debug, Clone)]
273struct AccessCount {
274    /// Number of accesses
275    count: u64,
276    /// Last access time
277    last_access: DateTime<Utc>,
278    /// Access frequency (accesses per hour)
279    frequency: f64,
280}
281
282impl AccessTracker {
283    fn new() -> Self {
284        Self {
285            counts: HashMap::new(),
286            total_accesses: 0,
287        }
288    }
289
290    fn record_access(&mut self, key: &str) {
291        let now = Utc::now();
292        self.total_accesses += 1;
293
294        self.counts
295            .entry(key.to_string())
296            .and_modify(|count| {
297                count.count += 1;
298                let hours_since_last =
299                    now.signed_duration_since(count.last_access).num_seconds() as f64 / 3600.0;
300                if hours_since_last > 0.0 {
301                    count.frequency = count.count as f64 / hours_since_last;
302                }
303                count.last_access = now;
304            })
305            .or_insert(AccessCount {
306                count: 1,
307                last_access: now,
308                frequency: 0.0,
309            });
310    }
311
312    fn get_top_keys(&self, limit: usize) -> Vec<String> {
313        let mut keys: Vec<(String, u64, f64)> = self
314            .counts
315            .iter()
316            .map(|(k, v)| (k.clone(), v.count, v.frequency))
317            .collect();
318
319        // Sort by frequency, then by count
320        keys.sort_by(|a, b| {
321            b.2.partial_cmp(&a.2)
322                .unwrap_or(std::cmp::Ordering::Equal)
323                .then_with(|| b.1.cmp(&a.1))
324        });
325
326        keys.into_iter().take(limit).map(|(k, _, _)| k).collect()
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[allow(dead_code)]
335    struct MockDataSource {
336        keys: Vec<String>,
337    }
338
339    #[allow(dead_code)]
340    #[async_trait]
341    impl CacheDataSource for MockDataSource {
342        async fn get_hot_keys(&self) -> Result<Vec<String>> {
343            Ok(self.keys.clone())
344        }
345
346        async fn load_data(&self, key: &str) -> Result<Option<(String, u64)>> {
347            Ok(Some((format!("data:{}", key), 3600)))
348        }
349    }
350
351    #[test]
352    fn test_warming_config_default() {
353        let config = CacheWarmingConfig::default();
354        assert_eq!(config.strategy, WarmingStrategy::All);
355        assert_eq!(config.refresh_interval_secs, 60);
356        assert_eq!(config.refresh_threshold_secs, 300);
357        assert_eq!(config.batch_size, 100);
358        assert!(config.enable_prediction);
359    }
360
361    #[test]
362    fn test_access_tracker_record() {
363        let mut tracker = AccessTracker::new();
364
365        tracker.record_access("key1");
366        tracker.record_access("key1");
367        tracker.record_access("key2");
368
369        assert_eq!(tracker.total_accesses, 3);
370        assert_eq!(tracker.counts.get("key1").unwrap().count, 2);
371        assert_eq!(tracker.counts.get("key2").unwrap().count, 1);
372    }
373
374    #[test]
375    fn test_access_tracker_top_keys() {
376        let mut tracker = AccessTracker::new();
377
378        for _ in 0..10 {
379            tracker.record_access("key1");
380        }
381        for _ in 0..5 {
382            tracker.record_access("key2");
383        }
384        tracker.record_access("key3");
385
386        let top_keys = tracker.get_top_keys(2);
387        assert_eq!(top_keys.len(), 2);
388        // Should contain key1 and key2 (most accessed)
389    }
390
391    #[test]
392    fn test_warming_strategy_equality() {
393        assert_eq!(WarmingStrategy::None, WarmingStrategy::None);
394        assert_ne!(WarmingStrategy::None, WarmingStrategy::Preload);
395    }
396}