use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use sha2::{Digest, Sha256};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CachedFile {
pub body: Arc<Vec<u8>>,
pub digest_hex: Arc<String>,
pub fetched_at: Instant,
}
#[derive(Debug, Clone)]
pub struct WebsiteCache {
inner: Arc<RwLock<HashMap<PathBuf, CachedFile>>>,
ttl: Duration,
}
impl WebsiteCache {
pub fn new(ttl_secs: u64) -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
ttl: Duration::from_secs(ttl_secs),
}
}
pub async fn get(&self, path: &std::path::Path) -> std::io::Result<CachedFile> {
if let Some(entry) = self.inner.read().await.get(path)
&& entry.fetched_at.elapsed() < self.ttl
{
return Ok(entry.clone());
}
let bytes = tokio::fs::read(path).await?;
let digest_hex = hex::encode(Sha256::digest(&bytes));
let entry = CachedFile {
body: Arc::new(bytes),
digest_hex: Arc::new(digest_hex),
fetched_at: Instant::now(),
};
self.inner
.write()
.await
.insert(path.to_path_buf(), entry.clone());
Ok(entry)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn cache_returns_consistent_digest() {
let dir = tempfile::tempdir().unwrap();
let file = dir.path().join("hello.txt");
std::fs::write(&file, b"hello world").unwrap();
let cache = WebsiteCache::new(60);
let a = cache.get(&file).await.unwrap();
let b = cache.get(&file).await.unwrap();
assert_eq!(a.digest_hex, b.digest_hex);
assert!(
a.digest_hex.starts_with("b94d27b9934d3e08"),
"got {}",
a.digest_hex
);
}
#[tokio::test]
async fn cache_misses_after_ttl_expiry() {
let dir = tempfile::tempdir().unwrap();
let file = dir.path().join("hello.txt");
std::fs::write(&file, b"v1").unwrap();
let cache = WebsiteCache::new(0);
let a = cache.get(&file).await.unwrap();
std::fs::write(&file, b"v2-different").unwrap();
let b = cache.get(&file).await.unwrap();
assert_ne!(a.digest_hex, b.digest_hex, "TTL expiry must re-read");
}
}