1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use parking_lot::RwLock;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::time::interval;
15use tracing::{debug, error, info};
16
17use crate::cache::RedisCache;
18use crate::error::Result;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum WarmingStrategy {
23 None,
25 Preload,
27 BackgroundRefresh,
29 Predictive,
31 All,
33}
34
35#[derive(Debug, Clone)]
37pub struct CacheWarmingConfig {
38 pub strategy: WarmingStrategy,
40 pub refresh_interval_secs: u64,
42 pub refresh_threshold_secs: u64,
44 pub batch_size: usize,
46 pub enable_prediction: bool,
48}
49
50impl Default for CacheWarmingConfig {
51 fn default() -> Self {
52 Self {
53 strategy: WarmingStrategy::All,
54 refresh_interval_secs: 60,
55 refresh_threshold_secs: 300, batch_size: 100,
57 enable_prediction: true,
58 }
59 }
60}
61
62#[async_trait]
64pub trait CacheDataSource: Send + Sync {
65 async fn get_hot_keys(&self) -> Result<Vec<String>>;
67
68 async fn load_data(&self, key: &str) -> Result<Option<(String, u64)>>;
70}
71
72pub struct CacheWarmer {
74 cache: Arc<RedisCache>,
75 config: CacheWarmingConfig,
76 access_tracker: Arc<RwLock<AccessTracker>>,
77}
78
79impl CacheWarmer {
80 pub fn new(cache: Arc<RedisCache>, config: CacheWarmingConfig) -> Self {
82 Self {
83 cache,
84 config,
85 access_tracker: Arc::new(RwLock::new(AccessTracker::new())),
86 }
87 }
88
89 pub async fn preload<T: CacheDataSource>(&self, source: Arc<T>) -> Result<usize> {
91 if !matches!(
92 self.config.strategy,
93 WarmingStrategy::Preload | WarmingStrategy::All
94 ) {
95 return Ok(0);
96 }
97
98 info!("Starting cache preload");
99
100 let hot_keys = source.get_hot_keys().await?;
101 let mut loaded_count = 0;
102
103 for chunk in hot_keys.chunks(self.config.batch_size) {
104 for key in chunk {
105 match source.load_data(key).await {
106 Ok(Some((value, ttl))) => {
107 if let Err(e) = self.cache.set(key, &value, ttl).await {
108 error!(key = %key, error = %e, "Failed to preload key");
109 } else {
110 loaded_count += 1;
111 debug!(key = %key, "Preloaded key");
112 }
113 }
114 Ok(None) => {
115 debug!(key = %key, "No data for key");
116 }
117 Err(e) => {
118 error!(key = %key, error = %e, "Failed to load data");
119 }
120 }
121 }
122 }
123
124 info!(count = loaded_count, "Cache preload completed");
125 Ok(loaded_count)
126 }
127
128 pub async fn start_background_refresh<T: CacheDataSource + 'static>(
130 self: Arc<Self>,
131 source: Arc<T>,
132 ) {
133 if !matches!(
134 self.config.strategy,
135 WarmingStrategy::BackgroundRefresh | WarmingStrategy::All
136 ) {
137 return;
138 }
139
140 info!(
141 interval_secs = self.config.refresh_interval_secs,
142 "Starting background cache refresh"
143 );
144
145 let mut refresh_interval = interval(Duration::from_secs(self.config.refresh_interval_secs));
146
147 tokio::spawn(async move {
148 loop {
149 refresh_interval.tick().await;
150
151 debug!("Running background cache refresh");
152
153 match source.get_hot_keys().await {
154 Ok(keys) => {
155 for chunk in keys.chunks(self.config.batch_size) {
156 for key in chunk {
157 match self.cache.ttl(key).await {
159 Ok(ttl)
160 if ttl > 0
161 && ttl < self.config.refresh_threshold_secs as i64 =>
162 {
163 match source.load_data(key).await {
165 Ok(Some((value, new_ttl))) => {
166 if let Err(e) =
167 self.cache.set(key, &value, new_ttl).await
168 {
169 error!(key = %key, error = %e, "Failed to refresh key");
170 } else {
171 debug!(key = %key, ttl = ttl, "Refreshed expiring key");
172 }
173 }
174 Ok(None) => {}
175 Err(e) => {
176 error!(key = %key, error = %e, "Failed to load data for refresh");
177 }
178 }
179 }
180 Ok(_) => {
181 }
183 Err(e) => {
184 error!(key = %key, error = %e, "Failed to get TTL");
185 }
186 }
187 }
188 }
189 }
190 Err(e) => {
191 error!(error = %e, "Failed to get hot keys for refresh");
192 }
193 }
194 }
195 });
196 }
197
198 pub fn track_access(&self, key: &str) {
200 if !self.config.enable_prediction {
201 return;
202 }
203
204 self.access_tracker.write().record_access(key);
205 }
206
207 pub fn get_predicted_keys(&self, limit: usize) -> Vec<String> {
209 if !self.config.enable_prediction {
210 return Vec::new();
211 }
212
213 self.access_tracker.read().get_top_keys(limit)
214 }
215
216 pub async fn predictive_load<T: CacheDataSource>(
218 &self,
219 source: Arc<T>,
220 limit: usize,
221 ) -> Result<usize> {
222 if !matches!(
223 self.config.strategy,
224 WarmingStrategy::Predictive | WarmingStrategy::All
225 ) {
226 return Ok(0);
227 }
228
229 let predicted_keys = self.get_predicted_keys(limit);
230 let mut loaded_count = 0;
231
232 for key in predicted_keys {
233 if let Ok(true) = self.cache.exists(&key).await {
235 continue;
236 }
237
238 match source.load_data(&key).await {
240 Ok(Some((value, ttl))) => {
241 if let Err(e) = self.cache.set(&key, &value, ttl).await {
242 error!(key = %key, error = %e, "Failed to predictively load key");
243 } else {
244 loaded_count += 1;
245 debug!(key = %key, "Predictively loaded key");
246 }
247 }
248 Ok(None) => {}
249 Err(e) => {
250 error!(key = %key, error = %e, "Failed to load data for prediction");
251 }
252 }
253 }
254
255 if loaded_count > 0 {
256 info!(count = loaded_count, "Predictive cache loading completed");
257 }
258
259 Ok(loaded_count)
260 }
261}
262
263#[derive(Debug, Clone)]
265struct AccessTracker {
266 counts: HashMap<String, AccessCount>,
268 total_accesses: u64,
270}
271
272#[derive(Debug, Clone)]
273struct AccessCount {
274 count: u64,
276 last_access: DateTime<Utc>,
278 frequency: f64,
280}
281
282impl AccessTracker {
283 fn new() -> Self {
284 Self {
285 counts: HashMap::new(),
286 total_accesses: 0,
287 }
288 }
289
290 fn record_access(&mut self, key: &str) {
291 let now = Utc::now();
292 self.total_accesses += 1;
293
294 self.counts
295 .entry(key.to_string())
296 .and_modify(|count| {
297 count.count += 1;
298 let hours_since_last =
299 now.signed_duration_since(count.last_access).num_seconds() as f64 / 3600.0;
300 if hours_since_last > 0.0 {
301 count.frequency = count.count as f64 / hours_since_last;
302 }
303 count.last_access = now;
304 })
305 .or_insert(AccessCount {
306 count: 1,
307 last_access: now,
308 frequency: 0.0,
309 });
310 }
311
312 fn get_top_keys(&self, limit: usize) -> Vec<String> {
313 let mut keys: Vec<(String, u64, f64)> = self
314 .counts
315 .iter()
316 .map(|(k, v)| (k.clone(), v.count, v.frequency))
317 .collect();
318
319 keys.sort_by(|a, b| {
321 b.2.partial_cmp(&a.2)
322 .unwrap_or(std::cmp::Ordering::Equal)
323 .then_with(|| b.1.cmp(&a.1))
324 });
325
326 keys.into_iter().take(limit).map(|(k, _, _)| k).collect()
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[allow(dead_code)]
335 struct MockDataSource {
336 keys: Vec<String>,
337 }
338
339 #[allow(dead_code)]
340 #[async_trait]
341 impl CacheDataSource for MockDataSource {
342 async fn get_hot_keys(&self) -> Result<Vec<String>> {
343 Ok(self.keys.clone())
344 }
345
346 async fn load_data(&self, key: &str) -> Result<Option<(String, u64)>> {
347 Ok(Some((format!("data:{}", key), 3600)))
348 }
349 }
350
351 #[test]
352 fn test_warming_config_default() {
353 let config = CacheWarmingConfig::default();
354 assert_eq!(config.strategy, WarmingStrategy::All);
355 assert_eq!(config.refresh_interval_secs, 60);
356 assert_eq!(config.refresh_threshold_secs, 300);
357 assert_eq!(config.batch_size, 100);
358 assert!(config.enable_prediction);
359 }
360
361 #[test]
362 fn test_access_tracker_record() {
363 let mut tracker = AccessTracker::new();
364
365 tracker.record_access("key1");
366 tracker.record_access("key1");
367 tracker.record_access("key2");
368
369 assert_eq!(tracker.total_accesses, 3);
370 assert_eq!(tracker.counts.get("key1").unwrap().count, 2);
371 assert_eq!(tracker.counts.get("key2").unwrap().count, 1);
372 }
373
374 #[test]
375 fn test_access_tracker_top_keys() {
376 let mut tracker = AccessTracker::new();
377
378 for _ in 0..10 {
379 tracker.record_access("key1");
380 }
381 for _ in 0..5 {
382 tracker.record_access("key2");
383 }
384 tracker.record_access("key3");
385
386 let top_keys = tracker.get_top_keys(2);
387 assert_eq!(top_keys.len(), 2);
388 }
390
391 #[test]
392 fn test_warming_strategy_equality() {
393 assert_eq!(WarmingStrategy::None, WarmingStrategy::None);
394 assert_ne!(WarmingStrategy::None, WarmingStrategy::Preload);
395 }
396}