use crate::constants;
use crate::error::Error;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub cache_dir: PathBuf,
pub default_ttl: Duration,
pub max_entries: usize,
pub enabled: bool,
pub allow_authenticated: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
cache_dir: PathBuf::from(".cache/responses"),
default_ttl: Duration::from_secs(300), max_entries: 1000,
enabled: true,
allow_authenticated: false, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedResponse {
pub body: String,
pub status_code: u16,
pub headers: HashMap<String, String>,
pub cached_at: u64,
pub ttl_seconds: u64,
pub request_info: CachedRequestInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedRequestInfo {
pub method: String,
pub url: String,
pub headers: HashMap<String, String>,
pub body_hash: Option<String>,
}
#[derive(Debug)]
pub struct CacheKey {
pub api_name: String,
pub operation_id: String,
pub request_hash: String,
}
impl CacheKey {
pub fn from_request(
api_name: &str,
operation_id: &str,
method: &str,
url: &str,
headers: &HashMap<String, String>,
body: Option<&str>,
) -> Result<Self, Error> {
let mut hasher = Sha256::new();
hasher.update(method.as_bytes());
hasher.update(url.as_bytes());
let mut sorted_headers: Vec<_> = headers
.iter()
.filter(|(key, _)| !is_auth_header(key))
.collect();
sorted_headers.sort_by_key(|(key, _)| *key);
for (key, value) in sorted_headers {
hasher.update(key.as_bytes());
hasher.update(value.as_bytes());
}
if let Some(body_content) = body {
hasher.update(body_content.as_bytes());
}
let hash = hasher.finalize();
let request_hash = format!("{hash:x}");
Ok(Self {
api_name: api_name.to_string(),
operation_id: operation_id.to_string(),
request_hash,
})
}
#[must_use]
pub fn to_filename(&self) -> String {
let hash_prefix = if self.request_hash.len() >= 16 {
&self.request_hash[..16]
} else {
&self.request_hash
};
format!(
"{}_{}_{}_{}{}",
self.api_name,
self.operation_id,
hash_prefix,
constants::CACHE_SUFFIX,
constants::FILE_EXT_JSON
)
}
}
pub struct ResponseCache {
config: CacheConfig,
}
impl ResponseCache {
pub fn new(config: CacheConfig) -> Result<Self, Error> {
std::fs::create_dir_all(&config.cache_dir)
.map_err(|e| Error::io_error(format!("Failed to create cache directory: {e}")))?;
Ok(Self { config })
}
async fn acquire_lock(&self) -> Result<crate::atomic::DirLock, Error> {
let cache_dir = self.config.cache_dir.clone();
tokio::task::spawn_blocking(move || crate::atomic::DirLock::acquire(&cache_dir))
.await
.map_err(|e| Error::io_error(format!("Lock task failed: {e}")))?
.map_err(|e| Error::io_error(format!("Failed to acquire cache lock: {e}")))
}
pub async fn store(
&self,
key: &CacheKey,
body: &str,
status_code: u16,
headers: &HashMap<String, String>,
request_info: CachedRequestInfo,
ttl: Option<Duration>,
) -> Result<(), Error> {
if !self.config.enabled {
return Ok(());
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::invalid_config(format!("System time error: {e}")))?
.as_secs();
let ttl_seconds = ttl.unwrap_or(self.config.default_ttl).as_secs();
let cached_response = CachedResponse {
body: body.to_string(),
status_code,
headers: headers.clone(),
cached_at: now,
ttl_seconds,
request_info,
};
let cache_file = self.config.cache_dir.join(key.to_filename());
let json_content = serde_json::to_string_pretty(&cached_response).map_err(|e| {
Error::serialization_error(format!("Failed to serialize cached response: {e}"))
})?;
let _lock = self.acquire_lock().await?;
crate::atomic::atomic_write(&cache_file, json_content.as_bytes())
.await
.map_err(|e| Error::io_error(format!("Failed to write cache file: {e}")))?;
self.cleanup_old_entries(&key.api_name).await?;
Ok(())
}
pub async fn get(&self, key: &CacheKey) -> Result<Option<CachedResponse>, Error> {
if !self.config.enabled {
return Ok(None);
}
let cache_file = self.config.cache_dir.join(key.to_filename());
if !cache_file.exists() {
return Ok(None);
}
let json_content = tokio::fs::read_to_string(&cache_file)
.await
.map_err(|e| Error::io_error(format!("Failed to read cache file: {e}")))?;
let cached_response: CachedResponse = serde_json::from_str(&json_content).map_err(|e| {
Error::serialization_error(format!("Failed to deserialize cached response: {e}"))
})?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::invalid_config(format!("System time error: {e}")))?
.as_secs();
if now > cached_response.cached_at + cached_response.ttl_seconds {
return Ok(None);
}
Ok(Some(cached_response))
}
pub async fn is_cached(&self, key: &CacheKey) -> Result<bool, Error> {
Ok(self.get(key).await?.is_some())
}
pub async fn clear_api_cache(&self, api_name: &str) -> Result<usize, Error> {
let _lock = self.acquire_lock().await?;
let mut cleared_count = 0;
let mut entries = tokio::fs::read_dir(&self.config.cache_dir)
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?
{
let filename = entry.file_name();
let filename_str = filename.to_string_lossy();
if filename_str.starts_with(&format!("{api_name}_"))
&& filename_str.ends_with(constants::CACHE_FILE_SUFFIX)
{
tokio::fs::remove_file(entry.path())
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?;
cleared_count += 1;
}
}
Ok(cleared_count)
}
pub async fn clear_all(&self) -> Result<usize, Error> {
let _lock = self.acquire_lock().await?;
let mut cleared_count = 0;
let mut entries = tokio::fs::read_dir(&self.config.cache_dir)
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?
{
let filename = entry.file_name();
let filename_str = filename.to_string_lossy();
if filename_str.ends_with(constants::CACHE_FILE_SUFFIX) {
tokio::fs::remove_file(entry.path())
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?;
cleared_count += 1;
}
}
Ok(cleared_count)
}
pub async fn get_stats(&self, api_name: Option<&str>) -> Result<CacheStats, Error> {
let mut stats = CacheStats::default();
let mut entries = tokio::fs::read_dir(&self.config.cache_dir)
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?
{
let filename = entry.file_name();
let filename_str = filename.to_string_lossy();
if !filename_str.ends_with(constants::CACHE_FILE_SUFFIX) {
continue;
}
let Some(target_api) = api_name else {
stats.total_entries += 1;
let Ok(metadata) = entry.metadata().await else {
continue;
};
stats.total_size_bytes += metadata.len();
let Ok(json_content) = tokio::fs::read_to_string(entry.path()).await else {
continue;
};
let Ok(cached_response) = serde_json::from_str::<CachedResponse>(&json_content)
else {
continue;
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::invalid_config(format!("System time error: {e}")))?
.as_secs();
if now > cached_response.cached_at + cached_response.ttl_seconds {
stats.expired_entries += 1;
} else {
stats.valid_entries += 1;
}
continue;
};
if !filename_str.starts_with(&format!("{target_api}_")) {
continue;
}
stats.total_entries += 1;
let Ok(metadata) = entry.metadata().await else {
continue;
};
stats.total_size_bytes += metadata.len();
let Ok(json_content) = tokio::fs::read_to_string(entry.path()).await else {
continue;
};
let Ok(cached_response) = serde_json::from_str::<CachedResponse>(&json_content) else {
continue;
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::invalid_config(format!("System time error: {e}")))?
.as_secs();
if now > cached_response.cached_at + cached_response.ttl_seconds {
stats.expired_entries += 1;
} else {
stats.valid_entries += 1;
}
}
Ok(stats)
}
async fn collect_stale_temp_file(
&self,
entry: &tokio::fs::DirEntry,
now: SystemTime,
stale_files: &mut Vec<std::path::PathBuf>,
) {
let is_stale = entry
.metadata()
.await
.ok()
.and_then(|m| m.modified().ok())
.is_some_and(|modified| {
now.duration_since(modified).unwrap_or(Duration::ZERO) > Duration::from_secs(3600)
});
if is_stale {
stale_files.push(entry.path());
}
}
async fn cleanup_old_entries(&self, api_name: &str) -> Result<(), Error> {
let mut entries = Vec::new();
let mut stale_tmp_files = Vec::new();
let now_system = SystemTime::now();
let mut dir_entries = tokio::fs::read_dir(&self.config.cache_dir)
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?;
while let Some(entry) = dir_entries
.next_entry()
.await
.map_err(|e| Error::io_error(format!("I/O operation failed: {e}")))?
{
let filename = entry.file_name();
let filename_str = filename.to_string_lossy();
let is_temp_file = filename_str.starts_with('.')
&& filename_str.ends_with(".tmp")
&& filename_str.len() > 4;
if is_temp_file {
self.collect_stale_temp_file(&entry, now_system, &mut stale_tmp_files)
.await;
continue;
}
if !filename_str.starts_with(&format!("{api_name}_"))
|| !filename_str.ends_with(constants::CACHE_FILE_SUFFIX)
{
continue;
}
let Ok(metadata) = entry.metadata().await else {
continue;
};
let Ok(modified) = metadata.modified() else {
continue;
};
entries.push((entry.path(), modified));
}
for path in &stale_tmp_files {
let _ = tokio::fs::remove_file(path).await;
}
if entries.len() > self.config.max_entries {
entries.sort_by_key(|(_, modified)| *modified);
let to_remove = entries.len() - self.config.max_entries;
for (path, _) in entries.iter().take(to_remove) {
let _ = tokio::fs::remove_file(path).await;
}
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct CacheStats {
pub total_entries: usize,
pub valid_entries: usize,
pub expired_entries: usize,
pub total_size_bytes: u64,
}
#[must_use]
pub fn is_auth_header(header_name: &str) -> bool {
constants::is_auth_header(header_name)
|| header_name
.to_lowercase()
.starts_with(constants::HEADER_PREFIX_X_AUTH)
|| header_name
.to_lowercase()
.starts_with(constants::HEADER_PREFIX_X_API)
}
#[must_use]
pub fn scrub_auth_headers<S: std::hash::BuildHasher>(
headers: &HashMap<String, String, S>,
) -> HashMap<String, String> {
headers
.iter()
.filter(|(key, _)| !is_auth_header(key))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_cache_config() -> (CacheConfig, TempDir) {
let temp_dir = TempDir::new().unwrap();
let config = CacheConfig {
cache_dir: temp_dir.path().to_path_buf(),
default_ttl: Duration::from_secs(60),
max_entries: 10,
enabled: true,
allow_authenticated: false,
};
(config, temp_dir)
}
#[test]
fn test_cache_key_generation() {
let mut headers = HashMap::new();
headers.insert(
constants::HEADER_CONTENT_TYPE_LC.to_string(),
constants::CONTENT_TYPE_JSON.to_string(),
);
headers.insert(
constants::HEADER_AUTHORIZATION_LC.to_string(),
"Bearer secret".to_string(),
);
let key = CacheKey::from_request(
"test_api",
"getUser",
constants::HTTP_METHOD_GET,
"https://api.example.com/users/123",
&headers,
None,
)
.unwrap();
assert_eq!(key.api_name, "test_api");
assert_eq!(key.operation_id, "getUser");
assert!(!key.request_hash.is_empty());
let filename = key.to_filename();
assert!(filename.starts_with("test_api_getUser_"));
assert!(filename.ends_with(constants::CACHE_FILE_SUFFIX));
}
#[test]
fn test_is_auth_header() {
assert!(is_auth_header(constants::HEADER_AUTHORIZATION));
assert!(is_auth_header("X-API-Key"));
assert!(is_auth_header("x-auth-token"));
assert!(!is_auth_header(constants::HEADER_CONTENT_TYPE));
assert!(!is_auth_header("User-Agent"));
}
#[test]
fn test_scrub_auth_headers() {
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer secret".to_string());
headers.insert("X-API-Key".to_string(), "api-key-123".to_string());
headers.insert("x-auth-token".to_string(), "token-456".to_string());
headers.insert("Content-Type".to_string(), "application/json".to_string());
headers.insert("User-Agent".to_string(), "test-agent".to_string());
headers.insert("Accept".to_string(), "application/json".to_string());
let scrubbed = scrub_auth_headers(&headers);
assert!(!scrubbed.contains_key("Authorization"));
assert!(!scrubbed.contains_key("X-API-Key"));
assert!(!scrubbed.contains_key("x-auth-token"));
assert_eq!(
scrubbed.get("Content-Type"),
Some(&"application/json".to_string())
);
assert_eq!(scrubbed.get("User-Agent"), Some(&"test-agent".to_string()));
assert_eq!(
scrubbed.get("Accept"),
Some(&"application/json".to_string())
);
assert_eq!(scrubbed.len(), 3);
}
#[tokio::test]
async fn test_cache_store_and_retrieve() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
let key = CacheKey {
api_name: "test_api".to_string(),
operation_id: "getUser".to_string(),
request_hash: "abc123".to_string(),
};
let mut headers = HashMap::new();
headers.insert(
constants::HEADER_CONTENT_TYPE_LC.to_string(),
constants::CONTENT_TYPE_JSON.to_string(),
);
let request_info = CachedRequestInfo {
method: constants::HTTP_METHOD_GET.to_string(),
url: "https://api.example.com/users/123".to_string(),
headers: headers.clone(),
body_hash: None,
};
cache
.store(
&key,
r#"{"id": 123, "name": "John"}"#,
200,
&headers,
request_info,
Some(Duration::from_secs(60)),
)
.await
.unwrap();
let cached = cache.get(&key).await.unwrap();
assert!(cached.is_some());
let response = cached.unwrap();
assert_eq!(response.body, r#"{"id": 123, "name": "John"}"#);
assert_eq!(response.status_code, 200);
}
#[tokio::test]
async fn test_cache_expiration() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
let key = CacheKey {
api_name: "test_api".to_string(),
operation_id: "getUser".to_string(),
request_hash: "abc123def456".to_string(),
};
let headers = HashMap::new();
let request_info = CachedRequestInfo {
method: constants::HTTP_METHOD_GET.to_string(),
url: "https://api.example.com/users/123".to_string(),
headers: headers.clone(),
body_hash: None,
};
cache
.store(
&key,
"test response",
200,
&headers,
request_info,
Some(Duration::from_secs(1)),
)
.await
.unwrap();
assert!(cache.is_cached(&key).await.unwrap());
let cache_file = cache.config.cache_dir.join(key.to_filename());
let mut cached_response: CachedResponse = {
let json_content = tokio::fs::read_to_string(&cache_file).await.unwrap();
serde_json::from_str(&json_content).unwrap()
};
cached_response.cached_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- 2;
let json_content = serde_json::to_string_pretty(&cached_response).unwrap();
tokio::fs::write(&cache_file, json_content).await.unwrap();
assert!(!cache.is_cached(&key).await.unwrap());
assert!(cache_file.exists());
}
async fn store_entry(cache: &ResponseCache, api_name: &str, operation_id: &str) {
let key = CacheKey {
api_name: api_name.to_string(),
operation_id: operation_id.to_string(),
request_hash: format!("{api_name}_{operation_id}"),
};
let request_info = CachedRequestInfo {
method: constants::HTTP_METHOD_GET.to_string(),
url: "https://api.example.com/test".to_string(),
headers: HashMap::new(),
body_hash: None,
};
cache
.store(
&key,
r#"{"ok": true}"#,
200,
&HashMap::new(),
request_info,
Some(Duration::from_secs(300)),
)
.await
.unwrap();
}
#[tokio::test]
async fn test_clear_api_cache_removes_only_target_api() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
store_entry(&cache, "api_a", "op1").await;
store_entry(&cache, "api_b", "op2").await;
let cleared = cache.clear_api_cache("api_a").await.unwrap();
assert_eq!(
cleared, 1,
"should have cleared exactly one entry for api_a"
);
let stats = cache.get_stats(Some("api_b")).await.unwrap();
assert_eq!(stats.total_entries, 1, "api_b entry must remain");
let stats_a = cache.get_stats(Some("api_a")).await.unwrap();
assert_eq!(stats_a.total_entries, 0, "api_a entries must be gone");
}
#[tokio::test]
async fn test_clear_api_cache_multiple_entries() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
store_entry(&cache, "api_a", "op1").await;
store_entry(&cache, "api_a", "op2").await;
store_entry(&cache, "api_a", "op3").await;
store_entry(&cache, "api_b", "opX").await;
let cleared = cache.clear_api_cache("api_a").await.unwrap();
assert_eq!(cleared, 3, "should clear all three api_a entries");
let remaining = cache.get_stats(None).await.unwrap();
assert_eq!(remaining.total_entries, 1, "only api_b entry should remain");
}
#[tokio::test]
async fn test_clear_all_empties_the_cache() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
store_entry(&cache, "api_a", "op1").await;
store_entry(&cache, "api_b", "op2").await;
store_entry(&cache, "api_c", "op3").await;
let cleared = cache.clear_all().await.unwrap();
assert_eq!(cleared, 3);
let stats = cache.get_stats(None).await.unwrap();
assert_eq!(
stats.total_entries, 0,
"cache must be empty after clear_all"
);
}
#[tokio::test]
async fn test_clear_all_on_empty_cache() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
let cleared = cache.clear_all().await.unwrap();
assert_eq!(cleared, 0, "clearing an empty cache returns 0");
}
#[tokio::test]
async fn test_get_stats_no_filter_counts_all_entries() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
store_entry(&cache, "api_a", "op1").await;
store_entry(&cache, "api_b", "op2").await;
let stats = cache.get_stats(None).await.unwrap();
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.valid_entries, 2);
assert_eq!(stats.expired_entries, 0);
assert!(stats.total_size_bytes > 0);
}
#[tokio::test]
async fn test_get_stats_with_api_filter() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
store_entry(&cache, "api_a", "op1").await;
store_entry(&cache, "api_a", "op2").await;
store_entry(&cache, "api_b", "opX").await;
let stats = cache.get_stats(Some("api_a")).await.unwrap();
assert_eq!(stats.total_entries, 2, "filter must restrict to api_a");
assert_eq!(stats.valid_entries, 2);
let stats_b = cache.get_stats(Some("api_b")).await.unwrap();
assert_eq!(stats_b.total_entries, 1);
}
#[tokio::test]
async fn test_get_stats_counts_expired_entries() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config).unwrap();
let key = CacheKey {
api_name: "api_a".to_string(),
operation_id: "expiredOp".to_string(),
request_hash: "expiredhash".to_string(),
};
let request_info = CachedRequestInfo {
method: constants::HTTP_METHOD_GET.to_string(),
url: "https://api.example.com/test".to_string(),
headers: HashMap::new(),
body_hash: None,
};
cache
.store(
&key,
"body",
200,
&HashMap::new(),
request_info,
Some(Duration::from_secs(1)),
)
.await
.unwrap();
let cache_file = cache.config.cache_dir.join(key.to_filename());
let json = tokio::fs::read_to_string(&cache_file).await.unwrap();
let mut entry: CachedResponse = serde_json::from_str(&json).unwrap();
entry.cached_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- 10; tokio::fs::write(&cache_file, serde_json::to_string_pretty(&entry).unwrap())
.await
.unwrap();
store_entry(&cache, "api_a", "validOp").await;
let stats = cache.get_stats(Some("api_a")).await.unwrap();
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.expired_entries, 1);
assert_eq!(stats.valid_entries, 1);
}
#[tokio::test]
async fn test_cleanup_removes_stale_tmp_files() {
let (config, _temp_dir) = create_test_cache_config();
let cache = ResponseCache::new(config.clone()).unwrap();
let tmp_path = config.cache_dir.join(".orphaned.1a2b3c.tmp");
tokio::fs::write(&tmp_path, b"partial write").await.unwrap();
assert!(tmp_path.exists(), "temp file must exist before cleanup");
let epoch = std::time::SystemTime::UNIX_EPOCH;
let file = std::fs::OpenOptions::new()
.write(true)
.open(&tmp_path)
.expect("temp file must be openable");
file.set_modified(epoch)
.expect("setting mtime to epoch must succeed");
store_entry(&cache, "api_sweep", "op1").await;
assert!(
!tmp_path.exists(),
"stale temp file must be removed by cleanup_old_entries"
);
}
}