use std::fs;
use std::path::PathBuf;
use std::sync::Once;
use anyhow::{Context, Result};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use tracing::warn;
static CACHE_UNAVAILABLE_WARNING: Once = Once::new();
pub const DEFAULT_ISSUE_TTL_MINS: i64 = 60;
pub const DEFAULT_REPO_TTL_HOURS: i64 = 24;
pub const DEFAULT_MODEL_TTL_SECS: u64 = 86400;
pub const DEFAULT_SECURITY_TTL_DAYS: i64 = 7;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry<T> {
pub data: T,
pub cached_at: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub etag: Option<String>,
}
impl<T> CacheEntry<T> {
pub fn new(data: T) -> Self {
Self {
data,
cached_at: Utc::now(),
etag: None,
}
}
pub fn with_etag(data: T, etag: String) -> Self {
Self {
data,
cached_at: Utc::now(),
etag: Some(etag),
}
}
pub fn is_valid(&self, ttl: Duration) -> bool {
let now = Utc::now();
now.signed_duration_since(self.cached_at) < ttl
}
}
#[must_use]
pub fn cache_dir() -> Option<PathBuf> {
dirs::cache_dir().map(|dir| dir.join("aptu"))
}
pub trait FileCache<V> {
fn get(&self, key: &str) -> Result<Option<V>>;
fn get_stale(&self, key: &str) -> Result<Option<V>>;
fn set(&self, key: &str, value: &V) -> Result<()>;
fn remove(&self, key: &str) -> Result<()>;
}
pub struct FileCacheImpl<V> {
cache_dir: Option<PathBuf>,
ttl: Duration,
subdirectory: String,
_phantom: std::marker::PhantomData<V>,
}
impl<V> FileCacheImpl<V>
where
V: Serialize + for<'de> Deserialize<'de>,
{
#[must_use]
pub fn new(subdirectory: impl Into<String>, ttl: Duration) -> Self {
let cache_dir = cache_dir();
if cache_dir.is_none() {
CACHE_UNAVAILABLE_WARNING.call_once(|| {
warn!("Cache directory unavailable, caching disabled");
});
}
Self::with_dir(cache_dir, subdirectory, ttl)
}
#[must_use]
pub fn with_dir(
cache_dir: Option<PathBuf>,
subdirectory: impl Into<String>,
ttl: Duration,
) -> Self {
Self {
cache_dir,
ttl,
subdirectory: subdirectory.into(),
_phantom: std::marker::PhantomData,
}
}
fn is_enabled(&self) -> bool {
self.cache_dir.is_some()
}
fn cache_path(&self, key: &str) -> Option<PathBuf> {
assert!(
!key.contains('/') && !key.contains('\\') && !key.contains(".."),
"cache key must not contain path separators or '..': {key}"
);
let filename = if std::path::Path::new(key)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("json"))
{
key.to_string()
} else {
format!("{key}.json")
};
self.cache_dir
.as_ref()
.map(|dir| dir.join(&self.subdirectory).join(filename))
}
}
impl<V> FileCache<V> for FileCacheImpl<V>
where
V: Serialize + for<'de> Deserialize<'de>,
{
fn get(&self, key: &str) -> Result<Option<V>> {
if !self.is_enabled() {
return Ok(None);
}
let Some(path) = self.cache_path(key) else {
return Ok(None);
};
if !path.exists() {
return Ok(None);
}
let contents = fs::read_to_string(&path)
.with_context(|| format!("Failed to read cache file: {}", path.display()))?;
let entry: CacheEntry<V> = serde_json::from_str(&contents)
.with_context(|| format!("Failed to parse cache file: {}", path.display()))?;
if entry.is_valid(self.ttl) {
Ok(Some(entry.data))
} else {
Ok(None)
}
}
fn get_stale(&self, key: &str) -> Result<Option<V>> {
if !self.is_enabled() {
return Ok(None);
}
let Some(path) = self.cache_path(key) else {
return Ok(None);
};
if !path.exists() {
return Ok(None);
}
let contents = fs::read_to_string(&path)
.with_context(|| format!("Failed to read cache file: {}", path.display()))?;
let entry: CacheEntry<V> = serde_json::from_str(&contents)
.with_context(|| format!("Failed to parse cache file: {}", path.display()))?;
Ok(Some(entry.data))
}
fn set(&self, key: &str, value: &V) -> Result<()> {
if !self.is_enabled() {
return Ok(());
}
let Some(path) = self.cache_path(key) else {
return Ok(());
};
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!("Failed to create cache directory: {}", parent.display())
})?;
}
let entry = CacheEntry::new(value);
let contents =
serde_json::to_string_pretty(&entry).context("Failed to serialize cache entry")?;
let temp_path = path.with_extension("tmp");
fs::write(&temp_path, contents)
.with_context(|| format!("Failed to write cache temp file: {}", temp_path.display()))?;
fs::rename(&temp_path, &path)
.with_context(|| format!("Failed to rename cache file: {}", path.display()))?;
Ok(())
}
fn remove(&self, key: &str) -> Result<()> {
if !self.is_enabled() {
return Ok(());
}
let Some(path) = self.cache_path(key) else {
return Ok(());
};
if path.exists() {
fs::remove_file(&path)
.with_context(|| format!("Failed to remove cache file: {}", path.display()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct TestData {
value: String,
count: u32,
}
#[test]
fn test_cache_entry_new() {
let data = TestData {
value: "test".to_string(),
count: 42,
};
let entry = CacheEntry::new(data.clone());
assert_eq!(entry.data, data);
assert!(entry.etag.is_none());
}
#[test]
fn test_cache_entry_with_etag() {
let data = TestData {
value: "test".to_string(),
count: 42,
};
let etag = "abc123".to_string();
let entry = CacheEntry::with_etag(data.clone(), etag.clone());
assert_eq!(entry.data, data);
assert_eq!(entry.etag, Some(etag));
}
#[test]
fn test_cache_entry_is_valid_within_ttl() {
let data = TestData {
value: "test".to_string(),
count: 42,
};
let entry = CacheEntry::new(data);
let ttl = Duration::hours(1);
assert!(entry.is_valid(ttl));
}
#[test]
fn test_cache_entry_is_valid_expired() {
let data = TestData {
value: "test".to_string(),
count: 42,
};
let mut entry = CacheEntry::new(data);
entry.cached_at = Utc::now() - Duration::hours(2);
let ttl = Duration::hours(1);
assert!(!entry.is_valid(ttl));
}
#[test]
fn test_cache_dir_path() {
let dir = cache_dir();
assert!(dir.is_some());
assert!(dir.unwrap().ends_with("aptu"));
}
#[test]
fn test_cache_serialization_with_etag() {
let data = TestData {
value: "test".to_string(),
count: 42,
};
let etag = "xyz789".to_string();
let entry = CacheEntry::with_etag(data.clone(), etag.clone());
let json = serde_json::to_string(&entry).expect("serialize");
let parsed: CacheEntry<TestData> = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.data, data);
assert_eq!(parsed.etag, Some(etag));
}
#[test]
fn test_file_cache_get_set() {
let cache: FileCacheImpl<TestData> = FileCacheImpl::new("test_cache", Duration::hours(1));
let data = TestData {
value: "test".to_string(),
count: 42,
};
cache.set("test_key", &data).expect("set cache");
let result = cache.get("test_key").expect("get cache");
assert!(result.is_some());
assert_eq!(result.unwrap(), data);
cache.remove("test_key").ok();
}
#[test]
fn test_file_cache_get_miss() {
let cache: FileCacheImpl<TestData> = FileCacheImpl::new("test_cache", Duration::hours(1));
let result = cache.get("nonexistent").expect("get cache");
assert!(result.is_none());
}
#[test]
fn test_file_cache_get_stale() {
let cache: FileCacheImpl<TestData> = FileCacheImpl::new("test_cache", Duration::seconds(0));
let data = TestData {
value: "stale".to_string(),
count: 99,
};
cache.set("stale_key", &data).expect("set cache");
std::thread::sleep(std::time::Duration::from_millis(10));
let result = cache.get("stale_key").expect("get cache");
assert!(result.is_none());
let stale_result = cache.get_stale("stale_key").expect("get stale cache");
assert!(stale_result.is_some());
assert_eq!(stale_result.unwrap(), data);
cache.remove("stale_key").ok();
}
#[test]
fn test_file_cache_remove() {
let cache: FileCacheImpl<TestData> = FileCacheImpl::new("test_cache", Duration::hours(1));
let data = TestData {
value: "remove_me".to_string(),
count: 1,
};
cache.set("remove_key", &data).expect("set cache");
assert!(cache.get("remove_key").expect("get cache").is_some());
cache.remove("remove_key").expect("remove cache");
assert!(cache.get("remove_key").expect("get cache").is_none());
}
#[test]
#[should_panic(expected = "cache key must not contain path separators")]
fn test_cache_key_rejects_forward_slash() {
let cache: FileCacheImpl<TestData> = FileCacheImpl::new("test_cache", Duration::hours(1));
let _ = cache.get("../etc/passwd");
}
#[test]
#[should_panic(expected = "cache key must not contain path separators")]
fn test_cache_key_rejects_backslash() {
let cache: FileCacheImpl<TestData> = FileCacheImpl::new("test_cache", Duration::hours(1));
let _ = cache.get("..\\windows\\system32");
}
#[test]
#[should_panic(expected = "cache key must not contain path separators")]
fn test_cache_key_rejects_parent_dir() {
let cache: FileCacheImpl<TestData> = FileCacheImpl::new("test_cache", Duration::hours(1));
let _ = cache.get("foo..bar");
}
#[test]
fn test_disabled_cache_get_returns_none() {
let cache: FileCacheImpl<TestData> =
FileCacheImpl::with_dir(None, "test_cache", Duration::hours(1));
let result = cache.get("any_key").expect("get should succeed");
assert!(result.is_none());
}
#[test]
fn test_disabled_cache_set_succeeds_silently() {
let cache: FileCacheImpl<TestData> =
FileCacheImpl::with_dir(None, "test_cache", Duration::hours(1));
let data = TestData {
value: "test".to_string(),
count: 42,
};
cache.set("any_key", &data).expect("set should succeed");
}
#[test]
fn test_disabled_cache_remove_succeeds_silently() {
let cache: FileCacheImpl<TestData> =
FileCacheImpl::with_dir(None, "test_cache", Duration::hours(1));
cache.remove("any_key").expect("remove should succeed");
}
#[test]
fn test_disabled_cache_get_stale_returns_none() {
let cache: FileCacheImpl<TestData> =
FileCacheImpl::with_dir(None, "test_cache", Duration::hours(1));
let result = cache
.get_stale("any_key")
.expect("get_stale should succeed");
assert!(result.is_none());
}
}