use serde::{Deserialize, Serialize};
use std::fs;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::time::{Duration, SystemTime};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum CacheError {
#[error("Failed to determine cache directory location")]
CacheDirectoryNotFound,
#[error("Failed to create cache directory at {path}: {source}")]
DirectoryCreationFailed {
path: PathBuf,
source: std::io::Error,
},
#[error("Failed to read cache file {path}: {source}")]
ReadFailed {
path: PathBuf,
source: std::io::Error,
},
#[error("Failed to write cache file {path}: {source}")]
WriteFailed {
path: PathBuf,
source: std::io::Error,
},
#[error("Failed to deserialize cache file {path}: {source}")]
DeserializationFailed {
path: PathBuf,
source: serde_json::Error,
},
#[error("Failed to serialize data: {0}")]
SerializationFailed(#[from] serde_json::Error),
}
#[derive(Debug, Serialize, Deserialize)]
struct CachedItem<T> {
data: T,
timestamp: SystemTime,
}
pub(crate) struct CacheStorage<T> {
cache_dir: PathBuf,
ttl: Option<Duration>,
_phantom: PhantomData<T>,
}
impl<T> CacheStorage<T>
where
T: Serialize + for<'de> Deserialize<'de>,
{
pub fn open(name: &str, ttl: Option<Duration>) -> Result<Self, CacheError> {
let proj_dirs = directories::ProjectDirs::from("de", "westhoffswelt", "dialogdetective")
.ok_or(CacheError::CacheDirectoryNotFound)?;
let sanitized_name = sanitize_name(name);
let cache_dir = proj_dirs.cache_dir().join(&sanitized_name);
fs::create_dir_all(&cache_dir).map_err(|e| CacheError::DirectoryCreationFailed {
path: cache_dir.clone(),
source: e,
})?;
Ok(Self {
cache_dir,
ttl,
_phantom: PhantomData,
})
}
pub fn load(&self, identifier: &str) -> Result<Option<T>, CacheError> {
let sanitized_id = sanitize_name(identifier);
let file_path = self.cache_dir.join(format!("{}.json", sanitized_id));
if !file_path.exists() {
return Ok(None);
}
let content = fs::read_to_string(&file_path).map_err(|e| CacheError::ReadFailed {
path: file_path.clone(),
source: e,
})?;
let cached_item: CachedItem<T> =
serde_json::from_str(&content).map_err(|e| CacheError::DeserializationFailed {
path: file_path.clone(),
source: e,
})?;
if let Some(ttl) = self.ttl {
if let Ok(age) = SystemTime::now().duration_since(cached_item.timestamp) {
if age > ttl {
let _ = self.remove(identifier);
return Ok(None);
}
}
}
Ok(Some(cached_item.data))
}
pub fn store(&self, identifier: &str, data: &T) -> Result<(), CacheError> {
let sanitized_id = sanitize_name(identifier);
let file_path = self.cache_dir.join(format!("{}.json", sanitized_id));
let cached_item = CachedItem {
data,
timestamp: SystemTime::now(),
};
let content = serde_json::to_string_pretty(&cached_item)?;
fs::write(&file_path, content).map_err(|e| CacheError::WriteFailed {
path: file_path,
source: e,
})?;
Ok(())
}
pub fn remove(&self, identifier: &str) -> Result<(), CacheError> {
let sanitized_id = sanitize_name(identifier);
let file_path = self.cache_dir.join(format!("{}.json", sanitized_id));
if file_path.exists() {
fs::remove_file(&file_path).map_err(|e| CacheError::WriteFailed {
path: file_path,
source: e,
})?;
}
Ok(())
}
pub fn cache_dir(&self) -> &PathBuf {
&self.cache_dir
}
pub fn clean(&self) -> Result<Option<usize>, CacheError> {
let ttl = match self.ttl {
Some(ttl) => ttl,
None => return Ok(None),
};
let mut removed_count = 0;
let entries = fs::read_dir(&self.cache_dir).map_err(|e| CacheError::ReadFailed {
path: self.cache_dir.clone(),
source: e,
})?;
for entry in entries {
let entry = entry.map_err(|e| CacheError::ReadFailed {
path: self.cache_dir.clone(),
source: e,
})?;
let path = entry.path();
if !path.extension().map_or(false, |ext| ext == "json") {
continue;
}
match fs::read_to_string(&path) {
Ok(content) => {
if let Ok(cached_item) =
serde_json::from_str::<CachedItem<serde_json::Value>>(&content)
{
if let Ok(age) = SystemTime::now().duration_since(cached_item.timestamp) {
if age > ttl {
if fs::remove_file(&path).is_ok() {
removed_count += 1;
}
}
}
}
}
Err(_) => {
continue;
}
}
}
Ok(Some(removed_count))
}
}
fn sanitize_name(name: &str) -> String {
name.to_lowercase()
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' {
c
} else {
'_'
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_name() {
assert_eq!(sanitize_name("Simple"), "simple");
assert_eq!(sanitize_name("With Spaces"), "with_spaces");
assert_eq!(sanitize_name("With-Hyphens"), "with-hyphens");
assert_eq!(sanitize_name("Special!@#$%"), "special_____");
assert_eq!(sanitize_name("Mixed123ABC"), "mixed123abc");
}
}