use crate::config::{CacheConfig, Config};
use crate::models::{SearchQuery, SearchResponse};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CacheMetadata {
cached_at: u64,
expires_at: u64,
source: String,
query: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CachedSearchResponse {
metadata: CacheMetadata,
response: SearchResponse,
}
pub enum CacheResult<T> {
Hit(T),
Miss,
Expired,
}
#[derive(Debug, Clone)]
pub struct CacheService {
base_dir: PathBuf,
search_dir: PathBuf,
citation_dir: PathBuf,
config: CacheConfig,
}
impl CacheService {
pub fn new() -> Self {
Self::from_config(Config::default().cache)
}
pub fn from_config(config: CacheConfig) -> Self {
let base_dir = config
.directory
.clone()
.unwrap_or_else(crate::config::default_cache_dir);
let search_dir = base_dir.join("searches");
let citation_dir = base_dir.join("citations");
Self {
base_dir,
search_dir,
citation_dir,
config,
}
}
pub fn initialize(&self) -> std::io::Result<()> {
if self.config.enabled {
fs::create_dir_all(&self.search_dir)?;
fs::create_dir_all(&self.citation_dir)?;
tracing::info!("Cache initialized at: {}", self.base_dir.display());
} else {
tracing::debug!("Cache is disabled");
}
Ok(())
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn cache_dir(&self) -> &PathBuf {
&self.base_dir
}
fn search_cache_key(
&self,
query: &str,
source: &str,
max_results: usize,
year: Option<&str>,
author: Option<&str>,
category: Option<&str>,
) -> String {
let input = format!(
"{}|{}|{}|{}|{}|{}",
query,
source,
max_results,
year.unwrap_or_default(),
author.unwrap_or_default(),
category.unwrap_or_default()
);
let digest = md5::compute(input.as_bytes());
format!("{:x}", digest)
}
fn citation_cache_key(&self, paper_id: &str, source: &str, max_results: usize) -> String {
let input = format!("{}|{}|{}", paper_id, source, max_results);
let digest = md5::compute(input.as_bytes());
format!("{:x}", digest)
}
fn is_expired(&self, expires_at: u64) -> bool {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
now >= expires_at
}
pub fn get_search(&self, query: &SearchQuery, source: &str) -> CacheResult<SearchResponse> {
if !self.is_enabled() {
return CacheResult::Miss;
}
let key = self.search_cache_key(
&query.query,
source,
query.max_results,
query.year.as_deref(),
query.author.as_deref(),
query.category.as_deref(),
);
let cache_path = self.search_dir.join(&key);
match self.read_cache_file::<CachedSearchResponse>(&cache_path) {
Ok(cached) => {
if self.is_expired(cached.metadata.expires_at) {
tracing::debug!("Cache expired for search: {}", key);
CacheResult::Expired
} else {
tracing::debug!("Cache HIT for search: {}", key);
CacheResult::Hit(cached.response)
}
}
Err(_) => {
tracing::debug!("Cache MISS for search: {}", key);
CacheResult::Miss
}
}
}
pub fn set_search(&self, source: &str, query: &SearchQuery, response: &SearchResponse) {
if !self.is_enabled() {
return;
}
let key = self.search_cache_key(
&query.query,
source,
query.max_results,
query.year.as_deref(),
query.author.as_deref(),
query.category.as_deref(),
);
let cache_path = self.search_dir.join(&key);
let cached = CachedSearchResponse {
metadata: CacheMetadata {
cached_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
expires_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ self.config.search_ttl_seconds,
source: source.to_string(),
query: query.query.clone(),
},
response: response.clone(),
};
if let Err(e) = self.write_cache_file(&cache_path, &cached) {
tracing::warn!("Failed to cache search result: {}", e);
} else {
tracing::debug!("Cached search result: {}", key);
}
}
pub fn get_citations(
&self,
paper_id: &str,
source: &str,
max_results: usize,
) -> CacheResult<SearchResponse> {
if !self.is_enabled() {
return CacheResult::Miss;
}
let key = self.citation_cache_key(paper_id, source, max_results);
let cache_path = self.citation_dir.join(&key);
match self.read_cache_file::<CachedSearchResponse>(&cache_path) {
Ok(cached) => {
if self.is_expired(cached.metadata.expires_at) {
tracing::debug!("Cache expired for citations: {}", key);
CacheResult::Expired
} else {
tracing::debug!("Cache HIT for citations: {}", key);
CacheResult::Hit(cached.response)
}
}
Err(_) => {
tracing::debug!("Cache MISS for citations: {}", key);
CacheResult::Miss
}
}
}
pub fn set_citations(&self, source: &str, paper_id: &str, response: &SearchResponse) {
if !self.is_enabled() {
return;
}
let key = self.citation_cache_key(paper_id, source, response.papers.len());
let cache_path = self.citation_dir.join(&key);
let cached = CachedSearchResponse {
metadata: CacheMetadata {
cached_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
expires_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ self.config.citation_ttl_seconds,
source: source.to_string(),
query: format!("citations for {}", paper_id),
},
response: response.clone(),
};
if let Err(e) = self.write_cache_file(&cache_path, &cached) {
tracing::warn!("Failed to cache citations: {}", e);
} else {
tracing::debug!("Cached citations: {}", key);
}
}
fn read_cache_file<T: for<'de> Deserialize<'de>>(
&self,
path: &Path,
) -> Result<T, std::io::Error> {
let content = fs::read_to_string(path)?;
serde_json::from_str(&content)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
}
fn write_cache_file<T: Serialize>(&self, path: &Path, data: &T) -> Result<(), std::io::Error> {
let content = serde_json::to_string_pretty(data)?;
fs::write(path, content)
}
pub fn clear_all(&self) -> std::io::Result<()> {
if !self.is_enabled() {
return Ok(());
}
let _ = fs::remove_dir_all(&self.base_dir);
self.initialize()?;
tracing::info!("Cache cleared");
Ok(())
}
pub fn clear_searches(&self) -> std::io::Result<()> {
if !self.is_enabled() {
return Ok(());
}
let _ = fs::remove_dir_all(&self.search_dir);
fs::create_dir_all(&self.search_dir)?;
tracing::info!("Search cache cleared");
Ok(())
}
pub fn clear_citations(&self) -> std::io::Result<()> {
if !self.is_enabled() {
return Ok(());
}
let _ = fs::remove_dir_all(&self.citation_dir);
fs::create_dir_all(&self.citation_dir)?;
tracing::info!("Citation cache cleared");
Ok(())
}
pub fn stats(&self) -> CacheStats {
if !self.is_enabled() {
return CacheStats::disabled();
}
let search_count = self.search_dir.read_dir().map(|e| e.count()).unwrap_or(0);
let citation_count = self.citation_dir.read_dir().map(|e| e.count()).unwrap_or(0);
let search_size = self
.dir_size(&self.search_dir)
.map(|s| s / 1024)
.unwrap_or(0); let citation_size = self
.dir_size(&self.citation_dir)
.map(|s| s / 1024)
.unwrap_or(0);
CacheStats {
enabled: true,
cache_dir: self.base_dir.clone(),
search_count,
citation_count,
search_size_kb: search_size,
citation_size_kb: citation_size,
total_size_kb: search_size + citation_size,
ttl_search: Duration::from_secs(self.config.search_ttl_seconds),
ttl_citations: Duration::from_secs(self.config.citation_ttl_seconds),
}
}
#[allow(clippy::only_used_in_recursion)]
fn dir_size(&self, path: &Path) -> Result<u64, std::io::Error> {
let mut size = 0;
if let Ok(entries) = path.read_dir() {
for entry in entries.flatten() {
size += if entry.path().is_dir() {
self.dir_size(&entry.path()).unwrap_or(0)
} else {
entry.metadata().map(|m| m.len()).unwrap_or(0)
};
}
}
Ok(size)
}
}
impl Default for CacheService {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub enabled: bool,
pub cache_dir: PathBuf,
pub search_count: usize,
pub citation_count: usize,
pub search_size_kb: u64,
pub citation_size_kb: u64,
pub total_size_kb: u64,
pub ttl_search: Duration,
pub ttl_citations: Duration,
}
impl CacheStats {
fn disabled() -> Self {
Self {
enabled: false,
cache_dir: PathBuf::new(),
search_count: 0,
citation_count: 0,
search_size_kb: 0,
citation_size_kb: 0,
total_size_kb: 0,
ttl_search: Duration::ZERO,
ttl_citations: Duration::ZERO,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn test_cache_config() -> CacheConfig {
CacheConfig {
enabled: true,
directory: None,
search_ttl_seconds: 60, citation_ttl_seconds: 30,
max_size_mb: 10,
}
}
#[tokio::test]
async fn test_cache_search() {
let temp_dir = TempDir::new().unwrap();
let mut config = test_cache_config();
config.directory = Some(temp_dir.path().to_path_buf());
let cache = CacheService::from_config(config);
cache.initialize().unwrap();
let response =
SearchResponse::new(vec![], "test_source".to_string(), "test query".to_string());
let query = SearchQuery::new("test query");
cache.set_search("test_source", &query, &response);
match cache.get_search(&query, "test_source") {
CacheResult::Hit(r) => {
assert_eq!(r.source, "test_source");
assert_eq!(r.query, "test query");
}
_ => panic!("Expected cache hit"),
}
let query2 = SearchQuery::new("different query");
match cache.get_search(&query2, "test_source") {
CacheResult::Miss => {}
_ => panic!("Expected cache miss for different query"),
}
cache.clear_all().unwrap();
}
#[tokio::test]
async fn test_cache_disabled() {
let temp_dir = TempDir::new().unwrap();
let config = CacheConfig {
enabled: false,
directory: Some(temp_dir.path().to_path_buf()),
..test_cache_config()
};
let cache = CacheService::from_config(config);
let response =
SearchResponse::new(vec![], "test_source".to_string(), "test query".to_string());
let query = SearchQuery::new("test query");
cache.set_search("test_source", &query, &response);
match cache.get_search(&query, "test_source") {
CacheResult::Miss => {}
_ => panic!("Expected cache miss when disabled"),
}
}
#[tokio::test]
async fn test_cache_expiration() {
let temp_dir = TempDir::new().unwrap();
let config = CacheConfig {
enabled: true,
directory: Some(temp_dir.path().to_path_buf()),
search_ttl_seconds: 0, citation_ttl_seconds: 0,
max_size_mb: 10,
};
let cache = CacheService::from_config(config);
cache.initialize().unwrap();
let response =
SearchResponse::new(vec![], "test_source".to_string(), "test query".to_string());
let query = SearchQuery::new("test query");
cache.set_search("test_source", &query, &response);
match cache.get_search(&query, "test_source") {
CacheResult::Expired => {}
_ => panic!("Expected cache expired"),
}
}
}