use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use tower::Service;
use crate::{
codec::{
header::Rcode,
synth::{EdnsInfo, Response},
ttl::TtlScan,
},
resolver::{
pipeline::{
BoxError, DnsRequest, Outcome, PipelineResponse,
cache_layer::{CacheDirective, ForwardOutput},
},
state::ResolverState,
upstream::SharedUpstreamPool,
},
};
#[derive(Clone)]
pub struct ForwardService {
pool: Arc<SharedUpstreamPool>,
state: Arc<ResolverState>,
}
impl ForwardService {
pub fn new(pool: Arc<SharedUpstreamPool>, state: Arc<ResolverState>) -> Self {
Self { pool, state }
}
}
impl Service<DnsRequest> for ForwardService {
type Response = ForwardOutput;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<ForwardOutput, BoxError>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
let pool = self.pool.clone();
let state = self.state.clone();
Box::pin(async move {
let question = req.question().clone();
let client_id = req.header().id;
let edns = EdnsInfo::scan(req.query());
match pool.forward(&question).await {
Ok(fr) => {
let settings = state.settings_full();
let scan = TtlScan::scan(&fr.bytes);
let directive = if let Ok(s) = scan.as_ref() {
let expiry = if fr.is_negative {
fr.negative_ttl.map(|t| t.min(settings.negative_ttl_cap))
} else {
s.min_ttl
};
match expiry {
Some(expiry) => CacheDirective::Store {
bytes: fr.bytes.clone(),
ttl_offsets: s.ttl_offsets.clone(),
expiry,
},
None => CacheDirective::Skip,
}
} else {
CacheDirective::Skip
};
let reply = fr.bytes.with_txn_id(client_id);
Ok(ForwardOutput::new(
PipelineResponse::new(reply, Outcome::Forwarded),
directive,
))
}
Err(e) => {
tracing::warn!(
qname = %question.name,
qtype = ?question.qtype,
error = %e,
"all upstreams failed; returning SERVFAIL"
);
let bytes =
Response::error_response(req.query(), Rcode::ServFail, edns.as_ref());
Ok(ForwardOutput::new(
PipelineResponse::new(bytes, Outcome::Servfail),
CacheDirective::Skip,
))
}
}
})
}
}
trait DnsMessageBytes {
fn with_txn_id(&self, id: u16) -> Bytes;
}
impl DnsMessageBytes for Bytes {
fn with_txn_id(&self, id: u16) -> Bytes {
if self.len() < 2 {
return self.clone();
}
let mut buf = BytesMut::from(&self[..]);
buf[0..2].copy_from_slice(&id.to_be_bytes());
buf.freeze()
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, sync::Arc, time::Duration};
use bytes::Bytes;
use hickory_net::proto::op::{Message, MessageType, ResponseCode};
use hickory_net::proto::rr::rdata::{A, SOA};
use hickory_net::proto::rr::{Name, RData, Record};
use tempfile::TempDir;
use tokio::net::UdpSocket;
use tokio::time::timeout;
use tokio_util::task::TaskTracker;
use tower::ServiceExt as _;
use super::*;
use crate::{
codec::{
header::Header, message::Query, name::Name as DnsName, reader::Reader, writer::Writer,
},
resolver::{
state::ResolverState,
upstream::{UpstreamConfig, UpstreamPool, UpstreamTransport},
},
storage::Db,
};
#[test]
fn with_txn_id_rewrites_first_two_bytes() {
let msg = Bytes::from_static(&[0x12, 0x34, 0xAA, 0xBB]);
let patched = msg.with_txn_id(0xBEEF);
assert_eq!(&patched[..], &[0xBE, 0xEF, 0xAA, 0xBB]);
}
#[test]
fn with_txn_id_short_buffer_is_unchanged() {
let one = Bytes::from_static(&[0xAA]);
assert_eq!(&one.with_txn_id(0xBEEF)[..], &[0xAA]);
assert_eq!(Bytes::new().with_txn_id(0x1234).len(), 0);
}
async fn open_temp_db() -> (TempDir, Db) {
let dir = TempDir::new().expect("temp dir");
let path = dir.path().join("test.db");
let db = Db::connect(&path).await.expect("connect");
(dir, db)
}
fn udp_config(addr: SocketAddr) -> UpstreamConfig {
UpstreamConfig {
addr,
transport: UpstreamTransport::Udp,
tls_server_name: None,
http_endpoint: None,
}
}
fn build_a_query(id: u16, name: &str) -> Bytes {
let mut w = Writer::with_capacity(64);
Header::new(id).with_qdcount(1).with_rd(true).write(&mut w);
let n: DnsName = name.parse().expect("valid name");
n.write(&mut w);
w.write_u16(1u16); w.write_u16(1u16); w.finish()
}
fn make_request(raw: Bytes) -> DnsRequest {
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let query = Query::try_from(raw).expect("valid query");
DnsRequest::new(query, client)
}
async fn spawn_mock_udp<F>(mut handler: F) -> SocketAddr
where
F: FnMut(Message) -> Option<Message> + Send + 'static,
{
let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = sock.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = vec![0u8; 512];
loop {
let Ok((len, peer)) = sock.recv_from(&mut buf).await else {
break;
};
let Ok(req) = Message::from_vec(&buf[..len]) else {
continue;
};
if let Some(resp) = handler(req)
&& let Ok(resp_bytes) = resp.to_vec()
{
let _ = sock.send_to(&resp_bytes, peer).await;
}
}
});
addr
}
fn parse_header(bytes: &Bytes) -> Header {
let mut r = Reader::new(bytes.clone());
Header::read(&mut r).expect("valid DNS header")
}
#[tokio::test]
async fn forward_returns_reply_with_client_id() {
let client_query_id: u16 = 0xBEEF;
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NoError;
let name = Name::from_ascii("example.com.").unwrap();
let rdata = RData::A(A::new(93, 184, 216, 34));
resp.add_answer(Record::from_rdata(name, 300, rdata));
Some(resp)
})
.await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = build_a_query(client_query_id, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(
out.reply.outcome,
Outcome::Forwarded,
"outcome must be Forwarded"
);
let hdr = parse_header(&out.reply.bytes);
assert_eq!(
hdr.id, client_query_id,
"reply txn-id must be patched to the client's query id"
);
}
#[tokio::test]
async fn positive_answer_directive_stores_min_ttl() {
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NoError;
let name = Name::from_ascii("example.com.").unwrap();
let rdata = RData::A(A::new(93, 184, 216, 34));
resp.add_answer(Record::from_rdata(name, 300, rdata));
Some(resp)
})
.await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = build_a_query(0x1234, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(out.reply.outcome, Outcome::Forwarded);
assert!(
matches!(out.directive, CacheDirective::Store { expiry: 300, .. }),
"positive answer → Store with min TTL 300, got {:?}",
out.directive
);
}
#[tokio::test]
async fn negative_with_soa_directive_stores() {
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NXDomain;
let zone = Name::from_ascii("example.com.").unwrap();
let mname = Name::from_ascii("ns1.example.com.").unwrap();
let rname = Name::from_ascii("hostmaster.example.com.").unwrap();
let soa = SOA::new(mname, rname, 1, 3600, 900, 604800, 60);
resp.add_authority(Record::from_rdata(zone, 120, RData::SOA(soa)));
Some(resp)
})
.await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = build_a_query(0x5678, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(
out.reply.outcome,
Outcome::Forwarded,
"NXDOMAIN with SOA must still return Forwarded"
);
assert!(
matches!(out.directive, CacheDirective::Store { expiry: 60, .. }),
"NXDOMAIN with SOA → Store with negative TTL 60, got {:?}",
out.directive
);
}
#[tokio::test]
async fn negative_without_soa_directive_skips() {
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NXDomain;
Some(resp)
})
.await;
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[udp_config(addr)],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let raw = build_a_query(0x9ABC, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert!(
matches!(out.directive, CacheDirective::Skip),
"NXDOMAIN without SOA → Skip (not cacheable), got {:?}",
out.directive
);
}
#[tokio::test]
async fn all_upstreams_fail_returns_servfail() {
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[],
&tracker,
Arc::new(crate::resolver::upstream::RandomSelector),
0,
Duration::from_millis(500),
)
.await;
let pool = Arc::new(SharedUpstreamPool::new(pool));
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let svc = ForwardService::new(pool, state);
let client_id: u16 = 0xDEAD;
let raw = build_a_query(client_id, "example.com");
let req = make_request(raw);
let out = timeout(Duration::from_secs(5), svc.oneshot(req))
.await
.expect("safety timeout")
.expect("service must not error");
assert_eq!(
out.reply.outcome,
Outcome::Servfail,
"outcome must be Servfail"
);
assert!(
matches!(out.directive, CacheDirective::Skip),
"SERVFAIL must not be cached"
);
let hdr = parse_header(&out.reply.bytes);
assert_eq!(
hdr.id, client_id,
"SERVFAIL reply must echo the client's query id"
);
assert_eq!(
hdr.rcode(),
crate::codec::header::Rcode::ServFail,
"RCODE must be ServFail"
);
}
}