use std::collections::HashMap;
use std::io;
use std::net::{Ipv4Addr, SocketAddr, UdpSocket};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
const DNS_TYPE_A: u16 = 1;
const DNS_CLASS_IN: u16 = 1;
const DNS_FLAG_QR: u16 = 0x8000; const DNS_FLAG_AA: u16 = 0x0400; const DNS_FLAG_RD: u16 = 0x0100; const DNS_FLAG_RA: u16 = 0x0080; const DNS_RCODE_NXDOMAIN: u16 = 3;
const DNS_RCODE_SERVFAIL: u16 = 2;
struct DnsQuery {
id: u16,
qname: String,
qtype: u16,
qclass: u16,
raw: Vec<u8>,
}
fn parse_dns_query(buf: &[u8]) -> Option<DnsQuery> {
if buf.len() < 12 {
return None;
}
let id = u16::from_be_bytes([buf[0], buf[1]]);
let flags = u16::from_be_bytes([buf[2], buf[3]]);
if flags & DNS_FLAG_QR != 0 {
return None;
}
let qdcount = u16::from_be_bytes([buf[4], buf[5]]);
if qdcount < 1 {
return None;
}
let (qname, pos) = parse_qname(buf, 12)?;
if pos + 4 > buf.len() {
return None;
}
let qtype = u16::from_be_bytes([buf[pos], buf[pos + 1]]);
let qclass = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]);
Some(DnsQuery {
id,
qname,
qtype,
qclass,
raw: buf.to_vec(),
})
}
fn parse_qname(buf: &[u8], mut offset: usize) -> Option<(String, usize)> {
let mut labels = Vec::new();
loop {
if offset >= buf.len() {
return None;
}
let len = buf[offset] as usize;
if len == 0 {
offset += 1;
break;
}
if len & 0xC0 != 0 {
return None;
}
offset += 1;
if offset + len > buf.len() {
return None;
}
let label = std::str::from_utf8(&buf[offset..offset + len]).ok()?;
labels.push(label.to_ascii_lowercase());
offset += len;
}
Some((labels.join("."), offset))
}
fn build_a_response(query: &DnsQuery, ip: Ipv4Addr) -> Vec<u8> {
let mut resp = Vec::with_capacity(64);
resp.extend_from_slice(&query.id.to_be_bytes());
let flags: u16 = DNS_FLAG_QR | DNS_FLAG_AA | DNS_FLAG_RD | DNS_FLAG_RA;
resp.extend_from_slice(&flags.to_be_bytes());
resp.extend_from_slice(&1u16.to_be_bytes()); resp.extend_from_slice(&1u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes());
encode_qname(&mut resp, &query.qname);
resp.extend_from_slice(&query.qtype.to_be_bytes());
resp.extend_from_slice(&query.qclass.to_be_bytes());
resp.extend_from_slice(&[0xC0, 0x0C]); resp.extend_from_slice(&DNS_TYPE_A.to_be_bytes());
resp.extend_from_slice(&DNS_CLASS_IN.to_be_bytes());
resp.extend_from_slice(&10u32.to_be_bytes()); resp.extend_from_slice(&4u16.to_be_bytes()); resp.extend_from_slice(&ip.octets());
resp
}
fn build_nodata(query: &DnsQuery) -> Vec<u8> {
let mut resp = Vec::with_capacity(32);
resp.extend_from_slice(&query.id.to_be_bytes());
let flags: u16 = DNS_FLAG_QR | DNS_FLAG_AA | DNS_FLAG_RD | DNS_FLAG_RA;
resp.extend_from_slice(&flags.to_be_bytes());
resp.extend_from_slice(&1u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes());
encode_qname(&mut resp, &query.qname);
resp.extend_from_slice(&query.qtype.to_be_bytes());
resp.extend_from_slice(&query.qclass.to_be_bytes());
resp
}
fn build_servfail(query: &DnsQuery) -> Vec<u8> {
let mut resp = Vec::with_capacity(32);
resp.extend_from_slice(&query.id.to_be_bytes());
let flags: u16 = DNS_FLAG_QR | DNS_FLAG_RD | DNS_FLAG_RA | DNS_RCODE_SERVFAIL;
resp.extend_from_slice(&flags.to_be_bytes());
resp.extend_from_slice(&1u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); encode_qname(&mut resp, &query.qname);
resp.extend_from_slice(&query.qtype.to_be_bytes());
resp.extend_from_slice(&query.qclass.to_be_bytes());
resp
}
fn build_nxdomain(query: &DnsQuery) -> Vec<u8> {
let mut resp = Vec::with_capacity(32);
resp.extend_from_slice(&query.id.to_be_bytes());
let flags: u16 = DNS_FLAG_QR | DNS_FLAG_AA | DNS_FLAG_RD | DNS_FLAG_RA | DNS_RCODE_NXDOMAIN;
resp.extend_from_slice(&flags.to_be_bytes());
resp.extend_from_slice(&1u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes()); resp.extend_from_slice(&0u16.to_be_bytes());
encode_qname(&mut resp, &query.qname);
resp.extend_from_slice(&query.qtype.to_be_bytes());
resp.extend_from_slice(&query.qclass.to_be_bytes());
resp
}
fn encode_qname(buf: &mut Vec<u8>, name: &str) {
for label in name.split('.') {
if label.is_empty() {
continue;
}
buf.push(label.len() as u8);
buf.extend_from_slice(label.as_bytes());
}
buf.push(0); }
struct NetworkConfig {
listen_ip: Ipv4Addr,
upstream: Vec<Ipv4Addr>,
entries: HashMap<String, Ipv4Addr>,
}
fn parse_network_config(content: &str) -> Option<NetworkConfig> {
let mut lines = content.lines().filter(|l| !l.trim().is_empty());
let header = lines.next()?;
let mut parts = header.split_whitespace();
let listen_ip: Ipv4Addr = parts.next()?.parse().ok()?;
let upstream_str = parts.next().unwrap_or("8.8.8.8,1.1.1.1");
let upstream: Vec<Ipv4Addr> = upstream_str
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
let mut entries = HashMap::new();
for line in lines {
let mut parts = line.split_whitespace();
if let (Some(name), Some(ip_str)) = (parts.next(), parts.next()) {
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
entries.insert(name.to_string(), ip);
}
}
}
Some(NetworkConfig {
listen_ip,
upstream,
entries,
})
}
struct ListenSocket {
socket: UdpSocket,
gateway_ip: Ipv4Addr,
}
struct ServerState {
config_dir: PathBuf,
configs: HashMap<String, NetworkConfig>,
sockets: Vec<ListenSocket>,
}
impl ServerState {
fn new(config_dir: PathBuf) -> Self {
ServerState {
config_dir,
configs: HashMap::new(),
sockets: Vec::new(),
}
}
fn reload(&mut self) {
self.configs.clear();
let entries = match std::fs::read_dir(&self.config_dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if !path.is_file() {
continue;
}
let name = match path.file_name().and_then(|n| n.to_str()) {
Some(n) if n != "pid" => n.to_string(),
_ => continue,
};
if let Ok(content) = std::fs::read_to_string(&path) {
if let Some(config) = parse_network_config(&content) {
if !config.entries.is_empty() {
self.configs.insert(name, config);
}
}
}
}
self.rebind_sockets();
}
fn rebind_sockets(&mut self) {
let desired: HashMap<Ipv4Addr, &str> = self
.configs
.iter()
.map(|(name, cfg)| (cfg.listen_ip, name.as_str()))
.collect();
self.sockets.retain(|s| desired.contains_key(&s.gateway_ip));
let bound: Vec<Ipv4Addr> = self.sockets.iter().map(|s| s.gateway_ip).collect();
for &ip in desired.keys() {
if bound.contains(&ip) {
continue;
}
let bind_addr = SocketAddr::new(std::net::IpAddr::V4(ip), 53);
match UdpSocket::bind(bind_addr) {
Ok(sock) => {
let _ = sock.set_nonblocking(true);
self.sockets.push(ListenSocket {
socket: sock,
gateway_ip: ip,
});
eprintln!("pelagos-dns: listening on {}", bind_addr);
}
Err(e) => {
eprintln!("pelagos-dns: failed to bind {}: {}", bind_addr, e);
}
}
}
}
fn has_entries(&self) -> bool {
self.configs.values().any(|c| !c.entries.is_empty())
}
fn lookup(&self, gateway: Ipv4Addr, name: &str) -> Option<Ipv4Addr> {
self.configs
.values()
.filter(|cfg| cfg.listen_ip == gateway)
.find_map(|cfg| cfg.entries.get(name).copied())
}
fn upstream(&self, gateway: Ipv4Addr) -> Vec<Ipv4Addr> {
self.configs
.values()
.find(|cfg| cfg.listen_ip == gateway)
.map(|c| c.upstream.clone())
.unwrap_or_default()
}
}
fn forward_upstream(raw: &[u8], upstream: &[Ipv4Addr]) -> Option<Vec<u8>> {
for &server in upstream {
let addr = SocketAddr::new(std::net::IpAddr::V4(server), 53);
let sock = match UdpSocket::bind("0.0.0.0:0") {
Ok(s) => s,
Err(_) => continue,
};
let _ = sock.set_read_timeout(Some(Duration::from_secs(3)));
if sock.send_to(raw, addr).is_err() {
continue;
}
let mut buf = [0u8; 4096];
let result = loop {
match sock.recv_from(&mut buf) {
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
other => break other,
}
};
match result {
Ok((n, _)) => return Some(buf[..n].to_vec()),
Err(_) => continue,
}
}
None
}
fn install_sighup_handler(flag: Arc<AtomicBool>) {
unsafe {
RELOAD_FLAG.store(Arc::into_raw(flag) as *mut bool as usize, Ordering::SeqCst);
libc::signal(
libc::SIGHUP,
sighup_handler as *const () as libc::sighandler_t,
);
}
}
static RELOAD_FLAG: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
extern "C" fn sighup_handler(_sig: libc::c_int) {
let ptr = RELOAD_FLAG.load(Ordering::SeqCst);
if ptr != 0 {
let flag = unsafe { &*(ptr as *const AtomicBool) };
flag.store(true, Ordering::SeqCst);
}
}
fn main() {
let args: Vec<String> = std::env::args().collect();
let config_dir = if let Some(pos) = args.iter().position(|a| a == "--config-dir") {
args.get(pos + 1).map(PathBuf::from).unwrap_or_else(|| {
eprintln!("pelagos-dns: --config-dir requires a path argument");
std::process::exit(1);
})
} else {
eprintln!("Usage: pelagos-dns --config-dir <dir>");
std::process::exit(1);
};
let pid_file = config_dir.join("pid");
if let Err(e) = std::fs::write(&pid_file, format!("{}", unsafe { libc::getpid() })) {
eprintln!("pelagos-dns: failed to write PID file: {}", e);
std::process::exit(1);
}
let reload_flag = Arc::new(AtomicBool::new(false));
install_sighup_handler(reload_flag.clone());
let mut state = ServerState::new(config_dir.clone());
state.reload();
if !state.has_entries() {
eprintln!("pelagos-dns: no entries found, exiting");
let _ = std::fs::remove_file(&pid_file);
return;
}
eprintln!(
"pelagos-dns: started with {} network(s)",
state.configs.len()
);
let mut buf = [0u8; 4096];
loop {
if reload_flag.swap(false, Ordering::SeqCst) {
state.reload();
if !state.has_entries() {
eprintln!("pelagos-dns: no entries remaining, exiting");
break;
}
eprintln!("pelagos-dns: reloaded, {} network(s)", state.configs.len());
}
let mut activity = false;
for i in 0..state.sockets.len() {
let recv = state.sockets[i].socket.recv_from(&mut buf);
match recv {
Ok((n, src)) => {
activity = true;
let gateway_ip = state.sockets[i].gateway_ip;
if let Some(response) = handle_query(&buf[..n], gateway_ip, &state) {
let _ = state.sockets[i].socket.send_to(&response, src);
}
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(_) => {}
}
}
if !activity {
std::thread::sleep(Duration::from_millis(50));
}
}
let _ = std::fs::remove_file(&pid_file);
eprintln!("pelagos-dns: stopped");
}
fn handle_query(packet: &[u8], gateway: Ipv4Addr, state: &ServerState) -> Option<Vec<u8>> {
let query = parse_dns_query(packet)?;
let name = query.qname.strip_suffix(".pelagos").unwrap_or(&query.qname);
let name = name.strip_suffix('.').unwrap_or(name);
if !name.contains('.') {
if query.qtype == DNS_TYPE_A && query.qclass == DNS_CLASS_IN {
if let Some(ip) = state.lookup(gateway, name) {
return Some(build_a_response(&query, ip));
}
}
if state.lookup(gateway, name).is_some() {
return Some(build_nodata(&query));
}
return Some(build_nxdomain(&query));
}
if query.qtype == DNS_TYPE_A && query.qclass == DNS_CLASS_IN {
if let Some(ip) = state.lookup(gateway, name) {
return Some(build_a_response(&query, ip));
}
}
let upstream = state.upstream(gateway);
if upstream.is_empty() {
return Some(build_nxdomain(&query));
}
Some(forward_upstream(&query.raw, &upstream).unwrap_or_else(|| build_servfail(&query)))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_a_query(name: &str) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&[0x12, 0x34]); buf.extend_from_slice(&[0x01, 0x00]); buf.extend_from_slice(&[0x00, 0x01]); buf.extend_from_slice(&[0x00, 0x00]); buf.extend_from_slice(&[0x00, 0x00]); buf.extend_from_slice(&[0x00, 0x00]); encode_qname(&mut buf, name);
buf.extend_from_slice(&DNS_TYPE_A.to_be_bytes());
buf.extend_from_slice(&DNS_CLASS_IN.to_be_bytes());
buf
}
#[test]
fn test_parse_valid_a_query() {
let packet = make_a_query("mycontainer");
let query = parse_dns_query(&packet).expect("should parse");
assert_eq!(query.id, 0x1234);
assert_eq!(query.qname, "mycontainer");
assert_eq!(query.qtype, DNS_TYPE_A);
assert_eq!(query.qclass, DNS_CLASS_IN);
}
#[test]
fn test_parse_dotted_name() {
let packet = make_a_query("mycontainer.pelagos");
let query = parse_dns_query(&packet).expect("should parse");
assert_eq!(query.qname, "mycontainer.pelagos");
}
#[test]
fn test_parse_rejects_truncated() {
assert!(parse_dns_query(&[0; 5]).is_none());
assert!(parse_dns_query(&[0; 12]).is_none()); }
#[test]
fn test_parse_rejects_response() {
let mut packet = make_a_query("test");
packet[2] |= 0x80;
assert!(parse_dns_query(&packet).is_none());
}
#[test]
fn test_build_a_response() {
let packet = make_a_query("redis");
let query = parse_dns_query(&packet).unwrap();
let response = build_a_response(&query, Ipv4Addr::new(172, 19, 0, 5));
assert_eq!(response[0], 0x12); assert_eq!(response[1], 0x34); assert_eq!(response[2] & 0x80, 0x80); assert_eq!(u16::from_be_bytes([response[6], response[7]]), 1);
let len = response.len();
assert_eq!(&response[len - 4..], &[172, 19, 0, 5]);
}
#[test]
fn test_build_nodata() {
let packet = make_a_query("app");
let query = parse_dns_query(&packet).unwrap();
let response = build_nodata(&query);
let flags = u16::from_be_bytes([response[2], response[3]]);
assert_eq!(flags & 0x000F, 0); assert!(flags & DNS_FLAG_AA != 0); assert_eq!(u16::from_be_bytes([response[6], response[7]]), 0);
}
#[test]
fn test_build_nxdomain() {
let packet = make_a_query("nonexistent");
let query = parse_dns_query(&packet).unwrap();
let response = build_nxdomain(&query);
let flags = u16::from_be_bytes([response[2], response[3]]);
assert_eq!(flags & 0x000F, DNS_RCODE_NXDOMAIN);
assert_eq!(u16::from_be_bytes([response[6], response[7]]), 0);
}
#[test]
fn test_parse_qname_labels() {
let mut buf = Vec::new();
buf.push(3);
buf.extend_from_slice(b"app");
buf.push(6);
buf.extend_from_slice(b"pelagos");
buf.push(0);
let (name, pos) = parse_qname(&buf, 0).expect("should parse");
assert_eq!(name, "app.pelagos");
assert_eq!(pos, buf.len());
}
#[test]
fn test_config_parse_roundtrip() {
let content = "\
172.19.0.1 8.8.8.8,1.1.1.1
redis 172.19.0.2
app 172.19.0.3
proxy 172.19.0.4
";
let config = parse_network_config(content).expect("should parse");
assert_eq!(config.listen_ip, Ipv4Addr::new(172, 19, 0, 1));
assert_eq!(config.upstream.len(), 2);
assert_eq!(config.upstream[0], Ipv4Addr::new(8, 8, 8, 8));
assert_eq!(config.upstream[1], Ipv4Addr::new(1, 1, 1, 1));
assert_eq!(config.entries.len(), 3);
assert_eq!(config.entries["redis"], Ipv4Addr::new(172, 19, 0, 2));
assert_eq!(config.entries["app"], Ipv4Addr::new(172, 19, 0, 3));
assert_eq!(config.entries["proxy"], Ipv4Addr::new(172, 19, 0, 4));
}
#[test]
fn test_config_parse_empty() {
assert!(parse_network_config("").is_none());
assert!(parse_network_config(" \n ").is_none());
}
#[test]
fn test_config_parse_header_only() {
let config = parse_network_config("172.19.0.1 8.8.8.8\n");
let config = config.expect("should parse header");
assert!(config.entries.is_empty());
}
}