use hickory_client::client::{Client, SyncClient};
use hickory_client::udp::UdpClientConnection;
use hickory_server::authority::{Catalog, ZoneType};
use hickory_server::proto::rr::rdata::{A, AAAA};
use hickory_server::proto::rr::{DNSClass, LowerName, Name, RData, Record, RecordType};
use hickory_server::resolver::config::NameServerConfigGroup;
use hickory_server::server::ServerFuture;
use hickory_server::store::in_memory::InMemoryAuthority;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, UdpSocket};
use tokio::sync::RwLock;
pub const DEFAULT_DNS_PORT: u16 = 15353;
const STANDARD_DNS_PORT: u16 = 53;
const PUBLIC_FALLBACK_UPSTREAMS: [IpAddr; 2] = [
IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
];
pub(crate) const RESOLV_CONF_PATH: &str = "/etc/resolv.conf";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DnsConfig {
pub zone: String,
pub port: u16,
pub bind_addr: IpAddr,
#[serde(default)]
pub upstreams: Option<Vec<SocketAddr>>,
}
impl DnsConfig {
#[must_use]
pub fn new(zone: &str, bind_addr: IpAddr) -> Self {
Self {
zone: zone.to_string(),
port: DEFAULT_DNS_PORT,
bind_addr,
upstreams: None,
}
}
#[must_use]
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
#[must_use]
pub fn with_upstreams(mut self, upstreams: Vec<SocketAddr>) -> Self {
self.upstreams = Some(upstreams);
self
}
}
fn is_unusable_upstream(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => v4.is_loopback() || v4.is_unspecified(),
IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified(),
}
}
fn parse_resolv_conf(contents: &str) -> Vec<SocketAddr> {
let mut out: Vec<SocketAddr> = Vec::new();
for line in contents.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
continue;
}
let mut parts = line.split_whitespace();
if parts.next() != Some("nameserver") {
continue;
}
let Some(addr_str) = parts.next() else {
continue;
};
let addr_str = addr_str.split('%').next().unwrap_or(addr_str);
let Ok(ip) = IpAddr::from_str(addr_str) else {
continue;
};
if is_unusable_upstream(ip) {
continue;
}
let sock = SocketAddr::new(ip, STANDARD_DNS_PORT);
if !out.contains(&sock) {
out.push(sock);
}
}
out
}
pub(crate) fn resolve_upstreams(config: &DnsConfig, resolv_conf_path: &str) -> Vec<SocketAddr> {
if let Some(explicit) = &config.upstreams {
if !explicit.is_empty() {
tracing::debug!(
count = explicit.len(),
"using explicit overlay DNS upstreams from config (host detection skipped)",
);
return explicit.clone();
}
}
let detected = match std::fs::read_to_string(resolv_conf_path) {
Ok(contents) => parse_resolv_conf(&contents),
Err(e) => {
tracing::warn!(
path = resolv_conf_path,
error = %e,
"could not read host resolv.conf for overlay DNS upstream detection",
);
Vec::new()
}
};
if detected.is_empty() {
let fallback: Vec<SocketAddr> = PUBLIC_FALLBACK_UPSTREAMS
.iter()
.map(|ip| SocketAddr::new(*ip, STANDARD_DNS_PORT))
.collect();
tracing::warn!(
fallback = ?fallback,
"no usable host DNS upstreams found (resolv.conf empty, missing, or stub-only); \
falling back to public resolvers for overlay forwarding",
);
fallback
} else {
tracing::info!(
upstreams = ?detected,
"overlay DNS forwarding to host upstreams (loopback/stub filtered out)",
);
detected
}
}
pub(crate) fn build_forward_resolver(
upstreams: &[SocketAddr],
) -> Result<hickory_server::resolver::TokioAsyncResolver, DnsError> {
use hickory_server::resolver::config::{ResolverConfig, ResolverOpts};
if upstreams.is_empty() {
return Err(DnsError::Server("no upstreams for forward resolver".into()));
}
let mut group = NameServerConfigGroup::new();
let mut by_port: std::collections::BTreeMap<u16, Vec<IpAddr>> =
std::collections::BTreeMap::new();
for addr in upstreams {
by_port.entry(addr.port()).or_default().push(addr.ip());
}
for (port, ips) in by_port {
group.merge(NameServerConfigGroup::from_ips_clear(&ips, port, true));
}
let mut options = ResolverOpts::default();
options.timeout = Duration::from_secs(2);
options.attempts = 2;
options.preserve_intermediates = true;
let config = ResolverConfig::from_parts(None, vec![], group);
Ok(hickory_server::resolver::TokioAsyncResolver::tokio(
config, options,
))
}
struct ForwardingCatalog {
catalog: Catalog,
zone_origin: LowerName,
resolver: Option<Arc<hickory_server::resolver::TokioAsyncResolver>>,
}
impl ForwardingCatalog {
fn forward_answer_response<'a>(
request: &'a hickory_server::server::Request,
answers: &'a [Record],
) -> hickory_server::authority::MessageResponse<
'a,
'a,
std::slice::Iter<'a, Record>,
std::iter::Empty<&'a Record>,
std::iter::Empty<&'a Record>,
std::iter::Empty<&'a Record>,
> {
use hickory_server::authority::MessageResponseBuilder;
use hickory_server::proto::op::ResponseCode;
let mut header = hickory_server::proto::op::Header::response_from_request(request.header());
header.set_recursion_available(true);
header.set_response_code(ResponseCode::NoError);
header.set_authoritative(false);
MessageResponseBuilder::from_message_request(request).build(
header,
answers.iter(),
std::iter::empty(),
std::iter::empty(),
std::iter::empty(),
)
}
fn forward_code_response(
request: &hickory_server::server::Request,
code: hickory_server::proto::op::ResponseCode,
) -> hickory_server::authority::MessageResponse<
'_,
'_,
impl Iterator<Item = &Record> + Send,
impl Iterator<Item = &Record> + Send,
impl Iterator<Item = &Record> + Send,
impl Iterator<Item = &Record> + Send,
> {
use hickory_server::authority::MessageResponseBuilder;
MessageResponseBuilder::from_message_request(request).error_msg(request.header(), code)
}
async fn forward<R: hickory_server::server::ResponseHandler>(
&self,
resolver: &hickory_server::resolver::TokioAsyncResolver,
request: &hickory_server::server::Request,
mut response_handle: R,
) -> hickory_server::server::ResponseInfo {
use hickory_server::proto::op::ResponseCode;
use hickory_server::resolver::error::ResolveErrorKind;
let query = request.request_info().query;
let name = Name::from(query.name());
let rtype = query.query_type();
match resolver.lookup(name, rtype).await {
Ok(lookup) => {
let records: Vec<Record> = lookup.records().to_vec();
let response = Self::forward_answer_response(request, &records);
Self::send_or_servfail(&mut response_handle, response).await
}
Err(e) => {
let code = match e.kind() {
ResolveErrorKind::NoRecordsFound { response_code, .. }
if *response_code == ResponseCode::NXDomain =>
{
ResponseCode::NXDomain
}
ResolveErrorKind::NoRecordsFound { response_code, .. }
if *response_code == ResponseCode::NoError =>
{
ResponseCode::NoError
}
_ => {
tracing::debug!(error = %e, "overlay DNS upstream forward failed; SERVFAIL");
ResponseCode::ServFail
}
};
let response = Self::forward_code_response(request, code);
Self::send_or_servfail(&mut response_handle, response).await
}
}
}
async fn send_or_servfail<'a, R, A, N, S, D>(
response_handle: &mut R,
response: hickory_server::authority::MessageResponse<'_, 'a, A, N, S, D>,
) -> hickory_server::server::ResponseInfo
where
R: hickory_server::server::ResponseHandler,
A: Iterator<Item = &'a Record> + Send + 'a,
N: Iterator<Item = &'a Record> + Send + 'a,
S: Iterator<Item = &'a Record> + Send + 'a,
D: Iterator<Item = &'a Record> + Send + 'a,
{
match response_handle.send_response(response).await {
Ok(info) => info,
Err(e) => {
tracing::error!(error = %e, "failed to send overlay DNS forward response");
let mut header = hickory_server::proto::op::Header::new();
header.set_response_code(hickory_server::proto::op::ResponseCode::ServFail);
header.into()
}
}
}
}
#[async_trait::async_trait]
impl hickory_server::server::RequestHandler for ForwardingCatalog {
async fn handle_request<R: hickory_server::server::ResponseHandler>(
&self,
request: &hickory_server::server::Request,
response_handle: R,
) -> hickory_server::server::ResponseInfo {
let query_name = request.request_info().query.name().clone();
let is_overlay = self.zone_origin.zone_of(&query_name);
match (&self.resolver, is_overlay) {
(Some(resolver), false) => self.forward(resolver, request, response_handle).await,
_ => self.catalog.handle_request(request, response_handle).await,
}
}
}
#[must_use]
pub fn peer_hostname(ip: IpAddr) -> String {
match ip {
IpAddr::V4(v4) => {
let octets = v4.octets();
format!("node-{}-{}", octets[2], octets[3])
}
IpAddr::V6(v6) => {
let segments = v6.segments();
let last_segment = segments[7];
format!("node-{last_segment:04x}")
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum DnsError {
#[error("Invalid domain name: {0}")]
InvalidName(String),
#[error("DNS server error: {0}")]
Server(String),
#[error("DNS client error: {0}")]
Client(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Record not found: {0}")]
NotFound(String),
}
#[derive(Clone)]
pub struct DnsHandle {
authority: Arc<InMemoryAuthority>,
zone_origin: Name,
serial: Arc<RwLock<u32>>,
}
impl DnsHandle {
pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
let fqdn = if hostname.ends_with('.') {
Name::from_str(hostname)
.map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
} else {
let name = Name::from_str(hostname)
.map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
name.append_domain(&self.zone_origin)
.map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
};
let rdata = match ip {
IpAddr::V4(v4) => RData::A(A::from(v4)),
IpAddr::V6(v6) => RData::AAAA(AAAA::from(v6)),
};
let record = Record::from_rdata(fqdn, 300, rdata);
let serial = {
let mut s = self.serial.write().await;
let current = *s;
*s = s.wrapping_add(1);
current
};
self.authority.upsert(record, serial).await;
Ok(())
}
pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
let fqdn = if hostname.ends_with('.') {
Name::from_str(hostname)
.map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
} else {
let name = Name::from_str(hostname)
.map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
name.append_domain(&self.zone_origin)
.map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
};
let serial = {
let mut s = self.serial.write().await;
let current = *s;
*s = s.wrapping_add(1);
current
};
let a_record = Record::with(fqdn.clone(), RecordType::A, 0);
self.authority.upsert(a_record, serial).await;
let aaaa_record = Record::with(fqdn.clone(), RecordType::AAAA, 0);
self.authority.upsert(aaaa_record, serial).await;
Ok(true)
}
#[must_use]
pub fn zone_origin(&self) -> &Name {
&self.zone_origin
}
pub async fn lookup_a(&self, fqdn: &str) -> Option<IpAddr> {
use hickory_server::authority::{Authority, LookupOptions};
let name = Name::from_str(fqdn).ok()?;
let lower = LowerName::from(name);
let lookup = self
.authority
.lookup(&lower, RecordType::A, LookupOptions::default())
.await
.ok()?;
lookup.iter().find_map(|record| match record.data() {
Some(RData::A(a)) => Some(IpAddr::V4((*a).into())),
_ => None,
})
}
}
pub struct DnsServer {
listen_addr: SocketAddr,
authority: Arc<InMemoryAuthority>,
zone_origin: Name,
serial: Arc<RwLock<u32>>,
upstreams: Vec<SocketAddr>,
}
impl DnsServer {
pub fn new(listen_addr: SocketAddr, zone: &str) -> Result<Self, DnsError> {
let upstreams =
resolve_upstreams(&DnsConfig::new(zone, listen_addr.ip()), RESOLV_CONF_PATH);
Self::new_with_upstreams(listen_addr, zone, upstreams)
}
pub fn new_with_upstreams(
listen_addr: SocketAddr,
zone: &str,
upstreams: Vec<SocketAddr>,
) -> Result<Self, DnsError> {
let zone_origin =
Name::from_str(zone).map_err(|e| DnsError::InvalidName(format!("{zone}: {e}")))?;
let authority = Arc::new(InMemoryAuthority::empty(
zone_origin.clone(),
ZoneType::Primary,
false,
));
Ok(Self {
listen_addr,
authority,
zone_origin,
serial: Arc::new(RwLock::new(1)),
upstreams,
})
}
pub fn from_config(config: &DnsConfig) -> Result<Self, DnsError> {
let listen_addr = SocketAddr::new(config.bind_addr, config.port);
let upstreams = resolve_upstreams(config, RESOLV_CONF_PATH);
Self::new_with_upstreams(listen_addr, &config.zone, upstreams)
}
#[must_use]
pub fn upstreams(&self) -> &[SocketAddr] {
&self.upstreams
}
fn build_catalog(
zone_origin: Name,
authority: Arc<InMemoryAuthority>,
upstreams: &[SocketAddr],
) -> ForwardingCatalog {
let lower_origin = LowerName::from(zone_origin.clone());
let mut catalog = Catalog::new();
catalog.upsert(zone_origin.into(), Box::new(authority));
let resolver = if upstreams.is_empty() {
None
} else {
match build_forward_resolver(upstreams) {
Ok(r) => {
tracing::debug!(
upstreams = ?upstreams,
"overlay DNS forwarder ready for non-overlay queries",
);
Some(Arc::new(r))
}
Err(e) => {
tracing::error!(
error = %e,
"failed to build overlay DNS forwarder; non-overlay queries \
will be refused (overlay zone still served)",
);
None
}
}
};
ForwardingCatalog {
catalog,
zone_origin: lower_origin,
resolver,
}
}
#[must_use]
pub fn handle(&self) -> DnsHandle {
DnsHandle {
authority: Arc::clone(&self.authority),
zone_origin: self.zone_origin.clone(),
serial: Arc::clone(&self.serial),
}
}
pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
self.handle().add_record(hostname, ip).await
}
pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
self.handle().remove_record(hostname).await
}
#[allow(clippy::unused_async)]
pub async fn start(self) -> Result<DnsHandle, DnsError> {
let handle = self.handle();
let listen_addr = self.listen_addr;
let zone_origin = self.zone_origin.clone();
let authority = Arc::clone(&self.authority);
let upstreams = self.upstreams.clone();
tokio::spawn(async move {
if let Err(e) = Self::run_server(listen_addr, zone_origin, authority, upstreams).await {
tracing::error!("DNS server error: {}", e);
}
});
Ok(handle)
}
#[allow(clippy::unused_async)]
pub async fn start_background(&self) -> Result<DnsHandle, DnsError> {
let handle = self.handle();
let listen_addr = self.listen_addr;
let zone_origin = self.zone_origin.clone();
let authority = Arc::clone(&self.authority);
let upstreams = self.upstreams.clone();
tokio::spawn(async move {
if let Err(e) = Self::run_server(listen_addr, zone_origin, authority, upstreams).await {
tracing::error!("DNS server error: {}", e);
}
});
Ok(handle)
}
#[allow(clippy::unused_async)]
pub async fn bind_windows_fallback(&self, bind_ip: IpAddr) -> Result<DnsHandle, DnsError> {
self.bind_secondary(SocketAddr::new(bind_ip, 53)).await
}
#[allow(clippy::unused_async)]
pub async fn bind_secondary(&self, listen_addr: SocketAddr) -> Result<DnsHandle, DnsError> {
let handle = self.handle();
let zone_origin = self.zone_origin.clone();
let authority = Arc::clone(&self.authority);
let upstreams = self.upstreams.clone();
let udp_socket = UdpSocket::bind(listen_addr).await?;
let tcp_listener = TcpListener::bind(listen_addr).await?;
tokio::spawn(async move {
let catalog = Self::build_catalog(zone_origin, authority, &upstreams);
let mut server = ServerFuture::new(catalog);
server.register_socket(udp_socket);
server.register_listener(tcp_listener, Duration::from_secs(30));
tracing::info!(
addr = %listen_addr,
"secondary DNS listener started",
);
if let Err(e) = server.block_until_done().await {
tracing::error!("secondary DNS listener error: {}", e);
}
});
Ok(handle)
}
async fn run_server(
listen_addr: SocketAddr,
zone_origin: Name,
authority: Arc<InMemoryAuthority>,
upstreams: Vec<SocketAddr>,
) -> Result<(), DnsError> {
let catalog = Self::build_catalog(zone_origin, authority, &upstreams);
let mut server = ServerFuture::new(catalog);
let udp_socket = UdpSocket::bind(listen_addr).await?;
server.register_socket(udp_socket);
let tcp_listener = TcpListener::bind(listen_addr).await?;
server.register_listener(tcp_listener, Duration::from_secs(30));
tracing::info!(addr = %listen_addr, "DNS server listening");
server
.block_until_done()
.await
.map_err(|e| DnsError::Server(e.to_string()))?;
Ok(())
}
#[must_use]
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}
#[must_use]
pub fn zone_origin(&self) -> &Name {
&self.zone_origin
}
}
pub struct DnsClient {
server_addr: SocketAddr,
}
impl DnsClient {
#[must_use]
pub fn new(server_addr: SocketAddr) -> Self {
Self { server_addr }
}
pub fn query_a(&self, hostname: &str) -> Result<Option<Ipv4Addr>, DnsError> {
let name = Name::from_str(hostname)
.map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
let conn = UdpClientConnection::new(self.server_addr)
.map_err(|e| DnsError::Client(e.to_string()))?;
let client = SyncClient::new(conn);
let response = client
.query(&name, DNSClass::IN, RecordType::A)
.map_err(|e| DnsError::Client(e.to_string()))?;
for answer in response.answers() {
if let Some(RData::A(a_record)) = answer.data() {
return Ok(Some((*a_record).into()));
}
}
Ok(None)
}
pub fn query_aaaa(&self, hostname: &str) -> Result<Option<Ipv6Addr>, DnsError> {
let name = Name::from_str(hostname)
.map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
let conn = UdpClientConnection::new(self.server_addr)
.map_err(|e| DnsError::Client(e.to_string()))?;
let client = SyncClient::new(conn);
let response = client
.query(&name, DNSClass::IN, RecordType::AAAA)
.map_err(|e| DnsError::Client(e.to_string()))?;
for answer in response.answers() {
if let Some(RData::AAAA(aaaa_record)) = answer.data() {
return Ok(Some((*aaaa_record).into()));
}
}
Ok(None)
}
pub fn query_addr(&self, hostname: &str) -> Result<Option<IpAddr>, DnsError> {
if let Ok(Some(v4)) = self.query_a(hostname) {
return Ok(Some(IpAddr::V4(v4)));
}
if let Ok(Some(v6)) = self.query_aaaa(hostname) {
return Ok(Some(IpAddr::V6(v6)));
}
Ok(None)
}
}
pub struct ServiceDiscovery {
dns_server: SocketAddr,
records: RwLock<HashMap<String, IpAddr>>,
}
impl ServiceDiscovery {
#[must_use]
pub fn new(dns_server_addr: SocketAddr) -> Self {
Self {
dns_server: dns_server_addr,
records: RwLock::new(HashMap::new()),
}
}
pub async fn register(&self, name: &str, ip: IpAddr) {
let mut records = self.records.write().await;
records.insert(name.to_string(), ip);
}
pub async fn resolve(&self, name: &str) -> Option<IpAddr> {
{
let records = self.records.read().await;
if let Some(ip) = records.get(name) {
return Some(*ip);
}
}
let client = DnsClient::new(self.dns_server);
if let Ok(Some(addr)) = client.query_addr(name) {
return Some(addr);
}
None
}
pub async fn unregister(&self, name: &str) {
let mut records = self.records.write().await;
records.remove(name);
}
pub async fn list_services(&self) -> Vec<String> {
let records = self.records.read().await;
records.keys().cloned().collect()
}
pub fn dns_server(&self) -> SocketAddr {
self.dns_server
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_hostname_v4() {
assert_eq!(
peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1))),
"node-0-1"
);
assert_eq!(
peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 5))),
"node-0-5"
);
assert_eq!(
peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 1, 100))),
"node-1-100"
);
assert_eq!(
peer_hostname(IpAddr::V4(Ipv4Addr::new(192, 168, 255, 254))),
"node-255-254"
);
}
#[test]
fn test_peer_hostname_v6() {
assert_eq!(
peer_hostname(IpAddr::V6("fd00::1".parse().unwrap())),
"node-0001"
);
assert_eq!(
peer_hostname(IpAddr::V6("fd00::abcd".parse().unwrap())),
"node-abcd"
);
assert_eq!(
peer_hostname(IpAddr::V6("fd00:200::ffff".parse().unwrap())),
"node-ffff"
);
assert_eq!(
peer_hostname(IpAddr::V6("fd00::1:0".parse().unwrap())),
"node-0000"
);
}
#[test]
fn test_dns_config() {
let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
assert_eq!(config.zone, "overlay.local.");
assert_eq!(config.port, DEFAULT_DNS_PORT);
assert_eq!(config.bind_addr, IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
let config = config.with_port(5353);
assert_eq!(config.port, 5353);
}
#[test]
fn test_dns_config_serialization() {
let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)))
.with_port(15353);
let json = serde_json::to_string(&config).unwrap();
let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.zone, config.zone);
assert_eq!(deserialized.port, config.port);
assert_eq!(deserialized.bind_addr, config.bind_addr);
}
#[tokio::test]
async fn test_service_discovery_local_cache() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let discovery = ServiceDiscovery::new(addr);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
discovery.register("test-service", ip).await;
let resolved = discovery.resolve("test-service").await;
assert_eq!(resolved, Some(ip));
discovery.unregister("test-service").await;
let services = discovery.list_services().await;
assert!(services.is_empty());
}
#[test]
fn test_dns_server_creation() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let server = DnsServer::new(addr, "overlay.local.");
assert!(server.is_ok());
let server = server.unwrap();
assert_eq!(server.listen_addr(), addr);
assert_eq!(server.zone_origin().to_string(), "overlay.local.");
}
#[test]
fn test_dns_server_from_config() {
let config =
DnsConfig::new("test.local.", IpAddr::V4(Ipv4Addr::LOCALHOST)).with_port(15353);
let server = DnsServer::from_config(&config);
assert!(server.is_ok());
let server = server.unwrap();
assert_eq!(server.listen_addr().port(), 15353);
assert_eq!(server.zone_origin().to_string(), "test.local.");
}
#[test]
fn test_dns_server_invalid_zone() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let server = DnsServer::new(addr, "overlay.local.");
assert!(server.is_ok());
}
#[tokio::test]
async fn test_dns_server_add_record() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let server = DnsServer::new(addr, "overlay.local.").unwrap();
let result = server
.add_record("myservice", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_dns_handle_add_record() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let server = DnsServer::new(addr, "overlay.local.").unwrap();
let handle = server.handle();
let result = handle
.add_record("service1", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
.await;
assert!(result.is_ok());
let result = handle
.add_record("service2", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)))
.await;
assert!(result.is_ok());
assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
}
#[test]
fn test_dns_client_creation() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53);
let client = DnsClient::new(addr);
assert_eq!(client.server_addr, addr);
}
#[tokio::test]
async fn test_dns_handle_add_aaaa_record() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let server = DnsServer::new(addr, "overlay.local.").unwrap();
let handle = server.handle();
let ipv6: IpAddr = "fd00::1".parse().unwrap();
let result = handle.add_record("service-v6", ipv6).await;
assert!(result.is_ok());
let ipv6_2: IpAddr = "fd00::abcd".parse().unwrap();
let result = handle.add_record("service-v6-2", ipv6_2).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_dns_server_add_aaaa_record() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let server = DnsServer::new(addr, "overlay.local.").unwrap();
let ipv6: IpAddr = "fd00::42".parse().unwrap();
let result = server.add_record("myservice-v6", ipv6).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_dns_handle_remove_record_covers_both_types() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let server = DnsServer::new(addr, "overlay.local.").unwrap();
let handle = server.handle();
let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
handle.add_record("dual-service", ipv4).await.unwrap();
let removed = handle.remove_record("dual-service").await.unwrap();
assert!(removed);
let ipv6: IpAddr = "fd00::1".parse().unwrap();
handle.add_record("v6-service", ipv6).await.unwrap();
let removed = handle.remove_record("v6-service").await.unwrap();
assert!(removed);
}
#[tokio::test]
async fn test_service_discovery_local_cache_ipv6() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let discovery = ServiceDiscovery::new(addr);
let ipv6: IpAddr = "fd00::beef".parse().unwrap();
discovery.register("v6-service", ipv6).await;
let resolved = discovery.resolve("v6-service").await;
assert_eq!(resolved, Some(ipv6));
discovery.unregister("v6-service").await;
let services = discovery.list_services().await;
assert!(services.is_empty());
}
#[tokio::test]
async fn test_service_discovery_mixed_v4_v6_cache() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
let discovery = ServiceDiscovery::new(addr);
let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
let ipv6: IpAddr = "fd00::1".parse().unwrap();
discovery.register("svc-v4", ipv4).await;
discovery.register("svc-v6", ipv6).await;
assert_eq!(discovery.resolve("svc-v4").await, Some(ipv4));
assert_eq!(discovery.resolve("svc-v6").await, Some(ipv6));
let mut services = discovery.list_services().await;
services.sort();
assert_eq!(services, vec!["svc-v4", "svc-v6"]);
}
#[test]
fn test_dns_config_with_ipv6_bind_addr() {
let ipv6_bind: IpAddr = "fd00::1".parse().unwrap();
let config = DnsConfig::new("overlay.local.", ipv6_bind);
assert_eq!(config.bind_addr, ipv6_bind);
assert_eq!(config.port, DEFAULT_DNS_PORT);
let json = serde_json::to_string(&config).unwrap();
let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.bind_addr, ipv6_bind);
}
#[test]
fn test_dns_server_creation_ipv6_bind() {
let ipv6_addr: IpAddr = "::1".parse().unwrap();
let addr = SocketAddr::new(ipv6_addr, 15353);
let server = DnsServer::new(addr, "overlay.local.");
assert!(server.is_ok());
let server = server.unwrap();
assert_eq!(server.listen_addr(), addr);
}
#[tokio::test]
async fn test_bind_windows_fallback_errors_or_shares_authority() {
let primary = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let server = DnsServer::new(primary, "overlay.local.").unwrap();
let bind_ip: IpAddr = "127.0.0.2".parse().unwrap();
match server.bind_windows_fallback(bind_ip).await {
Ok(handle) => {
assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
handle
.add_record("dual", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
.await
.expect("add_record via fallback handle");
}
Err(DnsError::Io(_)) => {
}
Err(other) => panic!("unexpected error from bind_windows_fallback: {other}"),
}
}
#[test]
fn test_peer_hostname_uniqueness() {
let v4_a = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
let v4_b = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)));
assert_ne!(v4_a, v4_b);
let v6_a = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
let v6_b = peer_hostname(IpAddr::V6("fd00::2".parse().unwrap()));
assert_ne!(v6_a, v6_b);
let v4 = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
let v6 = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
assert_ne!(v4, v6);
}
#[test]
fn test_parse_resolv_conf_filters_stub_and_loopback() {
let contents = "\
# generated by netbird\n\
nameserver 127.0.0.53\n\
nameserver 127.0.0.1\n\
nameserver 192.168.1.1\n\
search example.com\n\
options edns0\n";
let parsed = parse_resolv_conf(contents);
assert_eq!(
parsed,
vec![SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
53
)],
"127.0.0.53 stub and 127.0.0.1 loopback must be filtered out",
);
}
#[test]
fn test_parse_resolv_conf_dedup_and_comments() {
let contents = "\
; a comment\n\
nameserver 8.8.8.8\n\
nameserver 8.8.8.8\n\
nameserver fe80::1%eth0\n\
nameserver 0.0.0.0\n";
let parsed = parse_resolv_conf(contents);
assert_eq!(parsed.len(), 2);
assert_eq!(
parsed[0],
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)
);
assert_eq!(parsed[1].ip(), "fe80::1".parse::<IpAddr>().unwrap());
}
#[test]
fn test_resolve_upstreams_config_override_wins() {
let explicit = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 9, 9, 9)), 5300);
let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST))
.with_upstreams(vec![explicit]);
let resolved = resolve_upstreams(&config, "/nonexistent/resolv.conf");
assert_eq!(resolved, vec![explicit]);
}
#[test]
fn test_resolve_upstreams_falls_back_to_public_when_missing() {
let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST));
let resolved = resolve_upstreams(&config, "/definitely/not/a/real/resolv.conf");
assert_eq!(
resolved,
vec![
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
],
);
}
async fn spawn_stub_upstream(answer_ip: Ipv4Addr) -> SocketAddr {
use hickory_server::proto::op::{Message, MessageType, ResponseCode};
let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.await
.expect("bind stub upstream");
let addr = sock.local_addr().expect("stub local_addr");
tokio::spawn(async move {
let mut buf = vec![0u8; 1500];
loop {
let Ok((len, from)) = sock.recv_from(&mut buf).await else {
break;
};
let Ok(request) = Message::from_vec(&buf[..len]) else {
continue;
};
let mut resp = Message::new();
resp.set_id(request.id());
resp.set_message_type(MessageType::Response);
resp.set_recursion_available(true);
resp.set_response_code(ResponseCode::NoError);
for q in request.queries() {
resp.add_query(q.clone());
if q.query_type() == RecordType::A {
let rec =
Record::from_rdata(q.name().clone(), 60, RData::A(A::from(answer_ip)));
resp.add_answer(rec);
}
}
if let Ok(bytes) = resp.to_vec() {
let _ = sock.send_to(&bytes, from).await;
}
}
});
addr
}
async fn raw_query_a(
server: SocketAddr,
name: &str,
) -> Result<Option<Ipv4Addr>, hickory_server::proto::op::ResponseCode> {
use hickory_server::proto::op::{Message, MessageType, Query, ResponseCode};
let client = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.await
.expect("bind client");
let qname = Name::from_str(name).expect("query name");
let mut msg = Message::new();
msg.set_id(0x1234);
msg.set_message_type(MessageType::Query);
msg.set_recursion_desired(true);
msg.add_query(Query::query(qname, RecordType::A));
let bytes = msg.to_vec().expect("encode query");
client.send_to(&bytes, server).await.expect("send query");
let mut buf = vec![0u8; 1500];
let len = tokio::time::timeout(Duration::from_secs(12), client.recv(&mut buf))
.await
.expect("query timed out")
.expect("recv response");
let resp = Message::from_vec(&buf[..len]).expect("decode response");
if resp.response_code() != ResponseCode::NoError {
return Err(resp.response_code());
}
for ans in resp.answers() {
if let Some(RData::A(a)) = ans.data() {
return Ok(Some((*a).into()));
}
}
Ok(None)
}
#[tokio::test]
async fn test_forwarding_overlay_answered_and_nonoverlay_forwarded() {
let upstream_answer = Ipv4Addr::new(203, 0, 113, 7);
let upstream = spawn_stub_upstream(upstream_answer).await;
let bound = {
let probe = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.await
.unwrap();
let a = probe.local_addr().unwrap();
drop(probe);
a
};
let overlay_ip = Ipv4Addr::new(10, 200, 0, 5);
let server =
DnsServer::new_with_upstreams(bound, "overlay.local.", vec![upstream]).unwrap();
let handle = server.handle();
handle
.add_record("svc", IpAddr::V4(overlay_ip))
.await
.unwrap();
let _running = server.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
let overlay = raw_query_a(bound, "svc.overlay.local.")
.await
.expect("overlay query should not SERVFAIL");
assert_eq!(
overlay,
Some(overlay_ip),
"overlay name must be answered from InMemoryAuthority",
);
let forwarded = raw_query_a(bound, "example.com.")
.await
.expect("forwarded query should not SERVFAIL");
assert_eq!(
forwarded,
Some(upstream_answer),
"non-overlay name must be forwarded to the upstream stub",
);
}
#[tokio::test]
async fn test_forwarding_total_upstream_failure_is_servfail_not_panic() {
use hickory_server::proto::op::ResponseCode;
let dead_upstream = {
let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.await
.unwrap();
let a = s.local_addr().unwrap();
drop(s);
a
};
let bound = {
let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.await
.unwrap();
let a = s.local_addr().unwrap();
drop(s);
a
};
let server =
DnsServer::new_with_upstreams(bound, "overlay.local.", vec![dead_upstream]).unwrap();
let handle = server.handle();
handle
.add_record("svc", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 9)))
.await
.unwrap();
let _running = server.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
let overlay = raw_query_a(bound, "svc.overlay.local.")
.await
.expect("overlay query should still succeed");
assert_eq!(overlay, Some(Ipv4Addr::new(10, 200, 0, 9)));
match raw_query_a(bound, "example.com.").await {
Err(ResponseCode::ServFail) => {} Err(other) => panic!("expected SERVFAIL, got {other:?}"),
Ok(answer) => panic!("expected SERVFAIL, got answer {answer:?}"),
}
}
}