use crate::RegisterPlugin;
use crate::Result;
use crate::ShutdownPlugin;
use crate::config::PluginConfig;
use crate::dns::Message;
use crate::error::Error;
#[cfg(feature = "metrics")]
use crate::metrics;
use crate::plugin::traits::Shutdown;
use crate::plugin::{Context, Plugin, PluginHandler, RETURN_FLAG};
use crate::utils::task_queue::{RefreshCoordinator, RefreshTask};
use async_trait::async_trait;
use dashmap::DashSet;
use lru::LruCache;
use std::fmt;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, trace};
const STALE_RESPONSE_TTL_SECS: u32 = 5;
#[derive(Clone)]
struct CacheEntry {
response: Arc<Message>,
cached_at: Instant,
ttl: u32,
cache_ttl: u32,
original_ttl: u32,
last_accessed: Instant,
}
impl CacheEntry {
fn new(response: Message, ttl: u32, cache_ttl: u32) -> Self {
let now = Instant::now();
Self {
response: Arc::new(response),
cached_at: now,
ttl,
cache_ttl,
original_ttl: ttl,
last_accessed: now,
}
}
fn is_cache_expired(&self) -> bool {
if self.cache_ttl == 0 {
return true;
}
self.cached_at.elapsed() >= Duration::from_secs(self.cache_ttl as u64)
}
fn touch(&mut self) {
self.last_accessed = Instant::now();
}
fn remaining_ttl(&self) -> u32 {
let elapsed = self.cached_at.elapsed().as_secs() as u32;
self.ttl.saturating_sub(elapsed)
}
fn remaining_cache_ttl(&self) -> u32 {
let elapsed = self.cached_at.elapsed().as_secs() as u32;
self.cache_ttl.saturating_sub(elapsed)
}
}
#[derive(Debug, Default)]
pub struct CacheStats {
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
expirations: AtomicU64,
}
impl CacheStats {
fn new() -> Self {
Self::default()
}
fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
{
metrics::CACHE_HITS_TOTAL.inc();
}
}
fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
{
metrics::CACHE_MISSES_TOTAL.inc();
}
}
fn record_eviction(&self) {
self.evictions.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
{
crate::metrics::DNS_CACHE_EVICTIONS_TOTAL.inc();
}
}
fn record_expiration(&self) {
self.expirations.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
{
crate::metrics::DNS_CACHE_EXPIRATIONS_TOTAL.inc();
}
}
pub fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
pub fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
pub fn evictions(&self) -> u64 {
self.evictions.load(Ordering::Relaxed)
}
pub fn expirations(&self) -> u64 {
self.expirations.load(Ordering::Relaxed)
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits();
let total = hits + self.misses();
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn total_requests(&self) -> u64 {
self.hits() + self.misses()
}
}
impl fmt::Display for CacheStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CacheStats {{ hits: {}, misses: {}, evictions: {}, expirations: {}, hit_rate: {:.2}% }}",
self.hits(),
self.misses(),
self.evictions(),
self.expirations(),
self.hit_rate() * 100.0
)
}
}
#[derive(Debug, Default)]
pub struct LazyCacheStats {
refreshes: AtomicU64,
successful_refreshes: AtomicU64,
failed_refreshes: AtomicU64,
}
impl LazyCacheStats {
fn new() -> Self {
Self::default()
}
fn record_refresh(&self) {
self.refreshes.fetch_add(1, Ordering::Relaxed);
}
#[allow(dead_code)]
fn record_successful_refresh(&self) {
self.successful_refreshes.fetch_add(1, Ordering::Relaxed);
}
#[allow(dead_code)]
fn record_failed_refresh(&self) {
self.failed_refreshes.fetch_add(1, Ordering::Relaxed);
}
pub fn refreshes(&self) -> u64 {
self.refreshes.load(Ordering::Relaxed)
}
pub fn successful_refreshes(&self) -> u64 {
self.successful_refreshes.load(Ordering::Relaxed)
}
pub fn failed_refreshes(&self) -> u64 {
self.failed_refreshes.load(Ordering::Relaxed)
}
pub fn refresh_success_rate(&self) -> f64 {
let total = self.refreshes();
if total == 0 {
0.0
} else {
self.successful_refreshes() as f64 / total as f64
}
}
}
impl fmt::Display for LazyCacheStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"LazyCacheStats {{ refreshes: {}, successful: {}, failed: {}, success_rate: {:.2}% }}",
self.refreshes(),
self.successful_refreshes(),
self.failed_refreshes(),
self.refresh_success_rate() * 100.0
)
}
}
#[derive(Clone, RegisterPlugin, ShutdownPlugin)]
pub struct CachePlugin {
cache: Arc<parking_lot::RwLock<LruCache<String, CacheEntry>>>,
max_size: usize,
stats: Arc<CacheStats>,
negative_cache: bool,
negative_ttl: u32,
enable_prefetch: bool,
prefetch_threshold: f32,
enable_lazycache: bool,
lazycache_threshold: f32,
cache_ttl: Option<u32>,
lazycache_stats: Arc<LazyCacheStats>,
lazycache_threshold_dynamic: Arc<tokio::sync::RwLock<f32>>,
refreshing_keys: Arc<DashSet<String>>,
tag: Option<String>,
refresh_coordinator: Arc<Mutex<Option<RefreshCoordinator>>>,
enable_cleanup: bool,
cleanup_interval_secs: u64,
cleanup_pressure_threshold: f32,
}
impl CachePlugin {
pub fn new(max_size: usize) -> Self {
let capacity = NonZeroUsize::new(max_size.max(1)).unwrap();
Self {
cache: Arc::new(parking_lot::RwLock::new(LruCache::new(capacity))),
max_size,
stats: Arc::new(CacheStats::new()),
negative_cache: false,
negative_ttl: 300, enable_prefetch: false,
prefetch_threshold: 0.1, enable_lazycache: false,
lazycache_threshold: 0.05, cache_ttl: None,
lazycache_stats: Arc::new(LazyCacheStats::new()),
lazycache_threshold_dynamic: Arc::new(tokio::sync::RwLock::new(0.05)),
refreshing_keys: Arc::new(DashSet::new()),
tag: None,
refresh_coordinator: Arc::new(Mutex::new(None)),
enable_cleanup: true,
cleanup_interval_secs: 60,
cleanup_pressure_threshold: 0.8,
}
}
pub fn with_negative_cache(mut self, ttl: u32) -> Self {
self.negative_cache = true;
self.negative_ttl = ttl;
self
}
pub fn with_prefetch(mut self, threshold: f32) -> Self {
self.enable_prefetch = true;
self.prefetch_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn with_lazycache(mut self, threshold: f32) -> Self {
self.enable_lazycache = true;
self.lazycache_threshold = threshold.clamp(0.0, 1.0);
match self.refresh_coordinator.try_lock() {
Ok(mut guard) if guard.is_none() => {
*guard = Some(RefreshCoordinator::new(4, 1000));
}
_ => {}
}
self
}
pub fn with_cache_ttl(mut self, ttl_secs: u32) -> Self {
if ttl_secs > 0 {
self.cache_ttl = Some(ttl_secs);
match self.refresh_coordinator.try_lock() {
Ok(mut guard) if guard.is_none() => {
*guard = Some(RefreshCoordinator::new(4, 1000));
}
_ => {}
}
}
self
}
pub fn with_cleanup(
mut self,
enabled: bool,
interval_secs: u64,
pressure_threshold: f32,
) -> Self {
self.enable_cleanup = enabled;
self.cleanup_interval_secs = interval_secs.max(1); self.cleanup_pressure_threshold = pressure_threshold.clamp(0.0, 1.0);
self
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn lazycache_stats(&self) -> &LazyCacheStats {
&self.lazycache_stats
}
pub fn get_lazycache_threshold(&self) -> f32 {
self.lazycache_threshold
}
pub fn set_lazycache_threshold(&self, threshold: f32) {
let clamped = threshold.clamp(0.0, 1.0);
debug!("Updating LazyCache threshold to {:.2}%", clamped * 100.0);
}
pub async fn set_lazycache_threshold_async(&self, threshold: f32) {
let clamped = threshold.clamp(0.0, 1.0);
let mut dynamic_threshold = self.lazycache_threshold_dynamic.write().await;
*dynamic_threshold = clamped;
debug!("LazyCache threshold updated to {:.2}%", clamped * 100.0);
}
pub async fn get_lazycache_threshold_async(&self) -> f32 {
*self.lazycache_threshold_dynamic.read().await
}
pub fn size(&self) -> usize {
self.cache.read().len()
}
pub fn cleanup_expired(&self) -> usize {
let mut cache = self.cache.write();
let mut removed = 0;
debug!("Cleanup: starting cache cleanup of expired entries");
let expired_keys: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.is_cache_expired())
.map(|(k, _)| k.clone())
.collect();
for key in expired_keys {
debug!("Cleanup: removing expired cache entry: {}", key);
if let Some(removed_entry) = cache.pop(&key) {
drop(removed_entry); self.stats.record_expiration();
removed += 1;
}
}
#[cfg(feature = "metrics")]
{
metrics::CACHE_SIZE.set(cache.len() as i64);
}
if removed > 0 {
debug!("Cleanup removed {} expired cache entries", removed);
}
removed
}
fn should_cleanup_pressure(&self) -> bool {
let size = self.size();
let threshold = (self.max_size as f32 * self.cleanup_pressure_threshold) as usize;
size > threshold
}
pub fn is_cleanup_enabled(&self) -> bool {
self.enable_cleanup
}
pub fn spawn_cleanup_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval =
tokio::time::interval(Duration::from_secs(self.cleanup_interval_secs));
loop {
interval.tick().await;
let removed = self.cleanup_expired();
if self.should_cleanup_pressure() {
debug!(
"Memory pressure detected: {} / {}",
self.size(),
self.max_size
);
let pressure_removed = self.cleanup_expired();
debug!(
"Pressure cleanup removed {} entries (total in this cycle: {})",
pressure_removed,
removed + pressure_removed
);
}
}
})
}
pub fn clear(&self) {
self.cache.write().clear();
#[cfg(feature = "metrics")]
{
metrics::CACHE_SIZE.set(0);
}
}
fn make_key(message: &Message) -> Option<String> {
message.questions().first().map(|q| {
let qname_lower = q.qname().to_lowercase();
let key = format!(
"{}:{}:{}",
qname_lower,
q.qtype().to_u16(),
q.qclass().to_u16()
);
key
})
}
fn store(&self, key: String, entry: CacheEntry) {
let mut cache = self.cache.write();
let key_exists = cache.contains(&key);
if let Some((evicted_key, _)) = cache.push(key, entry) {
if !key_exists {
self.stats.record_eviction();
debug!("LRU evicted cache entry: {}", evicted_key);
} else {
trace!("Cache store: replaced existing entry: {}", evicted_key);
}
}
trace!(
stats = ?self.stats,
"Cache stats after store operation"
);
#[cfg(feature = "metrics")]
{
metrics::CACHE_SIZE.set(cache.len() as i64);
}
}
fn get_min_ttl(message: &Message) -> u32 {
let mut min_ttl = u32::MAX;
for record in message.answers() {
min_ttl = min_ttl.min(record.ttl());
}
for record in message.authority() {
min_ttl = min_ttl.min(record.ttl());
}
for record in message.additional() {
min_ttl = min_ttl.min(record.ttl());
}
if min_ttl == u32::MAX {
300
} else {
min_ttl.max(1)
}
}
fn update_ttls(message: &mut Message, remaining_ttl: u32) {
for record in message.answers_mut() {
record.set_ttl(remaining_ttl);
}
for record in message.authority_mut() {
record.set_ttl(remaining_ttl);
}
for record in message.additional_mut() {
record.set_ttl(remaining_ttl);
}
}
}
impl fmt::Debug for CachePlugin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CachePlugin")
.field("max_size", &self.max_size)
.field("current_size", &self.size())
.field("stats", &self.stats())
.finish()
}
}
#[async_trait]
impl Plugin for CachePlugin {
async fn execute(&self, context: &mut Context) -> Result<()> {
let key = match Self::make_key(context.request()) {
Some(k) => k,
None => {
debug!("Cannot generate cache key, no questions in request");
return Ok(());
}
};
let cache_already_checked = context.get_metadata::<bool>("cache_checked").is_some();
if context
.get_metadata::<bool>("response_from_cache")
.is_some()
{
return Ok(());
}
if context.response().is_none() {
context.set_metadata("cache_checked", true);
if context
.get_metadata::<bool>("background_lazy_refresh")
.is_some()
{
debug!("Skipping cache logic for background lazy refresh");
return Ok(());
}
let cached_entry = {
let mut cache = self.cache.write();
cache.get(&key).cloned()
};
if let Some(mut entry) = cached_entry {
if entry.is_cache_expired() {
debug!("Cache entry expired: {}", key);
self.cache.write().pop(&key);
self.stats.record_expiration();
self.stats.record_miss();
#[cfg(feature = "metrics")]
{
metrics::CACHE_SIZE.set(self.size() as i64);
}
return Ok(());
}
debug!("Cache hit: {}", key);
self.stats.record_hit();
entry.touch();
let remaining_ttl = entry.remaining_ttl();
if remaining_ttl == 0 {
if let Some(lazy_ttl) = self.cache_ttl {
debug!(
"Stale-serving TTL hit (stale entry): {}, cache_remaining: {}s, configured_lazy_ttl: {}s",
key,
entry.remaining_cache_ttl(),
lazy_ttl
);
let mut response_arc = Arc::clone(&entry.response);
let response_ref = Arc::make_mut(&mut response_arc);
Self::update_ttls(response_ref, STALE_RESPONSE_TTL_SECS); response_ref.set_id(context.request().id());
context.set_response_arc(Some(response_arc));
context.set_metadata("response_from_cache", true);
if self.refreshing_keys.insert(key.clone()) {
self.lazycache_stats.record_refresh();
if let (Some(handler), Some(entry_name)) = (
context.get_metadata::<Arc<PluginHandler>>("lazy_refresh_handler"),
context.get_metadata::<String>("lazy_refresh_entry"),
) {
let background_handler = Arc::new(PluginHandler {
registry: Arc::clone(&handler.registry),
entry: entry_name.clone(),
});
let refreshing_keys_clone = Arc::clone(&self.refreshing_keys);
let mut request_clone = context.request().clone();
let key_clone = key.clone();
let coordinator = Arc::clone(&self.refresh_coordinator);
request_clone.set_id(0xFFFF);
let task = RefreshTask {
key: key_clone.clone(),
message: request_clone,
handler: background_handler,
entry_name: entry_name.clone(),
created_at: Instant::now(),
};
tokio::spawn(async move {
if let Some(coord) = coordinator.lock().await.as_ref() {
match coord.enqueue(task).await {
Ok(_) => {
debug!(
"Background stale-serving TTL refresh enqueued for {}",
key_clone
);
}
Err(e) => {
debug!(
"Failed to enqueue stale-serving TTL refresh for {}: {}",
key_clone, e
);
refreshing_keys_clone.remove(&key_clone);
}
}
} else {
debug!("Refresh coordinator not initialized");
refreshing_keys_clone.remove(&key_clone);
}
});
} else {
debug!(
"Stale-serving TTL: handler metadata missing, falling back to invalidate stale entry"
);
let cache_clone = Arc::clone(&self.cache);
let refreshing_keys_clone = Arc::clone(&self.refreshing_keys);
let key_clone = key.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10))
.await;
cache_clone.write().pop(&key_clone);
refreshing_keys_clone.remove(&key_clone);
});
}
} else {
debug!(
"Stale-serving TTL: {} already being refreshed, skip duplicate background refresh",
key
);
}
context.set_metadata(RETURN_FLAG, true);
return Ok(());
} else {
debug!("Cache entry message TTL expired without cache_ttl: {}", key);
self.cache.write().pop(&key);
self.stats.record_expiration();
self.stats.record_miss();
#[cfg(feature = "metrics")]
{
metrics::CACHE_SIZE.set(self.size() as i64);
}
return Ok(());
}
}
let should_lazy_refresh = if self.enable_lazycache {
if context
.get_metadata::<bool>("background_lazy_refresh")
.is_some()
{
debug!(
"Background lazy refresh: skipping lazy refresh check for {}",
key
);
false
} else {
let ttl_percentage = remaining_ttl as f32 / entry.original_ttl as f32;
let threshold = self.lazycache_threshold;
debug!(
"LazyCache check: {}, original_ttl: {}s, remaining: {}s, percentage: {:.2}%, threshold: {:.2}%",
key,
entry.original_ttl,
remaining_ttl,
ttl_percentage * 100.0,
threshold * 100.0
);
if ttl_percentage <= threshold {
debug!(
"LazyCache threshold REACHED for {}: {:.2}% TTL remaining (< {:.2}%), triggering refresh",
key,
ttl_percentage * 100.0,
threshold * 100.0
);
self.lazycache_stats.record_refresh();
true
} else {
false
}
}
} else {
debug!(
"Lazycache disabled (enable_lazycache={})",
self.enable_lazycache
);
false
};
if should_lazy_refresh {
let mut response_arc = Arc::clone(&entry.response);
let response_ref = Arc::make_mut(&mut response_arc);
Self::update_ttls(response_ref, remaining_ttl);
response_ref.set_id(context.request().id());
context.set_response_arc(Some(response_arc));
context.set_metadata("response_from_cache", true);
if self.refreshing_keys.insert(key.clone()) {
debug!(
"LazyCache: returning cached response immediately, triggering background refresh for {}",
key
);
self.lazycache_stats.record_refresh();
if let (Some(handler), Some(entry_name)) = (
context.get_metadata::<Arc<PluginHandler>>("lazy_refresh_handler"),
context.get_metadata::<String>("lazy_refresh_entry"),
) {
let background_handler = Arc::new(PluginHandler {
registry: Arc::clone(&handler.registry),
entry: entry_name.clone(),
});
let refreshing_keys_clone = Arc::clone(&self.refreshing_keys);
let mut request_clone = context.request().clone();
let key_clone = key.clone();
let coordinator = Arc::clone(&self.refresh_coordinator);
request_clone.set_id(0xFFFF);
let task = RefreshTask {
key: key_clone.clone(),
message: request_clone,
handler: background_handler,
entry_name: entry_name.clone(),
created_at: Instant::now(),
};
tokio::spawn(async move {
if let Some(coord) = coordinator.lock().await.as_ref() {
match coord.enqueue(task).await {
Ok(_) => {
debug!(
"Background lazy refresh enqueued for {}",
key_clone
);
}
Err(e) => {
debug!(
"Failed to enqueue lazy refresh for {}: {}",
key_clone, e
);
refreshing_keys_clone.remove(&key_clone);
}
}
} else {
debug!("Refresh coordinator not initialized");
refreshing_keys_clone.remove(&key_clone);
}
});
} else {
debug!(
"LazyCache: lazy_refresh_handler not available in metadata or coordinator not initialized, falling back to cache invalidation"
);
let cache_clone = Arc::clone(&self.cache);
let refreshing_keys_clone = Arc::clone(&self.refreshing_keys);
let key_clone = key.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
debug!("Fallback: invalidating cache entry for {}", key_clone);
cache_clone.write().pop(&key_clone);
refreshing_keys_clone.remove(&key_clone);
});
}
} else {
debug!(
"LazyCache: {} already being refreshed by another request, skipping duplicate refresh",
key
);
}
context.set_metadata(RETURN_FLAG, true);
return Ok(());
} else {
if context
.get_metadata::<bool>("background_lazy_refresh")
.is_some()
{
debug!(
"Background lazy refresh: cache hit but continuing downstream for {}",
key
);
return Ok(());
} else {
let mut response_arc = Arc::clone(&entry.response);
let response_ref = Arc::make_mut(&mut response_arc);
Self::update_ttls(response_ref, remaining_ttl);
response_ref.set_id(context.request().id());
context.set_response_arc(Some(response_arc));
context.set_metadata("response_from_cache", true);
trace!("Normal cache hit: returning immediately and stopping chain");
context.set_metadata(RETURN_FLAG, true);
return Ok(());
}
}
}
if !cache_already_checked {
self.stats.record_miss();
debug!("Cache miss: {}", key);
}
} else {
if context
.get_metadata::<bool>("response_from_cache")
.is_none()
&& context.response().is_some()
{
let response = context.response().unwrap();
let response_code = response.response_code();
let is_error = response_code != crate::dns::ResponseCode::NoError;
if is_error {
if self.negative_cache {
debug!(
"Caching negative response: {:?} (TTL: {}s)",
response_code, self.negative_ttl
);
let cache_ttl = self.cache_ttl.unwrap_or(self.negative_ttl);
let entry = CacheEntry::new(response.clone(), self.negative_ttl, cache_ttl);
self.store(key.clone(), entry);
} else {
debug!("Not caching error response: {:?}", response_code);
}
} else if !response.answers().is_empty() {
let ttl = Self::get_min_ttl(response);
if ttl > 0 {
let cache_ttl = self.cache_ttl.unwrap_or(ttl);
debug!(
"Storing response in cache: {} (message TTL: {}s, cache TTL: {}s)",
key, ttl, cache_ttl
);
let entry = CacheEntry::new(response.clone(), ttl, cache_ttl);
self.store(key.clone(), entry);
}
}
}
}
Ok(())
}
fn name(&self) -> &str {
"cache"
}
fn tag(&self) -> Option<&str> {
self.tag.as_deref()
}
fn priority(&self) -> i32 {
50
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn init(config: &PluginConfig) -> Result<Arc<dyn Plugin>> {
let args = config.effective_args();
use serde_yaml::Value;
let size = match args.get("size") {
Some(Value::Number(n)) => n
.as_i64()
.ok_or_else(|| Error::Config("Invalid size value".to_string()))?
as usize,
Some(_) => return Err(Error::Config("size must be a number".to_string())),
None => 1024,
};
let mut cache = CachePlugin::new(size);
if let Some(Value::Bool(true)) = args.get("negative_cache") {
let negative_ttl = match args.get("negative_ttl") {
Some(Value::Number(n)) => n
.as_i64()
.ok_or_else(|| Error::Config("Invalid negative_ttl value".to_string()))?
as u32,
Some(_) => return Err(Error::Config("negative_ttl must be a number".to_string())),
None => 300,
};
cache = cache.with_negative_cache(negative_ttl);
}
if let Some(Value::Bool(true)) = args.get("enable_prefetch") {
let threshold = match args.get("prefetch_threshold") {
Some(Value::Number(n)) => n
.as_f64()
.ok_or_else(|| Error::Config("Invalid prefetch_threshold value".to_string()))?
as f32,
Some(_) => {
return Err(Error::Config(
"prefetch_threshold must be a number".to_string(),
));
}
None => 0.1,
};
cache = cache.with_prefetch(threshold);
}
if let Some(Value::Number(n)) = args.get("cache_ttl") {
let ttl = n
.as_i64()
.ok_or_else(|| Error::Config("Invalid cache_ttl value".to_string()))?
as u32;
if ttl > 0 {
cache = cache.with_cache_ttl(ttl);
}
}
let worker_count = match args.get("refresh_worker_count") {
Some(Value::Number(n)) => n
.as_i64()
.ok_or_else(|| Error::Config("Invalid refresh_worker_count value".to_string()))?
as usize,
Some(_) => {
return Err(Error::Config(
"refresh_worker_count must be a number".to_string(),
));
}
None => 4, };
let queue_capacity = match args.get("refresh_queue_capacity") {
Some(Value::Number(n)) => n
.as_i64()
.ok_or_else(|| Error::Config("Invalid refresh_queue_capacity value".to_string()))?
as usize,
Some(_) => {
return Err(Error::Config(
"refresh_queue_capacity must be a number".to_string(),
));
}
None => 1000, };
if cache.enable_lazycache || cache.cache_ttl.is_some() {
if let Ok(mut guard) = cache.refresh_coordinator.try_lock() {
if guard.is_none() {
*guard = Some(RefreshCoordinator::new(worker_count, queue_capacity));
}
} else {
let coordinator = RefreshCoordinator::new(worker_count, queue_capacity);
cache.refresh_coordinator = Arc::new(Mutex::new(Some(coordinator)));
}
}
if let Some(Value::Bool(true)) = args.get("enable_lazycache") {
let threshold = match args.get("lazycache_threshold") {
Some(Value::Number(n)) => n
.as_f64()
.ok_or_else(|| Error::Config("Invalid lazycache_threshold value".to_string()))?
as f32,
Some(_) => {
return Err(Error::Config(
"lazycache_threshold must be a number".to_string(),
));
}
None => 0.05, };
cache = cache.with_lazycache(threshold);
}
let enable_cleanup = match args.get("enable_cleanup") {
Some(Value::Bool(b)) => *b,
Some(_) => {
return Err(Error::Config(
"enable_cleanup must be a boolean".to_string(),
));
}
None => true,
};
let cleanup_interval_secs = match args.get("cleanup_interval_secs") {
Some(Value::Number(n)) => n
.as_i64()
.ok_or_else(|| Error::Config("Invalid cleanup_interval_secs value".to_string()))?
as u64,
Some(_) => {
return Err(Error::Config(
"cleanup_interval_secs must be a number".to_string(),
));
}
None => 60,
};
let cleanup_pressure_threshold = match args.get("cleanup_pressure_threshold") {
Some(Value::Number(n)) => n.as_f64().ok_or_else(|| {
Error::Config("Invalid cleanup_pressure_threshold value".to_string())
})? as f32,
Some(_) => {
return Err(Error::Config(
"cleanup_pressure_threshold must be a number".to_string(),
));
}
None => 0.8,
};
cache = cache.with_cleanup(
enable_cleanup,
cleanup_interval_secs,
cleanup_pressure_threshold,
);
cache.tag = config.tag.clone();
debug!(
"CachePlugin initialized: size={}, negative_cache={}, lazycache_enabled={}, lazycache_threshold={:.1}%, cleanup_enabled={}, cleanup_interval={}s",
cache.max_size,
cache.negative_cache,
cache.enable_lazycache,
cache.lazycache_threshold * 100.0,
cache.enable_cleanup,
cache.cleanup_interval_secs
);
Ok(Arc::new(cache))
}
}
#[async_trait]
impl Shutdown for CachePlugin {
async fn shutdown(&self) -> Result<()> {
if let Some(coordinator) = self.refresh_coordinator.lock().await.take() {
debug!("Shutting down CachePlugin refresh coordinator");
coordinator.shutdown().await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::{Message, Question, RData, RecordClass, RecordType, ResourceRecord};
use std::net::Ipv4Addr;
fn create_test_message() -> Message {
let mut msg = Message::new();
msg.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
msg
}
fn create_test_response() -> Message {
let mut msg = create_test_message();
msg.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A(Ipv4Addr::new(93, 184, 216, 34)),
));
msg
}
#[test]
fn test_cache_entry_creation() {
let response = create_test_response();
let entry = CacheEntry::new(response.clone(), 300, 300);
assert_eq!(entry.ttl, 300);
assert!(!entry.is_cache_expired());
assert_eq!(entry.response.answers().len(), response.answers().len());
}
#[test]
fn test_cache_entry_expiration() {
let response = create_test_response();
let entry = CacheEntry::new(response, 0, 0);
assert!(entry.is_cache_expired());
}
#[test]
fn test_cache_entry_remaining_ttl() {
let response = create_test_response();
let entry = CacheEntry::new(response, 300, 300);
let remaining = entry.remaining_ttl();
assert!(remaining <= 300);
assert!(remaining >= 299); }
#[test]
fn test_cache_stats() {
let stats = CacheStats::new();
assert_eq!(stats.hits(), 0);
assert_eq!(stats.misses(), 0);
assert_eq!(stats.evictions(), 0);
stats.record_hit();
stats.record_hit();
stats.record_miss();
assert_eq!(stats.hits(), 2);
assert_eq!(stats.misses(), 1);
assert_eq!(stats.hit_rate(), 2.0 / 3.0);
}
#[test]
fn test_cache_plugin_creation() {
let cache = CachePlugin::new(100);
assert_eq!(cache.max_size, 100);
assert_eq!(cache.size(), 0);
assert_eq!(cache.stats().hits(), 0);
}
#[test]
fn test_plugin_as_any_downcast_present() {
use std::sync::Arc;
let cache = CachePlugin::new(128);
let plugin: Arc<dyn crate::plugin::Plugin> = Arc::new(cache);
assert!(
plugin
.as_ref()
.as_any()
.downcast_ref::<CachePlugin>()
.is_some()
);
}
#[test]
fn test_make_key() {
let msg = create_test_message();
let key = CachePlugin::make_key(&msg);
assert!(key.is_some());
assert_eq!(key.unwrap(), "example.com:1:1");
}
#[test]
fn test_make_key_case_insensitive() {
let msg_lower = create_test_message();
let mut msg_upper = create_test_message();
msg_upper.questions_mut()[0].set_qname("EXAMPLE.COM".to_string());
let key_lower = CachePlugin::make_key(&msg_lower);
let key_upper = CachePlugin::make_key(&msg_upper);
assert!(key_lower.is_some());
assert!(key_upper.is_some());
let key_lower_str = key_lower.unwrap();
let key_upper_str = key_upper.unwrap();
assert_eq!(key_lower_str, key_upper_str);
assert_eq!(key_lower_str, "example.com:1:1");
}
#[test]
fn test_make_key_no_questions() {
let msg = Message::new();
let key = CachePlugin::make_key(&msg);
assert!(key.is_none());
}
#[test]
fn test_get_min_ttl() {
let response = create_test_response();
let ttl = CachePlugin::get_min_ttl(&response);
assert_eq!(ttl, 300);
}
#[test]
fn test_get_min_ttl_no_records() {
let msg = create_test_message();
let ttl = CachePlugin::get_min_ttl(&msg);
assert_eq!(ttl, 300);
}
#[cfg(feature = "metrics")]
#[tokio::test]
async fn test_cache_miss() {
let cache = CachePlugin::new(100);
let request = create_test_message();
let mut context = Context::new(request);
let prev_misses = metrics::CACHE_MISSES_TOTAL.get();
cache.execute(&mut context).await.unwrap();
assert!(context.response().is_none());
assert!(cache.stats().misses() >= 1);
assert_eq!(cache.stats().hits(), 0);
assert_eq!(metrics::CACHE_MISSES_TOTAL.get(), prev_misses + 1);
}
#[cfg(feature = "metrics")]
#[tokio::test]
async fn test_cache_hit() {
let cache = CachePlugin::new(100);
let response = create_test_response();
let key = "example.com:1:1".to_string();
let entry = CacheEntry::new(response.clone(), 300, 300);
cache.store(key.clone(), entry);
assert_eq!(metrics::CACHE_SIZE.get(), cache.size() as i64);
let request = create_test_message();
let mut context = Context::new(request);
let prev_hits = metrics::CACHE_HITS_TOTAL.get();
cache.execute(&mut context).await.unwrap();
assert!(context.response().is_some());
assert_eq!(cache.stats().hits(), 1);
assert_eq!(cache.stats().misses(), 0);
assert_eq!(metrics::CACHE_HITS_TOTAL.get(), prev_hits + 1);
}
#[tokio::test]
async fn test_cache_expiration() {
let cache = CachePlugin::new(100);
let response = create_test_response();
let key = "example.com:1:1".to_string();
let entry = CacheEntry::new(response.clone(), 0, 0);
cache.cache.write().push(key.clone(), entry);
let request = create_test_message();
let mut context = Context::new(request);
cache.execute(&mut context).await.unwrap();
assert!(context.response().is_none());
assert_eq!(cache.stats().misses(), 1);
assert_eq!(cache.stats().expirations(), 1);
assert!(!cache.cache.read().contains(&key));
}
#[cfg(feature = "metrics")]
#[test]
fn test_cache_clear() {
let cache = CachePlugin::new(100);
let response = create_test_response();
let entry = CacheEntry::new(response.clone(), 300, 300);
cache.store("key1".to_string(), entry.clone());
cache.store("key2".to_string(), entry.clone());
assert_eq!(cache.size(), 2);
assert_eq!(metrics::CACHE_SIZE.get(), 2);
cache.clear();
assert_eq!(cache.size(), 0);
assert_eq!(metrics::CACHE_SIZE.get(), 0);
}
#[test]
fn test_lru_eviction() {
let cache = CachePlugin::new(2);
let response = create_test_response();
let entry1 = CacheEntry::new(response.clone(), 300, 300);
let entry2 = CacheEntry::new(response.clone(), 300, 300);
let entry3 = CacheEntry::new(response.clone(), 300, 300);
cache.cache.write().push("key1".to_string(), entry1);
cache.cache.write().push("key2".to_string(), entry2);
assert_eq!(cache.size(), 2);
cache.store("key3".to_string(), entry3);
assert_eq!(cache.size(), 2);
assert_eq!(cache.stats().evictions(), 1);
}
#[tokio::test]
async fn test_configured_cache_sequence_execution() {
let yaml = r#"
plugins:
- tag: my_cache
type: cache
config:
size: 16
- tag: seq
type: sequence
args:
- exec: "$my_cache"
"#;
let cfg = crate::config::Config::from_yaml(yaml).expect("parse yaml");
let mut builder = crate::plugin::builder::PluginBuilder::new();
for pc in &cfg.plugins {
builder.build(pc).expect("build plugin");
}
builder
.resolve_references(&cfg.plugins)
.expect("resolve refs");
let plugin = builder.get_plugin("seq").expect("sequence exists");
let mut ctx = crate::plugin::Context::new(crate::dns::Message::new());
plugin.execute(&mut ctx).await.expect("execute sequence");
assert_eq!(plugin.name(), "sequence");
}
#[tokio::test]
async fn test_lazycache_refresh_threshold_triggers() {
let cache = CachePlugin::new(100);
cache.set_lazycache_threshold(0.1);
let response = create_test_response();
let mut ctx = crate::plugin::Context::new(create_test_message());
ctx.set_response(Some(response.clone()));
let res = cache.execute(&mut ctx).await;
assert!(res.is_ok());
let mut ctx = crate::plugin::Context::new(create_test_message());
let res = cache.execute(&mut ctx).await;
assert!(res.is_ok());
assert!(ctx.response().is_some());
assert!(
ctx.get_metadata::<bool>("needs_lazycache_refresh")
.is_none()
);
let cache_entry = cache
.cache
.read()
.peek(&"example.com:1:1".to_string())
.expect("entry exists")
.clone();
let ttl_percent = cache_entry.remaining_ttl() as f32 / cache_entry.ttl as f32;
let threshold = cache.get_lazycache_threshold();
assert!(ttl_percent > threshold);
assert_eq!(cache.lazycache_stats.refreshes(), 0); }
#[tokio::test]
async fn test_lazycache_continues_pipeline_on_refresh() {
let cache = CachePlugin::new(100);
cache.set_lazycache_threshold(0.05);
let response = create_test_response();
let mut ctx = crate::plugin::Context::new(create_test_message());
ctx.set_response(Some(response));
cache.execute(&mut ctx).await.expect("cache store");
assert!(ctx.response().is_some());
let mut ctx = crate::plugin::Context::new(create_test_message());
cache.execute(&mut ctx).await.expect("cache hit");
assert!(ctx.response().is_some());
assert!(
ctx.get_metadata::<bool>("needs_lazycache_refresh")
.is_none()
);
}
#[tokio::test]
async fn test_cache_ttl_serves_stale_and_refreshes() {
use tokio::time::{Duration, sleep};
let cache = CachePlugin::new(100).with_cache_ttl(10);
let mut response = create_test_response();
for rr in response.answers_mut() {
rr.set_ttl(1);
}
let mut ctx = crate::plugin::Context::new(create_test_message());
ctx.set_response(Some(response.clone()));
cache.execute(&mut ctx).await.expect("cache store");
sleep(Duration::from_secs(2)).await;
let mut ctx = crate::plugin::Context::new(create_test_message());
cache.execute(&mut ctx).await.expect("cache stale hit");
let resp = ctx.response().expect("stale response returned");
assert!(resp.answers()[0].ttl() <= STALE_RESPONSE_TTL_SECS);
sleep(Duration::from_millis(50)).await;
assert!(cache.lazycache_stats.refreshes() >= 1);
}
#[test]
fn test_cleanup_expired() {
let cache = CachePlugin::new(100);
let response = create_test_response();
let entry1 = CacheEntry::new(response.clone(), 0, 0); let entry2 = CacheEntry::new(response.clone(), 0, 0); let entry3 = CacheEntry::new(response.clone(), 300, 300);
cache.cache.write().push("key1".to_string(), entry1);
cache.cache.write().push("key2".to_string(), entry2);
cache.cache.write().push("key3".to_string(), entry3);
assert_eq!(cache.size(), 3);
assert_eq!(cache.stats().expirations(), 0);
let removed = cache.cleanup_expired();
assert_eq!(removed, 2); assert_eq!(cache.size(), 1); assert_eq!(cache.stats().expirations(), 2); }
#[test]
fn test_should_cleanup_pressure() {
let mut cache = CachePlugin::new(10);
cache = cache.with_cleanup(true, 60, 0.5);
let response = create_test_response();
for i in 0..6 {
let entry = CacheEntry::new(response.clone(), 300, 300);
cache.cache.write().push(format!("key{}", i), entry);
}
assert!(cache.should_cleanup_pressure());
let cache2 = CachePlugin::new(10).with_cleanup(true, 60, 0.9);
for i in 0..6 {
let entry = CacheEntry::new(response.clone(), 300, 300);
cache2.cache.write().push(format!("key{}", i), entry);
}
assert!(!cache2.should_cleanup_pressure()); }
#[tokio::test]
async fn test_spawn_cleanup_task() {
let cache = Arc::new(CachePlugin::new(100));
let response = create_test_response();
let entry1 = CacheEntry::new(response.clone(), 0, 0);
let entry2 = CacheEntry::new(response.clone(), 1, 1);
let entry3 = CacheEntry::new(response.clone(), 300, 300);
cache.cache.write().push("key1".to_string(), entry1);
cache.cache.write().push("key2".to_string(), entry2);
cache.cache.write().push("key3".to_string(), entry3);
assert_eq!(cache.size(), 3);
let cache_with_short_interval = {
let mut c = CachePlugin::new(100);
c.cleanup_interval_secs = 1; c.enable_cleanup = true;
Arc::new(c)
};
let cleanup_handle = cache_with_short_interval.clone().spawn_cleanup_task();
tokio::time::sleep(Duration::from_millis(1500)).await;
cleanup_handle.abort();
}
}