use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use bytes::Bytes;
use tower::{Layer, Service};
use crate::resolver::{
pipeline::{BoxError, DnsRequest, Outcome, PipelineResponse},
state::ResolverState,
};
#[derive(Debug, Clone)]
pub enum CacheDirective {
Store {
bytes: Bytes,
ttl_offsets: Vec<usize>,
expiry: u32,
},
Skip,
}
#[derive(Debug, Clone)]
pub struct ForwardOutput {
pub reply: PipelineResponse,
pub directive: CacheDirective,
}
impl ForwardOutput {
pub fn new(reply: PipelineResponse, directive: CacheDirective) -> Self {
Self { reply, directive }
}
}
pub struct CacheLayer {
state: Arc<ResolverState>,
}
impl CacheLayer {
pub fn new(state: Arc<ResolverState>) -> Self {
Self { state }
}
}
impl<S> Layer<S> for CacheLayer {
type Service = CacheService<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheService::new(self.state.clone(), inner)
}
}
#[derive(Clone)]
pub struct CacheService<S> {
state: Arc<ResolverState>,
inner: S,
}
impl<S> CacheService<S> {
pub fn new(state: Arc<ResolverState>, inner: S) -> Self {
Self { state, inner }
}
}
impl<S> Service<DnsRequest> for CacheService<S>
where
S: Service<DnsRequest, Response = ForwardOutput, Error = BoxError> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = PipelineResponse;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<PipelineResponse, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
let question = req.question().clone();
let client_id = req.header().id;
let state = self.state.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
if let Some(bytes) = state.cache().get(&question, client_id).await {
return Ok(PipelineResponse::new(bytes, Outcome::Cached));
}
let out = inner.call(req).await?;
if let CacheDirective::Store {
bytes,
ttl_offsets,
expiry,
} = out.directive
{
state
.cache()
.insert(question, bytes, ttl_offsets, expiry)
.await;
}
Ok(out.reply)
})
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, time::Duration};
use bytes::Bytes;
use tempfile::TempDir;
use tokio::time::timeout;
use tower::ServiceExt as _;
use super::*;
use crate::{
codec::message::{Qclass, Qtype, Query, Question},
resolver::state::ResolverState,
};
async fn hydrate_state() -> (TempDir, Arc<ResolverState>) {
let (dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
(dir, state)
}
fn request(id: u16, name: &str) -> DnsRequest {
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let query = Query::try_from(crate::test_support::a_query(id, name)).expect("valid query");
DnsRequest::new(query, client)
}
fn question(name: &str) -> Question {
Question {
name: name.parse().unwrap(),
qtype: Qtype::A,
qclass: Qclass::In,
}
}
fn txn_id(bytes: &Bytes) -> u16 {
u16::from_be_bytes([bytes[0], bytes[1]])
}
#[tokio::test]
async fn hit_serves_cached_without_calling_inner() {
let (_dir, state) = hydrate_state().await;
let stored = crate::test_support::a_query(0xAAAA, "example.com");
state
.cache()
.insert(question("example.com"), stored, vec![], 300)
.await;
let inner = tower::service_fn(|_req: DnsRequest| async move {
Ok::<_, BoxError>(ForwardOutput::new(
PipelineResponse::new(Bytes::from_static(b"unused"), Outcome::Forwarded),
CacheDirective::Skip,
))
});
let svc = CacheService::new(state, inner);
let client_id = 0xBEEF;
let resp = timeout(
Duration::from_secs(5),
svc.oneshot(request(client_id, "example.com")),
)
.await
.expect("safety timeout")
.expect("service ok");
assert_eq!(resp.outcome, Outcome::Cached, "hit must serve from cache");
assert_eq!(
txn_id(&resp.bytes),
client_id,
"cache must patch the txn id to the requesting client"
);
}
#[tokio::test]
async fn miss_with_store_directive_inserts() {
let (_dir, state) = hydrate_state().await;
let inner = tower::service_fn(|req: DnsRequest| async move {
let bytes = req.raw().clone();
Ok::<_, BoxError>(ForwardOutput::new(
PipelineResponse::new(bytes.clone(), Outcome::Forwarded),
CacheDirective::Store {
bytes,
ttl_offsets: vec![],
expiry: 300,
},
))
});
let svc = CacheService::new(state.clone(), inner);
let resp = timeout(
Duration::from_secs(5),
svc.oneshot(request(0x1234, "store.example")),
)
.await
.expect("safety timeout")
.expect("service ok");
assert_eq!(resp.outcome, Outcome::Forwarded);
assert!(
state
.cache()
.get(&question("store.example"), 0x9999)
.await
.is_some(),
"Store directive must insert into the cache"
);
}
#[tokio::test]
async fn miss_with_skip_directive_does_not_insert() {
let (_dir, state) = hydrate_state().await;
let inner = tower::service_fn(|req: DnsRequest| async move {
Ok::<_, BoxError>(ForwardOutput::new(
PipelineResponse::new(req.raw().clone(), Outcome::Servfail),
CacheDirective::Skip,
))
});
let svc = CacheService::new(state.clone(), inner);
let resp = timeout(
Duration::from_secs(5),
svc.oneshot(request(0x4321, "skip.example")),
)
.await
.expect("safety timeout")
.expect("service ok");
assert_eq!(resp.outcome, Outcome::Servfail);
assert!(
state
.cache()
.get(&question("skip.example"), 0x9999)
.await
.is_none(),
"Skip directive must not insert into the cache"
);
}
}