use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
time::Duration,
};
use kameo::{
actor::ActorRef,
message::{Context, Message},
};
use netstack::{CreateSocket, netcore::Channel};
use tokio::{sync::watch, task::JoinSet, time::timeout};
use ts_control::{DnsConfig, DnsResolver, Node};
use ts_dns_wire::{Name, QType, RData, Rcode, decode_query, encode_response};
use crate::{
Error,
env::Env,
peer_tracker::{PeerDb, PeerState},
};
const UPSTREAM_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_UPSTREAM_RESPONSE: usize = 1232;
const MAGIC_DNS_IP: Ipv4Addr = Ipv4Addr::new(100, 100, 100, 100);
const MAGIC_DNS_PORT: u16 = 53;
#[derive(Clone, Default)]
pub(crate) struct DnsView {
pub(crate) cfg: DnsConfig,
pub(crate) peers: Option<Arc<PeerDb>>,
pub(crate) self_node: Option<Node>,
pub(crate) exit_doh: Option<SocketAddr>,
pub(crate) enable_ipv6: bool,
pub(crate) accept_dns: bool,
}
impl DnsView {
fn node_by_name(&self, name: &str) -> Option<Node> {
if let Some(node) = self
.peers
.as_ref()
.and_then(|p| p.get(&name).map(|(_, n)| n.clone()))
{
return Some(node);
}
self.self_node
.as_ref()
.filter(|n| n.matches_name(name))
.cloned()
}
fn resolve_addr(&self, canon: &str, want_v4: bool) -> Option<IpAddr> {
let addr_of = |node: Node| -> IpAddr {
if want_v4 {
IpAddr::from(node.tailnet_address.ipv4.addr())
} else {
IpAddr::from(node.tailnet_address.ipv6.addr())
}
};
if let Some(node) = self.node_by_name(canon) {
return Some(addr_of(node));
}
for suffix in &self.cfg.search_domains {
if let Some(node) = self.node_by_name(&format!("{canon}.{suffix}")) {
return Some(addr_of(node));
}
}
self.cfg.extra_records.iter().find_map(|rec| {
let family_ok = matches!(
(rec.addr, want_v4),
(IpAddr::V4(_), true) | (IpAddr::V6(_), false)
);
(rec.name == canon && family_ok).then_some(rec.addr)
})
}
fn node_by_ip(&self, ip: IpAddr) -> Option<Node> {
if let Some(node) = self
.peers
.as_ref()
.and_then(|p| p.get(&ip).map(|(_, n)| n.clone()))
{
return Some(node);
}
self.self_node
.as_ref()
.filter(|n| {
IpAddr::from(n.tailnet_address.ipv4.addr()) == ip
|| IpAddr::from(n.tailnet_address.ipv6.addr()) == ip
})
.cloned()
}
fn route_for(&self, name: &str) -> Upstreams<'_> {
let mut best: Option<(&str, &Vec<DnsResolver>)> = None;
for (suffix, upstreams) in &self.cfg.routes {
if suffix_matches(name, suffix) && best.is_none_or(|(b, _)| suffix.len() > b.len()) {
best = Some((suffix.as_str(), upstreams));
}
}
if let Some((_, upstreams)) = best {
return if upstreams.is_empty() {
Upstreams::Block
} else {
Upstreams::Route(upstreams)
};
}
if !self.cfg.fallback_resolvers.is_empty() {
return Upstreams::Recursive(&self.cfg.fallback_resolvers);
}
if !self.cfg.resolvers.is_empty() {
return Upstreams::Recursive(&self.cfg.resolvers);
}
Upstreams::None
}
}
enum Upstreams<'a> {
Route(&'a [DnsResolver]),
Recursive(&'a [DnsResolver]),
Block,
None,
}
pub(crate) enum Decision {
Reply(Vec<u8>),
Forward {
upstreams: Vec<SocketAddr>,
query: Vec<u8>,
nxdomain: Vec<u8>,
recursive: bool,
},
}
fn suffix_matches(name: &str, suffix: &str) -> bool {
if suffix.is_empty() {
return false;
}
name == suffix
|| (name.len() > suffix.len()
&& name.ends_with(suffix)
&& name.as_bytes()[name.len() - suffix.len() - 1] == b'.')
}
fn is_tailnet_name(view: &DnsView, name: &str) -> bool {
view.cfg
.search_domains
.iter()
.any(|suffix| suffix_matches(name, suffix))
}
fn is_ip6_arpa(name: &str) -> bool {
suffix_matches(name, "ip6.arpa")
}
fn is_tailnet_cgnat(ip: Ipv4Addr) -> bool {
let o = ip.octets();
o[0] == 100 && (64..=127).contains(&o[1])
}
pub(crate) fn decide(view: &DnsView, buf: &[u8]) -> Option<Decision> {
let query = decode_query(buf).ok()?;
let q = &query.question;
let id = query.id;
let reply = |rcode, answers: &[RData]| Decision::Reply(encode_response(id, q, rcode, answers));
if !view.cfg.magic_dns || !view.accept_dns {
return Some(reply(Rcode::Refused, &[]));
}
const CLASS_IN: u16 = 1;
if q.qclass != CLASS_IN {
return Some(reply(Rcode::Refused, &[]));
}
let canon = q.name.to_canon();
Some(match &q.qtype {
QType::A => match view.resolve_addr(&canon, true) {
Some(IpAddr::V4(v4)) => reply(Rcode::NoError, &[RData::A(v4.octets())]),
_ => forward_or_nxdomain(view, &canon, buf, id, q),
},
QType::Aaaa => match view.resolve_addr(&canon, false) {
Some(IpAddr::V6(v6)) if view.enable_ipv6 => {
reply(Rcode::NoError, &[RData::Aaaa(v6.octets())])
}
Some(IpAddr::V6(_)) => reply(Rcode::NoError, &[]),
_ => forward_or_nxdomain(view, &canon, buf, id, q),
},
QType::Ptr => match q.name.ptr_to_ipv4() {
Some(octets) => {
let v4: Ipv4Addr = octets.into();
let ip = IpAddr::V4(v4);
match view.node_by_ip(ip) {
Some(node) => {
let fqdn = node.fqdn(false);
let labels: Vec<String> = fqdn.split('.').map(str::to_owned).collect();
reply(Rcode::NoError, &[RData::Ptr(Name(labels))])
}
None if is_tailnet_cgnat(v4) => reply(Rcode::NxDomain, &[]),
None => forward_or_nxdomain(view, &canon, buf, id, q),
}
}
None if is_ip6_arpa(&canon) => reply(Rcode::NxDomain, &[]),
None => forward_or_nxdomain(view, &canon, buf, id, q),
},
QType::Other(_) => reply(Rcode::Refused, &[]),
})
}
fn forward_or_nxdomain(
view: &DnsView,
canon: &str,
buf: &[u8],
id: u16,
q: &ts_dns_wire::Question,
) -> Decision {
let nxdomain = encode_response(id, q, Rcode::NxDomain, &[]);
if is_tailnet_name(view, canon) {
return Decision::Reply(nxdomain);
}
let (resolvers, recursive) = match view.route_for(canon) {
Upstreams::Route(resolvers) => (resolvers, false),
Upstreams::Recursive(resolvers) => (resolvers, true),
Upstreams::Block | Upstreams::None => return Decision::Reply(nxdomain),
};
let upstreams: Vec<SocketAddr> = resolvers
.iter()
.map(DnsResolver::udp_addr)
.filter(SocketAddr::is_ipv4)
.collect();
if upstreams.is_empty() {
Decision::Reply(nxdomain)
} else {
Decision::Forward {
upstreams,
query: buf.to_vec(),
nxdomain,
recursive,
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum RecursivePlan {
Udp(Vec<SocketAddr>),
Doh(SocketAddr),
}
pub(crate) fn recursive_plan(view: &DnsView, default_upstreams: Vec<SocketAddr>) -> RecursivePlan {
let Some(doh) = view.exit_doh else {
return RecursivePlan::Udp(default_upstreams);
};
let kept: Vec<SocketAddr> = view
.cfg
.resolvers_with_exit_node()
.map(DnsResolver::udp_addr)
.filter(SocketAddr::is_ipv4)
.collect();
if kept.is_empty() {
RecursivePlan::Doh(doh)
} else {
RecursivePlan::Udp(kept)
}
}
fn cap_response(mut resp: Vec<u8>) -> Vec<u8> {
if resp.len() > MAX_UPSTREAM_RESPONSE {
resp.truncate(MAX_UPSTREAM_RESPONSE);
if let Some(flags_hi) = resp.get_mut(2) {
*flags_hi |= 0x02;
}
}
resp
}
const DNS_HEADER_LEN: usize = 12;
fn question_range(msg: &[u8]) -> Option<std::ops::Range<usize>> {
let mut off = DNS_HEADER_LEN;
loop {
let len = *msg.get(off)? as usize;
if len & 0xC0 != 0 {
return None;
}
off += 1;
if len == 0 {
break; }
off = off.checked_add(len)?;
if off > msg.len() {
return None;
}
}
let end = off.checked_add(4)?;
if end > msg.len() {
return None;
}
Some(DNS_HEADER_LEN..end)
}
fn response_matches_query(query: &[u8], resp: &[u8]) -> bool {
if query.len() < DNS_HEADER_LEN || resp.len() < DNS_HEADER_LEN {
return false;
}
let id_matches = query[0..2] == resp[0..2];
let is_response = resp[2] & 0x80 != 0;
if !id_matches || !is_response {
return false;
}
match (question_range(query), question_range(resp)) {
(Some(q), Some(r)) => query[q] == resp[r],
_ => false,
}
}
pub(crate) async fn forward_query(
channel: &Channel,
upstreams: &[SocketAddr],
query: &[u8],
nxdomain: Vec<u8>,
) -> Vec<u8> {
for upstream in upstreams {
let socket = match channel
.udp_bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
.await
{
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, %upstream, "magic dns upstream bind failed");
continue;
}
};
if let Err(e) = socket.send_to(*upstream, query).await {
tracing::warn!(error = %e, %upstream, "magic dns upstream send failed");
continue;
}
match timeout(UPSTREAM_TIMEOUT, socket.recv_from_bytes()).await {
Ok(Ok((from, resp))) if !resp.is_empty() => {
if from.ip() != upstream.ip() || !response_matches_query(query, &resp) {
tracing::debug!(%upstream, %from, "magic dns dropping unsolicited/mismatched response");
continue;
}
return cap_response(resp.to_vec());
}
Ok(Ok(_)) => continue,
Ok(Err(e)) => {
tracing::warn!(error = %e, %upstream, "magic dns upstream recv failed");
continue;
}
Err(_) => {
tracing::debug!(%upstream, "magic dns upstream timed out");
continue;
}
}
}
nxdomain
}
async fn serve(
socket: netstack::netsock::UdpSocket,
rx: watch::Receiver<Arc<DnsView>>,
channel: Channel,
) {
let socket = Arc::new(socket);
let mut forwards = JoinSet::new();
loop {
let (src, buf) = match socket.recv_from_bytes().await {
Ok(pkt) => pkt,
Err(e) => {
tracing::warn!(error = %e, "magic dns socket recv failed, stopping responder");
return;
}
};
let view = rx.borrow().clone();
match decide(&view, &buf) {
None => continue,
Some(Decision::Reply(resp)) => {
if let Err(e) = socket.send_to(src, &resp).await {
tracing::warn!(error = %e, %src, "magic dns response send failed");
}
}
Some(Decision::Forward {
upstreams,
query,
nxdomain,
recursive,
}) => {
let plan = if recursive {
recursive_plan(&view, upstreams)
} else {
RecursivePlan::Udp(upstreams)
};
let socket = socket.clone();
let channel = channel.clone();
forwards.spawn(async move {
let resp = match plan {
RecursivePlan::Udp(upstreams) => {
forward_query(&channel, &upstreams, &query, nxdomain).await
}
RecursivePlan::Doh(doh_addr) => {
crate::peerapi_doh::forward_doh(&channel, doh_addr, &query, nxdomain)
.await
}
};
if let Err(e) = socket.send_to(src, &resp).await {
tracing::warn!(error = %e, %src, "magic dns forwarded response send failed");
}
});
}
}
while forwards.try_join_next().is_some() {}
}
}
pub struct MagicDnsActor {
_joinset: JoinSet<()>,
view_tx: watch::Sender<Arc<DnsView>>,
env: Env,
}
impl kameo::Actor for MagicDnsActor {
type Args = (Env, Channel);
type Error = Error;
async fn on_start(
(env, channel): Self::Args,
slf: ActorRef<Self>,
) -> Result<Self, Self::Error> {
env.subscribe::<Arc<ts_control::StateUpdate>>(&slf).await?;
env.subscribe::<Arc<PeerState>>(&slf).await?;
env.subscribe::<crate::route_updater::ActiveExitNode>(&slf)
.await?;
let (view_tx, view_rx) = watch::channel(Arc::new(DnsView {
enable_ipv6: env.enable_ipv6,
accept_dns: env.accept_dns(),
..DnsView::default()
}));
let mut joinset = JoinSet::new();
let addr = SocketAddr::from((MAGIC_DNS_IP, MAGIC_DNS_PORT));
match channel.udp_bind(addr).await {
Ok(socket) => {
tracing::debug!(%addr, "magic dns responder bound");
joinset.spawn(serve(socket, view_rx.clone(), channel.clone()));
}
Err(e) => {
tracing::error!(error = %e, %addr, "magic dns udp bind failed; responder inert");
}
}
if let Some(port) = env.peerapi_port {
let channel = channel.clone();
let view_rx = view_rx.clone();
let forward_exit_egress = env.forward_exit_egress;
let taildrop = env.taildrop_store.clone();
let funnel_ingress = env.funnel_ingress.clone();
joinset.spawn(crate::peerapi::serve(
channel,
port,
view_rx,
forward_exit_egress,
taildrop,
funnel_ingress,
));
}
Ok(Self {
_joinset: joinset,
view_tx,
env,
})
}
}
impl Message<Arc<ts_control::StateUpdate>> for MagicDnsActor {
type Reply = ();
async fn handle(
&mut self,
update: Arc<ts_control::StateUpdate>,
_ctx: &mut Context<Self, Self::Reply>,
) {
let accept_dns = self.env.accept_dns();
self.view_tx.send_modify(|view| {
let mut next = (**view).clone();
next.cfg = update.dns_config.clone().unwrap_or_default();
next.self_node = update.node.clone();
next.accept_dns = accept_dns;
*view = Arc::new(next);
});
}
}
impl Message<Arc<PeerState>> for MagicDnsActor {
type Reply = ();
async fn handle(&mut self, state: Arc<PeerState>, _ctx: &mut Context<Self, Self::Reply>) {
let accept_dns = self.env.accept_dns();
self.view_tx.send_modify(|view| {
let mut next = (**view).clone();
next.peers = Some(state.peers.clone());
next.accept_dns = accept_dns;
*view = Arc::new(next);
});
}
}
impl Message<crate::route_updater::ActiveExitNode> for MagicDnsActor {
type Reply = ();
async fn handle(
&mut self,
active: crate::route_updater::ActiveExitNode,
_ctx: &mut Context<Self, Self::Reply>,
) {
let exit_doh = active.node.as_ref().and_then(|n| n.peerapi_doh_addr());
self.view_tx.send_modify(|view| {
let mut next = (**view).clone();
next.exit_doh = exit_doh;
*view = Arc::new(next);
});
}
}
#[cfg(test)]
mod tests {
use ts_control::{StableNodeId, TailnetAddress};
use super::*;
fn answer(view: &DnsView, buf: &[u8]) -> Option<Vec<u8>> {
match decide(view, buf)? {
Decision::Reply(resp) => Some(resp),
Decision::Forward { .. } => panic!("unexpected forward in authoritative-only test"),
}
}
fn test_node() -> Node {
Node {
id: 1,
stable_id: StableNodeId("n1".to_string()),
hostname: "host".to_string(),
user_id: 0,
tailnet: Some("user.ts.net".to_string()),
tags: vec![],
tailnet_address: TailnetAddress {
ipv4: "100.64.0.1/32".parse().unwrap(),
ipv6: "fd7a::1/128".parse().unwrap(),
},
node_key: [0u8; 32].into(),
node_key_expiry: None,
online: None,
last_seen: None,
key_signature: vec![],
machine_key: None,
disco_key: None,
accepted_routes: vec![],
underlay_addresses: vec![],
derp_region: None,
cap: Default::default(),
cap_map: Default::default(),
peerapi_port: None,
peerapi_dns_proxy: false,
is_wireguard_only: false,
exit_node_dns_resolvers: vec![],
peer_relay: false,
service_vips: Default::default(),
}
}
fn view_with_peer() -> DnsView {
let mut db = PeerDb::default();
db.upsert(&test_node());
DnsView {
cfg: DnsConfig {
magic_dns: true,
search_domains: vec!["user.ts.net".to_string()],
..Default::default()
},
peers: Some(Arc::new(db)),
self_node: None,
exit_doh: None,
enable_ipv6: false,
accept_dns: true,
}
}
fn build_query(id: u16, labels: &[&str], qtype: u16, qclass: u16) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(&id.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&1u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); for label in labels {
buf.push(label.len() as u8);
buf.extend_from_slice(label.as_bytes());
}
buf.push(0); buf.extend_from_slice(&qtype.to_be_bytes());
buf.extend_from_slice(&qclass.to_be_bytes());
buf
}
fn parse_header(resp: &[u8]) -> (u16, u8, u16) {
let id = u16::from_be_bytes([resp[0], resp[1]]);
let flags = u16::from_be_bytes([resp[2], resp[3]]);
let ancount = u16::from_be_bytes([resp[6], resp[7]]);
(id, (flags & 0x000F) as u8, ancount)
}
#[test]
fn a_query_for_known_peer_answers_v4() {
let view = view_with_peer();
let buf = build_query(0x1234, &["host", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (id, rcode, ancount) = parse_header(&resp);
assert_eq!(id, 0x1234);
assert_eq!(rcode, 0, "NoError");
assert_eq!(ancount, 1);
let tail = &resp[resp.len() - 4..];
assert_eq!(tail, &[100, 64, 0, 1]);
}
#[test]
fn aaaa_query_for_known_peer_is_nodata_when_ipv6_off() {
let view = view_with_peer();
assert!(!view.enable_ipv6, "default gate is off");
let buf = build_query(0x5, &["host", "user", "ts", "net"], 28, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError (NODATA)");
assert_eq!(ancount, 0, "empty answer: no AAAA handed out with IPv6 off");
}
#[test]
fn a_query_still_resolves_when_ipv6_off() {
let view = view_with_peer();
let buf = build_query(0x6, &["host", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError");
assert_eq!(ancount, 1);
let tail = &resp[resp.len() - 4..];
assert_eq!(tail, &[100, 64, 0, 1]);
}
#[test]
fn aaaa_query_for_known_peer_answers_v6_when_ipv6_on() {
let mut view = view_with_peer();
view.enable_ipv6 = true;
let buf = build_query(0x5, &["host", "user", "ts", "net"], 28, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError");
assert_eq!(ancount, 1);
let expected = "fd7a::1".parse::<std::net::Ipv6Addr>().unwrap().octets();
let tail = &resp[resp.len() - 16..];
assert_eq!(tail, expected);
}
#[test]
fn aaaa_for_unknown_tailnet_name_is_nxdomain_not_forwarded_with_ipv6_off() {
let mut db = PeerDb::default();
db.upsert(&test_node());
let view = DnsView {
cfg: DnsConfig {
magic_dns: true,
search_domains: vec!["user.ts.net".to_string()],
fallback_resolvers: vec![DnsResolver {
transport: ts_control::ResolverTransport::Udp("9.9.9.9:53".parse().unwrap()),
use_with_exit_node: false,
}],
..Default::default()
},
peers: Some(Arc::new(db)),
self_node: None,
exit_doh: None,
enable_ipv6: false,
accept_dns: true,
};
let buf = build_query(0x5A, &["ghost", "user", "ts", "net"], 28, 1);
match decide(&view, &buf).expect("decides") {
Decision::Reply(resp) => {
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain: tailnet AAAA not leaked upstream");
}
Decision::Forward { .. } => panic!("tailnet AAAA must never be forwarded"),
}
}
#[test]
fn bare_hostname_resolves() {
let view = view_with_peer();
let buf = build_query(0x7, &["host"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0);
assert_eq!(ancount, 1);
}
#[test]
fn unknown_name_is_nxdomain() {
let view = view_with_peer();
let buf = build_query(0x9, &["nope", "example", "com"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain");
assert_eq!(ancount, 0);
}
#[test]
fn magic_dns_off_is_refused() {
let mut view = view_with_peer();
view.cfg.magic_dns = false;
let buf = build_query(0xAB, &["host", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 5, "Refused");
assert_eq!(ancount, 0);
}
#[test]
fn accept_dns_false_refuses_otherwise_answerable_query() {
let mut view = view_with_peer();
assert!(view.cfg.magic_dns, "MagicDNS itself is on");
view.accept_dns = false;
let buf = build_query(0xDD, &["host", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 5, "Refused: accept_dns off ⇒ serve nothing");
assert_eq!(ancount, 0);
view.accept_dns = true;
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError: accept_dns on ⇒ the known peer answers");
assert_eq!(ancount, 1);
let tail = &resp[resp.len() - 4..];
assert_eq!(tail, &[100, 64, 0, 1], "the peer's tailnet v4 is served");
}
#[test]
fn default_view_serves_nothing() {
let view = DnsView::default();
let buf = build_query(0x1, &["host", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 5, "Refused");
}
#[test]
fn unsupported_qtype_is_refused() {
let view = view_with_peer();
let buf = build_query(0x1, &["host", "user", "ts", "net"], 16, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 5, "Refused");
}
#[test]
fn malformed_query_is_dropped() {
let mut buf = build_query(0x1, &["host"], 1, 1);
buf[2] = 0x80; assert!(answer(&view_with_peer(), &buf).is_none());
}
#[test]
fn ptr_for_known_ip_answers_fqdn() {
let view = view_with_peer();
let buf = build_query(0x33, &["1", "0", "64", "100", "in-addr", "arpa"], 12, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError");
assert_eq!(ancount, 1);
let expected = {
let mut out = Vec::new();
for label in ["host", "user", "ts", "net"] {
out.push(label.len() as u8);
out.extend_from_slice(label.as_bytes());
}
out.push(0);
out
};
let tail = &resp[resp.len() - expected.len()..];
assert_eq!(tail, expected.as_slice());
}
#[test]
fn ptr_for_unknown_ip_is_nxdomain() {
let view = view_with_peer();
let buf = build_query(0x34, &["9", "9", "9", "9", "in-addr", "arpa"], 12, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain");
}
#[test]
fn ptr_for_unknown_tailnet_ip_is_nxdomain_not_forwarded() {
let mut db = PeerDb::default();
db.upsert(&test_node());
let view = DnsView {
cfg: DnsConfig {
magic_dns: true,
search_domains: vec!["user.ts.net".to_string()],
fallback_resolvers: vec![DnsResolver {
transport: ts_control::ResolverTransport::Udp("9.9.9.9:53".parse().unwrap()),
use_with_exit_node: false,
}],
..Default::default()
},
peers: Some(Arc::new(db)),
self_node: None,
exit_doh: None,
enable_ipv6: false,
accept_dns: true,
};
let buf = build_query(0x35, &["9", "0", "64", "100", "in-addr", "arpa"], 12, 1);
match decide(&view, &buf).expect("decides") {
Decision::Reply(resp) => {
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain");
}
Decision::Forward { .. } => {
panic!("tailnet CGNAT PTR must never be forwarded upstream")
}
}
}
#[test]
fn is_tailnet_cgnat_classifies_range() {
assert!(is_tailnet_cgnat("100.64.0.0".parse().unwrap()));
assert!(is_tailnet_cgnat("100.64.0.1".parse().unwrap()));
assert!(is_tailnet_cgnat("100.127.255.255".parse().unwrap()));
assert!(!is_tailnet_cgnat("100.63.255.255".parse().unwrap()));
assert!(!is_tailnet_cgnat("100.128.0.0".parse().unwrap()));
assert!(!is_tailnet_cgnat("9.9.9.9".parse().unwrap()));
assert!(is_tailnet_cgnat("100.100.100.100".parse().unwrap()));
}
#[test]
fn response_matches_query_validates_id_and_qr() {
let query = build_query(0x1234, &["a", "com"], 1, 1);
let mut good = query.clone();
good[2] |= 0x80;
assert!(response_matches_query(&query, &good));
assert!(!response_matches_query(&query, &query));
let mut wrong_id = good.clone();
wrong_id[0] ^= 0xFF;
assert!(!response_matches_query(&query, &wrong_id));
assert!(!response_matches_query(&query, &[0u8; 2]));
assert!(!response_matches_query(&[0u8; 3], &good));
}
#[test]
fn self_node_resolves_when_no_peer_match() {
let view = DnsView {
cfg: DnsConfig {
magic_dns: true,
search_domains: vec![],
..Default::default()
},
peers: None,
self_node: Some(test_node()),
exit_doh: None,
enable_ipv6: false,
accept_dns: true,
};
let buf = build_query(0x44, &["host", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0);
assert_eq!(ancount, 1);
let tail = &resp[resp.len() - 4..];
assert_eq!(tail, &[100, 64, 0, 1]);
}
#[test]
fn partially_qualified_name_resolves_via_search_domain() {
let mut view = view_with_peer();
view.cfg.search_domains = vec!["ts.net".to_string()];
let buf = build_query(0x55, &["host", "user"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError via search-domain expansion");
assert_eq!(ancount, 1);
let tail = &resp[resp.len() - 4..];
assert_eq!(tail, &[100, 64, 0, 1]);
}
#[test]
fn extra_record_a_answers_when_no_peer_match() {
let mut view = view_with_peer();
view.cfg.extra_records = vec![ts_control::ExtraRecord {
name: "static.user.ts.net".to_string(),
addr: IpAddr::V4(Ipv4Addr::new(100, 64, 0, 9)),
}];
let buf = build_query(0x77, &["static", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError from extra record");
assert_eq!(ancount, 1);
let tail = &resp[resp.len() - 4..];
assert_eq!(tail, &[100, 64, 0, 9]);
}
#[test]
fn extra_record_matches_query_case_insensitively() {
let mut view = view_with_peer();
view.cfg.extra_records = vec![ts_control::ExtraRecord {
name: "static.user.ts.net".to_string(),
addr: IpAddr::V4(Ipv4Addr::new(100, 64, 0, 9)),
}];
let buf = build_query(0x7A, &["Static", "User", "TS", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "NoError: case-insensitive match");
assert_eq!(ancount, 1);
let tail = &resp[resp.len() - 4..];
assert_eq!(tail, &[100, 64, 0, 9]);
}
#[test]
fn extra_record_not_expanded_by_search_domain() {
let mut view = view_with_peer();
view.cfg.extra_records = vec![ts_control::ExtraRecord {
name: "static.user.ts.net".to_string(),
addr: IpAddr::V4(Ipv4Addr::new(100, 64, 0, 9)),
}];
let buf = build_query(0x7B, &["static"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain: extra records are not search-expanded");
}
#[test]
fn extra_record_aaaa_family_is_isolated() {
let mut view = view_with_peer();
view.cfg.extra_records = vec![ts_control::ExtraRecord {
name: "v4only.user.ts.net".to_string(),
addr: IpAddr::V4(Ipv4Addr::new(100, 64, 0, 9)),
}];
let buf = build_query(0x78, &["v4only", "user", "ts", "net"], 28, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain: A record does not satisfy AAAA");
}
#[test]
fn extra_record_ignored_when_magic_dns_off() {
let mut view = view_with_peer();
view.cfg.magic_dns = false;
view.cfg.extra_records = vec![ts_control::ExtraRecord {
name: "static.user.ts.net".to_string(),
addr: IpAddr::V4(Ipv4Addr::new(100, 64, 0, 9)),
}];
let buf = build_query(0x79, &["static", "user", "ts", "net"], 1, 1);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 5, "Refused");
}
#[test]
fn non_in_class_is_refused() {
let view = view_with_peer();
let buf = build_query(0x66, &["host", "user", "ts", "net"], 1, 3);
let resp = answer(&view, &buf).expect("answers");
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 5, "Refused");
assert_eq!(ancount, 0);
}
fn view_with_routes(
routes: std::collections::BTreeMap<String, Vec<DnsResolver>>,
resolvers: Vec<DnsResolver>,
fallback: Vec<DnsResolver>,
) -> DnsView {
DnsView {
cfg: DnsConfig {
magic_dns: true,
search_domains: vec!["user.ts.net".to_string()],
routes,
resolvers,
fallback_resolvers: fallback,
..Default::default()
},
peers: None,
self_node: None,
exit_doh: None,
enable_ipv6: false,
accept_dns: true,
}
}
fn udp(addr: &str) -> DnsResolver {
DnsResolver {
transport: ts_control::ResolverTransport::Udp(addr.parse().unwrap()),
use_with_exit_node: false,
}
}
#[test]
fn split_dns_route_forwards_to_matching_upstream() {
let mut routes = std::collections::BTreeMap::new();
routes.insert("corp.example".to_string(), vec![udp("10.0.0.53:53")]);
let view = view_with_routes(routes, vec![], vec![]);
let buf = build_query(0x100, &["api", "corp", "example"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Forward { upstreams, .. } => {
assert_eq!(upstreams, vec!["10.0.0.53:53".parse().unwrap()]);
}
Decision::Reply(_) => panic!("expected forward to the split-DNS upstream"),
}
}
#[test]
fn longest_suffix_route_wins() {
let mut routes = std::collections::BTreeMap::new();
routes.insert("example".to_string(), vec![udp("10.0.0.1:53")]);
routes.insert("corp.example".to_string(), vec![udp("10.0.0.2:53")]);
let view = view_with_routes(routes, vec![], vec![]);
let buf = build_query(0x101, &["api", "corp", "example"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Forward { upstreams, .. } => {
assert_eq!(
upstreams,
vec!["10.0.0.2:53".parse().unwrap()],
"longer suffix wins"
);
}
Decision::Reply(_) => panic!("expected forward"),
}
}
#[test]
fn negative_route_is_nxdomain_not_forwarded() {
let mut routes = std::collections::BTreeMap::new();
routes.insert("blocked.example".to_string(), vec![]);
let view = view_with_routes(routes, vec![udp("8.8.8.8:53")], vec![]);
let buf = build_query(0x102, &["x", "blocked", "example"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Reply(resp) => {
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain: negative route is not forwarded");
}
Decision::Forward { .. } => panic!("negative route must not forward"),
}
}
#[test]
fn unrouted_name_forwards_to_fallback_then_global() {
let view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp("8.8.8.8:53")],
vec![udp("1.1.1.1:53")],
);
let buf = build_query(0x103, &["example", "com"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Forward { upstreams, .. } => {
assert_eq!(
upstreams,
vec!["1.1.1.1:53".parse().unwrap()],
"fallback preferred"
);
}
Decision::Reply(_) => panic!("expected forward to fallback"),
}
}
#[test]
fn unrouted_name_forwards_to_global_when_no_fallback() {
let view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp("8.8.8.8:53")],
vec![],
);
let buf = build_query(0x104, &["example", "com"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Forward { upstreams, .. } => {
assert_eq!(upstreams, vec!["8.8.8.8:53".parse().unwrap()]);
}
Decision::Reply(_) => panic!("expected forward to global resolver"),
}
}
#[test]
fn tailnet_name_is_never_forwarded() {
let view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp("8.8.8.8:53")],
vec![udp("1.1.1.1:53")],
);
let buf = build_query(0x105, &["ghost", "user", "ts", "net"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Reply(resp) => {
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain: tailnet name not leaked upstream");
}
Decision::Forward { .. } => panic!("tailnet name must never be forwarded"),
}
}
#[test]
fn no_resolvers_fails_closed() {
let view = view_with_routes(std::collections::BTreeMap::new(), vec![], vec![]);
let buf = build_query(0x106, &["example", "com"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Reply(resp) => {
let (_, rcode, _) = parse_header(&resp);
assert_eq!(rcode, 3, "NxDomain");
}
Decision::Forward { .. } => panic!("must not forward with no resolvers"),
}
}
#[test]
fn overlay_match_wins_over_forwarding() {
let mut db = PeerDb::default();
db.upsert(&test_node());
let view = DnsView {
cfg: DnsConfig {
magic_dns: true,
search_domains: vec!["user.ts.net".to_string()],
resolvers: vec![udp("8.8.8.8:53")],
..Default::default()
},
peers: Some(Arc::new(db)),
self_node: None,
exit_doh: None,
enable_ipv6: false,
accept_dns: true,
};
let buf = build_query(0x107, &["host", "user", "ts", "net"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Reply(resp) => {
let (_, rcode, ancount) = parse_header(&resp);
assert_eq!(rcode, 0, "authoritative answer wins");
assert_eq!(ancount, 1);
}
Decision::Forward { .. } => panic!("overlay match must not forward"),
}
}
#[test]
fn ipv6_reverse_ptr_is_nxdomain_not_forwarded() {
let view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp("8.8.8.8:53")],
vec![udp("1.1.1.1:53")],
);
let labels = vec![
"1", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0",
"0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "a", "7", "d", "f", "ip6",
"arpa",
];
let buf = build_query(0x200, &labels, 12, 1);
match decide(&view, &buf).expect("decides") {
Decision::Reply(resp) => {
let (_, rcode, _) = parse_header(&resp);
assert_eq!(
rcode, 3,
"NxDomain: ip6.arpa reverse must not leak upstream"
);
}
Decision::Forward { .. } => panic!("ip6.arpa PTR must never be forwarded"),
}
}
#[test]
fn cap_response_sets_tc_when_truncated() {
let mut big = build_query(0x300, &["example", "com"], 1, 1);
big[2] |= 0x80; big.resize(MAX_UPSTREAM_RESPONSE + 500, 0xAB);
let out = cap_response(big);
assert_eq!(out.len(), MAX_UPSTREAM_RESPONSE, "capped to one datagram");
assert_ne!(out[2] & 0x02, 0, "TC bit set on truncation");
}
#[test]
fn cap_response_leaves_small_response_untouched() {
let mut small = build_query(0x301, &["example", "com"], 1, 1);
small[2] |= 0x80;
let before = small.clone();
let out = cap_response(small);
assert_eq!(out, before, "small response unchanged");
assert_eq!(out[2] & 0x02, 0, "TC bit not set when no truncation");
}
#[test]
fn response_matches_query_rejects_mismatched_question() {
let query = build_query(0x1234, &["a", "com"], 1, 1);
let mut wrong_question = build_query(0x1234, &["b", "com"], 1, 1);
wrong_question[2] |= 0x80; assert!(
!response_matches_query(&query, &wrong_question),
"different QNAME must be rejected"
);
let mut wrong_qtype = build_query(0x1234, &["a", "com"], 28, 1);
wrong_qtype[2] |= 0x80;
assert!(
!response_matches_query(&query, &wrong_qtype),
"different QTYPE must be rejected"
);
let mut good = query.clone();
good[2] |= 0x80;
assert!(
response_matches_query(&query, &good),
"matching question accepted"
);
}
#[test]
fn suffix_matches_handles_boundaries_and_empty() {
assert!(suffix_matches("corp", "corp"));
assert!(suffix_matches("a.corp", "corp"));
assert!(suffix_matches("a.b.corp", "corp"));
assert!(!suffix_matches("acorp", "corp"));
assert!(!suffix_matches("anything.example", ""));
assert!(!suffix_matches("", ""));
}
#[test]
fn empty_search_domain_does_not_capture_everything() {
let mut view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp("8.8.8.8:53")],
vec![],
);
view.cfg.search_domains = vec![String::new()];
let buf = build_query(0x400, &["example", "com"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Forward { upstreams, .. } => {
assert_eq!(upstreams, vec!["8.8.8.8:53".parse().unwrap()]);
}
Decision::Reply(_) => {
panic!("empty search domain must not treat every name as tailnet")
}
}
}
#[test]
fn empty_route_suffix_does_not_capture_everything() {
let mut routes = std::collections::BTreeMap::new();
routes.insert(String::new(), vec![udp("10.9.9.9:53")]);
let view = view_with_routes(routes, vec![udp("8.8.8.8:53")], vec![]);
let buf = build_query(0x401, &["example", "com"], 1, 1);
match decide(&view, &buf).expect("decides") {
Decision::Forward { upstreams, .. } => {
assert_eq!(
upstreams,
vec!["8.8.8.8:53".parse().unwrap()],
"empty route suffix must not capture; falls through to global"
);
}
Decision::Reply(_) => panic!("expected forward to global resolver"),
}
}
fn udp_exit(addr: &str) -> DnsResolver {
DnsResolver {
transport: ts_control::ResolverTransport::Udp(addr.parse().unwrap()),
use_with_exit_node: true,
}
}
#[test]
fn recursive_forward_is_flagged_route_forward_is_not() {
let mut routes = std::collections::BTreeMap::new();
routes.insert("corp.example".to_string(), vec![udp("10.0.0.53:53")]);
let view = view_with_routes(routes, vec![udp("8.8.8.8:53")], vec![]);
let routed = build_query(0x500, &["api", "corp", "example"], 1, 1);
match decide(&view, &routed).expect("decides") {
Decision::Forward { recursive, .. } => {
assert!(!recursive, "split-DNS route is not a recursive forward")
}
Decision::Reply(_) => panic!("expected route forward"),
}
let global = build_query(0x501, &["example", "com"], 1, 1);
match decide(&view, &global).expect("decides") {
Decision::Forward { recursive, .. } => {
assert!(recursive, "unrouted name is a recursive forward")
}
Decision::Reply(_) => panic!("expected recursive forward"),
}
}
#[test]
fn recursive_plan_keeps_udp_without_exit_node() {
let view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp("8.8.8.8:53")],
vec![],
);
let default = vec!["8.8.8.8:53".parse().unwrap()];
assert_eq!(
recursive_plan(&view, default.clone()),
RecursivePlan::Udp(default)
);
}
#[test]
fn recursive_plan_delegates_to_doh_with_exit_node() {
let mut view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp("8.8.8.8:53")],
vec![],
);
let doh: SocketAddr = "100.64.0.5:8080".parse().unwrap();
view.exit_doh = Some(doh);
assert_eq!(
recursive_plan(&view, vec!["8.8.8.8:53".parse().unwrap()]),
RecursivePlan::Doh(doh)
);
}
#[test]
fn recursive_plan_keeps_use_with_exit_node_resolvers_local() {
let mut view = view_with_routes(
std::collections::BTreeMap::new(),
vec![udp_exit("10.0.0.53:53"), udp("8.8.8.8:53")],
vec![],
);
view.exit_doh = Some("100.64.0.5:8080".parse().unwrap());
assert_eq!(
recursive_plan(&view, vec!["8.8.8.8:53".parse().unwrap()]),
RecursivePlan::Udp(vec!["10.0.0.53:53".parse().unwrap()])
);
}
}