use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
path::PathBuf,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde::{Deserialize, Serialize};
use super::Fetcher;
use crate::{
error::KumoError,
extract::{Response, response::ResponseBody},
logging::{event, target},
middleware::FetchRequest,
};
#[derive(Serialize, Deserialize)]
struct CacheEntry {
url: String,
status: u16,
body: String, cached_at: u64, }
pub struct CachingFetcher {
inner: Arc<dyn Fetcher>,
dir: PathBuf,
ttl: Option<Duration>,
}
impl CachingFetcher {
pub fn new(inner: impl Fetcher + 'static, dir: impl Into<PathBuf>) -> Result<Self, KumoError> {
let dir = dir.into();
std::fs::create_dir_all(&dir).map_err(|e| KumoError::store("http cache", e))?;
Ok(Self {
inner: Arc::new(inner),
dir,
ttl: None,
})
}
pub fn ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
fn cache_path(&self, url: &str) -> PathBuf {
let mut hasher = DefaultHasher::new();
url.hash(&mut hasher);
self.dir.join(format!("{:016x}.json", hasher.finish()))
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn is_fresh(&self, entry: &CacheEntry) -> bool {
match self.ttl {
None => true,
Some(ttl) => Self::now_secs().saturating_sub(entry.cached_at) < ttl.as_secs(),
}
}
fn is_cacheable(request: &FetchRequest) -> bool {
request.method == reqwest::Method::GET
}
}
impl std::fmt::Debug for CachingFetcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachingFetcher")
.field("dir", &self.dir)
.field("ttl", &self.ttl)
.finish_non_exhaustive()
}
}
#[async_trait]
impl Fetcher for CachingFetcher {
async fn fetch(&self, request: &FetchRequest) -> Result<Response, KumoError> {
if !Self::is_cacheable(request) {
tracing::debug!(
target: target::CACHE,
event = event::CACHE_BYPASS,
url = request.url(),
method = %request.method,
"cache.bypass"
);
return self.inner.fetch(request).await;
}
let path = self.cache_path(request.url());
if path.exists()
&& let Ok(data) = std::fs::read_to_string(&path)
&& let Ok(entry) = serde_json::from_str::<CacheEntry>(&data)
&& entry.url == request.url()
&& self.is_fresh(&entry)
{
tracing::debug!(
target: target::CACHE,
event = event::CACHE_HIT,
url = request.url(),
method = %request.method,
"cache.hit"
);
return Ok(Response::new(
entry.url,
entry.status,
HeaderMap::new(),
Duration::ZERO,
ResponseBody::Text(entry.body),
));
}
tracing::debug!(
target: target::CACHE,
event = event::CACHE_MISS,
url = request.url(),
method = %request.method,
"cache.miss"
);
let response = self.inner.fetch(request).await?;
if let Some(body_text) = response.text() {
let entry = CacheEntry {
url: response.url().to_string(),
status: response.status(),
body: body_text.to_string(),
cached_at: Self::now_secs(),
};
if let Ok(json) = serde_json::to_string(&entry)
&& std::fs::write(&path, json).is_err()
{
tracing::debug!(
target: target::CACHE,
event = event::CACHE_STORE_SKIP,
url = response.url(),
path = %path.display(),
"cache.store_skip"
);
}
}
Ok(response)
}
}