use std::{
net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
sync::{Arc, RwLock},
};
use anyhow::{bail, Result};
use bimap::BiMap;
#[derive(Debug, Clone)]
pub struct DnsMap {
subnet: u8,
inner: Arc<RwLock<DnsMapInner>>,
}
#[derive(Debug, Default)]
struct DnsMapInner {
counter: u32,
map: BiMap<Ipv4Addr, String>,
}
impl DnsMap {
pub fn new(subnet: u8) -> Self {
DnsMap {
subnet,
inner: Arc::new(RwLock::new(DnsMapInner::default())),
}
}
pub fn get_or_alloc(&self, hostname: &str) -> Result<Ipv4Addr> {
{
let r = self.inner.read().unwrap();
if let Some(ip) = r.map.get_by_right(hostname) {
return Ok(*ip);
}
}
let mut w = self.inner.write().unwrap();
if let Some(ip) = w.map.get_by_right(hostname) {
return Ok(*ip);
}
let index = w.counter;
if index >= 0xFF_FFFF {
bail!("dns map exhausted");
}
w.counter += 1;
let ip = make_fake_ip(self.subnet, index);
w.map.insert(ip, hostname.to_owned());
Ok(ip)
}
pub fn lookup_hostname(&self, ip: Ipv4Addr) -> Option<String> {
self.inner.read().unwrap().map.get_by_left(&ip).cloned()
}
pub fn is_fake_ip(&self, ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => v4.octets()[0] == self.subnet,
_ => false,
}
}
pub fn handle_dns_query(&self, packet: &[u8]) -> Option<Vec<u8>> {
dns_handle_query(packet, self)
}
}
fn make_fake_ip(subnet: u8, index: u32) -> Ipv4Addr {
let idx = index + 1;
Ipv4Addr::new(
subnet,
((idx >> 16) & 0xFF) as u8,
((idx >> 8) & 0xFF) as u8,
(idx & 0xFF) as u8,
)
}
pub struct DnsServer {
socket: UdpSocket,
map: DnsMap,
}
impl DnsServer {
pub fn bind(addr: SocketAddr, map: DnsMap) -> Result<Self> {
let socket = UdpSocket::bind(addr)?;
Ok(DnsServer { socket, map })
}
pub fn local_addr(&self) -> SocketAddr {
self.socket.local_addr().unwrap()
}
pub fn run(self) {
let mut buf = [0u8; 512];
loop {
let (n, src) = match self.socket.recv_from(&mut buf) {
Ok(x) => x,
Err(_) => continue,
};
let packet = &buf[..n];
if let Some(response) = self.handle_query(packet) {
let _ = self.socket.send_to(&response, src);
}
}
}
pub fn handle_query(&self, packet: &[u8]) -> Option<Vec<u8>> {
dns_handle_query(packet, &self.map)
}
}
fn dns_handle_query(packet: &[u8], map: &DnsMap) -> Option<Vec<u8>> {
if packet.len() < 12 {
return None;
}
let txid = &packet[0..2];
if u16::from_be_bytes([packet[4], packet[5]]) != 1 {
return None; }
let mut offset = 12usize;
let mut labels: Vec<String> = Vec::new();
loop {
if offset >= packet.len() {
return None;
}
let len = packet[offset] as usize;
if len == 0 {
offset += 1;
break;
}
if len & 0xC0 != 0 {
return None;
} offset += 1;
if offset + len > packet.len() {
return None;
}
labels.push(String::from_utf8_lossy(&packet[offset..offset + len]).into_owned());
offset += len;
}
let qname = labels.join(".");
if offset + 4 > packet.len() {
return None;
}
let qtype = u16::from_be_bytes([packet[offset], packet[offset + 1]]);
let qclass = u16::from_be_bytes([packet[offset + 2], packet[offset + 3]]);
offset += 4;
if qtype != 1 || qclass != 1 {
return None;
}
let fake_ip = map.get_or_alloc(&qname).ok()?;
let question = &packet[12..offset];
let mut resp = Vec::with_capacity(offset + 16);
resp.extend_from_slice(txid);
resp.extend_from_slice(&[0x84, 0x00]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x00]); resp.extend_from_slice(&[0x00, 0x00]); resp.extend_from_slice(question);
resp.extend_from_slice(&[0xC0, 0x0C]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x00, 0x00, 0x3C]); resp.extend_from_slice(&[0x00, 0x04]); resp.extend_from_slice(&fake_ip.octets());
Some(resp)
}
#[cfg(test)]
mod tests {
use super::*;
use std::{net::UdpSocket as StdUdpSocket, thread, time::Duration};
fn bind_server(subnet: u8) -> (DnsMap, SocketAddr) {
let map = DnsMap::new(subnet);
let server = DnsServer::bind("127.0.0.1:0".parse().unwrap(), map.clone()).unwrap();
let addr = server.local_addr();
thread::spawn(move || server.run());
(map, addr)
}
fn query_a(server: SocketAddr, name: &str) -> Option<Ipv4Addr> {
let sock = StdUdpSocket::bind("127.0.0.1:0").unwrap();
sock.set_read_timeout(Some(Duration::from_secs(2))).unwrap();
let mut pkt = Vec::new();
pkt.extend_from_slice(&[0xAB, 0xCD]); pkt.extend_from_slice(&[0x01, 0x00]); pkt.extend_from_slice(&[0x00, 0x01]); pkt.extend_from_slice(&[0x00, 0x00]); pkt.extend_from_slice(&[0x00, 0x00]); pkt.extend_from_slice(&[0x00, 0x00]); for label in name.split('.') {
pkt.push(label.len() as u8);
pkt.extend_from_slice(label.as_bytes());
}
pkt.push(0); pkt.extend_from_slice(&[0x00, 0x01]); pkt.extend_from_slice(&[0x00, 0x01]);
sock.send_to(&pkt, server).ok()?;
let mut buf = [0u8; 512];
let (n, _) = sock.recv_from(&mut buf).ok()?;
let resp = &buf[..n];
let mut off = 12usize;
loop {
if off >= resp.len() {
return None;
}
let l = resp[off] as usize;
if l == 0 {
off += 1;
break;
}
if l & 0xC0 != 0 {
off += 2;
break;
}
off += 1 + l;
}
off += 4; off += 2 + 2 + 2 + 4;
let rdlen = u16::from_be_bytes([resp[off], resp[off + 1]]) as usize;
off += 2;
if rdlen != 4 || off + 4 > resp.len() {
return None;
}
Some(Ipv4Addr::new(
resp[off],
resp[off + 1],
resp[off + 2],
resp[off + 3],
))
}
#[test]
fn test_dns_a_query_returns_fake_ip() {
let (_, addr) = bind_server(224);
let ip = query_a(addr, "example.com").unwrap();
assert_eq!(ip.octets()[0], 224);
}
#[test]
fn test_dns_same_hostname_same_ip() {
let (_, addr) = bind_server(224);
let ip1 = query_a(addr, "example.com").unwrap();
let ip2 = query_a(addr, "example.com").unwrap();
assert_eq!(ip1, ip2);
}
#[test]
fn test_dns_map_reverse_lookup() {
let map = DnsMap::new(224);
let ip = map.get_or_alloc("example.com").unwrap();
assert_eq!(map.lookup_hostname(ip).as_deref(), Some("example.com"));
}
#[test]
fn test_dns_map_different_hostnames_different_ips() {
let map = DnsMap::new(224);
let ip1 = map.get_or_alloc("a.example.com").unwrap();
let ip2 = map.get_or_alloc("b.example.com").unwrap();
assert_ne!(ip1, ip2);
assert_eq!(map.lookup_hostname(ip1).as_deref(), Some("a.example.com"));
assert_eq!(map.lookup_hostname(ip2).as_deref(), Some("b.example.com"));
}
#[test]
fn test_dns_map_is_fake_ip() {
let map = DnsMap::new(224);
let ip = map.get_or_alloc("test.com").unwrap();
assert!(map.is_fake_ip(IpAddr::V4(ip)));
assert!(!map.is_fake_ip("8.8.8.8".parse().unwrap()));
}
#[test]
fn test_dns_map_exhaustion() {
let map = DnsMap::new(224);
map.inner.write().unwrap().counter = 0xFF_FFFF;
let result = map.get_or_alloc("overflow.com");
assert!(
result.is_err(),
"should fail when address space is exhausted"
);
}
}