pub mod cache_layer;
pub mod engine;
pub mod forward;
pub mod layers;
pub mod listener;
pub mod middleware;
use std::net::SocketAddr;
use bytes::Bytes;
use crate::codec::{
header::Header,
message::{Query, Question},
synth::EdnsInfo,
};
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone)]
pub struct DnsRequest {
query: Query,
client: SocketAddr,
edns: Option<EdnsInfo>,
allow_bypass: bool,
forward_target: Option<SocketAddr>,
}
impl DnsRequest {
pub fn new(query: Query, client: SocketAddr) -> Self {
let edns = EdnsInfo::scan(&query);
Self {
query,
client,
edns,
allow_bypass: false,
forward_target: None,
}
}
pub fn query(&self) -> &Query {
&self.query
}
pub fn raw(&self) -> &Bytes {
self.query.raw()
}
pub fn header(&self) -> &Header {
self.query.header()
}
pub fn question(&self) -> &Question {
self.query.question()
}
pub fn client(&self) -> SocketAddr {
self.client
}
pub fn edns(&self) -> Option<&EdnsInfo> {
self.edns.as_ref()
}
pub fn allow_bypass(&self) -> bool {
self.allow_bypass
}
pub fn set_allow_bypass(&mut self, value: bool) {
self.allow_bypass = value;
}
pub fn forward_target(&self) -> Option<SocketAddr> {
self.forward_target
}
pub fn set_forward_target(&mut self, target: SocketAddr) {
self.forward_target = Some(target);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogAction {
Block,
Unblock,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, strum::EnumString, strum::IntoStaticStr, strum::EnumIter,
)]
pub enum Outcome {
#[strum(serialize = "local")]
Local,
#[strum(serialize = "local-nodata")]
LocalNoData,
#[strum(serialize = "blocked-admin")]
BlockedByAdmin,
#[strum(serialize = "blocked-blocklist")]
BlockedByBlocklist,
#[strum(serialize = "cached")]
Cached,
#[strum(serialize = "forwarded")]
Forwarded,
#[strum(serialize = "refused")]
Refused,
#[strum(serialize = "formerr")]
Formerr,
#[strum(serialize = "servfail")]
Servfail,
#[strum(serialize = "error")]
Error,
}
impl Outcome {
#[must_use]
pub fn is_blocked(&self) -> bool {
matches!(self, Self::BlockedByAdmin | Self::BlockedByBlocklist)
}
#[must_use]
pub fn log_action(&self) -> Option<LogAction> {
match self {
Self::Cached | Self::Forwarded => Some(LogAction::Block),
Self::BlockedByAdmin | Self::BlockedByBlocklist => Some(LogAction::Unblock),
_ => None,
}
}
#[must_use]
pub fn as_str(&self) -> &'static str {
self.into()
}
#[must_use]
pub fn category(&self) -> &'static str {
match self {
Self::BlockedByAdmin | Self::BlockedByBlocklist => "blocked",
Self::Cached => "cached",
Self::Forwarded => "forwarded",
Self::Local | Self::LocalNoData => "local",
Self::Refused | Self::Formerr | Self::Servfail | Self::Error => "other",
}
}
}
impl std::fmt::Display for Outcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Local => "local",
Self::LocalNoData => "local (no data)",
Self::BlockedByAdmin => "blocked (admin)",
Self::BlockedByBlocklist => "blocked (blocklist)",
Self::Cached => "cached",
Self::Forwarded => "forwarded",
Self::Refused => "refused",
Self::Formerr => "formerr",
Self::Servfail => "servfail",
Self::Error => "error",
};
f.write_str(s)
}
}
#[derive(Debug, Clone)]
pub struct PipelineResponse {
pub bytes: Bytes,
pub outcome: Outcome,
pub upstream: Option<SocketAddr>,
}
impl PipelineResponse {
pub fn new(bytes: Bytes, outcome: Outcome) -> Self {
Self {
bytes,
outcome,
upstream: None,
}
}
#[must_use]
pub fn with_upstream(mut self, upstream: SocketAddr) -> Self {
self.upstream = Some(upstream);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{header::Header, message::Qtype, name::Name, writer::Writer};
#[test]
fn dns_request_accessors_and_allow_bypass() {
let raw = crate::test_support::a_query(0x1234, "example.com");
let query = Query::try_from(raw.clone()).expect("valid query");
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let mut req = DnsRequest::new(query, client);
assert_eq!(req.client(), client);
assert_eq!(req.header().id, 0x1234);
assert!(req.header().rd());
assert_eq!(req.question().name.to_string(), "example.com.");
assert_eq!(req.question().qtype, Qtype::A);
assert_eq!(req.raw(), &raw);
assert!(!req.allow_bypass());
req.set_allow_bypass(true);
assert!(req.allow_bypass());
req.set_allow_bypass(false);
assert!(!req.allow_bypass());
}
#[test]
fn dns_request_carries_edns_info() {
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let plain = Query::try_from(crate::test_support::a_query(0x0001, "plain.example")).unwrap();
assert!(DnsRequest::new(plain, client).edns().is_none());
let mut w = Writer::with_capacity(64);
Header::new(0x0002)
.with_qdcount(1)
.with_arcount(1)
.write(&mut w);
let name: Name = "edns.example".parse().unwrap();
name.write(&mut w);
w.write_u16(1); w.write_u16(1); w.write_u8(0x00); w.write_u16(41); w.write_u16(4096); w.write_u32(0); w.write_u16(0); let query = Query::try_from(w.finish()).unwrap();
let req = DnsRequest::new(query, client);
let edns = req.edns().expect("OPT query must carry EDNS info");
assert_eq!(edns.udp_payload_size, 4096);
let fresh = EdnsInfo::scan(req.query()).expect("fresh scan finds the OPT");
assert_eq!(fresh.udp_payload_size, edns.udp_payload_size);
}
#[test]
fn outcome_is_blocked() {
assert!(Outcome::BlockedByAdmin.is_blocked());
assert!(Outcome::BlockedByBlocklist.is_blocked());
assert!(!Outcome::Local.is_blocked());
assert!(!Outcome::LocalNoData.is_blocked());
assert!(!Outcome::Cached.is_blocked());
assert!(!Outcome::Forwarded.is_blocked());
assert!(!Outcome::Refused.is_blocked());
assert!(!Outcome::Formerr.is_blocked());
assert!(!Outcome::Servfail.is_blocked());
assert!(!Outcome::Error.is_blocked());
}
#[test]
fn outcome_log_action() {
assert_eq!(Outcome::Cached.log_action(), Some(LogAction::Block));
assert_eq!(Outcome::Forwarded.log_action(), Some(LogAction::Block));
assert_eq!(
Outcome::BlockedByAdmin.log_action(),
Some(LogAction::Unblock)
);
assert_eq!(
Outcome::BlockedByBlocklist.log_action(),
Some(LogAction::Unblock)
);
assert_eq!(Outcome::Local.log_action(), None);
assert_eq!(Outcome::LocalNoData.log_action(), None);
assert_eq!(Outcome::Refused.log_action(), None);
assert_eq!(Outcome::Formerr.log_action(), None);
assert_eq!(Outcome::Servfail.log_action(), None);
assert_eq!(Outcome::Error.log_action(), None);
}
#[test]
fn outcome_category_groups_variants() {
assert_eq!(Outcome::BlockedByAdmin.category(), "blocked");
assert_eq!(Outcome::BlockedByBlocklist.category(), "blocked");
assert_eq!(Outcome::Cached.category(), "cached");
assert_eq!(Outcome::Forwarded.category(), "forwarded");
assert_eq!(Outcome::Local.category(), "local");
assert_eq!(Outcome::LocalNoData.category(), "local");
assert_eq!(Outcome::Refused.category(), "other");
assert_eq!(Outcome::Servfail.category(), "other");
assert_eq!(Outcome::Error.category(), "other");
}
#[test]
fn outcome_as_str_from_str_round_trips_every_variant() {
use std::str::FromStr as _;
use strum::IntoEnumIterator as _;
for outcome in Outcome::iter() {
let token = outcome.as_str();
let parsed = Outcome::from_str(token)
.unwrap_or_else(|e| panic!("token {token:?} must parse back: {e}"));
assert_eq!(parsed, outcome, "round-trip mismatch for {token:?}");
}
}
#[test]
fn outcome_as_str_tokens_are_distinct() {
use strum::IntoEnumIterator as _;
let mut tokens: Vec<&str> = Outcome::iter().map(|o| o.as_str()).collect();
let count = tokens.len();
tokens.sort_unstable();
tokens.dedup();
assert_eq!(tokens.len(), count, "as_str tokens must all be distinct");
}
#[test]
fn outcome_from_str_rejects_unknown() {
use std::str::FromStr as _;
assert!(Outcome::from_str("nonsense").is_err());
assert!(
Outcome::from_str("blocked (admin)").is_err(),
"Display labels are not the persistence format"
);
}
#[test]
fn outcome_display() {
assert_eq!(Outcome::Local.to_string(), "local");
assert_eq!(Outcome::LocalNoData.to_string(), "local (no data)");
assert_eq!(Outcome::BlockedByAdmin.to_string(), "blocked (admin)");
assert_eq!(
Outcome::BlockedByBlocklist.to_string(),
"blocked (blocklist)"
);
assert_eq!(Outcome::Cached.to_string(), "cached");
assert_eq!(Outcome::Forwarded.to_string(), "forwarded");
assert_eq!(Outcome::Refused.to_string(), "refused");
assert_eq!(Outcome::Formerr.to_string(), "formerr");
assert_eq!(Outcome::Servfail.to_string(), "servfail");
assert_eq!(Outcome::Error.to_string(), "error");
}
#[tokio::test]
async fn service_shape_round_trip() {
use tower::ServiceExt as _;
let raw = crate::test_support::a_query(0xBEEF, "roundtrip.test");
let query = Query::try_from(raw.clone()).expect("valid query");
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let request = DnsRequest::new(query, client);
let svc = tower::service_fn(|req: DnsRequest| async move {
Ok::<_, BoxError>(PipelineResponse::new(req.raw().clone(), Outcome::Forwarded))
});
let resp = svc.oneshot(request).await.expect("service must not error");
assert_eq!(resp.outcome, Outcome::Forwarded);
assert_eq!(resp.bytes, raw);
}
}