1use std::{
2 net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
3 sync::{Arc, RwLock},
4};
5
6use anyhow::{bail, Result};
7use bimap::BiMap;
8
9#[derive(Debug, Clone)]
15pub struct DnsMap {
16 subnet: u8,
17 inner: Arc<RwLock<DnsMapInner>>,
18}
19
20#[derive(Debug, Default)]
21struct DnsMapInner {
22 counter: u32,
23 map: BiMap<Ipv4Addr, String>,
24}
25
26impl DnsMap {
27 pub fn new(subnet: u8) -> Self {
29 DnsMap {
30 subnet,
31 inner: Arc::new(RwLock::new(DnsMapInner::default())),
32 }
33 }
34
35 pub fn get_or_alloc(&self, hostname: &str) -> Result<Ipv4Addr> {
41 {
43 let r = self.inner.read().unwrap();
44 if let Some(ip) = r.map.get_by_right(hostname) {
45 return Ok(*ip);
46 }
47 }
48 let mut w = self.inner.write().unwrap();
50 if let Some(ip) = w.map.get_by_right(hostname) {
52 return Ok(*ip);
53 }
54 let index = w.counter;
55 if index >= 0xFF_FFFF {
56 bail!("dns map exhausted");
57 }
58 w.counter += 1;
59 let ip = make_fake_ip(self.subnet, index);
60 w.map.insert(ip, hostname.to_owned());
61 Ok(ip)
62 }
63
64 pub fn lookup_hostname(&self, ip: Ipv4Addr) -> Option<String> {
66 self.inner.read().unwrap().map.get_by_left(&ip).cloned()
67 }
68
69 pub fn is_fake_ip(&self, ip: IpAddr) -> bool {
71 match ip {
72 IpAddr::V4(v4) => v4.octets()[0] == self.subnet,
73 _ => false,
74 }
75 }
76
77 pub fn handle_dns_query(&self, packet: &[u8]) -> Option<Vec<u8>> {
82 dns_handle_query(packet, self)
83 }
84}
85
86fn make_fake_ip(subnet: u8, index: u32) -> Ipv4Addr {
87 let idx = index + 1;
89 Ipv4Addr::new(
90 subnet,
91 ((idx >> 16) & 0xFF) as u8,
92 ((idx >> 8) & 0xFF) as u8,
93 (idx & 0xFF) as u8,
94 )
95}
96
97pub struct DnsServer {
104 socket: UdpSocket,
105 map: DnsMap,
106}
107
108impl DnsServer {
109 pub fn bind(addr: SocketAddr, map: DnsMap) -> Result<Self> {
111 let socket = UdpSocket::bind(addr)?;
112 Ok(DnsServer { socket, map })
113 }
114
115 pub fn local_addr(&self) -> SocketAddr {
117 self.socket.local_addr().unwrap()
118 }
119
120 pub fn run(self) {
122 let mut buf = [0u8; 512];
123 loop {
124 let (n, src) = match self.socket.recv_from(&mut buf) {
125 Ok(x) => x,
126 Err(_) => continue,
127 };
128 let packet = &buf[..n];
129 if let Some(response) = self.handle_query(packet) {
130 let _ = self.socket.send_to(&response, src);
131 }
132 }
133 }
134
135 pub fn handle_query(&self, packet: &[u8]) -> Option<Vec<u8>> {
138 dns_handle_query(packet, &self.map)
139 }
140}
141
142fn dns_handle_query(packet: &[u8], map: &DnsMap) -> Option<Vec<u8>> {
148 if packet.len() < 12 {
149 return None;
150 }
151 let txid = &packet[0..2];
152 if u16::from_be_bytes([packet[4], packet[5]]) != 1 {
153 return None; }
155
156 let mut offset = 12usize;
158 let mut labels: Vec<String> = Vec::new();
159 loop {
160 if offset >= packet.len() {
161 return None;
162 }
163 let len = packet[offset] as usize;
164 if len == 0 {
165 offset += 1;
166 break;
167 }
168 if len & 0xC0 != 0 {
169 return None;
170 } offset += 1;
172 if offset + len > packet.len() {
173 return None;
174 }
175 labels.push(String::from_utf8_lossy(&packet[offset..offset + len]).into_owned());
176 offset += len;
177 }
178 let qname = labels.join(".");
179
180 if offset + 4 > packet.len() {
181 return None;
182 }
183 let qtype = u16::from_be_bytes([packet[offset], packet[offset + 1]]);
184 let qclass = u16::from_be_bytes([packet[offset + 2], packet[offset + 3]]);
185 offset += 4;
186
187 if qtype != 1 || qclass != 1 {
188 return None;
189 } let fake_ip = map.get_or_alloc(&qname).ok()?;
192 let question = &packet[12..offset];
193
194 let mut resp = Vec::with_capacity(offset + 16);
195 resp.extend_from_slice(txid);
196 resp.extend_from_slice(&[0x84, 0x00]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x00]); resp.extend_from_slice(&[0x00, 0x00]); resp.extend_from_slice(question);
202 resp.extend_from_slice(&[0xC0, 0x0C]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x00, 0x00, 0x3C]); resp.extend_from_slice(&[0x00, 0x04]); resp.extend_from_slice(&fake_ip.octets());
208
209 Some(resp)
210}
211
212#[cfg(test)]
215mod tests {
216 use super::*;
217 use std::{net::UdpSocket as StdUdpSocket, thread, time::Duration};
218
219 fn bind_server(subnet: u8) -> (DnsMap, SocketAddr) {
220 let map = DnsMap::new(subnet);
221 let server = DnsServer::bind("127.0.0.1:0".parse().unwrap(), map.clone()).unwrap();
222 let addr = server.local_addr();
223 thread::spawn(move || server.run());
224 (map, addr)
225 }
226
227 fn query_a(server: SocketAddr, name: &str) -> Option<Ipv4Addr> {
228 let sock = StdUdpSocket::bind("127.0.0.1:0").unwrap();
229 sock.set_read_timeout(Some(Duration::from_secs(2))).unwrap();
230
231 let mut pkt = Vec::new();
233 pkt.extend_from_slice(&[0xAB, 0xCD]); pkt.extend_from_slice(&[0x01, 0x00]); pkt.extend_from_slice(&[0x00, 0x01]); pkt.extend_from_slice(&[0x00, 0x00]); pkt.extend_from_slice(&[0x00, 0x00]); pkt.extend_from_slice(&[0x00, 0x00]); for label in name.split('.') {
240 pkt.push(label.len() as u8);
241 pkt.extend_from_slice(label.as_bytes());
242 }
243 pkt.push(0); pkt.extend_from_slice(&[0x00, 0x01]); pkt.extend_from_slice(&[0x00, 0x01]); sock.send_to(&pkt, server).ok()?;
248
249 let mut buf = [0u8; 512];
250 let (n, _) = sock.recv_from(&mut buf).ok()?;
251 let resp = &buf[..n];
252
253 let mut off = 12usize;
255 loop {
256 if off >= resp.len() {
257 return None;
258 }
259 let l = resp[off] as usize;
260 if l == 0 {
261 off += 1;
262 break;
263 }
264 if l & 0xC0 != 0 {
265 off += 2;
266 break;
267 }
268 off += 1 + l;
269 }
270 off += 4; off += 2 + 2 + 2 + 4;
273 let rdlen = u16::from_be_bytes([resp[off], resp[off + 1]]) as usize;
274 off += 2;
275 if rdlen != 4 || off + 4 > resp.len() {
276 return None;
277 }
278 Some(Ipv4Addr::new(
279 resp[off],
280 resp[off + 1],
281 resp[off + 2],
282 resp[off + 3],
283 ))
284 }
285
286 #[test]
288 fn test_dns_a_query_returns_fake_ip() {
289 let (_, addr) = bind_server(224);
290 let ip = query_a(addr, "example.com").unwrap();
291 assert_eq!(ip.octets()[0], 224);
292 }
293
294 #[test]
296 fn test_dns_same_hostname_same_ip() {
297 let (_, addr) = bind_server(224);
298 let ip1 = query_a(addr, "example.com").unwrap();
299 let ip2 = query_a(addr, "example.com").unwrap();
300 assert_eq!(ip1, ip2);
301 }
302
303 #[test]
305 fn test_dns_map_reverse_lookup() {
306 let map = DnsMap::new(224);
307 let ip = map.get_or_alloc("example.com").unwrap();
308 assert_eq!(map.lookup_hostname(ip).as_deref(), Some("example.com"));
309 }
310
311 #[test]
312 fn test_dns_map_different_hostnames_different_ips() {
313 let map = DnsMap::new(224);
314 let ip1 = map.get_or_alloc("a.example.com").unwrap();
315 let ip2 = map.get_or_alloc("b.example.com").unwrap();
316 assert_ne!(ip1, ip2);
317 assert_eq!(map.lookup_hostname(ip1).as_deref(), Some("a.example.com"));
318 assert_eq!(map.lookup_hostname(ip2).as_deref(), Some("b.example.com"));
319 }
320
321 #[test]
322 fn test_dns_map_is_fake_ip() {
323 let map = DnsMap::new(224);
324 let ip = map.get_or_alloc("test.com").unwrap();
325 assert!(map.is_fake_ip(IpAddr::V4(ip)));
326 assert!(!map.is_fake_ip("8.8.8.8".parse().unwrap()));
327 }
328
329 #[test]
331 fn test_dns_map_exhaustion() {
332 let map = DnsMap::new(224);
333 map.inner.write().unwrap().counter = 0xFF_FFFF;
335 let result = map.get_or_alloc("overflow.com");
336 assert!(
337 result.is_err(),
338 "should fail when address space is exhausted"
339 );
340 }
341}