use crate::types::{RobotsCacheKey, RobotsPolicy};
use dashmap::DashMap;
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::fs;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{debug, info};
pub const MAX_CACHE_TTL: Duration = Duration::from_secs(24 * 60 * 60);
pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(60 * 60);
#[derive(Debug, Default)]
pub struct CacheStats {
pub hits: AtomicU64,
pub misses: AtomicU64,
pub evictions: AtomicU64,
pub entries: AtomicU64,
}
impl CacheStats {
pub fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
pub fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
pub fn record_eviction(&self) {
self.evictions.fetch_add(1, Ordering::Relaxed);
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn snapshot(&self) -> CacheStatsSnapshot {
CacheStatsSnapshot {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
entries: self.entries.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStatsSnapshot {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub entries: u64,
}
pub struct RobotsCache {
cache: Arc<DashMap<RobotsCacheKey, RobotsPolicy>>,
default_ttl: Duration,
persist_dir: Option<String>,
stats: Arc<CacheStats>,
}
impl Default for RobotsCache {
fn default() -> Self {
Self::new(DEFAULT_CACHE_TTL)
}
}
impl RobotsCache {
pub fn new(default_ttl: Duration) -> Self {
let default_ttl = default_ttl.min(MAX_CACHE_TTL);
Self {
cache: Arc::new(DashMap::new()),
default_ttl,
persist_dir: None,
stats: Arc::new(CacheStats::default()),
}
}
pub fn with_persistence(default_ttl: Duration, persist_dir: &str) -> Self {
let default_ttl = default_ttl.min(MAX_CACHE_TTL);
Self {
cache: Arc::new(DashMap::new()),
default_ttl,
persist_dir: Some(persist_dir.to_string()),
stats: Arc::new(CacheStats::default()),
}
}
pub fn default_ttl(&self) -> Duration {
self.default_ttl
}
pub fn stats(&self) -> Arc<CacheStats> {
self.stats.clone()
}
pub fn get(&self, key: &RobotsCacheKey) -> Option<RobotsPolicy> {
if let Some(entry) = self.cache.get(key) {
if !entry.is_expired() {
self.stats.record_hit();
debug!("Cache hit for {}", key.robots_url());
return Some(entry.clone());
}
drop(entry);
self.cache.remove(key);
self.stats.record_eviction();
}
self.stats.record_miss();
debug!("Cache miss for {}", key.robots_url());
None
}
pub fn insert(&self, key: RobotsCacheKey, policy: RobotsPolicy) {
let old = self.cache.insert(key.clone(), policy);
if old.is_none() {
self.stats.entries.fetch_add(1, Ordering::Relaxed);
}
debug!("Cached robots.txt for {}", key.robots_url());
}
pub fn remove(&self, key: &RobotsCacheKey) -> Option<RobotsPolicy> {
let removed = self.cache.remove(key).map(|(_, v)| v);
if removed.is_some() {
self.stats.entries.fetch_sub(1, Ordering::Relaxed);
self.stats.record_eviction();
}
removed
}
pub fn evict_expired(&self) -> usize {
let mut evicted = 0;
self.cache.retain(|_, policy| {
if policy.is_expired() {
evicted += 1;
false
} else {
true
}
});
if evicted > 0 {
self.stats.entries.fetch_sub(evicted as u64, Ordering::Relaxed);
self.stats.evictions.fetch_add(evicted as u64, Ordering::Relaxed);
info!("Evicted {} expired robots.txt entries", evicted);
}
evicted
}
pub fn clear(&self) {
let count = self.cache.len();
self.cache.clear();
self.stats.entries.store(0, Ordering::Relaxed);
info!("Cleared {} robots.txt cache entries", count);
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn domains(&self) -> Vec<String> {
self.cache
.iter()
.map(|entry| entry.key().authority.clone())
.collect()
}
pub async fn save_to_disk(&self) -> std::io::Result<usize> {
let persist_dir = match &self.persist_dir {
Some(dir) => dir,
None => return Ok(0),
};
fs::create_dir_all(persist_dir).await?;
let mut saved = 0;
for entry in self.cache.iter() {
let key = entry.key();
let policy = entry.value();
if policy.is_expired() {
continue;
}
let filename = self.cache_filename(key);
let filepath = Path::new(persist_dir).join(&filename);
if let Ok(json) = serde_json::to_string_pretty(&CacheEntry {
key: key.clone(),
groups: policy.groups.clone(),
sitemaps: policy.sitemaps.clone(),
content_size: policy.content_size,
ttl_secs: policy.ttl().as_secs(),
}) {
if let Ok(mut file) = fs::File::create(&filepath).await {
if file.write_all(json.as_bytes()).await.is_ok() {
saved += 1;
}
}
}
}
info!("Saved {} robots.txt entries to disk", saved);
Ok(saved)
}
pub async fn load_from_disk(&self) -> std::io::Result<usize> {
let persist_dir = match &self.persist_dir {
Some(dir) => dir,
None => return Ok(0),
};
let path = Path::new(persist_dir);
if !path.exists() {
return Ok(0);
}
let mut loaded = 0;
let mut entries = fs::read_dir(persist_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let filepath = entry.path();
if filepath.extension().is_some_and(|ext| ext == "json") {
if let Ok(mut file) = fs::File::open(&filepath).await {
let mut content = String::new();
if file.read_to_string(&mut content).await.is_ok() {
if let Ok(cache_entry) = serde_json::from_str::<CacheEntry>(&content) {
let ttl = Duration::from_secs(cache_entry.ttl_secs);
if ttl > Duration::ZERO {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let policy = RobotsPolicy {
fetched_at_ms: now,
expires_at_ms: now + ttl.as_millis() as u64,
fetch_status: crate::types::FetchStatus::Success,
groups: cache_entry.groups,
sitemaps: cache_entry.sitemaps,
content_size: cache_entry.content_size,
etag: None,
last_modified: None,
};
self.insert(cache_entry.key, policy);
loaded += 1;
}
}
}
}
}
}
info!("Loaded {} robots.txt entries from disk", loaded);
Ok(loaded)
}
fn cache_filename(&self, key: &RobotsCacheKey) -> String {
let combined = format!("{}_{}", key.scheme, key.authority);
let encoded = base64_encode(&combined);
format!("{}.json", encoded)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct CacheEntry {
key: RobotsCacheKey,
groups: Vec<crate::types::Group>,
sitemaps: Vec<String>,
content_size: usize,
ttl_secs: u64,
}
fn base64_encode(s: &str) -> String {
s.bytes()
.map(|b| format!("{:02x}", b))
.collect()
}
pub fn parse_cache_control(header: &str) -> Option<Duration> {
for directive in header.split(',') {
let directive = directive.trim();
if let Some(value) = directive.strip_prefix("max-age=") {
if let Ok(secs) = value.trim().parse::<u64>() {
return Some(Duration::from_secs(secs).min(MAX_CACHE_TTL));
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::RobotsParser;
fn create_test_policy() -> RobotsPolicy {
let parser = RobotsParser::new();
parser.parse("User-agent: *\nDisallow: /admin", Duration::from_secs(3600))
}
#[test]
fn test_cache_insert_get() {
let cache = RobotsCache::new(Duration::from_secs(3600));
let key = RobotsCacheKey {
scheme: "https".to_string(),
authority: "example.com".to_string(),
};
let policy = create_test_policy();
cache.insert(key.clone(), policy);
assert!(cache.get(&key).is_some());
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_miss() {
let cache = RobotsCache::new(Duration::from_secs(3600));
let key = RobotsCacheKey {
scheme: "https".to_string(),
authority: "example.com".to_string(),
};
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cache_stats() {
let cache = RobotsCache::new(Duration::from_secs(3600));
let key = RobotsCacheKey {
scheme: "https".to_string(),
authority: "example.com".to_string(),
};
cache.get(&key);
cache.insert(key.clone(), create_test_policy());
cache.get(&key);
let stats = cache.stats().snapshot();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_clear() {
let cache = RobotsCache::new(Duration::from_secs(3600));
let key = RobotsCacheKey {
scheme: "https".to_string(),
authority: "example.com".to_string(),
};
cache.insert(key, create_test_policy());
assert_eq!(cache.len(), 1);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_parse_cache_control() {
assert_eq!(
parse_cache_control("max-age=3600"),
Some(Duration::from_secs(3600))
);
assert_eq!(
parse_cache_control("public, max-age=7200"),
Some(Duration::from_secs(7200))
);
assert_eq!(
parse_cache_control("no-cache"),
None
);
assert_eq!(
parse_cache_control("max-age=999999"),
Some(MAX_CACHE_TTL)
);
}
#[test]
fn test_max_ttl_enforcement() {
let cache = RobotsCache::new(Duration::from_secs(100000));
assert_eq!(cache.default_ttl(), MAX_CACHE_TTL);
}
}