use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use tower::Service;
use crate::{
codec::{header::Rcode, synth::Response, ttl::TtlScan},
resolver::{
pipeline::{
BoxError, DnsRequest, Outcome, PipelineResponse,
cache_layer::{CacheDirective, ForwardOutput},
},
state::ResolverState,
upstream::SharedUpstreamPool,
},
};
#[derive(Clone)]
pub struct ForwardService {
pool: Arc<SharedUpstreamPool>,
state: Arc<ResolverState>,
}
impl ForwardService {
pub fn new(pool: Arc<SharedUpstreamPool>, state: Arc<ResolverState>) -> Self {
Self { pool, state }
}
}
impl Service<DnsRequest> for ForwardService {
type Response = ForwardOutput;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<ForwardOutput, BoxError>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
let pool = self.pool.clone();
let state = self.state.clone();
Box::pin(async move {
let question = req.question().clone();
let client_id = req.header().id;
let forward_result = match req.forward_target() {
Some(target) => state.forward_zones().forward(target, &question).await,
None => pool.forward(&question).await,
};
match forward_result {
Ok(fr) => {
let settings = state.settings_full();
let scan = TtlScan::scan(&fr.bytes);
let directive = if let Ok(s) = scan.as_ref() {
let expiry = if fr.is_negative {
fr.negative_ttl.map(|t| t.min(settings.negative_ttl_cap))
} else {
s.min_ttl
};
match expiry {
Some(expiry) => CacheDirective::Store {
bytes: fr.bytes.clone(),
ttl_offsets: s.ttl_offsets.clone(),
expiry,
},
None => CacheDirective::Skip,
}
} else {
CacheDirective::Skip
};
let reply = fr.bytes.with_txn_id(client_id);
Ok(ForwardOutput::new(
PipelineResponse::new(reply, Outcome::Forwarded).with_upstream(fr.upstream),
directive,
))
}
Err(e) => {
tracing::warn!(
qname = %question.name,
qtype = %question.qtype,
error = %e,
"all upstreams failed; returning SERVFAIL"
);
let bytes = Response::error_response(req.query(), Rcode::ServFail, req.edns());
Ok(ForwardOutput::new(
PipelineResponse::new(bytes, Outcome::Servfail),
CacheDirective::Skip,
))
}
}
})
}
}
trait DnsMessageBytes {
fn with_txn_id(&self, id: u16) -> Bytes;
}
impl DnsMessageBytes for Bytes {
fn with_txn_id(&self, id: u16) -> Bytes {
if self.len() < 2 {
return self.clone();
}
let mut buf = BytesMut::from(&self[..]);
buf[0..2].copy_from_slice(&id.to_be_bytes());
buf.freeze()
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, sync::Arc, time::Duration};
use bytes::Bytes;
use tokio::time::timeout;
use tokio_util::task::TaskTracker;
use tower::ServiceExt as _;
use super::*;
use crate::{
codec::{header::Header, message::Query, reader::Reader},
resolver::{
state::ResolverState,
upstream::{UpstreamConfig, UpstreamPool, UpstreamTransport},
},
test_support::{
a_query, mock_udp_upstream, nxdomain_handler, nxdomain_with_soa_handler,
positive_a_handler,
},
};
#[test]
fn with_txn_id_rewrites_first_two_bytes() {
let msg = Bytes::from_static(&[0x12, 0x34, 0xAA, 0xBB]);
let patched = msg.with_txn_id(0xBEEF);
assert_eq!(&patched[..], &[0xBE, 0xEF, 0xAA, 0xBB]);
}
#[test]
fn with_txn_id_short_buffer_is_unchanged() {
let one = Bytes::from_static(&[0xAA]);
assert_eq!(&one.with_txn_id(0xBEEF)[..], &[0xAA]);
assert_eq!(Bytes::new().with_txn_id(0x1234).len(), 0);
}
fn udp_config(addr: SocketAddr) -> UpstreamConfig {
UpstreamConfig {
addr,
transport: UpstreamTransport::Udp,
tls_server_name: None,
http_endpoint: None,
}
}
fn make_request(raw: Bytes) -> DnsRequest {
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let query = Query::try_from(raw).expect("valid query");
DnsRequest::new(query, client)
}
fn parse_header(bytes: &Bytes) -> Header {
let mut r = Reader::new(bytes.clone());
Header::read(&mut r).expect("valid DNS header")
}
#[tokio::test]
async fn forward_returns_reply_with_client_id() {
let client_query_id: u16 = 0xBEEF;
let addr = mock_udp_upstream(positive_a_handler).await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = a_query(client_query_id, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(
out.reply.outcome,
Outcome::Forwarded,
"outcome must be Forwarded"
);
assert_eq!(
out.reply.upstream,
Some(addr),
"forwarded reply must attribute the answering upstream"
);
let hdr = parse_header(&out.reply.bytes);
assert_eq!(
hdr.id, client_query_id,
"reply txn-id must be patched to the client's query id"
);
}
#[tokio::test]
async fn positive_answer_directive_stores_min_ttl() {
let addr = mock_udp_upstream(positive_a_handler).await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = a_query(0x1234, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(out.reply.outcome, Outcome::Forwarded);
assert!(
matches!(out.directive, CacheDirective::Store { expiry: 300, .. }),
"positive answer → Store with min TTL 300, got {:?}",
out.directive
);
}
#[tokio::test]
async fn negative_with_soa_directive_stores() {
let addr = mock_udp_upstream(nxdomain_with_soa_handler).await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = a_query(0x5678, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(
out.reply.outcome,
Outcome::Forwarded,
"NXDOMAIN with SOA must still return Forwarded"
);
assert!(
matches!(out.directive, CacheDirective::Store { expiry: 60, .. }),
"NXDOMAIN with SOA → Store with negative TTL 60, got {:?}",
out.directive
);
}
#[tokio::test]
async fn negative_without_soa_directive_skips() {
let addr = mock_udp_upstream(nxdomain_handler).await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = a_query(0x9ABC, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert!(
matches!(out.directive, CacheDirective::Skip),
"NXDOMAIN without SOA → Skip (not cacheable), got {:?}",
out.directive
);
}
#[tokio::test]
async fn all_upstreams_fail_returns_servfail() {
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let client_id: u16 = 0xDEAD;
let raw = a_query(client_id, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(
out.reply.outcome,
Outcome::Servfail,
"outcome must be Servfail"
);
assert!(
matches!(out.directive, CacheDirective::Skip),
"SERVFAIL must not be cached"
);
let hdr = parse_header(&out.reply.bytes);
assert_eq!(
hdr.id, client_id,
"SERVFAIL reply must echo the client's query id"
);
assert_eq!(
hdr.rcode(),
crate::codec::header::Rcode::ServFail,
"RCODE must be ServFail"
);
}
}