use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{Layer, Service};
use crate::{
codec::synth::{EdnsInfo, LocalRecord, Response},
resolver::{
local::{LocalMatch, RecordData},
pipeline::{BoxError, DnsRequest, Outcome, PipelineResponse},
state::ResolverState,
},
};
pub const BLOCK_TTL_SECS: u32 = 60;
#[derive(Clone)]
pub struct DecisionStack<S> {
state: Arc<ResolverState>,
inner: S,
}
impl<S> DecisionStack<S> {
pub fn new(state: Arc<ResolverState>, inner: S) -> Self {
Self { state, inner }
}
}
impl<S> Service<DnsRequest> for DecisionStack<S>
where
S: Service<DnsRequest, Response = PipelineResponse, Error = BoxError> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = PipelineResponse;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<PipelineResponse, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: DnsRequest) -> Self::Future {
let state = self.state.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
let name = &req.question().name.clone();
let qtype = req.question().qtype;
let edns = EdnsInfo::scan(req.query());
let block_mode = state.settings().block_mode.clone();
match state.local().lookup(name, qtype) {
LocalMatch::Answer { data, ttl } => {
let bytes = match data {
RecordData::A(addr) => {
let octets = addr.octets();
let record = LocalRecord {
rtype: 1,
rdata: &octets,
};
Response::local(req.query(), &[record], ttl, edns.as_ref())
}
RecordData::Aaaa(addr) => {
let octets = addr.octets();
let record = LocalRecord {
rtype: 28,
rdata: &octets,
};
Response::local(req.query(), &[record], ttl, edns.as_ref())
}
};
return Ok(PipelineResponse::new(bytes, Outcome::Local));
}
LocalMatch::NameExistsNoData => {
let bytes = Response::local_nodata(req.query(), edns.as_ref());
return Ok(PipelineResponse::new(bytes, Outcome::LocalNoData));
}
LocalMatch::Miss => {} }
if state.blacklist().contains(name) {
let bytes =
Response::block(req.query(), &block_mode, BLOCK_TTL_SECS, edns.as_ref());
return Ok(PipelineResponse::new(bytes, Outcome::BlockedByAdmin));
}
let mut bypass = false;
if state.allowlist().contains(name) {
bypass = true;
req.set_allow_bypass(true);
}
if !bypass && state.blocklist().contains(name) {
let bytes =
Response::block(req.query(), &block_mode, BLOCK_TTL_SECS, edns.as_ref());
return Ok(PipelineResponse::new(bytes, Outcome::BlockedByBlocklist));
}
inner.call(req).await
})
}
}
pub struct DecisionLayer {
state: Arc<ResolverState>,
}
impl DecisionLayer {
pub fn new(state: Arc<ResolverState>) -> Self {
Self { state }
}
}
impl<S> Layer<S> for DecisionLayer {
type Service = DecisionStack<S>;
fn layer(&self, inner: S) -> Self::Service {
DecisionStack::new(self.state.clone(), inner)
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddr};
use bytes::Bytes;
use tempfile::TempDir;
use tower::ServiceExt as _;
use super::*;
use crate::{
codec::{
header::{Header, Rcode},
message::Query,
name::Name,
reader::Reader,
writer::Writer,
},
resolver::{
local::{LocalRecords, RecordData as LRecordData},
pipeline::{BoxError, DnsRequest, Outcome, PipelineResponse},
state::{ResolverState, RuntimeSettings},
},
storage::Db,
};
async fn open_temp_db() -> (TempDir, Db) {
let dir = TempDir::new().expect("temp dir");
let path = dir.path().join("test.db");
let db = Db::connect(&path).await.expect("connect");
(dir, db)
}
fn name(s: &str) -> Name {
s.parse().expect("valid domain name")
}
fn build_a_query(id: u16, domain: &str) -> Bytes {
let mut w = Writer::with_capacity(64);
Header::new(id).with_qdcount(1).with_rd(true).write(&mut w);
let n: Name = domain.parse().expect("valid name");
n.write(&mut w);
w.write_u16(1u16); w.write_u16(1u16); w.finish()
}
fn build_aaaa_query(id: u16, domain: &str) -> Bytes {
let mut w = Writer::with_capacity(64);
Header::new(id).with_qdcount(1).with_rd(true).write(&mut w);
let n: Name = domain.parse().expect("valid name");
n.write(&mut w);
w.write_u16(28u16); w.write_u16(1u16); w.finish()
}
fn make_request(raw: Bytes) -> DnsRequest {
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let query = Query::try_from(raw).expect("valid query");
DnsRequest::new(query, client)
}
fn stub_fn(req: DnsRequest) -> std::future::Ready<Result<PipelineResponse, BoxError>> {
std::future::ready(Ok(PipelineResponse::new(
req.raw().clone(),
Outcome::Forwarded,
)))
}
fn parse_header(bytes: &Bytes) -> Header {
let mut r = Reader::new(bytes.clone());
Header::read(&mut r).expect("valid DNS header")
}
#[tokio::test]
async fn blacklisted_and_allowlisted_still_blocked_by_admin() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("evil.example.com");
state
.blacklist()
.store([target.clone()].into_iter().collect());
state
.allowlist()
.store([target.clone()].into_iter().collect());
let raw = build_a_query(0x0001, "evil.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::BlockedByAdmin,
"admin blacklist must win over allowlist"
);
}
#[tokio::test]
async fn allowlisted_and_blocklisted_forwards() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("safe.example.com");
state
.allowlist()
.store([target.clone()].into_iter().collect());
state
.blocklist()
.store([target.clone()].into_iter().collect());
let raw = build_a_query(0x0002, "safe.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"allowlist must bypass blocklist → stub returns Forwarded"
);
}
#[tokio::test]
async fn local_record_wins_over_all_blocking() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("router.home.lan");
state
.blacklist()
.store([target.clone()].into_iter().collect());
state
.blocklist()
.store([target.clone()].into_iter().collect());
let mut b = LocalRecords::builder();
b.add(
"router.home.lan",
LRecordData::A("192.168.1.1".parse().unwrap()),
300,
)
.unwrap();
state.local().store(b.build());
let raw = build_a_query(0x0003, "router.home.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Local,
"local record must win over all blocking"
);
}
#[tokio::test]
async fn local_name_exists_but_qtype_absent_returns_nodata() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let mut b = LocalRecords::builder();
b.add("host.lan", LRecordData::A("10.0.0.1".parse().unwrap()), 60)
.unwrap();
state.local().store(b.build());
let raw = build_aaaa_query(0x0004, "host.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::LocalNoData,
"AAAA query for A-only local name must return LocalNoData"
);
}
#[tokio::test]
async fn plain_blocklist_hit_returns_blocked_by_blocklist() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let target = name("tracker.bad.example");
state
.blocklist()
.store([target.clone()].into_iter().collect());
let raw = build_a_query(0x0005, "tracker.bad.example");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::BlockedByBlocklist,
"plain blocklist hit must return BlockedByBlocklist"
);
}
#[tokio::test]
async fn plain_non_match_falls_through_to_forwarded() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let raw = build_a_query(0x0006, "nobody.example.com");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(
resp.outcome,
Outcome::Forwarded,
"plain miss must fall through to Forwarded"
);
}
#[tokio::test]
async fn blocklist_null_ip_response_is_well_formed() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let settings_guard = state.settings();
assert_eq!(
settings_guard.block_mode,
crate::codec::synth::BlockMode::null_ip(),
"seeded default must be null-ip"
);
drop(settings_guard);
let target = name("blocked.example");
state
.blocklist()
.store([target.clone()].into_iter().collect());
let query_id: u16 = 0x1234;
let raw = build_a_query(query_id, "blocked.example");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::BlockedByBlocklist);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id, "response id must match query id");
assert!(hdr.qr(), "QR must be set");
assert_eq!(hdr.rcode(), Rcode::NoError, "null-ip A → NOERROR");
assert_eq!(hdr.ancount, 1, "null-ip A → one answer RR");
}
#[tokio::test]
async fn admin_blacklist_nxdomain_response_is_well_formed() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let new_settings = RuntimeSettings {
block_mode: crate::codec::synth::BlockMode::NxDomain,
..(*state.settings_full()).clone()
};
state.store_settings(new_settings);
let target = name("evil.example");
state
.blacklist()
.store([target.clone()].into_iter().collect());
let query_id: u16 = 0x5678;
let raw = build_a_query(query_id, "evil.example");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::BlockedByAdmin);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id, "response id must match query id");
assert_eq!(hdr.rcode(), Rcode::NxDomain, "NxDomain mode → NXDOMAIN");
assert_eq!(hdr.ancount, 0, "NXDOMAIN → no answer RRs");
}
#[tokio::test]
async fn decision_layer_wraps_correctly() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let layer = DecisionLayer::new(state);
let svc = layer.layer(tower::service_fn(stub_fn));
let raw = build_a_query(0x9999, "via-layer.example.com");
let req = make_request(raw);
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::Forwarded);
}
#[tokio::test]
async fn local_a_record_response_is_authoritative() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let ip: Ipv4Addr = "192.168.1.42".parse().unwrap();
let mut b = LocalRecords::builder();
b.add("myhost.lan", LRecordData::A(ip), 120).unwrap();
state.local().store(b.build());
let query_id: u16 = 0xABCD;
let raw = build_a_query(query_id, "myhost.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::Local);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id);
assert!(hdr.aa(), "local record must set AA=1");
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1);
}
#[tokio::test]
async fn local_nodata_response_is_authoritative_nodata() {
let (_dir, db) = open_temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let mut b = LocalRecords::builder();
b.add(
"nodata.lan",
LRecordData::A("10.0.0.2".parse().unwrap()),
60,
)
.unwrap();
state.local().store(b.build());
let query_id: u16 = 0xDEAD;
let raw = build_aaaa_query(query_id, "nodata.lan");
let req = make_request(raw);
let stack = DecisionStack::new(state, tower::service_fn(stub_fn));
let resp = stack.oneshot(req).await.unwrap();
assert_eq!(resp.outcome, Outcome::LocalNoData);
let hdr = parse_header(&resp.bytes);
assert_eq!(hdr.id, query_id);
assert!(hdr.aa(), "NODATA response must be authoritative");
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 0, "NODATA must have no answer RRs");
}
}