use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime};
use crate::{XXResult, file, hash::hash_to_str};
#[derive(Debug, Serialize, Deserialize)]
struct CacheEntry<T> {
data: T,
created_at: u64,
version: String,
files_hash: Option<String>,
}
#[derive(Default)]
pub struct CacheManagerBuilder {
cache_dir: Option<PathBuf>,
version: String,
fresh_duration: Option<Duration>,
fresh_files: Vec<PathBuf>,
}
impl CacheManagerBuilder {
pub fn cache_dir<P: AsRef<Path>>(mut self, dir: P) -> Self {
self.cache_dir = Some(dir.as_ref().to_path_buf());
self
}
pub fn version<S: Into<String>>(mut self, version: S) -> Self {
self.version = version.into();
self
}
pub fn fresh_duration(mut self, duration: Duration) -> Self {
self.fresh_duration = Some(duration);
self
}
pub fn fresh_file<P: AsRef<Path>>(mut self, path: P) -> Self {
self.fresh_files.push(path.as_ref().to_path_buf());
self
}
pub fn fresh_files<I, P>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = P>,
P: AsRef<Path>,
{
for path in paths {
self.fresh_files.push(path.as_ref().to_path_buf());
}
self
}
pub fn build(self) -> XXResult<CacheManager> {
let cache_dir = self
.cache_dir
.ok_or_else(|| crate::error!("cache_dir is required"))?;
file::mkdirp(&cache_dir)?;
Ok(CacheManager {
cache_dir,
version: self.version,
fresh_duration: self.fresh_duration,
fresh_files: self.fresh_files,
})
}
}
pub struct CacheManager {
cache_dir: PathBuf,
version: String,
fresh_duration: Option<Duration>,
fresh_files: Vec<PathBuf>,
}
impl CacheManager {
pub fn builder() -> CacheManagerBuilder {
CacheManagerBuilder::default()
}
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
let path = self.cache_path(key);
if !path.exists() {
return None;
}
let content = file::read_to_string(&path).ok()?;
let entry: CacheEntry<T> = serde_json::from_str(&content).ok()?;
if !self.is_entry_fresh(
key,
entry.created_at,
&entry.version,
entry.files_hash.as_deref(),
) {
return None;
}
trace!("Cache hit: {}", key);
Some(entry.data)
}
pub fn set<T: Serialize>(&self, key: &str, data: &T) -> XXResult<()> {
let path = self.cache_path(key);
let entry = CacheEntry {
data,
created_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
version: self.version.clone(),
files_hash: self.compute_files_hash(),
};
let content = serde_json::to_string_pretty(&entry)
.map_err(|e| crate::error!("Failed to serialize cache entry: {}", e))?;
file::write(&path, content)?;
trace!("Cache set: {}", key);
Ok(())
}
pub fn remove(&self, key: &str) -> XXResult<()> {
let path = self.cache_path(key);
file::remove_file(&path)
}
pub fn clear(&self) -> XXResult<()> {
file::remove_dir_all(&self.cache_dir)?;
file::mkdirp(&self.cache_dir)
}
pub fn get_or_try<T, F, E>(&self, key: &str, f: F) -> Result<T, E>
where
T: Serialize + DeserializeOwned,
F: FnOnce() -> Result<T, E>,
E: From<crate::XXError>,
{
if let Some(value) = self.get::<T>(key) {
return Ok(value);
}
let value = f()?;
self.set(key, &value)?;
Ok(value)
}
pub fn contains(&self, key: &str) -> bool {
let path = self.cache_path(key);
if !path.exists() {
return false;
}
if let Ok(content) = file::read_to_string(&path) {
if let Ok(entry) = serde_json::from_str::<CacheEntry<serde_json::Value>>(&content) {
return self.is_entry_fresh(
key,
entry.created_at,
&entry.version,
entry.files_hash.as_deref(),
);
}
}
false
}
fn cache_path(&self, key: &str) -> PathBuf {
let hash = hash_to_str(&key);
self.cache_dir.join(format!("{}.json", hash))
}
fn is_entry_fresh(
&self,
key: &str,
created_at: u64,
version: &str,
files_hash: Option<&str>,
) -> bool {
if version != self.version {
trace!("Cache miss (version mismatch): {}", key);
return false;
}
if let Some(duration) = self.fresh_duration {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now.saturating_sub(created_at) >= duration.as_secs() {
trace!("Cache miss (expired): {}", key);
return false;
}
}
let current_hash = self.compute_files_hash();
if current_hash.as_deref() != files_hash {
trace!("Cache miss (files changed): {}", key);
return false;
}
true
}
fn compute_files_hash(&self) -> Option<String> {
if self.fresh_files.is_empty() {
return None;
}
let mtimes: Vec<Option<u64>> = self
.fresh_files
.iter()
.map(|path| file::modified_time(path).ok().map(|m| m.as_secs()))
.collect();
Some(hash_to_str(&mtimes))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_basic() {
let tmpdir = tempfile::tempdir().unwrap();
let cache = CacheManager::builder()
.cache_dir(tmpdir.path())
.version("1.0")
.build()
.unwrap();
assert!(cache.get::<String>("key1").is_none());
cache.set("key1", &"value1".to_string()).unwrap();
assert_eq!(cache.get::<String>("key1"), Some("value1".to_string()));
cache.remove("key1").unwrap();
assert!(cache.get::<String>("key1").is_none());
}
#[test]
fn test_cache_version_invalidation() {
let tmpdir = tempfile::tempdir().unwrap();
let cache_v1 = CacheManager::builder()
.cache_dir(tmpdir.path())
.version("1.0")
.build()
.unwrap();
cache_v1.set("key", &"value".to_string()).unwrap();
assert_eq!(cache_v1.get::<String>("key"), Some("value".to_string()));
let cache_v2 = CacheManager::builder()
.cache_dir(tmpdir.path())
.version("2.0")
.build()
.unwrap();
assert!(cache_v2.get::<String>("key").is_none());
}
#[test]
fn test_cache_duration_expiration() {
let tmpdir = tempfile::tempdir().unwrap();
let cache = CacheManager::builder()
.cache_dir(tmpdir.path())
.version("1.0")
.fresh_duration(Duration::from_secs(0)) .build()
.unwrap();
cache.set("key", &"value".to_string()).unwrap();
std::thread::sleep(Duration::from_millis(10));
assert!(cache.get::<String>("key").is_none());
}
#[test]
fn test_cache_contains() {
let tmpdir = tempfile::tempdir().unwrap();
let cache = CacheManager::builder()
.cache_dir(tmpdir.path())
.version("1.0")
.build()
.unwrap();
assert!(!cache.contains("key"));
cache.set("key", &"value".to_string()).unwrap();
assert!(cache.contains("key"));
}
#[test]
fn test_cache_clear() {
let tmpdir = tempfile::tempdir().unwrap();
let cache = CacheManager::builder()
.cache_dir(tmpdir.path())
.version("1.0")
.build()
.unwrap();
cache.set("key1", &"value1".to_string()).unwrap();
cache.set("key2", &"value2".to_string()).unwrap();
cache.clear().unwrap();
assert!(!cache.contains("key1"));
assert!(!cache.contains("key2"));
}
#[test]
fn test_cache_complex_types() {
let tmpdir = tempfile::tempdir().unwrap();
let cache = CacheManager::builder()
.cache_dir(tmpdir.path())
.version("1.0")
.build()
.unwrap();
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestData {
name: String,
values: Vec<i32>,
}
let data = TestData {
name: "test".to_string(),
values: vec![1, 2, 3],
};
cache.set("complex", &data).unwrap();
let retrieved: Option<TestData> = cache.get("complex");
assert_eq!(retrieved, Some(data));
}
}