use std::collections::HashMap;
use std::fs;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::time::{Duration, Instant};
use crate::error::{NetError, Result};
pub const DNS_PORT: u16 = 53;
pub const DEFAULT_UPSTREAM: &[Ipv4Addr] = &[
Ipv4Addr::new(8, 8, 8, 8), Ipv4Addr::new(1, 1, 1, 1), ];
pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(300);
#[derive(Debug, Clone)]
pub struct DnsConfig {
pub listen_addr: SocketAddr,
pub upstream: Vec<SocketAddr>,
pub cache_ttl: Duration,
pub local_domain: Option<String>,
}
impl Default for DnsConfig {
fn default() -> Self {
Self {
listen_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), DNS_PORT),
upstream: DEFAULT_UPSTREAM
.iter()
.map(|ip| SocketAddr::new(IpAddr::V4(*ip), DNS_PORT))
.collect(),
cache_ttl: DEFAULT_CACHE_TTL,
local_domain: Some("arcbox.local".to_string()),
}
}
}
impl DnsConfig {
#[must_use]
pub fn new(listen_addr: Ipv4Addr) -> Self {
let mut config = Self {
listen_addr: SocketAddr::new(IpAddr::V4(listen_addr), DNS_PORT),
..Default::default()
};
let detected = detect_system_upstream();
if !detected.is_empty() {
config.upstream = detected;
}
config
}
#[must_use]
pub fn with_listen_addr(mut self, addr: SocketAddr) -> Self {
self.listen_addr = addr;
self
}
#[must_use]
pub fn with_upstream(mut self, servers: Vec<SocketAddr>) -> Self {
self.upstream = servers;
self
}
#[must_use]
pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
self.cache_ttl = ttl;
self
}
#[must_use]
pub fn with_local_domain(mut self, domain: impl Into<String>) -> Self {
self.local_domain = Some(domain.into());
self
}
}
fn parse_resolv_conf_nameservers(contents: &str) -> Vec<SocketAddr> {
let mut servers = Vec::new();
for line in contents.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
continue;
}
let mut parts = line.split_whitespace();
if parts.next() != Some("nameserver") {
continue;
}
let Some(raw_ip) = parts.next() else {
continue;
};
let Ok(ip) = raw_ip.parse::<IpAddr>() else {
continue;
};
let addr = SocketAddr::new(ip, DNS_PORT);
if !servers.contains(&addr) {
servers.push(addr);
}
}
servers
}
fn detect_system_upstream() -> Vec<SocketAddr> {
let Ok(contents) = fs::read_to_string("/etc/resolv.conf") else {
return Vec::new();
};
parse_resolv_conf_nameservers(&contents)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub enum DnsRecordType {
A = 1,
Aaaa = 28,
Cname = 5,
Ptr = 12,
Mx = 15,
Txt = 16,
Srv = 33,
}
impl TryFrom<u16> for DnsRecordType {
type Error = ();
fn try_from(value: u16) -> std::result::Result<Self, Self::Error> {
match value {
1 => Ok(Self::A),
28 => Ok(Self::Aaaa),
5 => Ok(Self::Cname),
12 => Ok(Self::Ptr),
15 => Ok(Self::Mx),
16 => Ok(Self::Txt),
33 => Ok(Self::Srv),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum DnsClass {
In = 1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum DnsResponseCode {
NoError = 0,
FormErr = 1,
ServFail = 2,
NxDomain = 3,
NotImp = 4,
Refused = 5,
}
#[derive(Debug, Clone)]
pub enum DnsRdata {
A(Ipv4Addr),
Aaaa(Ipv6Addr),
Cname(String),
Raw(Vec<u8>),
}
#[derive(Debug, Clone)]
pub struct DnsRecord {
pub name: String,
pub rtype: DnsRecordType,
pub class: DnsClass,
pub ttl: u32,
pub rdata: DnsRdata,
}
#[derive(Debug, Clone)]
struct CacheEntry {
records: Vec<DnsRecord>,
cached_at: Instant,
ttl: Duration,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.cached_at.elapsed() >= self.ttl
}
}
pub struct DnsForwarder {
config: DnsConfig,
local_hosts: HashMap<String, IpAddr>,
cache: HashMap<(String, DnsRecordType), CacheEntry>,
}
impl DnsForwarder {
#[must_use]
pub fn new(config: DnsConfig) -> Self {
Self {
config,
local_hosts: HashMap::new(),
cache: HashMap::new(),
}
}
#[must_use]
pub fn config(&self) -> &DnsConfig {
&self.config
}
pub fn add_local_host(&mut self, hostname: &str, ip: IpAddr) {
let hostname = hostname.to_lowercase();
self.local_hosts.insert(hostname.clone(), ip);
if let Some(ref domain) = self.config.local_domain {
let fqdn = format!("{}.{}", hostname, domain);
self.local_hosts.insert(fqdn, ip);
}
tracing::debug!("Added local host: {} -> {}", hostname, ip);
}
pub fn remove_local_host(&mut self, hostname: &str) {
let hostname = hostname.to_lowercase();
self.local_hosts.remove(&hostname);
if let Some(ref domain) = self.config.local_domain {
let fqdn = format!("{}.{}", hostname, domain);
self.local_hosts.remove(&fqdn);
}
}
#[must_use]
pub fn resolve_local(&self, hostname: &str) -> Option<IpAddr> {
let hostname = hostname.to_lowercase();
self.local_hosts.get(&hostname).copied()
}
pub fn try_resolve_locally(&self, data: &[u8]) -> Option<Vec<u8>> {
let query = DnsQuery::parse(data).ok()?;
let ip = self.resolve_local(&query.name)?;
self.build_local_response(&query, ip).ok()
}
pub fn try_resolve_locally_or_nxdomain(&self, data: &[u8]) -> Option<Vec<u8>> {
let query = DnsQuery::parse(data).ok()?;
if let Some(ip) = self.resolve_local(&query.name) {
return self.build_local_response(&query, ip).ok();
}
if let Some(ref domain) = self.config.local_domain {
let name_lower = query.name.to_lowercase();
if name_lower == *domain || name_lower.ends_with(&format!(".{domain}")) {
return Some(Self::build_nxdomain_response(&query));
}
}
None
}
fn build_nxdomain_response(query: &DnsQuery) -> Vec<u8> {
let mut response = Vec::with_capacity(query.raw_header.len() + query.raw_question.len());
response.extend_from_slice(&query.raw_header);
response[2] = 0x85;
response[3] = 0x83;
response[6] = 0x00;
response[7] = 0x00;
response[8] = 0x00;
response[9] = 0x00;
response[10] = 0x00;
response[11] = 0x00;
response.extend_from_slice(&query.raw_question);
response
}
#[must_use]
pub fn upstream(&self) -> &[SocketAddr] {
&self.config.upstream
}
pub fn handle_query(&mut self, data: &[u8]) -> Result<Vec<u8>> {
let query = DnsQuery::parse(data)?;
if let Some(ip) = self.resolve_local(&query.name) {
return self.build_local_response(&query, ip);
}
if let Some(cached) = self.check_cache(&query.name, query.qtype) {
return self.build_cached_response(&query, &cached);
}
self.forward_query(data)
}
fn check_cache(&mut self, name: &str, qtype: DnsRecordType) -> Option<Vec<DnsRecord>> {
let key = (name.to_lowercase(), qtype);
self.cache.retain(|_, v| !v.is_expired());
self.cache.get(&key).map(|e| e.records.clone())
}
fn build_local_response(&self, query: &DnsQuery, ip: IpAddr) -> Result<Vec<u8>> {
let mut response = Vec::with_capacity(512);
response.extend_from_slice(&query.raw_header);
response[2] = 0x81; response[3] = 0x80; response[6] = 0x00; response[7] = 0x01;
response.extend_from_slice(&query.raw_question);
response.extend_from_slice(&[0xc0, 0x0c]);
match ip {
IpAddr::V4(v4) => {
response.extend_from_slice(&[0x00, 0x01]);
response.extend_from_slice(&[0x00, 0x01]);
response.extend_from_slice(&[0x00, 0x00, 0x01, 0x2c]);
response.extend_from_slice(&[0x00, 0x04]);
response.extend_from_slice(&v4.octets());
}
IpAddr::V6(v6) => {
response.extend_from_slice(&[0x00, 0x1c]);
response.extend_from_slice(&[0x00, 0x01]);
response.extend_from_slice(&[0x00, 0x00, 0x01, 0x2c]);
response.extend_from_slice(&[0x00, 0x10]);
response.extend_from_slice(&v6.octets());
}
}
Ok(response)
}
fn build_cached_response(&self, query: &DnsQuery, records: &[DnsRecord]) -> Result<Vec<u8>> {
let mut response = Vec::with_capacity(512);
response.extend_from_slice(&query.raw_header);
response[2] = 0x81;
response[3] = 0x80;
response[6] = 0x00;
response[7] = records.len() as u8;
response.extend_from_slice(&query.raw_question);
for record in records {
response.extend_from_slice(&[0xc0, 0x0c]);
response.extend_from_slice(&(record.rtype as u16).to_be_bytes());
response.extend_from_slice(&[0x00, 0x01]);
response.extend_from_slice(&record.ttl.to_be_bytes());
match &record.rdata {
DnsRdata::A(ip) => {
response.extend_from_slice(&[0x00, 0x04]);
response.extend_from_slice(&ip.octets());
}
DnsRdata::Aaaa(ip) => {
response.extend_from_slice(&[0x00, 0x10]);
response.extend_from_slice(&ip.octets());
}
DnsRdata::Raw(data) => {
response.extend_from_slice(&(data.len() as u16).to_be_bytes());
response.extend_from_slice(data);
}
DnsRdata::Cname(_) => {
}
}
}
Ok(response)
}
fn forward_query(&mut self, data: &[u8]) -> Result<Vec<u8>> {
use std::net::UdpSocket;
let socket = UdpSocket::bind("0.0.0.0:0")
.map_err(|e| NetError::Dns(format!("failed to bind socket: {}", e)))?;
socket
.set_read_timeout(Some(Duration::from_secs(2)))
.map_err(|e| NetError::Dns(format!("failed to set timeout: {}", e)))?;
for upstream in &self.config.upstream {
if socket.send_to(data, upstream).is_err() {
continue;
}
let mut buf = [0u8; 512];
match socket.recv_from(&mut buf) {
Ok((len, _)) => {
let response = buf[..len].to_vec();
if let Ok(query) = DnsQuery::parse(data) {
self.cache_response(&query.name, query.qtype, &response);
}
return Ok(response);
}
Err(_) => continue,
}
}
Err(NetError::Dns("all upstream servers failed".to_string()))
}
fn cache_response(&mut self, name: &str, qtype: DnsRecordType, _response: &[u8]) {
let key = (name.to_lowercase(), qtype);
let entry = CacheEntry {
records: Vec::new(), cached_at: Instant::now(),
ttl: self.config.cache_ttl,
};
self.cache.insert(key, entry);
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
#[derive(Debug)]
struct DnsQuery {
name: String,
qtype: DnsRecordType,
#[allow(dead_code)]
qclass: DnsClass,
raw_header: Vec<u8>,
raw_question: Vec<u8>,
}
impl DnsQuery {
fn parse(data: &[u8]) -> Result<Self> {
if data.len() < 12 {
return Err(NetError::Dns("query too short".to_string()));
}
let raw_header = data[..12].to_vec();
let mut offset = 12;
let mut name_parts = Vec::new();
while offset < data.len() {
let len = data[offset] as usize;
if len == 0 {
offset += 1;
break;
}
if offset + 1 + len > data.len() {
return Err(NetError::Dns("invalid name".to_string()));
}
let label = String::from_utf8_lossy(&data[offset + 1..offset + 1 + len]);
name_parts.push(label.to_string());
offset += 1 + len;
}
if offset + 4 > data.len() {
return Err(NetError::Dns("query truncated".to_string()));
}
let name = name_parts.join(".");
let qtype_raw = u16::from_be_bytes([data[offset], data[offset + 1]]);
let qclass_raw = u16::from_be_bytes([data[offset + 2], data[offset + 3]]);
let qtype = DnsRecordType::try_from(qtype_raw)
.map_err(|()| NetError::Dns(format!("unsupported query type: {}", qtype_raw)))?;
let qclass = if qclass_raw == 1 {
DnsClass::In
} else {
return Err(NetError::Dns(format!("unsupported class: {}", qclass_raw)));
};
let raw_question = data[12..offset + 4].to_vec();
Ok(Self {
name,
qtype,
qclass,
raw_header,
raw_question,
})
}
}
pub struct DnsServer {
listen_addr: Ipv4Addr,
#[allow(dead_code)]
upstream: Vec<Ipv4Addr>,
forwarder: DnsForwarder,
}
impl DnsServer {
#[must_use]
pub fn new(listen_addr: Ipv4Addr, upstream: Vec<Ipv4Addr>) -> Self {
let config = DnsConfig::new(listen_addr).with_upstream(
upstream
.iter()
.map(|ip| SocketAddr::new(IpAddr::V4(*ip), DNS_PORT))
.collect(),
);
Self {
listen_addr,
upstream,
forwarder: DnsForwarder::new(config),
}
}
#[must_use]
pub fn listen_addr(&self) -> Ipv4Addr {
self.listen_addr
}
pub fn add_host(&mut self, hostname: &str, ip: IpAddr) {
self.forwarder.add_local_host(hostname, ip);
}
#[must_use]
pub fn resolve(&self, hostname: &str) -> Option<Ipv4Addr> {
match self.forwarder.resolve_local(hostname) {
Some(IpAddr::V4(v4)) => Some(v4),
_ => None,
}
}
pub fn forwarder_mut(&mut self) -> &mut DnsForwarder {
&mut self.forwarder
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dns_config_default() {
let config = DnsConfig::default();
assert_eq!(config.listen_addr.port(), DNS_PORT);
assert!(!config.upstream.is_empty());
}
#[test]
fn test_dns_record_type_conversion() {
assert_eq!(DnsRecordType::try_from(1), Ok(DnsRecordType::A));
assert_eq!(DnsRecordType::try_from(28), Ok(DnsRecordType::Aaaa));
assert!(DnsRecordType::try_from(999).is_err());
}
#[test]
fn test_dns_forwarder_local_hosts() {
let config = DnsConfig::default();
let mut forwarder = DnsForwarder::new(config);
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 64, 10));
forwarder.add_local_host("myvm", ip);
assert_eq!(forwarder.resolve_local("myvm"), Some(ip));
assert_eq!(forwarder.resolve_local("MYVM"), Some(ip)); assert_eq!(forwarder.resolve_local("myvm.arcbox.local"), Some(ip));
forwarder.remove_local_host("myvm");
assert_eq!(forwarder.resolve_local("myvm"), None);
}
#[test]
fn test_dns_server_legacy() {
let server = DnsServer::new(
Ipv4Addr::new(192, 168, 64, 1),
vec![Ipv4Addr::new(8, 8, 8, 8)],
);
assert_eq!(server.listen_addr(), Ipv4Addr::new(192, 168, 64, 1));
}
fn build_test_query(name: &str) -> Vec<u8> {
let mut packet = Vec::with_capacity(64);
packet.extend_from_slice(&[0xAB, 0xCD, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00]);
packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
for label in name.split('.') {
packet.push(label.len() as u8);
packet.extend_from_slice(label.as_bytes());
}
packet.push(0x00); packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x01]); packet
}
#[test]
fn test_try_resolve_locally_or_nxdomain_returns_response_for_registered() {
let config = DnsConfig::default();
let mut forwarder = DnsForwarder::new(config);
let ip = IpAddr::V4(Ipv4Addr::new(172, 17, 0, 2));
forwarder.add_local_host("my-nginx", ip);
let query = build_test_query("my-nginx.arcbox.local");
let response = forwarder
.try_resolve_locally_or_nxdomain(&query)
.expect("should resolve registered host");
assert_eq!(response[2] & 0x80, 0x80, "QR bit");
assert_eq!(response[3] & 0x0F, 0, "RCODE=NoError");
assert_eq!(response[7], 1, "ANCOUNT=1");
}
#[test]
fn test_try_resolve_locally_or_nxdomain_returns_nxdomain() {
let config = DnsConfig::default();
let forwarder = DnsForwarder::new(config);
let query = build_test_query("nonexistent.arcbox.local");
let response = forwarder
.try_resolve_locally_or_nxdomain(&query)
.expect("should return NXDOMAIN for unregistered local host");
assert_eq!(response[2] & 0x80, 0x80, "QR bit");
assert_eq!(response[3] & 0x0F, 3, "RCODE=NXDOMAIN");
assert_eq!(response[7], 0, "ANCOUNT=0");
}
#[test]
fn test_try_resolve_locally_or_nxdomain_returns_none_for_external() {
let config = DnsConfig::default();
let forwarder = DnsForwarder::new(config);
let query = build_test_query("google.com");
let result = forwarder.try_resolve_locally_or_nxdomain(&query);
assert!(result.is_none(), "should return None for non-local domains");
}
#[test]
fn test_try_resolve_locally_or_nxdomain_bare_domain() {
let config = DnsConfig::default();
let forwarder = DnsForwarder::new(config);
let query = build_test_query("arcbox.local");
let response = forwarder
.try_resolve_locally_or_nxdomain(&query)
.expect("bare domain should return NXDOMAIN");
assert_eq!(response[3] & 0x0F, 3, "RCODE=NXDOMAIN");
}
#[test]
fn test_custom_domain_nxdomain() {
let config = DnsConfig::default().with_local_domain("myorg.test");
let mut forwarder = DnsForwarder::new(config);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5));
forwarder.add_local_host("web", ip);
let query = build_test_query("web.myorg.test");
let response = forwarder
.try_resolve_locally_or_nxdomain(&query)
.expect("should resolve registered host under custom domain");
assert_eq!(response[3] & 0x0F, 0, "RCODE=NoError");
assert_eq!(response[7], 1, "ANCOUNT=1");
let query = build_test_query("unknown.myorg.test");
let response = forwarder
.try_resolve_locally_or_nxdomain(&query)
.expect("should NXDOMAIN for unregistered custom-domain host");
assert_eq!(response[3] & 0x0F, 3, "RCODE=NXDOMAIN");
let query = build_test_query("something.arcbox.local");
assert!(
forwarder.try_resolve_locally_or_nxdomain(&query).is_none(),
"old default domain should not be handled after domain change"
);
}
#[test]
fn test_parse_resolv_conf_nameservers() {
let conf = r#"
# comment
nameserver 10.0.0.2
search local
nameserver 2001:4860:4860::8888
nameserver invalid
nameserver 10.0.0.2
"#;
let servers = parse_resolv_conf_nameservers(conf);
assert_eq!(
servers,
vec![
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), DNS_PORT),
SocketAddr::new(
IpAddr::V6("2001:4860:4860::8888".parse().unwrap()),
DNS_PORT
)
]
);
}
}