Skip to main content

proxychains_masq/
dns.rs

1use std::{
2    net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
3    sync::{Arc, RwLock},
4};
5
6use anyhow::{bail, Result};
7use bimap::BiMap;
8
9// ─── DnsMap ───────────────────────────────────────────────────────────────────
10
11/// Thread-safe bidirectional map between fake IPv4 addresses and hostnames.
12///
13/// Fake IPs are allocated sequentially from `subnet.0.0.1` upwards.
14#[derive(Debug, Clone)]
15pub struct DnsMap {
16    subnet: u8,
17    inner: Arc<RwLock<DnsMapInner>>,
18}
19
20#[derive(Debug, Default)]
21struct DnsMapInner {
22    counter: u32,
23    map: BiMap<Ipv4Addr, String>,
24}
25
26impl DnsMap {
27    /// Create a new map that allocates from `subnet.x.x.x`.
28    pub fn new(subnet: u8) -> Self {
29        DnsMap {
30            subnet,
31            inner: Arc::new(RwLock::new(DnsMapInner::default())),
32        }
33    }
34
35    /// Get or allocate a fake IP for `hostname`.
36    ///
37    /// # Errors
38    ///
39    /// Returns `Err` when the address space is exhausted (>= 16 million entries).
40    pub fn get_or_alloc(&self, hostname: &str) -> Result<Ipv4Addr> {
41        // Fast path: already allocated.
42        {
43            let r = self.inner.read().unwrap();
44            if let Some(ip) = r.map.get_by_right(hostname) {
45                return Ok(*ip);
46            }
47        }
48        // Slow path: allocate a new entry.
49        let mut w = self.inner.write().unwrap();
50        // Double-check after lock upgrade.
51        if let Some(ip) = w.map.get_by_right(hostname) {
52            return Ok(*ip);
53        }
54        let index = w.counter;
55        if index >= 0xFF_FFFF {
56            bail!("dns map exhausted");
57        }
58        w.counter += 1;
59        let ip = make_fake_ip(self.subnet, index);
60        w.map.insert(ip, hostname.to_owned());
61        Ok(ip)
62    }
63
64    /// Reverse-lookup: hostname for a fake IP.
65    pub fn lookup_hostname(&self, ip: Ipv4Addr) -> Option<String> {
66        self.inner.read().unwrap().map.get_by_left(&ip).cloned()
67    }
68
69    /// Check whether an IP belongs to this map's subnet.
70    pub fn is_fake_ip(&self, ip: IpAddr) -> bool {
71        match ip {
72            IpAddr::V4(v4) => v4.octets()[0] == self.subnet,
73            _ => false,
74        }
75    }
76
77    /// Parse a raw DNS wire-format query and return a response that assigns a
78    /// fake IP for the queried hostname.
79    ///
80    /// Returns `None` if the packet is not a valid single-question A query.
81    pub fn handle_dns_query(&self, packet: &[u8]) -> Option<Vec<u8>> {
82        dns_handle_query(packet, self)
83    }
84}
85
86fn make_fake_ip(subnet: u8, index: u32) -> Ipv4Addr {
87    // index starts at 0 → .0.0.1
88    let idx = index + 1;
89    Ipv4Addr::new(
90        subnet,
91        ((idx >> 16) & 0xFF) as u8,
92        ((idx >> 8) & 0xFF) as u8,
93        (idx & 0xFF) as u8,
94    )
95}
96
97// ─── DnsServer ────────────────────────────────────────────────────────────────
98
99/// Minimal UDP DNS server that answers A queries with fake IPs from a [`DnsMap`].
100///
101/// Runs in a dedicated blocking thread (DNS is inherently synchronous and
102/// low-throughput enough that this is optimal).
103pub struct DnsServer {
104    socket: UdpSocket,
105    map: DnsMap,
106}
107
108impl DnsServer {
109    /// Bind to `addr` and create a server backed by `map`.
110    pub fn bind(addr: SocketAddr, map: DnsMap) -> Result<Self> {
111        let socket = UdpSocket::bind(addr)?;
112        Ok(DnsServer { socket, map })
113    }
114
115    /// Return the local address the server is listening on.
116    pub fn local_addr(&self) -> SocketAddr {
117        self.socket.local_addr().unwrap()
118    }
119
120    /// Run the server loop (blocks the calling thread indefinitely).
121    pub fn run(self) {
122        let mut buf = [0u8; 512];
123        loop {
124            let (n, src) = match self.socket.recv_from(&mut buf) {
125                Ok(x) => x,
126                Err(_) => continue,
127            };
128            let packet = &buf[..n];
129            if let Some(response) = self.handle_query(packet) {
130                let _ = self.socket.send_to(&response, src);
131            }
132        }
133    }
134
135    /// Parse a raw DNS query `packet` and return a DNS response allocating a
136    /// fake IP from the map, or `None` if the packet is not a valid A query.
137    pub fn handle_query(&self, packet: &[u8]) -> Option<Vec<u8>> {
138        dns_handle_query(packet, &self.map)
139    }
140}
141
142// ─── Shared DNS query handler ─────────────────────────────────────────────────
143
144/// Parse a DNS wire-format query and build a response that allocates a fake IP
145/// for the queried hostname using `map`.  Returns `None` for non-A queries,
146/// malformed packets, or address-space exhaustion.
147fn dns_handle_query(packet: &[u8], map: &DnsMap) -> Option<Vec<u8>> {
148    if packet.len() < 12 {
149        return None;
150    }
151    let txid = &packet[0..2];
152    if u16::from_be_bytes([packet[4], packet[5]]) != 1 {
153        return None; // QDCOUNT must be 1
154    }
155
156    // Parse QNAME
157    let mut offset = 12usize;
158    let mut labels: Vec<String> = Vec::new();
159    loop {
160        if offset >= packet.len() {
161            return None;
162        }
163        let len = packet[offset] as usize;
164        if len == 0 {
165            offset += 1;
166            break;
167        }
168        if len & 0xC0 != 0 {
169            return None;
170        } // pointer compression not handled
171        offset += 1;
172        if offset + len > packet.len() {
173            return None;
174        }
175        labels.push(String::from_utf8_lossy(&packet[offset..offset + len]).into_owned());
176        offset += len;
177    }
178    let qname = labels.join(".");
179
180    if offset + 4 > packet.len() {
181        return None;
182    }
183    let qtype = u16::from_be_bytes([packet[offset], packet[offset + 1]]);
184    let qclass = u16::from_be_bytes([packet[offset + 2], packet[offset + 3]]);
185    offset += 4;
186
187    if qtype != 1 || qclass != 1 {
188        return None;
189    } // only A IN
190
191    let fake_ip = map.get_or_alloc(&qname).ok()?;
192    let question = &packet[12..offset];
193
194    let mut resp = Vec::with_capacity(offset + 16);
195    resp.extend_from_slice(txid);
196    resp.extend_from_slice(&[0x84, 0x00]); // QR=1 AA=1
197    resp.extend_from_slice(&[0x00, 0x01]); // QDCOUNT=1
198    resp.extend_from_slice(&[0x00, 0x01]); // ANCOUNT=1
199    resp.extend_from_slice(&[0x00, 0x00]); // NSCOUNT
200    resp.extend_from_slice(&[0x00, 0x00]); // ARCOUNT
201    resp.extend_from_slice(question);
202    resp.extend_from_slice(&[0xC0, 0x0C]); // name pointer to offset 12
203    resp.extend_from_slice(&[0x00, 0x01]); // TYPE A
204    resp.extend_from_slice(&[0x00, 0x01]); // CLASS IN
205    resp.extend_from_slice(&[0x00, 0x00, 0x00, 0x3C]); // TTL 60
206    resp.extend_from_slice(&[0x00, 0x04]); // RDLENGTH 4
207    resp.extend_from_slice(&fake_ip.octets());
208
209    Some(resp)
210}
211
212// ─── Tests ────────────────────────────────────────────────────────────────────
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use std::{net::UdpSocket as StdUdpSocket, thread, time::Duration};
218
219    fn bind_server(subnet: u8) -> (DnsMap, SocketAddr) {
220        let map = DnsMap::new(subnet);
221        let server = DnsServer::bind("127.0.0.1:0".parse().unwrap(), map.clone()).unwrap();
222        let addr = server.local_addr();
223        thread::spawn(move || server.run());
224        (map, addr)
225    }
226
227    fn query_a(server: SocketAddr, name: &str) -> Option<Ipv4Addr> {
228        let sock = StdUdpSocket::bind("127.0.0.1:0").unwrap();
229        sock.set_read_timeout(Some(Duration::from_secs(2))).unwrap();
230
231        // Build minimal A query
232        let mut pkt = Vec::new();
233        pkt.extend_from_slice(&[0xAB, 0xCD]); // txid
234        pkt.extend_from_slice(&[0x01, 0x00]); // QR=0 RD=1
235        pkt.extend_from_slice(&[0x00, 0x01]); // QDCOUNT
236        pkt.extend_from_slice(&[0x00, 0x00]); // ANCOUNT
237        pkt.extend_from_slice(&[0x00, 0x00]); // NSCOUNT
238        pkt.extend_from_slice(&[0x00, 0x00]); // ARCOUNT
239        for label in name.split('.') {
240            pkt.push(label.len() as u8);
241            pkt.extend_from_slice(label.as_bytes());
242        }
243        pkt.push(0); // root label
244        pkt.extend_from_slice(&[0x00, 0x01]); // QTYPE A
245        pkt.extend_from_slice(&[0x00, 0x01]); // QCLASS IN
246
247        sock.send_to(&pkt, server).ok()?;
248
249        let mut buf = [0u8; 512];
250        let (n, _) = sock.recv_from(&mut buf).ok()?;
251        let resp = &buf[..n];
252
253        // Skip to answer section: header (12) + question
254        let mut off = 12usize;
255        loop {
256            if off >= resp.len() {
257                return None;
258            }
259            let l = resp[off] as usize;
260            if l == 0 {
261                off += 1;
262                break;
263            }
264            if l & 0xC0 != 0 {
265                off += 2;
266                break;
267            }
268            off += 1 + l;
269        }
270        off += 4; // QTYPE + QCLASS
271                  // Answer RR starts here; skip name (pointer = 2 bytes), type(2), class(2), ttl(4), rdlen(2)
272        off += 2 + 2 + 2 + 4;
273        let rdlen = u16::from_be_bytes([resp[off], resp[off + 1]]) as usize;
274        off += 2;
275        if rdlen != 4 || off + 4 > resp.len() {
276            return None;
277        }
278        Some(Ipv4Addr::new(
279            resp[off],
280            resp[off + 1],
281            resp[off + 2],
282            resp[off + 3],
283        ))
284    }
285
286    // T-25
287    #[test]
288    fn test_dns_a_query_returns_fake_ip() {
289        let (_, addr) = bind_server(224);
290        let ip = query_a(addr, "example.com").unwrap();
291        assert_eq!(ip.octets()[0], 224);
292    }
293
294    // T-26
295    #[test]
296    fn test_dns_same_hostname_same_ip() {
297        let (_, addr) = bind_server(224);
298        let ip1 = query_a(addr, "example.com").unwrap();
299        let ip2 = query_a(addr, "example.com").unwrap();
300        assert_eq!(ip1, ip2);
301    }
302
303    // T-27
304    #[test]
305    fn test_dns_map_reverse_lookup() {
306        let map = DnsMap::new(224);
307        let ip = map.get_or_alloc("example.com").unwrap();
308        assert_eq!(map.lookup_hostname(ip).as_deref(), Some("example.com"));
309    }
310
311    #[test]
312    fn test_dns_map_different_hostnames_different_ips() {
313        let map = DnsMap::new(224);
314        let ip1 = map.get_or_alloc("a.example.com").unwrap();
315        let ip2 = map.get_or_alloc("b.example.com").unwrap();
316        assert_ne!(ip1, ip2);
317        assert_eq!(map.lookup_hostname(ip1).as_deref(), Some("a.example.com"));
318        assert_eq!(map.lookup_hostname(ip2).as_deref(), Some("b.example.com"));
319    }
320
321    #[test]
322    fn test_dns_map_is_fake_ip() {
323        let map = DnsMap::new(224);
324        let ip = map.get_or_alloc("test.com").unwrap();
325        assert!(map.is_fake_ip(IpAddr::V4(ip)));
326        assert!(!map.is_fake_ip("8.8.8.8".parse().unwrap()));
327    }
328
329    // T-28
330    #[test]
331    fn test_dns_map_exhaustion() {
332        let map = DnsMap::new(224);
333        // Force counter to max
334        map.inner.write().unwrap().counter = 0xFF_FFFF;
335        let result = map.get_or_alloc("overflow.com");
336        assert!(
337            result.is_err(),
338            "should fail when address space is exhausted"
339        );
340    }
341}