use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use thiserror::Error;
use tokio::fs::{self, OpenOptions};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[derive(Debug, Error)]
pub enum WarmingError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WarmingStrategy {
FrequencyBased,
RecencyBased,
Hybrid,
Predictive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmingConfig {
pub strategy: WarmingStrategy,
pub max_items: usize,
pub max_bytes: u64,
pub access_log_path: PathBuf,
pub warmup_on_startup: bool,
}
impl Default for WarmingConfig {
fn default() -> Self {
Self {
strategy: WarmingStrategy::Hybrid,
max_items: 100,
max_bytes: 100 * 1024 * 1024, access_log_path: PathBuf::from("/tmp/chie_access.log"),
warmup_on_startup: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AccessRecord {
cid: String,
size_bytes: u64,
access_count: u64,
last_access_ms: u64,
first_access_ms: u64,
}
#[derive(Debug, Clone)]
pub struct WarmingCandidate {
pub cid: String,
pub size_bytes: u64,
pub score: f64,
pub access_count: u64,
pub last_access_ms: u64,
}
pub struct CacheWarmer {
config: WarmingConfig,
access_records: HashMap<String, AccessRecord>,
}
impl CacheWarmer {
#[inline]
pub fn new(config: WarmingConfig) -> Result<Self, WarmingError> {
if config.max_items == 0 {
return Err(WarmingError::InvalidConfig(
"max_items must be > 0".to_string(),
));
}
if config.max_bytes == 0 {
return Err(WarmingError::InvalidConfig(
"max_bytes must be > 0".to_string(),
));
}
Ok(Self {
config,
access_records: HashMap::new(),
})
}
#[inline]
pub async fn record_access(&mut self, cid: String, size_bytes: u64) {
let now_ms = Self::current_timestamp_ms();
self.access_records
.entry(cid.clone())
.and_modify(|record| {
record.access_count += 1;
record.last_access_ms = now_ms;
})
.or_insert_with(|| AccessRecord {
cid,
size_bytes,
access_count: 1,
last_access_ms: now_ms,
first_access_ms: now_ms,
});
}
pub async fn persist(&self) -> Result<(), WarmingError> {
let records: Vec<&AccessRecord> = self.access_records.values().collect();
let json = serde_json::to_string_pretty(&records)?;
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&self.config.access_log_path)
.await?;
file.write_all(json.as_bytes()).await?;
file.flush().await?;
Ok(())
}
pub async fn load(&mut self) -> Result<(), WarmingError> {
if !self.config.access_log_path.exists() {
return Ok(()); }
let mut file = fs::File::open(&self.config.access_log_path).await?;
let mut contents = String::new();
file.read_to_string(&mut contents).await?;
let records: Vec<AccessRecord> = serde_json::from_str(&contents)?;
self.access_records.clear();
for record in records {
self.access_records.insert(record.cid.clone(), record);
}
Ok(())
}
pub fn get_warming_candidates(&self) -> Result<Vec<WarmingCandidate>, WarmingError> {
let mut candidates: Vec<WarmingCandidate> = self
.access_records
.values()
.map(|record| {
let score = self.calculate_score(record);
WarmingCandidate {
cid: record.cid.clone(),
size_bytes: record.size_bytes,
score,
access_count: record.access_count,
last_access_ms: record.last_access_ms,
}
})
.collect();
candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
self.apply_constraints(&mut candidates);
Ok(candidates)
}
#[inline]
fn calculate_score(&self, record: &AccessRecord) -> f64 {
match self.config.strategy {
WarmingStrategy::FrequencyBased => {
record.access_count as f64
}
WarmingStrategy::RecencyBased => {
let now = Self::current_timestamp_ms();
let age_ms = now.saturating_sub(record.last_access_ms);
let age_hours = age_ms as f64 / (1000.0 * 3600.0);
1.0 / (1.0 + age_hours)
}
WarmingStrategy::Hybrid => {
let frequency_score = record.access_count as f64;
let now = Self::current_timestamp_ms();
let age_ms = now.saturating_sub(record.last_access_ms);
let age_hours = age_ms as f64 / (1000.0 * 3600.0);
let recency_score = 1.0 / (1.0 + age_hours);
0.7 * frequency_score + 0.3 * recency_score * 100.0
}
WarmingStrategy::Predictive => {
let frequency = record.access_count as f64;
let lifetime_days =
(record.last_access_ms - record.first_access_ms) as f64 / (1000.0 * 86400.0);
if lifetime_days < 0.01 {
return frequency;
}
let access_rate = frequency / lifetime_days;
let now = Self::current_timestamp_ms();
let age_hours =
(now.saturating_sub(record.last_access_ms)) as f64 / (1000.0 * 3600.0);
let recency_boost = if age_hours < 24.0 {
2.0 } else if age_hours < 168.0 {
1.5
} else {
1.0
};
access_rate * recency_boost
}
}
}
#[inline]
fn apply_constraints(&self, candidates: &mut Vec<WarmingCandidate>) {
let mut total_bytes = 0u64;
let mut keep_count = 0usize;
for candidate in candidates.iter() {
if keep_count >= self.config.max_items {
break;
}
if total_bytes + candidate.size_bytes > self.config.max_bytes {
break;
}
total_bytes += candidate.size_bytes;
keep_count += 1;
}
candidates.truncate(keep_count);
}
#[must_use]
#[inline]
pub fn warming_stats(&self) -> WarmingStats {
let candidates = self.get_warming_candidates().unwrap_or_default();
let total_items = candidates.len();
let total_bytes: u64 = candidates.iter().map(|c| c.size_bytes).sum();
let avg_score = if !candidates.is_empty() {
candidates.iter().map(|c| c.score).sum::<f64>() / candidates.len() as f64
} else {
0.0
};
WarmingStats {
total_items,
total_bytes,
avg_score,
strategy: self.config.strategy,
}
}
#[inline]
pub fn clear(&mut self) {
self.access_records.clear();
}
#[inline]
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
}
#[derive(Debug, Clone)]
pub struct WarmingStats {
pub total_items: usize,
pub total_bytes: u64,
pub avg_score: f64,
pub strategy: WarmingStrategy,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_warmer() -> CacheWarmer {
let config = WarmingConfig {
strategy: WarmingStrategy::FrequencyBased,
max_items: 10,
max_bytes: 1024 * 1024, access_log_path: PathBuf::from("/tmp/test_access.log"),
warmup_on_startup: false,
};
CacheWarmer::new(config).unwrap()
}
#[tokio::test]
async fn test_record_access() {
let mut warmer = create_test_warmer();
warmer.record_access("QmTest1".to_string(), 1024).await;
warmer.record_access("QmTest1".to_string(), 1024).await;
warmer.record_access("QmTest2".to_string(), 2048).await;
assert_eq!(warmer.access_records.len(), 2);
assert_eq!(warmer.access_records["QmTest1"].access_count, 2);
assert_eq!(warmer.access_records["QmTest2"].access_count, 1);
}
#[tokio::test]
async fn test_frequency_based_warming() {
let mut warmer = create_test_warmer();
for _ in 0..10 {
warmer.record_access("QmFrequent".to_string(), 100).await;
}
for _ in 0..3 {
warmer.record_access("QmMedium".to_string(), 100).await;
}
warmer.record_access("QmRare".to_string(), 100).await;
let candidates = warmer.get_warming_candidates().unwrap();
assert_eq!(candidates.len(), 3);
assert_eq!(candidates[0].cid, "QmFrequent");
assert_eq!(candidates[1].cid, "QmMedium");
assert_eq!(candidates[2].cid, "QmRare");
}
#[tokio::test]
async fn test_max_items_constraint() {
let mut warmer = create_test_warmer();
for i in 0..20 {
warmer.record_access(format!("QmTest{}", i), 100).await;
}
let candidates = warmer.get_warming_candidates().unwrap();
assert_eq!(candidates.len(), 10);
}
#[tokio::test]
async fn test_max_bytes_constraint() {
let mut warmer = create_test_warmer();
for i in 0..10 {
warmer
.record_access(format!("QmTest{}", i), 200 * 1024)
.await; }
let candidates = warmer.get_warming_candidates().unwrap();
let total_bytes: u64 = candidates.iter().map(|c| c.size_bytes).sum();
assert!(total_bytes <= 1024 * 1024); }
#[tokio::test]
async fn test_persist_and_load() {
let log_path = PathBuf::from("/tmp/test_persist_access.log");
let mut warmer = CacheWarmer::new(WarmingConfig {
access_log_path: log_path.clone(),
..Default::default()
})
.unwrap();
warmer.record_access("QmTest1".to_string(), 1024).await;
warmer.record_access("QmTest2".to_string(), 2048).await;
warmer.persist().await.unwrap();
let mut new_warmer = CacheWarmer::new(WarmingConfig {
access_log_path: log_path.clone(),
..Default::default()
})
.unwrap();
new_warmer.load().await.unwrap();
assert_eq!(new_warmer.access_records.len(), 2);
assert!(new_warmer.access_records.contains_key("QmTest1"));
assert!(new_warmer.access_records.contains_key("QmTest2"));
let _ = std::fs::remove_file(log_path);
}
#[tokio::test]
async fn test_hybrid_strategy() {
let config = WarmingConfig {
strategy: WarmingStrategy::Hybrid,
max_items: 10,
max_bytes: 1024 * 1024,
access_log_path: PathBuf::from("/tmp/test_hybrid.log"),
warmup_on_startup: false,
};
let mut warmer = CacheWarmer::new(config).unwrap();
for _ in 0..100 {
warmer.record_access("QmOldFrequent".to_string(), 100).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
for _ in 0..5 {
warmer.record_access("QmRecentRare".to_string(), 100).await;
}
let candidates = warmer.get_warming_candidates().unwrap();
assert!(!candidates.is_empty());
}
#[test]
fn test_warming_stats() {
let warmer = create_test_warmer();
let stats = warmer.warming_stats();
assert_eq!(stats.total_items, 0);
assert_eq!(stats.total_bytes, 0);
}
#[test]
fn test_invalid_config() {
let config = WarmingConfig {
max_items: 0,
..Default::default()
};
assert!(CacheWarmer::new(config).is_err());
}
#[tokio::test]
async fn test_clear() {
let mut warmer = create_test_warmer();
warmer.record_access("QmTest1".to_string(), 1024).await;
warmer.record_access("QmTest2".to_string(), 2048).await;
assert_eq!(warmer.access_records.len(), 2);
warmer.clear();
assert_eq!(warmer.access_records.len(), 0);
}
}