pub mod cache_layer;
pub mod engine;
pub mod forward;
pub mod layers;
pub mod listener;
pub mod middleware;
use std::net::SocketAddr;
use bytes::Bytes;
use crate::codec::{
header::Header,
message::{Query, Question},
};
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone)]
pub struct DnsRequest {
query: Query,
client: SocketAddr,
allow_bypass: bool,
}
impl DnsRequest {
pub fn new(query: Query, client: SocketAddr) -> Self {
Self {
query,
client,
allow_bypass: false,
}
}
pub fn query(&self) -> &Query {
&self.query
}
pub fn raw(&self) -> &Bytes {
self.query.raw()
}
pub fn header(&self) -> &Header {
self.query.header()
}
pub fn question(&self) -> &Question {
self.query.question()
}
pub fn client(&self) -> SocketAddr {
self.client
}
pub fn allow_bypass(&self) -> bool {
self.allow_bypass
}
pub fn set_allow_bypass(&mut self, value: bool) {
self.allow_bypass = value;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogAction {
Block,
Unblock,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Outcome {
Local,
LocalNoData,
BlockedByAdmin,
BlockedByBlocklist,
Cached,
Forwarded,
Refused,
Formerr,
Servfail,
Error,
}
impl Outcome {
#[must_use]
pub fn is_blocked(&self) -> bool {
matches!(self, Self::BlockedByAdmin | Self::BlockedByBlocklist)
}
#[must_use]
pub fn log_action(&self) -> Option<LogAction> {
match self {
Self::Cached | Self::Forwarded => Some(LogAction::Block),
Self::BlockedByAdmin | Self::BlockedByBlocklist => Some(LogAction::Unblock),
_ => None,
}
}
#[must_use]
pub fn category(&self) -> &'static str {
match self {
Self::BlockedByAdmin | Self::BlockedByBlocklist => "blocked",
Self::Cached => "cached",
Self::Forwarded => "forwarded",
Self::Local | Self::LocalNoData => "local",
Self::Refused | Self::Formerr | Self::Servfail | Self::Error => "other",
}
}
}
impl std::fmt::Display for Outcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Local => "local",
Self::LocalNoData => "local (no data)",
Self::BlockedByAdmin => "blocked (admin)",
Self::BlockedByBlocklist => "blocked (blocklist)",
Self::Cached => "cached",
Self::Forwarded => "forwarded",
Self::Refused => "refused",
Self::Formerr => "formerr",
Self::Servfail => "servfail",
Self::Error => "error",
};
f.write_str(s)
}
}
#[derive(Debug, Clone)]
pub struct PipelineResponse {
pub bytes: Bytes,
pub outcome: Outcome,
}
impl PipelineResponse {
pub fn new(bytes: Bytes, outcome: Outcome) -> Self {
Self { bytes, outcome }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{header::Header, message::Qtype, name::Name, writer::Writer};
fn build_a_query(id: u16, name: &str) -> Bytes {
let mut w = Writer::with_capacity(64);
let hdr = Header::new(id).with_qdcount(1).with_rd(true);
hdr.write(&mut w);
let n: Name = name.parse().expect("valid name in test helper");
n.write(&mut w);
w.write_u16(1u16); w.write_u16(1u16); w.finish()
}
#[test]
fn dns_request_accessors_and_allow_bypass() {
let raw = build_a_query(0x1234, "example.com");
let query = Query::try_from(raw.clone()).expect("valid query");
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let mut req = DnsRequest::new(query, client);
assert_eq!(req.client(), client);
assert_eq!(req.header().id, 0x1234);
assert!(req.header().rd());
assert_eq!(req.question().name.to_string(), "example.com.");
assert_eq!(req.question().qtype, Qtype::A);
assert_eq!(req.raw(), &raw);
assert!(!req.allow_bypass());
req.set_allow_bypass(true);
assert!(req.allow_bypass());
req.set_allow_bypass(false);
assert!(!req.allow_bypass());
}
#[test]
fn outcome_is_blocked() {
assert!(Outcome::BlockedByAdmin.is_blocked());
assert!(Outcome::BlockedByBlocklist.is_blocked());
assert!(!Outcome::Local.is_blocked());
assert!(!Outcome::LocalNoData.is_blocked());
assert!(!Outcome::Cached.is_blocked());
assert!(!Outcome::Forwarded.is_blocked());
assert!(!Outcome::Refused.is_blocked());
assert!(!Outcome::Formerr.is_blocked());
assert!(!Outcome::Servfail.is_blocked());
assert!(!Outcome::Error.is_blocked());
}
#[test]
fn outcome_log_action() {
assert_eq!(Outcome::Cached.log_action(), Some(LogAction::Block));
assert_eq!(Outcome::Forwarded.log_action(), Some(LogAction::Block));
assert_eq!(
Outcome::BlockedByAdmin.log_action(),
Some(LogAction::Unblock)
);
assert_eq!(
Outcome::BlockedByBlocklist.log_action(),
Some(LogAction::Unblock)
);
assert_eq!(Outcome::Local.log_action(), None);
assert_eq!(Outcome::LocalNoData.log_action(), None);
assert_eq!(Outcome::Refused.log_action(), None);
assert_eq!(Outcome::Formerr.log_action(), None);
assert_eq!(Outcome::Servfail.log_action(), None);
assert_eq!(Outcome::Error.log_action(), None);
}
#[test]
fn outcome_category_groups_variants() {
assert_eq!(Outcome::BlockedByAdmin.category(), "blocked");
assert_eq!(Outcome::BlockedByBlocklist.category(), "blocked");
assert_eq!(Outcome::Cached.category(), "cached");
assert_eq!(Outcome::Forwarded.category(), "forwarded");
assert_eq!(Outcome::Local.category(), "local");
assert_eq!(Outcome::LocalNoData.category(), "local");
assert_eq!(Outcome::Refused.category(), "other");
assert_eq!(Outcome::Servfail.category(), "other");
assert_eq!(Outcome::Error.category(), "other");
}
#[test]
fn outcome_display() {
assert_eq!(Outcome::Local.to_string(), "local");
assert_eq!(Outcome::LocalNoData.to_string(), "local (no data)");
assert_eq!(Outcome::BlockedByAdmin.to_string(), "blocked (admin)");
assert_eq!(
Outcome::BlockedByBlocklist.to_string(),
"blocked (blocklist)"
);
assert_eq!(Outcome::Cached.to_string(), "cached");
assert_eq!(Outcome::Forwarded.to_string(), "forwarded");
assert_eq!(Outcome::Refused.to_string(), "refused");
assert_eq!(Outcome::Formerr.to_string(), "formerr");
assert_eq!(Outcome::Servfail.to_string(), "servfail");
assert_eq!(Outcome::Error.to_string(), "error");
}
#[tokio::test]
async fn service_shape_round_trip() {
use tower::ServiceExt as _;
let raw = build_a_query(0xBEEF, "roundtrip.test");
let query = Query::try_from(raw.clone()).expect("valid query");
let client: SocketAddr = "127.0.0.1:5353".parse().unwrap();
let request = DnsRequest::new(query, client);
let svc = tower::service_fn(|req: DnsRequest| async move {
Ok::<_, BoxError>(PipelineResponse::new(req.raw().clone(), Outcome::Forwarded))
});
let resp = svc.oneshot(request).await.expect("service must not error");
assert_eq!(resp.outcome, Outcome::Forwarded);
assert_eq!(resp.bytes, raw);
}
}