use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::fs;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub dir: PathBuf,
pub max_size: Option<u64>,
pub max_age: Option<Duration>,
pub persist_metadata: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
dir: PathBuf::from(".cache/pulith-fetch"),
max_size: Some(1024 * 1024 * 1024), max_age: Some(Duration::from_secs(7 * 24 * 60 * 60)), persist_metadata: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
pub url: String,
pub etag: Option<String>,
pub last_modified: Option<u64>, pub cached_at: u64, pub size: u64,
pub checksum: [u8; 32],
pub access_count: u64,
pub last_accessed: u64, pub max_age: Option<u64>, pub no_cache: bool,
}
impl CacheEntry {
pub fn is_expired(&self, config_max_age: Option<Duration>) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if let Some(server_max_age) = self.max_age
&& self.cached_at + server_max_age < now
{
return true;
}
if let Some(config_max_age) = config_max_age
&& self.cached_at + config_max_age.as_secs() < now
{
return true;
}
false
}
pub fn should_revalidate(&self) -> bool {
self.no_cache || self.etag.is_some() || self.last_modified.is_some()
}
}
pub struct Cache {
config: CacheConfig,
entries: RwLock<HashMap<String, CacheEntry>>,
current_size: RwLock<u64>,
}
impl Cache {
pub async fn new(config: CacheConfig) -> Result<Self> {
fs::create_dir_all(&config.dir)
.await
.map_err(|e| Error::Network(format!("Failed to create cache directory: {}", e)))?;
let cache = Self {
entries: RwLock::new(HashMap::new()),
current_size: RwLock::new(0),
config,
};
if cache.config.persist_metadata {
cache.load_metadata().await?;
}
Ok(cache)
}
pub async fn get(&self, url: &str) -> Result<Option<CacheEntry>> {
let entries = self.entries.read().await;
if let Some(entry) = entries.get(url) {
if entry.is_expired(self.config.max_age) {
return Ok(None);
}
drop(entries);
self.update_access(url).await;
let entries = self.entries.read().await;
Ok(entries.get(url).cloned())
} else {
Ok(None)
}
}
pub async fn put(
&self,
url: String,
content: &[u8],
etag: Option<String>,
last_modified: Option<u64>,
max_age: Option<u64>,
no_cache: bool,
) -> Result<()> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(content);
let checksum = hasher.finalize().into();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let entry = CacheEntry {
url: url.clone(),
etag,
last_modified,
cached_at: now,
size: content.len() as u64,
checksum,
access_count: 1,
last_accessed: now,
max_age,
no_cache,
};
if let Some(max_size) = self.config.max_size {
let current_size = *self.current_size.read().await;
if current_size + entry.size > max_size {
self.evict_lru(entry.size).await?;
}
}
let cache_file = self.cache_file_path(&url);
fs::write(&cache_file, content)
.await
.map_err(|e| Error::Network(format!("Failed to write cache file: {}", e)))?;
{
let mut entries = self.entries.write().await;
let mut current_size = self.current_size.write().await;
if let Some(old_entry) = entries.remove(&url) {
*current_size = current_size.saturating_sub(old_entry.size);
}
entries.insert(url.clone(), entry.clone());
*current_size += entry.size;
}
if self.config.persist_metadata {
self.save_metadata().await?;
}
Ok(())
}
pub async fn validate(
&self,
url: &str,
server_etag: Option<&str>,
server_last_modified: Option<u64>,
) -> Result<bool> {
let entries = self.entries.read().await;
if let Some(entry) = entries.get(url) {
if let (Some(cached_etag), Some(server_etag)) = (&entry.etag, server_etag)
&& cached_etag == server_etag
{
return Ok(true);
}
if let (Some(cached_modified), Some(server_modified)) =
(entry.last_modified, server_last_modified)
&& cached_modified >= server_modified
{
return Ok(true);
}
}
Ok(false)
}
fn cache_file_path(&self, url: &str) -> PathBuf {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(url.as_bytes());
let hash = hex::encode(hasher.finalize());
self.config.dir.join(format!("{}.cache", hash))
}
async fn update_access(&self, url: &str) {
let mut entries = self.entries.write().await;
if let Some(entry) = entries.get_mut(url) {
entry.access_count += 1;
entry.last_accessed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
}
}
async fn evict_lru(&self, needed_space: u64) -> Result<()> {
let mut entries = self.entries.write().await;
let mut current_size = self.current_size.write().await;
let mut sorted_entries: Vec<_> = entries.iter().collect();
sorted_entries.sort_by_key(|(_, entry)| entry.last_accessed);
let mut freed_space = 0u64;
let mut to_remove = Vec::new();
for (url, entry) in sorted_entries {
if freed_space >= needed_space {
break;
}
to_remove.push(url.clone());
freed_space += entry.size;
}
for url in to_remove {
if let Some(entry) = entries.remove(&url) {
*current_size = current_size.saturating_sub(entry.size);
let cache_file = self.cache_file_path(&url);
let _ = fs::remove_file(cache_file).await;
}
}
Ok(())
}
async fn load_metadata(&self) -> Result<()> {
let metadata_file = self.config.dir.join("metadata.json");
if !metadata_file.exists() {
return Ok(());
}
let content = fs::read_to_string(&metadata_file)
.await
.map_err(|e| Error::Network(format!("Failed to read metadata file: {}", e)))?;
let loaded_entries: HashMap<String, CacheEntry> = serde_json::from_str(&content)
.map_err(|e| Error::InvalidState(format!("Invalid metadata format: {}", e)))?;
let mut total_size = 0u64;
for entry in loaded_entries.values() {
total_size += entry.size;
}
*self.entries.write().await = loaded_entries;
*self.current_size.write().await = total_size;
Ok(())
}
async fn save_metadata(&self) -> Result<()> {
let metadata_file = self.config.dir.join("metadata.json");
let entries = self.entries.read().await;
let content = serde_json::to_string_pretty(&*entries)
.map_err(|e| Error::InvalidState(format!("Failed to serialize metadata: {}", e)))?;
fs::write(&metadata_file, content)
.await
.map_err(|e| Error::Network(format!("Failed to write metadata file: {}", e)))?;
Ok(())
}
pub async fn clear(&self) -> Result<()> {
let entries = self.entries.read().await;
for url in entries.keys() {
let cache_file = self.cache_file_path(url);
let _ = fs::remove_file(cache_file).await;
}
drop(entries);
self.entries.write().await.clear();
*self.current_size.write().await = 0;
let metadata_file = self.config.dir.join("metadata.json");
let _ = fs::remove_file(metadata_file).await;
Ok(())
}
pub async fn stats(&self) -> CacheStats {
let entries = self.entries.read().await;
let current_size = *self.current_size.read().await;
CacheStats {
entry_count: entries.len(),
total_size: current_size,
max_size: self.config.max_size,
hit_count: 0, miss_count: 0, }
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub entry_count: usize,
pub total_size: u64,
pub max_size: Option<u64>,
pub hit_count: u64,
pub miss_count: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn create_test_cache() -> (Cache, TempDir) {
let temp_dir = TempDir::new().unwrap();
let config = CacheConfig {
dir: temp_dir.path().to_path_buf(),
max_size: Some(1024),
max_age: Some(Duration::from_secs(3600)),
persist_metadata: true,
};
(Cache::new(config).await.unwrap(), temp_dir)
}
#[tokio::test]
async fn test_cache_put_and_get() {
let (cache, _temp_dir) = create_test_cache().await;
let url = "https://example.com/test.txt";
let content = b"Hello, World!";
cache
.put(
url.to_string(),
content,
Some("\"etag123\"".to_string()),
Some(1234567890),
Some(3600),
false,
)
.await
.unwrap();
let entry = cache.get(url).await.unwrap().unwrap();
assert_eq!(entry.url, url);
assert_eq!(entry.etag, Some("\"etag123\"".to_string()));
assert_eq!(entry.last_modified, Some(1234567890));
assert_eq!(entry.size, content.len() as u64);
}
#[tokio::test]
async fn test_cache_expiration() {
let (cache, _temp_dir) = create_test_cache().await;
let url = "https://example.com/test.txt";
let content = b"Hello, World!";
cache
.put(
url.to_string(),
content,
None,
None,
Some(1), false,
)
.await
.unwrap();
assert!(cache.get(url).await.unwrap().is_some());
}
#[tokio::test]
async fn test_cache_validation() {
let (cache, _temp_dir) = create_test_cache().await;
let url = "https://example.com/test.txt";
let content = b"Hello, World!";
cache
.put(
url.to_string(),
content,
Some("\"etag123\"".to_string()),
Some(1234567890),
None,
false,
)
.await
.unwrap();
assert!(
cache
.validate(url, Some("\"etag123\""), None)
.await
.unwrap()
);
assert!(
!cache
.validate(url, Some("\"etag456\""), None)
.await
.unwrap()
);
assert!(cache.validate(url, None, Some(1234567890)).await.unwrap());
assert!(!cache.validate(url, None, Some(1234567891)).await.unwrap());
}
#[tokio::test]
async fn test_cache_eviction() {
let (cache, _temp_dir) = create_test_cache().await;
for i in 0..5 {
let url = format!("https://example.com/test{}.txt", i);
let content = vec![b'x'; 300]; cache
.put(url, &content, None, None, None, false)
.await
.unwrap();
}
let stats = cache.stats().await;
assert!(stats.entry_count <= 3);
}
#[tokio::test]
async fn test_cache_clear() {
let (cache, _temp_dir) = create_test_cache().await;
cache
.put(
"https://example.com/test1.txt".to_string(),
b"content1",
None,
None,
None,
false,
)
.await
.unwrap();
cache
.put(
"https://example.com/test2.txt".to_string(),
b"content2",
None,
None,
None,
false,
)
.await
.unwrap();
cache.clear().await.unwrap();
let stats = cache.stats().await;
assert_eq!(stats.entry_count, 0);
assert_eq!(stats.total_size, 0);
}
#[tokio::test]
async fn test_metadata_persistence() {
let temp_dir = TempDir::new().unwrap();
let config = CacheConfig {
dir: temp_dir.path().to_path_buf(),
max_size: Some(1024),
max_age: None,
persist_metadata: true,
};
let cache1 = Cache::new(config.clone()).await.unwrap();
cache1
.put(
"https://example.com/test.txt".to_string(),
b"Hello, World!",
Some("\"etag123\"".to_string()),
None,
None,
false,
)
.await
.unwrap();
drop(cache1);
let cache2 = Cache::new(config).await.unwrap();
let entry = cache2
.get("https://example.com/test.txt")
.await
.unwrap()
.unwrap();
assert_eq!(entry.etag, Some("\"etag123\"".to_string()));
}
}