use std::{
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::{Duration, Instant},
};
use bytes::{BufMut, Bytes, BytesMut};
use moka::future::Cache;
use crate::codec::message::Question;
#[derive(Clone, Debug)]
struct CachedResponse {
bytes: Bytes,
ttl_offsets: Arc<[usize]>,
stored_at: Instant,
expiry: Duration,
}
#[derive(Debug, Clone)]
struct DnsCacheExpiry;
impl moka::Expiry<Question, CachedResponse> for DnsCacheExpiry {
fn expire_after_create(
&self,
_key: &Question,
value: &CachedResponse,
_current_time: Instant,
) -> Option<Duration> {
Some(value.expiry)
}
}
impl CachedResponse {
fn patched_for(&self, client_txn_id: u16, elapsed_secs: u32) -> Bytes {
let len = self.bytes.len();
let mut buf = BytesMut::with_capacity(len);
buf.put_slice(&self.bytes);
if len >= 2 {
let id_bytes = client_txn_id.to_be_bytes();
buf[0] = id_bytes[0];
buf[1] = id_bytes[1];
}
for &offset in self.ttl_offsets.iter() {
if offset + 4 <= len {
let old = u32::from_be_bytes([
buf[offset],
buf[offset + 1],
buf[offset + 2],
buf[offset + 3],
]);
let new_ttl = old.saturating_sub(elapsed_secs);
let new_bytes = new_ttl.to_be_bytes();
buf[offset] = new_bytes[0];
buf[offset + 1] = new_bytes[1];
buf[offset + 2] = new_bytes[2];
buf[offset + 3] = new_bytes[3];
}
}
buf.freeze()
}
}
#[derive(Clone, Debug)]
pub struct DnsCache {
inner: Cache<Question, CachedResponse>,
ttl_bounds: Arc<AtomicU64>,
}
impl DnsCache {
#[must_use]
pub fn new(capacity: u64, min_ttl: u32, max_ttl: u32) -> Self {
debug_assert!(min_ttl <= max_ttl, "min_ttl must not exceed max_ttl");
let inner = Cache::builder()
.max_capacity(capacity)
.expire_after(DnsCacheExpiry)
.build();
Self {
inner,
ttl_bounds: Arc::new(AtomicU64::new(pack_bounds(min_ttl, max_ttl))),
}
}
pub fn set_ttl_bounds(&self, min_ttl: u32, max_ttl: u32) {
debug_assert!(min_ttl <= max_ttl, "min_ttl must not exceed max_ttl");
self.ttl_bounds
.store(pack_bounds(min_ttl, max_ttl), Ordering::Relaxed);
}
pub async fn insert(
&self,
key: Question,
bytes: Bytes,
ttl_offsets: Vec<usize>,
expiry_ttl_secs: u32,
) {
let clamped = self.clamp_ttl(expiry_ttl_secs);
let entry = CachedResponse {
bytes,
ttl_offsets: Arc::from(ttl_offsets),
stored_at: Instant::now(),
expiry: Duration::from_secs(u64::from(clamped)),
};
self.inner.insert(key, entry).await;
}
pub async fn get(&self, key: &Question, client_txn_id: u16) -> Option<Bytes> {
let entry = self.inner.get(key).await?;
let elapsed_secs = entry
.stored_at
.elapsed()
.as_secs()
.try_into()
.unwrap_or(u32::MAX);
Some(entry.patched_for(client_txn_id, elapsed_secs))
}
pub async fn run_pending_tasks(&self) {
self.inner.run_pending_tasks().await;
}
pub fn entry_count(&self) -> u64 {
self.inner.entry_count()
}
pub fn clear(&self) {
self.inner.invalidate_all();
}
#[inline]
fn clamp_ttl(&self, secs: u32) -> u32 {
let (min_ttl, max_ttl) = unpack_bounds(self.ttl_bounds.load(Ordering::Relaxed));
secs.clamp(min_ttl, max_ttl)
}
}
#[inline]
fn pack_bounds(min_ttl: u32, max_ttl: u32) -> u64 {
(u64::from(min_ttl) << 32) | u64::from(max_ttl)
}
#[inline]
fn unpack_bounds(packed: u64) -> (u32, u32) {
((packed >> 32) as u32, packed as u32)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{
header::{Header, Rcode},
message::{Qclass, Qtype, Question},
name::Name,
reader::Reader,
ttl::TtlScan,
writer::Writer,
};
fn build_a_response(id: u16, name: &str, ttl: u32) -> Bytes {
let mut w = Writer::with_capacity(128);
Header::new(id)
.with_qr(true)
.with_rcode(Rcode::NoError)
.with_qdcount(1)
.with_ancount(1)
.write(&mut w);
let n: Name = name.parse().expect("valid name");
n.write(&mut w);
w.write_u16(1); w.write_u16(1);
n.write(&mut w);
w.write_u16(1); w.write_u16(1); w.write_u32(ttl);
w.write_u16(4); w.write_slice(&[127, 0, 0, 1]);
w.finish()
}
fn read_txn_id(bytes: &[u8]) -> u16 {
u16::from_be_bytes([bytes[0], bytes[1]])
}
fn read_u32_at(bytes: &[u8], offset: usize) -> u32 {
u32::from_be_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
])
}
fn make_question(name: &str) -> Question {
Question {
name: name.parse().expect("valid name"),
qtype: Qtype::A,
qclass: Qclass::In,
}
}
fn cached(bytes: Bytes, ttl_offsets: Vec<usize>) -> CachedResponse {
CachedResponse {
bytes,
ttl_offsets: Arc::from(ttl_offsets),
stored_at: Instant::now(),
expiry: Duration::from_secs(0),
}
}
#[test]
fn clamp_ttl_below_min_returns_min() {
let cache = DnsCache::new(100, 5, 3600);
assert_eq!(cache.clamp_ttl(0), 5);
assert_eq!(cache.clamp_ttl(4), 5);
}
#[test]
fn clamp_ttl_above_max_returns_max() {
let cache = DnsCache::new(100, 5, 3600);
assert_eq!(cache.clamp_ttl(7200), 3600);
assert_eq!(cache.clamp_ttl(u32::MAX), 3600);
}
#[test]
fn set_ttl_bounds_is_live() {
let cache = DnsCache::new(100, 5, 3600);
assert_eq!(cache.clamp_ttl(4), 5);
assert_eq!(cache.clamp_ttl(7200), 3600);
cache.set_ttl_bounds(60, 86400);
assert_eq!(cache.clamp_ttl(4), 60, "raised min applies live");
assert_eq!(cache.clamp_ttl(7200), 7200, "widened max applies live");
assert_eq!(cache.clamp_ttl(100_000), 86400);
}
#[test]
fn clamp_ttl_in_range_unchanged() {
let cache = DnsCache::new(100, 5, 3600);
assert_eq!(cache.clamp_ttl(300), 300);
assert_eq!(cache.clamp_ttl(5), 5);
assert_eq!(cache.clamp_ttl(3600), 3600);
}
#[test]
fn patch_response_txn_id_and_ttl_decrement() {
let original_ttl: u32 = 300;
let original_id: u16 = 0x1234;
let msg = build_a_response(original_id, "example.com", original_ttl);
let bytes_ref: Bytes = msg.clone();
let scan = TtlScan::scan(&bytes_ref).expect("scan must succeed");
assert_eq!(scan.ttl_offsets.len(), 1, "expected exactly one TTL offset");
let elapsed: u32 = 30;
let client_id: u16 = 0xBEEF;
let patched = cached(msg.clone(), scan.ttl_offsets.clone()).patched_for(client_id, elapsed);
assert_eq!(
read_txn_id(&patched),
client_id,
"patched txn-id must match client_id"
);
let expected_ttl = original_ttl.saturating_sub(elapsed);
let actual_ttl = read_u32_at(&patched, scan.ttl_offsets[0]);
assert_eq!(
actual_ttl, expected_ttl,
"TTL must be decremented by elapsed"
);
}
#[test]
fn patch_response_txn_id_only() {
let msg = build_a_response(0xAAAA, "a.example.com", 60);
let scan = TtlScan::scan(&msg.clone()).expect("scan");
let patched = cached(msg.clone(), scan.ttl_offsets.clone()).patched_for(0xBBBB, 0);
assert_eq!(read_txn_id(&patched), 0xBBBB);
assert_eq!(read_u32_at(&patched, scan.ttl_offsets[0]), 60);
}
#[test]
fn patch_response_ttl_saturates_at_zero() {
let ttl: u32 = 10;
let msg = build_a_response(0x0001, "sat.example.com", ttl);
let scan = TtlScan::scan(&msg.clone()).expect("scan");
let patched = cached(msg.clone(), scan.ttl_offsets.clone()).patched_for(0x0001, ttl + 50);
assert_eq!(
read_u32_at(&patched, scan.ttl_offsets[0]),
0,
"TTL must saturate at 0"
);
}
#[test]
fn patch_response_zero_ttl_stays_zero() {
let msg = build_a_response(0x0001, "zero.example.com", 0);
let scan = TtlScan::scan(&msg.clone()).expect("scan");
let patched = cached(msg.clone(), scan.ttl_offsets.clone()).patched_for(0x0001, 100);
assert_eq!(read_u32_at(&patched, scan.ttl_offsets[0]), 0);
}
#[test]
fn patch_response_out_of_bounds_offset_skipped() {
let msg = build_a_response(0x0001, "oob.example.com", 100);
let bad_offsets: Vec<usize> = vec![msg.len() - 1, msg.len(), msg.len() + 100];
let _ = cached(msg.clone(), bad_offsets).patched_for(0x0001, 10);
}
#[test]
fn patch_response_short_buffer_no_panic() {
let buf = &[0xAAu8]; let patched = cached(Bytes::copy_from_slice(buf), vec![]).patched_for(0xBEEF, 0);
assert_eq!(&patched[..], &[0xAAu8]);
}
#[test]
fn patch_response_empty_buffer_no_panic() {
let patched = cached(Bytes::new(), vec![]).patched_for(0x1234, 10);
assert_eq!(patched.len(), 0);
}
#[tokio::test]
async fn cache_hit_returns_patched_bytes() {
let cache = DnsCache::new(100, 1, 3600);
let question = make_question("hit.example.com");
let ttl: u32 = 300;
let msg = build_a_response(0xAAAA, "hit.example.com", ttl);
let bytes_ref: Bytes = msg.clone();
let scan = TtlScan::scan(&bytes_ref).expect("scan");
cache
.insert(question.clone(), msg, scan.ttl_offsets.clone(), ttl)
.await;
let client_id: u16 = 0xBEEF;
let result = cache.get(&question, client_id).await;
assert!(result.is_some(), "expected a cache hit");
let patched = result.unwrap();
assert_eq!(
read_txn_id(&patched),
client_id,
"transaction ID must be patched to client_id"
);
}
#[tokio::test]
async fn clear_invalidates_all_entries() {
let cache = DnsCache::new(100, 1, 3600);
let question = make_question("flush.example.com");
let msg = build_a_response(0xAAAA, "flush.example.com", 300);
let scan = TtlScan::scan(&msg.clone()).expect("scan");
cache
.insert(question.clone(), msg, scan.ttl_offsets, 300)
.await;
assert!(cache.get(&question, 0x1).await.is_some(), "entry cached");
cache.clear();
cache.run_pending_tasks().await;
assert!(
cache.get(&question, 0x1).await.is_none(),
"clear() must invalidate the entry"
);
}
#[tokio::test]
async fn cache_miss_returns_none() {
let cache = DnsCache::new(100, 1, 3600);
let question = make_question("miss.example.com");
let result = cache.get(&question, 0x1234).await;
assert!(result.is_none(), "expected a cache miss");
}
#[tokio::test]
async fn different_questions_do_not_collide() {
let cache = DnsCache::new(100, 1, 3600);
let q1 = make_question("a.example.com");
let q2 = make_question("b.example.com");
let msg1 = build_a_response(0x0001, "a.example.com", 100);
let msg2 = build_a_response(0x0002, "b.example.com", 200);
let bytes1: Bytes = msg1.clone();
let scan1 = TtlScan::scan(&bytes1).unwrap();
let bytes2: Bytes = msg2.clone();
let scan2 = TtlScan::scan(&bytes2).unwrap();
cache.insert(q1.clone(), msg1, scan1.ttl_offsets, 100).await;
cache.insert(q2.clone(), msg2, scan2.ttl_offsets, 200).await;
assert!(cache.get(&q1, 0xAAAA).await.is_some(), "q1 should hit");
assert!(cache.get(&q2, 0xBBBB).await.is_some(), "q2 should hit");
let q3 = make_question("c.example.com");
assert!(cache.get(&q3, 0xCCCC).await.is_none(), "q3 should miss");
}
#[tokio::test]
async fn cache_entry_expires_after_ttl() {
let cache = DnsCache::new(100, 1, 3600);
let question = make_question("expire.example.com");
let msg = build_a_response(0x0001, "expire.example.com", 1);
let bytes_ref: Bytes = msg.clone();
let scan = TtlScan::scan(&bytes_ref).expect("scan");
cache
.insert(question.clone(), msg, scan.ttl_offsets, 1)
.await;
assert!(
cache.get(&question, 0x0001).await.is_some(),
"entry should be present immediately after insert"
);
tokio::time::sleep(Duration::from_millis(1100)).await;
cache.run_pending_tasks().await;
assert!(
cache.get(&question, 0x0001).await.is_none(),
"entry should have expired after 1.1 s"
);
}
#[tokio::test]
async fn insert_clamps_ttl_to_min() {
let cache = DnsCache::new(100, 5, 3600);
let question = make_question("clamp.example.com");
let msg = build_a_response(0x0001, "clamp.example.com", 1);
let bytes_ref: Bytes = msg.clone();
let scan = TtlScan::scan(&bytes_ref).expect("scan");
cache
.insert(question.clone(), msg, scan.ttl_offsets, 0)
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
cache.run_pending_tasks().await;
assert!(
cache.get(&question, 0x0001).await.is_some(),
"entry clamped to min_ttl=5 should still be present after 100 ms"
);
}
#[test]
fn patched_bytes_are_re_parseable() {
let original_ttl: u32 = 120;
let msg = build_a_response(0xFFFF, "reparse.example.com", original_ttl);
let bytes_ref: Bytes = msg.clone();
let scan = TtlScan::scan(&bytes_ref).expect("scan");
let elapsed: u32 = 20;
let client_id: u16 = 0x4242;
let patched_bytes =
cached(msg.clone(), scan.ttl_offsets.clone()).patched_for(client_id, elapsed);
let patched_scan = TtlScan::scan(&patched_bytes).expect("patched message must scan");
assert_eq!(patched_scan.ttl_offsets.len(), 1);
let mut reader = Reader::new(patched_bytes.clone());
let header = Header::read(&mut reader).expect("header must parse");
assert_eq!(
header.id, client_id,
"patched header.id must match client_id"
);
let expected = original_ttl.saturating_sub(elapsed);
assert_eq!(
read_u32_at(&patched_bytes, scan.ttl_offsets[0]),
expected,
"TTL at offset must be decremented"
);
}
}