use super::{Cache, CacheError};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Serialize, Deserialize)]
struct DiskCacheEntry {
data: String,
cached_at: u64,
ttl_seconds: u64,
}
#[derive(Debug)]
pub struct DiskCache {
pub(crate) cache_dir: PathBuf,
}
impl DiskCache {
pub fn new() -> Result<Self, std::io::Error> {
let cache_dir = Self::get_cache_dir()?;
fs::create_dir_all(&cache_dir)?;
for subdir in ["search", "info", "comments", "pkgbuild"] {
fs::create_dir_all(cache_dir.join(subdir))?;
}
Ok(Self { cache_dir })
}
fn get_cache_dir() -> Result<PathBuf, std::io::Error> {
dirs::cache_dir()
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
"Cannot determine cache directory",
)
})
.map(|dir| dir.join("arch-toolkit"))
}
fn get_file_path(&self, key: &str) -> PathBuf {
let (subdir, key_part) = key
.strip_prefix("search:")
.map(|rest| ("search", rest))
.or_else(|| key.strip_prefix("info:").map(|rest| ("info", rest)))
.or_else(|| key.strip_prefix("comments:").map(|rest| ("comments", rest)))
.or_else(|| key.strip_prefix("pkgbuild:").map(|rest| ("pkgbuild", rest)))
.unwrap_or(("search", key));
let safe_filename = key_part
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' {
c
} else {
'_'
}
})
.collect::<String>();
let filename = if safe_filename.len() > 200 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
format!("{:x}", hasher.finish())
} else {
safe_filename
};
self.cache_dir.join(subdir).join(format!("{filename}.json"))
}
fn is_expired(entry: &DiskCacheEntry) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_secs());
entry.cached_at + entry.ttl_seconds < now
}
#[allow(dead_code)] pub fn cleanup_expired(&self) {
for subdir in ["search", "info", "comments", "pkgbuild"] {
let dir = self.cache_dir.join(subdir);
if let Ok(entries) = fs::read_dir(&dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension() == Some(std::ffi::OsStr::new("json"))
&& let Ok(content) = fs::read_to_string(&path)
&& let Ok(cache_entry) = serde_json::from_str::<DiskCacheEntry>(&content)
&& Self::is_expired(&cache_entry)
{
let _ = fs::remove_file(&path);
}
}
}
}
}
}
impl<K, V> Cache<K, V> for DiskCache
where
K: AsRef<str>,
V: Clone + Serialize + for<'de> Deserialize<'de>,
{
fn get(&self, key: &K) -> Option<V> {
let path = self.get_file_path(key.as_ref());
let content = fs::read_to_string(&path).ok()?;
let entry: DiskCacheEntry = serde_json::from_str(&content).ok()?;
if Self::is_expired(&entry) {
let _ = fs::remove_file(&path);
return None;
}
serde_json::from_str(&entry.data).ok()
}
fn set(&self, key: &K, value: &V, ttl: Duration) -> Result<(), CacheError> {
let path = self.get_file_path(key.as_ref());
let data =
serde_json::to_string(value).map_err(|e| CacheError::Serialization(e.to_string()))?;
let cached_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| CacheError::Other(format!("System time error: {e}")))?;
let entry = DiskCacheEntry {
data,
cached_at: cached_at.as_secs(),
ttl_seconds: ttl.as_secs(),
};
let json =
serde_json::to_string(&entry).map_err(|e| CacheError::Serialization(e.to_string()))?;
let temp_path = path.with_extension("tmp");
fs::write(&temp_path, json).map_err(CacheError::Io)?;
fs::rename(&temp_path, &path).map_err(CacheError::Io)?;
Ok(())
}
fn invalidate(&self, key: &K) -> Result<(), CacheError> {
let path = self.get_file_path(key.as_ref());
if path.exists() {
fs::remove_file(&path).map_err(CacheError::Io)?;
}
Ok(())
}
fn clear(&self) -> Result<(), CacheError> {
for subdir in ["search", "info", "comments", "pkgbuild"] {
let dir = self.cache_dir.join(subdir);
if dir.exists() {
for entry in fs::read_dir(&dir).map_err(CacheError::Io)? {
let entry = entry.map_err(CacheError::Io)?;
let path = entry.path();
if path.extension() == Some(std::ffi::OsStr::new("json"))
|| path.extension() == Some(std::ffi::OsStr::new("tmp"))
{
let _ = fs::remove_file(&path);
}
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[allow(clippy::unwrap_used)]
fn create_test_cache() -> (DiskCache, TempDir) {
let temp_dir = TempDir::new().unwrap();
let cache_dir = temp_dir.path().join("cache");
fs::create_dir_all(&cache_dir).unwrap();
for subdir in ["search", "info", "comments", "pkgbuild"] {
fs::create_dir_all(cache_dir.join(subdir)).unwrap();
}
let cache = DiskCache { cache_dir };
(cache, temp_dir)
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_disk_cache_get_set() {
let (cache, _temp_dir) = create_test_cache();
let key = "test_key".to_string();
let value = "test_value".to_string();
assert!(<DiskCache as Cache<String, String>>::get(&cache, &key).is_none());
<DiskCache as Cache<String, String>>::set(&cache, &key, &value, Duration::from_secs(60))
.unwrap();
let retrieved = <DiskCache as Cache<String, String>>::get(&cache, &key);
assert_eq!(retrieved, Some(value));
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_disk_cache_invalidate() {
let (cache, _temp_dir) = create_test_cache();
let key = "test_key".to_string();
let value = "test_value".to_string();
<DiskCache as Cache<String, String>>::set(&cache, &key, &value, Duration::from_secs(60))
.unwrap();
assert!(<DiskCache as Cache<String, String>>::get(&cache, &key).is_some());
<DiskCache as Cache<String, String>>::invalidate(&cache, &key).unwrap();
assert!(<DiskCache as Cache<String, String>>::get(&cache, &key).is_none());
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_disk_cache_clear() {
let (cache, _temp_dir) = create_test_cache();
let value = "test_value".to_string();
<DiskCache as Cache<String, String>>::set(
&cache,
&"key1".to_string(),
&value,
Duration::from_secs(60),
)
.unwrap();
<DiskCache as Cache<String, String>>::set(
&cache,
&"key2".to_string(),
&value,
Duration::from_secs(60),
)
.unwrap();
<DiskCache as Cache<String, String>>::clear(&cache).unwrap();
assert!(<DiskCache as Cache<String, String>>::get(&cache, &"key1".to_string()).is_none());
assert!(<DiskCache as Cache<String, String>>::get(&cache, &"key2".to_string()).is_none());
}
}