1use crate::error::{CacheError, Result};
11use crate::multi_tier::{CacheKey, CacheValue, MultiTierCache};
12use async_trait::async_trait;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17#[async_trait]
19pub trait WarmingStrategy: Send + Sync {
20 async fn next_batch(&mut self, batch_size: usize) -> Result<Vec<CacheKey>>;
22
23 async fn is_complete(&self) -> bool;
25
26 async fn progress(&self) -> f64;
28
29 async fn reset(&mut self);
31}
32
33pub struct SequentialWarming {
36 keys: Vec<CacheKey>,
38 position: usize,
40}
41
42impl SequentialWarming {
43 pub fn new(keys: Vec<CacheKey>) -> Self {
45 Self { keys, position: 0 }
46 }
47}
48
49#[async_trait]
50impl WarmingStrategy for SequentialWarming {
51 async fn next_batch(&mut self, batch_size: usize) -> Result<Vec<CacheKey>> {
52 let end = (self.position + batch_size).min(self.keys.len());
53 let batch = self.keys[self.position..end].to_vec();
54 self.position = end;
55 Ok(batch)
56 }
57
58 async fn is_complete(&self) -> bool {
59 self.position >= self.keys.len()
60 }
61
62 async fn progress(&self) -> f64 {
63 if self.keys.is_empty() {
64 1.0
65 } else {
66 self.position as f64 / self.keys.len() as f64
67 }
68 }
69
70 async fn reset(&mut self) {
71 self.position = 0;
72 }
73}
74
75pub struct RandomWarming {
78 keys: Vec<CacheKey>,
80 warmed_count: usize,
82}
83
84impl RandomWarming {
85 pub fn new(keys: Vec<CacheKey>) -> Self {
87 Self {
88 keys,
89 warmed_count: 0,
90 }
91 }
92}
93
94#[async_trait]
95impl WarmingStrategy for RandomWarming {
96 async fn next_batch(&mut self, batch_size: usize) -> Result<Vec<CacheKey>> {
97 fastrand::seed(42);
99 let remaining = self.keys.len().saturating_sub(self.warmed_count);
100 let batch_size = batch_size.min(remaining);
101
102 let mut batch = Vec::with_capacity(batch_size);
103 let mut indices: Vec<usize> = (0..self.keys.len()).collect();
104
105 for i in (1..indices.len()).rev() {
107 let j = fastrand::usize(0..=i);
108 indices.swap(i, j);
109 }
110
111 for i in 0..batch_size {
112 if let Some(&idx) = indices.get(i) {
113 if let Some(key) = self.keys.get(idx) {
114 batch.push(key.clone());
115 }
116 }
117 }
118
119 self.warmed_count += batch.len();
120 Ok(batch)
121 }
122
123 async fn is_complete(&self) -> bool {
124 self.warmed_count >= self.keys.len()
125 }
126
127 async fn progress(&self) -> f64 {
128 if self.keys.is_empty() {
129 1.0
130 } else {
131 self.warmed_count as f64 / self.keys.len() as f64
132 }
133 }
134
135 async fn reset(&mut self) {
136 self.warmed_count = 0;
137 }
138}
139
140pub struct PriorityWarming {
143 keys_with_priority: Vec<(CacheKey, f64)>,
145 position: usize,
147}
148
149impl PriorityWarming {
150 pub fn new(mut keys_with_priority: Vec<(CacheKey, f64)>) -> Self {
152 keys_with_priority
154 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
155
156 Self {
157 keys_with_priority,
158 position: 0,
159 }
160 }
161}
162
163#[async_trait]
164impl WarmingStrategy for PriorityWarming {
165 async fn next_batch(&mut self, batch_size: usize) -> Result<Vec<CacheKey>> {
166 let end = (self.position + batch_size).min(self.keys_with_priority.len());
167 let batch = self.keys_with_priority[self.position..end]
168 .iter()
169 .map(|(key, _)| key.clone())
170 .collect();
171 self.position = end;
172 Ok(batch)
173 }
174
175 async fn is_complete(&self) -> bool {
176 self.position >= self.keys_with_priority.len()
177 }
178
179 async fn progress(&self) -> f64 {
180 if self.keys_with_priority.is_empty() {
181 1.0
182 } else {
183 self.position as f64 / self.keys_with_priority.len() as f64
184 }
185 }
186
187 async fn reset(&mut self) {
188 self.position = 0;
189 }
190}
191
192#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
194pub struct WarmingProgress {
195 pub total_keys: usize,
197 pub warmed_keys: usize,
199 pub progress_percent: f64,
201 pub estimated_time_remaining: Option<i64>,
203 pub warming_rate: f64,
205 pub start_time: chrono::DateTime<chrono::Utc>,
207 pub elapsed_seconds: i64,
209}
210
211impl WarmingProgress {
212 pub fn new(total_keys: usize) -> Self {
214 Self {
215 total_keys,
216 warmed_keys: 0,
217 progress_percent: 0.0,
218 estimated_time_remaining: None,
219 warming_rate: 0.0,
220 start_time: chrono::Utc::now(),
221 elapsed_seconds: 0,
222 }
223 }
224
225 pub fn update(&mut self, warmed_keys: usize) {
227 self.warmed_keys = warmed_keys;
228 self.progress_percent = if self.total_keys > 0 {
229 (warmed_keys as f64 / self.total_keys as f64) * 100.0
230 } else {
231 100.0
232 };
233
234 let now = chrono::Utc::now();
235 self.elapsed_seconds = (now - self.start_time).num_seconds();
236
237 if self.elapsed_seconds > 0 {
238 self.warming_rate = warmed_keys as f64 / self.elapsed_seconds as f64;
239
240 let remaining_keys = self.total_keys.saturating_sub(warmed_keys);
241 if self.warming_rate > 0.0 {
242 self.estimated_time_remaining =
243 Some((remaining_keys as f64 / self.warming_rate) as i64);
244 }
245 }
246 }
247
248 pub fn is_complete(&self) -> bool {
250 self.warmed_keys >= self.total_keys
251 }
252}
253
254#[async_trait]
256pub trait DataSource: Send + Sync {
257 async fn load(&self, key: &CacheKey) -> Result<CacheValue>;
259
260 async fn exists(&self, key: &CacheKey) -> Result<bool>;
262
263 async fn keys(&self) -> Result<Vec<CacheKey>>;
265}
266
267pub struct CacheWarmer {
269 cache: Arc<MultiTierCache>,
271 data_source: Arc<dyn DataSource>,
273 strategy: Arc<RwLock<Box<dyn WarmingStrategy>>>,
275 progress: Arc<RwLock<WarmingProgress>>,
277 batch_size: usize,
279 is_active: Arc<RwLock<bool>>,
281}
282
283impl CacheWarmer {
284 pub fn new(
286 cache: Arc<MultiTierCache>,
287 data_source: Arc<dyn DataSource>,
288 strategy: Box<dyn WarmingStrategy>,
289 total_keys: usize,
290 ) -> Self {
291 Self {
292 cache,
293 data_source,
294 strategy: Arc::new(RwLock::new(strategy)),
295 progress: Arc::new(RwLock::new(WarmingProgress::new(total_keys))),
296 batch_size: 10,
297 is_active: Arc::new(RwLock::new(false)),
298 }
299 }
300
301 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
303 self.batch_size = batch_size;
304 self
305 }
306
307 pub async fn start(&self) -> Result<()> {
309 let mut is_active = self.is_active.write().await;
310 if *is_active {
311 return Err(CacheError::Other("Warming already in progress".to_string()));
312 }
313 *is_active = true;
314 drop(is_active);
315
316 let mut warmed_count = 0;
317
318 loop {
319 let is_complete = {
321 let strategy = self.strategy.read().await;
322 strategy.is_complete().await
323 };
324
325 if is_complete {
326 break;
327 }
328
329 let batch = {
331 let mut strategy = self.strategy.write().await;
332 strategy.next_batch(self.batch_size).await?
333 };
334
335 if batch.is_empty() {
336 break;
337 }
338
339 let mut tasks: Vec<tokio::task::JoinHandle<std::result::Result<usize, CacheError>>> =
341 Vec::new();
342
343 for key in batch {
344 let data_source = Arc::clone(&self.data_source);
345 let cache = Arc::clone(&self.cache);
346
347 let task = tokio::spawn(async move {
348 if let Ok(value) = data_source.load(&key).await {
349 let _ = cache.put(key, value).await;
350 Ok::<usize, CacheError>(1)
351 } else {
352 Ok::<usize, CacheError>(0)
353 }
354 });
355
356 tasks.push(task);
357 }
358
359 for task in tasks {
361 if let Ok(Ok(count)) = task.await {
362 warmed_count += count;
363 }
364 }
365
366 let mut progress = self.progress.write().await;
368 progress.update(warmed_count);
369 }
370
371 let mut is_active = self.is_active.write().await;
372 *is_active = false;
373
374 Ok(())
375 }
376
377 pub fn start_background(self: Arc<Self>) -> tokio::task::JoinHandle<Result<()>> {
379 tokio::spawn(async move { self.start().await })
380 }
381
382 pub async fn stop(&self) -> Result<()> {
384 let mut is_active = self.is_active.write().await;
385 *is_active = false;
386 Ok(())
387 }
388
389 pub async fn progress(&self) -> WarmingProgress {
391 self.progress.read().await.clone()
392 }
393
394 pub async fn is_active(&self) -> bool {
396 *self.is_active.read().await
397 }
398
399 pub async fn reset(&self) -> Result<()> {
401 let mut strategy = self.strategy.write().await;
402 strategy.reset().await;
403
404 let mut progress = self.progress.write().await;
405 *progress = WarmingProgress::new(progress.total_keys);
406
407 Ok(())
408 }
409}
410
411pub struct InMemoryDataSource {
413 data: Arc<RwLock<HashMap<CacheKey, CacheValue>>>,
415}
416
417impl InMemoryDataSource {
418 pub fn new() -> Self {
420 Self {
421 data: Arc::new(RwLock::new(HashMap::new())),
422 }
423 }
424
425 pub async fn add(&self, key: CacheKey, value: CacheValue) {
427 let mut data = self.data.write().await;
428 data.insert(key, value);
429 }
430}
431
432impl Default for InMemoryDataSource {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438#[async_trait]
439impl DataSource for InMemoryDataSource {
440 async fn load(&self, key: &CacheKey) -> Result<CacheValue> {
441 let data = self.data.read().await;
442 data.get(key)
443 .cloned()
444 .ok_or_else(|| CacheError::KeyNotFound(key.clone()))
445 }
446
447 async fn exists(&self, key: &CacheKey) -> Result<bool> {
448 let data = self.data.read().await;
449 Ok(data.contains_key(key))
450 }
451
452 async fn keys(&self) -> Result<Vec<CacheKey>> {
453 let data = self.data.read().await;
454 Ok(data.keys().cloned().collect())
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use crate::CacheConfig;
462 use crate::compression::DataType;
463 use bytes::Bytes;
464
465 #[tokio::test]
466 async fn test_sequential_warming() {
467 let keys: Vec<_> = (0..100).map(|i| format!("key{}", i)).collect();
468 let mut strategy = SequentialWarming::new(keys.clone());
469
470 let batch = strategy.next_batch(10).await.expect("next_batch failed");
471 assert_eq!(batch.len(), 10);
472 assert_eq!(batch[0], "key0");
473
474 let progress = strategy.progress().await;
475 approx::assert_relative_eq!(progress, 0.1, epsilon = 0.01);
476 }
477
478 #[tokio::test]
479 async fn test_random_warming() {
480 let keys: Vec<_> = (0..100).map(|i| format!("key{}", i)).collect();
481 let mut strategy = RandomWarming::new(keys.clone());
482
483 let batch = strategy.next_batch(10).await.expect("next_batch failed");
484 assert_eq!(batch.len(), 10);
485
486 let progress = strategy.progress().await;
487 approx::assert_relative_eq!(progress, 0.1, epsilon = 0.01);
488 }
489
490 #[tokio::test]
491 async fn test_priority_warming() {
492 let mut keys_with_priority = Vec::new();
493 for i in 0..100 {
494 keys_with_priority.push((format!("key{}", i), i as f64));
495 }
496
497 let mut strategy = PriorityWarming::new(keys_with_priority);
498
499 let batch = strategy.next_batch(10).await.expect("next_batch failed");
500 assert_eq!(batch.len(), 10);
501
502 assert_eq!(batch[0], "key99");
504 }
505
506 #[tokio::test]
507 async fn test_cache_warmer() {
508 let temp_dir = std::env::temp_dir().join("oxigdal_warmer_test");
509 let config = CacheConfig {
510 l1_size: 1024 * 1024,
511 l2_size: 0,
512 l3_size: 0,
513 enable_compression: false,
514 enable_prefetch: false,
515 enable_distributed: false,
516 cache_dir: Some(temp_dir.clone()),
517 };
518
519 let cache = Arc::new(
520 MultiTierCache::new(config)
521 .await
522 .expect("cache creation failed"),
523 );
524
525 let data_source = Arc::new(InMemoryDataSource::new());
527
528 for i in 0..10 {
529 let key = format!("key{}", i);
530 let value = CacheValue::new(Bytes::from(format!("value{}", i)), DataType::Binary);
531 data_source.add(key.clone(), value).await;
532 }
533
534 let keys: Vec<_> = (0..10).map(|i| format!("key{}", i)).collect();
536 let strategy = Box::new(SequentialWarming::new(keys.clone()));
537
538 let warmer = Arc::new(CacheWarmer::new(
539 Arc::clone(&cache),
540 data_source,
541 strategy,
542 10,
543 ));
544
545 warmer.start().await.expect("warming failed");
547
548 let progress = warmer.progress().await;
550 assert!(progress.is_complete());
551
552 let _ = tokio::fs::remove_dir_all(temp_dir).await;
554 }
555
556 #[test]
557 fn test_warming_progress() {
558 let mut progress = WarmingProgress::new(100);
559
560 progress.update(50);
561 approx::assert_relative_eq!(progress.progress_percent, 50.0, epsilon = 0.01);
562 assert!(!progress.is_complete());
563
564 progress.update(100);
565 approx::assert_relative_eq!(progress.progress_percent, 100.0, epsilon = 0.01);
566 assert!(progress.is_complete());
567 }
568}