use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime};
pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(24 * 60 * 60);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
pub created_at: SystemTime,
pub data: Vec<u8>,
pub content_type: String,
}
pub struct FileCache {
cache_dir: PathBuf,
ttl: Duration,
}
impl FileCache {
pub fn new(cache_dir: impl AsRef<Path>, ttl: Duration) -> Result<Self> {
let cache_dir = cache_dir.as_ref().to_path_buf();
fs::create_dir_all(&cache_dir)
.map_err(|e| Error::io(format!("Failed to create cache directory: {e}")))?;
Ok(Self { cache_dir, ttl })
}
pub fn create_default() -> Result<Self> {
let cache_dir = dirs::cache_dir()
.ok_or_else(|| Error::config("Could not determine cache directory"))?
.join("ferrous-forge")
.join("github");
Self::new(cache_dir, DEFAULT_CACHE_TTL)
}
pub fn get(&self, key: &str) -> Option<CacheEntry> {
let path = self.cache_path(key);
if !path.exists() {
return None;
}
let entry = match self.read_entry(&path) {
Ok(e) => e,
Err(_) => return None,
};
if self.is_expired(&entry) {
let _ = fs::remove_file(&path);
return None;
}
Some(entry)
}
pub fn set(&self, key: &str, data: Vec<u8>, content_type: &str) -> Result<()> {
let path = self.cache_path(key);
let entry = CacheEntry {
created_at: SystemTime::now(),
data,
content_type: content_type.to_string(),
};
let json = serde_json::to_vec(&entry)
.map_err(|e| Error::Validation(format!("Failed to serialize cache entry: {e}")))?;
fs::write(&path, json)
.map_err(|e| Error::io(format!("Failed to write cache file: {e}")))?;
Ok(())
}
pub fn should_use_offline(&self) -> bool {
if std::env::var("FERROUS_FORGE_OFFLINE").is_ok() {
return true;
}
self.has_valid_cache()
}
fn has_valid_cache(&self) -> bool {
let Ok(entries) = fs::read_dir(&self.cache_dir) else {
return false;
};
for entry in entries.flatten() {
if let Ok(cache_entry) = self.read_entry(&entry.path())
&& !self.is_expired(&cache_entry)
{
return true;
}
}
false
}
fn cache_path(&self, key: &str) -> PathBuf {
let safe_key = key.replace(['/', '\\', ':', ' '], "_");
self.cache_dir.join(format!("{safe_key}.json"))
}
fn read_entry(&self, path: &Path) -> Result<CacheEntry> {
let data =
fs::read(path).map_err(|e| Error::io(format!("Failed to read cache file: {e}")))?;
serde_json::from_slice(&data)
.map_err(|e| Error::parse(format!("Failed to parse cache entry: {e}")))
}
fn is_expired(&self, entry: &CacheEntry) -> bool {
SystemTime::now()
.duration_since(entry.created_at)
.map(|elapsed| elapsed > self.ttl)
.unwrap_or(true)
}
pub fn clear(&self) -> Result<()> {
let entries = fs::read_dir(&self.cache_dir)
.map_err(|e| Error::io(format!("Failed to read cache directory: {e}")))?;
for entry in entries.flatten() {
let _ = fs::remove_file(entry.path());
}
Ok(())
}
pub fn stats(&self) -> CacheStats {
let mut stats = CacheStats {
total_entries: 0,
valid_entries: 0,
expired_entries: 0,
total_size: 0,
};
let Ok(entries) = fs::read_dir(&self.cache_dir) else {
return stats;
};
for entry in entries.flatten() {
stats.total_entries += 1;
if let Ok(metadata) = entry.metadata() {
stats.total_size += metadata.len();
}
if let Ok(cache_entry) = self.read_entry(&entry.path()) {
if self.is_expired(&cache_entry) {
stats.expired_entries += 1;
} else {
stats.valid_entries += 1;
}
}
}
stats
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub valid_entries: usize,
pub expired_entries: usize,
pub total_size: u64,
}
impl std::fmt::Display for CacheStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cache: {} total, {} valid, {} expired, {} bytes",
self.total_entries, self.valid_entries, self.expired_entries, self.total_size
)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_cache_basic() {
let temp_dir = tempfile::tempdir().unwrap();
let cache = FileCache::new(temp_dir.path(), Duration::from_secs(60)).unwrap();
cache
.set("test-key", b"test data".to_vec(), "text/plain")
.unwrap();
let entry = cache.get("test-key").unwrap();
assert_eq!(entry.data, b"test data");
assert_eq!(entry.content_type, "text/plain");
}
#[test]
fn test_cache_expiration() {
let temp_dir = tempfile::tempdir().unwrap();
let cache = FileCache::new(temp_dir.path(), Duration::from_millis(50)).unwrap();
cache
.set("test-key", b"test data".to_vec(), "text/plain")
.unwrap();
assert!(cache.get("test-key").is_some());
thread::sleep(Duration::from_millis(60));
assert!(cache.get("test-key").is_none());
}
#[test]
fn test_cache_stats() {
let temp_dir = tempfile::tempdir().unwrap();
let cache = FileCache::new(temp_dir.path(), Duration::from_secs(60)).unwrap();
cache.set("key1", b"data1".to_vec(), "text/plain").unwrap();
cache.set("key2", b"data2".to_vec(), "text/plain").unwrap();
let stats = cache.stats();
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.valid_entries, 2);
assert_eq!(stats.expired_entries, 0);
}
}