1use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::path::PathBuf;
40use thiserror::Error;
41use tokio::fs::{self, OpenOptions};
42use tokio::io::{AsyncReadExt, AsyncWriteExt};
43
44#[derive(Debug, Error)]
46pub enum WarmingError {
47 #[error("IO error: {0}")]
48 Io(#[from] std::io::Error),
49
50 #[error("Serialization error: {0}")]
51 Serialization(#[from] serde_json::Error),
52
53 #[error("Invalid configuration: {0}")]
54 InvalidConfig(String),
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum WarmingStrategy {
60 FrequencyBased,
62 RecencyBased,
64 Hybrid,
66 Predictive,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct WarmingConfig {
73 pub strategy: WarmingStrategy,
75 pub max_items: usize,
77 pub max_bytes: u64,
79 pub access_log_path: PathBuf,
81 pub warmup_on_startup: bool,
83}
84
85impl Default for WarmingConfig {
86 fn default() -> Self {
87 Self {
88 strategy: WarmingStrategy::Hybrid,
89 max_items: 100,
90 max_bytes: 100 * 1024 * 1024, access_log_path: PathBuf::from("/tmp/chie_access.log"),
92 warmup_on_startup: true,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99struct AccessRecord {
100 cid: String,
101 size_bytes: u64,
102 access_count: u64,
103 last_access_ms: u64,
104 first_access_ms: u64,
105}
106
107#[derive(Debug, Clone)]
109pub struct WarmingCandidate {
110 pub cid: String,
112 pub size_bytes: u64,
114 pub score: f64,
116 pub access_count: u64,
118 pub last_access_ms: u64,
120}
121
122pub struct CacheWarmer {
124 config: WarmingConfig,
125 access_records: HashMap<String, AccessRecord>,
126}
127
128impl CacheWarmer {
129 #[inline]
131 pub fn new(config: WarmingConfig) -> Result<Self, WarmingError> {
132 if config.max_items == 0 {
133 return Err(WarmingError::InvalidConfig(
134 "max_items must be > 0".to_string(),
135 ));
136 }
137 if config.max_bytes == 0 {
138 return Err(WarmingError::InvalidConfig(
139 "max_bytes must be > 0".to_string(),
140 ));
141 }
142
143 Ok(Self {
144 config,
145 access_records: HashMap::new(),
146 })
147 }
148
149 #[inline]
151 pub async fn record_access(&mut self, cid: String, size_bytes: u64) {
152 let now_ms = Self::current_timestamp_ms();
153
154 self.access_records
155 .entry(cid.clone())
156 .and_modify(|record| {
157 record.access_count += 1;
158 record.last_access_ms = now_ms;
159 })
160 .or_insert_with(|| AccessRecord {
161 cid,
162 size_bytes,
163 access_count: 1,
164 last_access_ms: now_ms,
165 first_access_ms: now_ms,
166 });
167 }
168
169 pub async fn persist(&self) -> Result<(), WarmingError> {
171 let records: Vec<&AccessRecord> = self.access_records.values().collect();
172 let json = serde_json::to_string_pretty(&records)?;
173
174 let mut file = OpenOptions::new()
175 .write(true)
176 .create(true)
177 .truncate(true)
178 .open(&self.config.access_log_path)
179 .await?;
180
181 file.write_all(json.as_bytes()).await?;
182 file.flush().await?;
183 Ok(())
184 }
185
186 pub async fn load(&mut self) -> Result<(), WarmingError> {
188 if !self.config.access_log_path.exists() {
189 return Ok(()); }
191
192 let mut file = fs::File::open(&self.config.access_log_path).await?;
193 let mut contents = String::new();
194 file.read_to_string(&mut contents).await?;
195
196 let records: Vec<AccessRecord> = serde_json::from_str(&contents)?;
197
198 self.access_records.clear();
199 for record in records {
200 self.access_records.insert(record.cid.clone(), record);
201 }
202
203 Ok(())
204 }
205
206 pub fn get_warming_candidates(&self) -> Result<Vec<WarmingCandidate>, WarmingError> {
208 let mut candidates: Vec<WarmingCandidate> = self
209 .access_records
210 .values()
211 .map(|record| {
212 let score = self.calculate_score(record);
213 WarmingCandidate {
214 cid: record.cid.clone(),
215 size_bytes: record.size_bytes,
216 score,
217 access_count: record.access_count,
218 last_access_ms: record.last_access_ms,
219 }
220 })
221 .collect();
222
223 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
225
226 self.apply_constraints(&mut candidates);
228
229 Ok(candidates)
230 }
231
232 #[inline]
234 fn calculate_score(&self, record: &AccessRecord) -> f64 {
235 match self.config.strategy {
236 WarmingStrategy::FrequencyBased => {
237 record.access_count as f64
239 }
240 WarmingStrategy::RecencyBased => {
241 let now = Self::current_timestamp_ms();
243 let age_ms = now.saturating_sub(record.last_access_ms);
244 let age_hours = age_ms as f64 / (1000.0 * 3600.0);
245
246 1.0 / (1.0 + age_hours)
248 }
249 WarmingStrategy::Hybrid => {
250 let frequency_score = record.access_count as f64;
252
253 let now = Self::current_timestamp_ms();
254 let age_ms = now.saturating_sub(record.last_access_ms);
255 let age_hours = age_ms as f64 / (1000.0 * 3600.0);
256 let recency_score = 1.0 / (1.0 + age_hours);
257
258 0.7 * frequency_score + 0.3 * recency_score * 100.0
260 }
261 WarmingStrategy::Predictive => {
262 let frequency = record.access_count as f64;
264 let lifetime_days =
265 (record.last_access_ms - record.first_access_ms) as f64 / (1000.0 * 86400.0);
266
267 if lifetime_days < 0.01 {
268 return frequency;
270 }
271
272 let access_rate = frequency / lifetime_days;
274
275 let now = Self::current_timestamp_ms();
277 let age_hours =
278 (now.saturating_sub(record.last_access_ms)) as f64 / (1000.0 * 3600.0);
279 let recency_boost = if age_hours < 24.0 {
280 2.0 } else if age_hours < 168.0 {
282 1.5
284 } else {
285 1.0
286 };
287
288 access_rate * recency_boost
289 }
290 }
291 }
292
293 #[inline]
295 fn apply_constraints(&self, candidates: &mut Vec<WarmingCandidate>) {
296 let mut total_bytes = 0u64;
297 let mut keep_count = 0usize;
298
299 for candidate in candidates.iter() {
300 if keep_count >= self.config.max_items {
301 break;
302 }
303 if total_bytes + candidate.size_bytes > self.config.max_bytes {
304 break;
305 }
306
307 total_bytes += candidate.size_bytes;
308 keep_count += 1;
309 }
310
311 candidates.truncate(keep_count);
312 }
313
314 #[must_use]
316 #[inline]
317 pub fn warming_stats(&self) -> WarmingStats {
318 let candidates = self.get_warming_candidates().unwrap_or_default();
319
320 let total_items = candidates.len();
321 let total_bytes: u64 = candidates.iter().map(|c| c.size_bytes).sum();
322 let avg_score = if !candidates.is_empty() {
323 candidates.iter().map(|c| c.score).sum::<f64>() / candidates.len() as f64
324 } else {
325 0.0
326 };
327
328 WarmingStats {
329 total_items,
330 total_bytes,
331 avg_score,
332 strategy: self.config.strategy,
333 }
334 }
335
336 #[inline]
338 pub fn clear(&mut self) {
339 self.access_records.clear();
340 }
341
342 #[inline]
344 fn current_timestamp_ms() -> u64 {
345 std::time::SystemTime::now()
346 .duration_since(std::time::UNIX_EPOCH)
347 .unwrap()
348 .as_millis() as u64
349 }
350}
351
352#[derive(Debug, Clone)]
354pub struct WarmingStats {
355 pub total_items: usize,
357 pub total_bytes: u64,
359 pub avg_score: f64,
361 pub strategy: WarmingStrategy,
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 fn create_test_warmer() -> CacheWarmer {
370 let config = WarmingConfig {
371 strategy: WarmingStrategy::FrequencyBased,
372 max_items: 10,
373 max_bytes: 1024 * 1024, access_log_path: PathBuf::from("/tmp/test_access.log"),
375 warmup_on_startup: false,
376 };
377 CacheWarmer::new(config).unwrap()
378 }
379
380 #[tokio::test]
381 async fn test_record_access() {
382 let mut warmer = create_test_warmer();
383
384 warmer.record_access("QmTest1".to_string(), 1024).await;
385 warmer.record_access("QmTest1".to_string(), 1024).await;
386 warmer.record_access("QmTest2".to_string(), 2048).await;
387
388 assert_eq!(warmer.access_records.len(), 2);
389 assert_eq!(warmer.access_records["QmTest1"].access_count, 2);
390 assert_eq!(warmer.access_records["QmTest2"].access_count, 1);
391 }
392
393 #[tokio::test]
394 async fn test_frequency_based_warming() {
395 let mut warmer = create_test_warmer();
396
397 for _ in 0..10 {
399 warmer.record_access("QmFrequent".to_string(), 100).await;
400 }
401 for _ in 0..3 {
402 warmer.record_access("QmMedium".to_string(), 100).await;
403 }
404 warmer.record_access("QmRare".to_string(), 100).await;
405
406 let candidates = warmer.get_warming_candidates().unwrap();
407
408 assert_eq!(candidates.len(), 3);
409 assert_eq!(candidates[0].cid, "QmFrequent");
410 assert_eq!(candidates[1].cid, "QmMedium");
411 assert_eq!(candidates[2].cid, "QmRare");
412 }
413
414 #[tokio::test]
415 async fn test_max_items_constraint() {
416 let mut warmer = create_test_warmer();
417
418 for i in 0..20 {
420 warmer.record_access(format!("QmTest{}", i), 100).await;
421 }
422
423 let candidates = warmer.get_warming_candidates().unwrap();
424
425 assert_eq!(candidates.len(), 10);
427 }
428
429 #[tokio::test]
430 async fn test_max_bytes_constraint() {
431 let mut warmer = create_test_warmer();
432
433 for i in 0..10 {
435 warmer
436 .record_access(format!("QmTest{}", i), 200 * 1024)
437 .await; }
439
440 let candidates = warmer.get_warming_candidates().unwrap();
441
442 let total_bytes: u64 = candidates.iter().map(|c| c.size_bytes).sum();
443 assert!(total_bytes <= 1024 * 1024); }
445
446 #[tokio::test]
447 async fn test_persist_and_load() {
448 let log_path = PathBuf::from("/tmp/test_persist_access.log");
449
450 let mut warmer = CacheWarmer::new(WarmingConfig {
452 access_log_path: log_path.clone(),
453 ..Default::default()
454 })
455 .unwrap();
456
457 warmer.record_access("QmTest1".to_string(), 1024).await;
458 warmer.record_access("QmTest2".to_string(), 2048).await;
459
460 warmer.persist().await.unwrap();
462
463 let mut new_warmer = CacheWarmer::new(WarmingConfig {
465 access_log_path: log_path.clone(),
466 ..Default::default()
467 })
468 .unwrap();
469
470 new_warmer.load().await.unwrap();
471
472 assert_eq!(new_warmer.access_records.len(), 2);
473 assert!(new_warmer.access_records.contains_key("QmTest1"));
474 assert!(new_warmer.access_records.contains_key("QmTest2"));
475
476 let _ = std::fs::remove_file(log_path);
478 }
479
480 #[tokio::test]
481 async fn test_hybrid_strategy() {
482 let config = WarmingConfig {
483 strategy: WarmingStrategy::Hybrid,
484 max_items: 10,
485 max_bytes: 1024 * 1024,
486 access_log_path: PathBuf::from("/tmp/test_hybrid.log"),
487 warmup_on_startup: false,
488 };
489
490 let mut warmer = CacheWarmer::new(config).unwrap();
491
492 for _ in 0..100 {
494 warmer.record_access("QmOldFrequent".to_string(), 100).await;
495 }
496
497 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
499
500 for _ in 0..5 {
502 warmer.record_access("QmRecentRare".to_string(), 100).await;
503 }
504
505 let candidates = warmer.get_warming_candidates().unwrap();
506
507 assert!(!candidates.is_empty());
509 }
510
511 #[test]
512 fn test_warming_stats() {
513 let warmer = create_test_warmer();
514
515 let stats = warmer.warming_stats();
516 assert_eq!(stats.total_items, 0);
517 assert_eq!(stats.total_bytes, 0);
518 }
519
520 #[test]
521 fn test_invalid_config() {
522 let config = WarmingConfig {
523 max_items: 0,
524 ..Default::default()
525 };
526
527 assert!(CacheWarmer::new(config).is_err());
528 }
529
530 #[tokio::test]
531 async fn test_clear() {
532 let mut warmer = create_test_warmer();
533
534 warmer.record_access("QmTest1".to_string(), 1024).await;
535 warmer.record_access("QmTest2".to_string(), 2048).await;
536
537 assert_eq!(warmer.access_records.len(), 2);
538
539 warmer.clear();
540
541 assert_eq!(warmer.access_records.len(), 0);
542 }
543}