use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{Layer, Service};
use crate::{
codec::message::Qtype,
codec::synth::{LocalRecord, Response},
resolver::{
local::{LocalMatch, RecordData},
pipeline::{BoxError, DnsRequest, Outcome, PipelineResponse},
state::ResolverState,
},
};
pub const BLOCK_TTL_SECS: u32 = 60;
#[derive(Clone)]
pub struct DecisionStack<S> {
state: Arc<ResolverState>,
inner: S,
}
impl<S> DecisionStack<S> {
pub fn new(state: Arc<ResolverState>, inner: S) -> Self {
Self { state, inner }
}
}
impl<S> Service<DnsRequest> for DecisionStack<S>
where
S: Service<DnsRequest, Response = PipelineResponse, Error = BoxError> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = PipelineResponse;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<PipelineResponse, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: DnsRequest) -> Self::Future {
let state = self.state.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
let name = req.question().name.clone();
let qtype = req.question().qtype;
let edns = req.edns().cloned();
let block_mode = state.settings().block_mode.clone();
if qtype == Qtype::Ptr
&& let Some(ip) = name.reverse_addr()
&& let Some((target, ttl)) = state.local().reverse_lookup(ip)
{
let bytes = Response::local_ptr(req.query(), &target, ttl, edns.as_ref());
return Ok(PipelineResponse::new(bytes, Outcome::Local));
}
match state.local().lookup(&name, qtype) {
LocalMatch::Answer { data, ttl } => {
let bytes = match data {
RecordData::A(addr) => {
let octets = addr.octets();
let record = LocalRecord {
rtype: 1,
rdata: &octets,
};
Response::local(req.query(), &[record], ttl, edns.as_ref())
}
RecordData::Aaaa(addr) => {
let octets = addr.octets();
let record = LocalRecord {
rtype: 28,
rdata: &octets,
};
Response::local(req.query(), &[record], ttl, edns.as_ref())
}
};
return Ok(PipelineResponse::new(bytes, Outcome::Local));
}
LocalMatch::NameExistsNoData => {
let bytes = Response::local_nodata(req.query(), edns.as_ref());
return Ok(PipelineResponse::new(bytes, Outcome::LocalNoData));
}
LocalMatch::Miss => {} }
if let Some(target) = state.forward_zones().match_target(&name) {
req.set_forward_target(target);
return inner.call(req).await;
}
if state.blocking_paused() {
return inner.call(req).await;
}
if state.blacklist().contains(&name) {
let bytes =
Response::block(req.query(), &block_mode, BLOCK_TTL_SECS, edns.as_ref());
return Ok(PipelineResponse::new(bytes, Outcome::BlockedByAdmin));
}
let mut bypass = false;
if state.allowlist().contains(&name) {
bypass = true;
req.set_allow_bypass(true);
}
if !bypass && state.blocklist().contains(&name) {
let bytes =
Response::block(req.query(), &block_mode, BLOCK_TTL_SECS, edns.as_ref());
return Ok(PipelineResponse::new(bytes, Outcome::BlockedByBlocklist));
}
inner.call(req).await
})
}
}
pub struct DecisionLayer {
state: Arc<ResolverState>,
}
impl DecisionLayer {
pub fn new(state: Arc<ResolverState>) -> Self {
Self { state }
}
}
impl<S> Layer<S> for DecisionLayer {
type Service = DecisionStack<S>;
fn layer(&self, inner: S) -> Self::Service {
DecisionStack::new(self.state.clone(), inner)
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddr};
use bytes::Bytes;
use tower::ServiceExt as _;
use super::*;
use crate::test_support::{
a_query, aaaa_query, mock_udp_upstream, positive_a_handler, ptr_query,
};
use crate::{
codec::{
header::{Header, Rcode},
message::Query,
name::Name,
reader::Reader,
},
resolver::{
forward_zone::ForwardZoneSet,
local::{LocalRecords, RecordData as LRecordData},
pipeline::{
BoxError, DnsRequest, Outcome, PipelineResponse, cache_layer::CacheService,
forward::ForwardService,
},
state::{ResolverState, RuntimeSettings},
upstream::{RandomSelector, SharedUpstreamPool, UpstreamPool},
},
storage::forward_zones::ForwardZone,
};
use std::time::Duration;
use tokio_util::task::TaskTracker;
fn name(s: &str) -> Name {
s.parse().expect("valid domain name")
}
fn make_request(raw: Bytes) -> DnsRequest {
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let query = Query::try_from(raw).expect("valid query");
DnsRequest::new(query, client)
}
fn stub_fn(req: DnsRequest) -> std::future::Ready<Result<PipelineResponse, BoxError>> {
std::future::ready(Ok(PipelineResponse::new(
req.raw().clone(),
Outcome::Forwarded,
)))
}
fn parse_header(bytes: &Bytes) -> Header {
let mut r = Reader::new(bytes.clone());
Header::read(&mut r).expect("valid DNS header")
}
#[tokio::test]
async fn blacklisted_and_allowlisted_still_blocked_by_admin() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("evil.example.com");
state
.blacklist()
.store([target.clone()].into_iter().collect());
state
.allowlist()
.store([target.clone()].into_iter().collect());
let raw = a_query(0x0001, "evil.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::BlockedByAdmin,
"admin blacklist must win over allowlist"
);
}
#[tokio::test]
async fn allowlisted_and_blocklisted_forwards() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("safe.example.com");
state
.allowlist()
.store([target.clone()].into_iter().collect());
state
.blocklist()
.store([(target.clone(), 1)].into_iter().collect());
let raw = a_query(0x0002, "safe.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"allowlist must bypass blocklist → stub returns Forwarded"
);
}
#[tokio::test]
async fn local_record_wins_over_all_blocking() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("router.home.lan");
state
.blacklist()
.store([target.clone()].into_iter().collect());
state
.blocklist()
.store([(target.clone(), 1)].into_iter().collect());
let mut b = LocalRecords::builder();
b.add(
"router.home.lan",
LRecordData::A("192.168.1.1".parse().unwrap()),
300,
)
.unwrap();
state.local().store(b.build());
let raw = a_query(0x0003, "router.home.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Local,
"local record must win over all blocking"
);
}
#[tokio::test]
async fn local_name_exists_but_qtype_absent_returns_nodata() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let mut b = LocalRecords::builder();
b.add("host.lan", LRecordData::A("10.0.0.1".parse().unwrap()), 60)
.unwrap();
state.local().store(b.build());
let raw = aaaa_query(0x0004, "host.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::LocalNoData,
"AAAA query for A-only local name must return LocalNoData"
);
}
#[tokio::test]
async fn plain_blocklist_hit_returns_blocked_by_blocklist() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("tracker.bad.example");
state
.blocklist()
.store([(target.clone(), 1)].into_iter().collect());
let raw = a_query(0x0005, "tracker.bad.example");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::BlockedByBlocklist,
"plain blocklist hit must return BlockedByBlocklist"
);
}
#[tokio::test]
async fn plain_non_match_falls_through_to_forwarded() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let raw = a_query(0x0006, "nobody.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"plain miss must fall through to Forwarded"
);
}
#[tokio::test]
async fn blocklist_null_ip_response_is_well_formed() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let settings_guard = state.settings();
assert_eq!(
settings_guard.block_mode,
crate::codec::synth::BlockMode::null_ip(),
"seeded default must be null-ip"
);
drop(settings_guard);
let target = name("blocked.example");
state
.blocklist()
.store([(target.clone(), 1)].into_iter().collect());
let query_id: u16 = 0x1234;
let raw = a_query(query_id, "blocked.example");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::BlockedByBlocklist);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id, "response id must match query id");
assert!(hdr.qr(), "QR must be set");
assert_eq!(hdr.rcode(), Rcode::NoError, "null-ip A → NOERROR");
assert_eq!(hdr.ancount, 1, "null-ip A → one answer RR");
}
#[tokio::test]
async fn admin_blacklist_nxdomain_response_is_well_formed() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let new_settings = RuntimeSettings {
block_mode: crate::codec::synth::BlockMode::NxDomain,
..(*state.settings_full()).clone()
};
state.store_settings(new_settings);
let target = name("evil.example");
state
.blacklist()
.store([target.clone()].into_iter().collect());
let query_id: u16 = 0x5678;
let raw = a_query(query_id, "evil.example");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::BlockedByAdmin);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id, "response id must match query id");
assert_eq!(hdr.rcode(), Rcode::NxDomain, "NxDomain mode → NXDOMAIN");
assert_eq!(hdr.ancount, 0, "NXDOMAIN → no answer RRs");
}
#[tokio::test]
async fn paused_blacklisted_name_falls_through() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("evil.example.com");
state
.blacklist()
.store([target.clone()].into_iter().collect());
state.pause_for_secs(300);
let raw = a_query(0x0101, "evil.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"paused blocking must let a blacklisted name through"
);
}
#[tokio::test]
async fn paused_blocklisted_name_falls_through() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("tracker.bad.example");
state
.blocklist()
.store([(target.clone(), 1)].into_iter().collect());
state.pause_for_secs(300);
let raw = a_query(0x0102, "tracker.bad.example");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"paused blocking must let a blocklisted name through"
);
}
#[tokio::test]
async fn paused_local_record_still_answers() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let mut b = LocalRecords::builder();
b.add(
"router.home.lan",
LRecordData::A("192.168.1.1".parse().unwrap()),
300,
)
.unwrap();
state.local().store(b.build());
state.pause_for_secs(300);
let raw = a_query(0x0103, "router.home.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Local,
"local records must answer even while blocking is paused"
);
}
#[tokio::test]
async fn resumed_blocking_blocks_again() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("ads.example.com");
state
.blacklist()
.store([target.clone()].into_iter().collect());
state.pause_for_secs(300);
state.resume();
let raw = a_query(0x0104, "ads.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::BlockedByAdmin,
"resuming must restore blocking immediately"
);
}
#[tokio::test]
async fn decision_layer_wraps_correctly() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let layer = DecisionLayer::new(state);
let svc = layer.layer(tower::service_fn(stub_fn));
let raw = a_query(0x9999, "via-layer.example.com");
let req = make_request(raw);
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::Forwarded);
}
#[tokio::test]
async fn local_a_record_response_is_authoritative() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let ip: Ipv4Addr = "192.168.1.42".parse().unwrap();
let mut b = LocalRecords::builder();
b.add("myhost.lan", LRecordData::A(ip), 120).unwrap();
state.local().store(b.build());
let query_id: u16 = 0xABCD;
let raw = a_query(query_id, "myhost.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::Local);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id);
assert!(hdr.aa(), "local record must set AA=1");
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1);
}
#[tokio::test]
async fn ptr_for_local_record_is_authoritative() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let mut b = LocalRecords::builder();
b.add(
"router.home.lan",
LRecordData::A("192.168.1.1".parse().unwrap()),
300,
)
.unwrap();
state.local().store(b.build());
let query_id: u16 = 0x0ABC;
let raw = ptr_query(query_id, "1.1.168.192.in-addr.arpa");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::Local, "PTR for owned IP → Local");
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id);
assert!(hdr.aa(), "PTR answer must be authoritative");
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1, "one PTR answer RR");
}
#[tokio::test]
async fn ptr_for_unknown_ip_falls_through() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let raw = ptr_query(0x0DEF, "5.1.168.192.in-addr.arpa");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"PTR for an unknown IP must fall through, not be answered here"
);
}
#[tokio::test]
async fn local_nodata_response_is_authoritative_nodata() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let mut b = LocalRecords::builder();
b.add(
"nodata.lan",
LRecordData::A("10.0.0.2".parse().unwrap()),
60,
)
.unwrap();
state.local().store(b.build());
let query_id: u16 = 0xDEAD;
let raw = aaaa_query(query_id, "nodata.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::LocalNoData);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id);
assert!(hdr.aa(), "NODATA response must be authoritative");
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 0, "NODATA must have no answer RRs");
}
async fn empty_pool() -> Arc<SharedUpstreamPool> {
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[],
&tracker,
Arc::new(RandomSelector),
0,
Duration::from_millis(500),
)
.await;
Arc::new(SharedUpstreamPool::new(pool))
}
async fn install_zones(state: &Arc<ResolverState>, zones: &[(&str, std::net::SocketAddr)]) {
let tracker = TaskTracker::new();
let rows: Vec<ForwardZone> = zones
.iter()
.enumerate()
.map(|(i, (suffix, target))| ForwardZone {
id: i as i64 + 1,
zone_suffix: (*suffix).to_owned(),
target: Some(target.to_string()),
enabled: true,
sort_order: i as i64,
})
.collect();
let set = ForwardZoneSet::build(&rows, &tracker).await;
state.store_forward_zones(set);
}
fn stack_with_real_inner(
state: Arc<ResolverState>,
pool: Arc<SharedUpstreamPool>,
) -> DecisionStack<CacheService<ForwardService>> {
let forward = ForwardService::new(pool, state.clone());
let cached = CacheService::new(state.clone(), forward);
DecisionStack::new(state, cached)
}
#[tokio::test]
async fn ptr_under_enabled_zone_forwards_and_caches() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let zone_addr = mock_udp_upstream(positive_a_handler).await;
install_zones(&state, &[("168.192.in-addr.arpa", zone_addr)]).await;
let stack = stack_with_real_inner(state, empty_pool().await);
let raw = ptr_query(0x0001, "1.1.168.192.in-addr.arpa");
let resp = stack.clone().oneshot(make_request(raw)).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"zone-matched query must be forwarded, not SERVFAIL"
);
assert_eq!(
resp.upstream,
Some(zone_addr),
"must forward to the zone's target resolver"
);
let raw2 = ptr_query(0x0002, "1.1.168.192.in-addr.arpa");
let resp2 = stack.oneshot(make_request(raw2)).await.unwrap();
assert_eq!(
resp2.outcome,
Outcome::Cached,
"second identical zone query must be served from cache"
);
}
#[tokio::test]
async fn non_zone_query_falls_through() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let zone_addr = mock_udp_upstream(positive_a_handler).await;
install_zones(&state, &[("168.192.in-addr.arpa", zone_addr)]).await;
let stack = stack_with_real_inner(state, empty_pool().await);
let raw = a_query(0x0003, "example.com");
let resp = stack.oneshot(make_request(raw)).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Servfail,
"a non-matching name must take the normal upstream path"
);
}
#[tokio::test]
async fn no_enabled_zones_ignores_reverse_query() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let stack = stack_with_real_inner(state, empty_pool().await);
let raw = ptr_query(0x0004, "1.1.168.192.in-addr.arpa");
let resp = stack.oneshot(make_request(raw)).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Servfail,
"with no enabled zones the reverse query must not be zone-routed"
);
}
#[tokio::test]
async fn most_specific_zone_wins_in_pipeline() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let general = mock_udp_upstream(positive_a_handler).await;
let specific = mock_udp_upstream(positive_a_handler).await;
install_zones(
&state,
&[
("10.in-addr.arpa", general),
("0.10.in-addr.arpa", specific),
],
)
.await;
let stack = stack_with_real_inner(state, empty_pool().await);
let raw = ptr_query(0x0005, "5.1.0.10.in-addr.arpa");
let resp = stack.oneshot(make_request(raw)).await.unwrap();
assert_eq!(resp.outcome, Outcome::Forwarded);
assert_eq!(
resp.upstream,
Some(specific),
"the most-specific zone's target must answer"
);
}
}