Skip to main content

microsandbox_network/dns/
interceptor.rs

1//! DNS query interception, filtering, and resolution.
2//!
3//! The [`DnsInterceptor`] bridges the smoltcp UDP socket (bound to gateway:53)
4//! and the host DNS resolvers. Queries are read from the socket, checked
5//! against the domain block list, forwarded to hickory-resolver for
6//! resolution, and responses are sent back through the socket.
7//!
8//! Because resolution is async and the poll loop is sync, queries are sent to
9//! a background tokio task via a channel. Responses come back through another
10//! channel and are written to the smoltcp socket on the next poll iteration.
11
12use std::collections::HashSet;
13use std::sync::Arc;
14
15use bytes::Bytes;
16use smoltcp::iface::SocketSet;
17use smoltcp::socket::udp;
18use smoltcp::storage::PacketMetadata;
19use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
20use tokio::sync::mpsc;
21
22use crate::config::DnsConfig;
23use crate::shared::SharedState;
24
25//--------------------------------------------------------------------------------------------------
26// Constants
27//--------------------------------------------------------------------------------------------------
28
29/// DNS port.
30const DNS_PORT: u16 = 53;
31
32/// Max DNS message size (UDP).
33const DNS_MAX_SIZE: usize = 4096;
34
35/// Number of packet slots in the smoltcp UDP socket buffers.
36const DNS_SOCKET_PACKET_SLOTS: usize = 16;
37
38/// Capacity of the query/response channels.
39const CHANNEL_CAPACITY: usize = 64;
40
41//--------------------------------------------------------------------------------------------------
42// Types
43//--------------------------------------------------------------------------------------------------
44
45/// DNS query/response interceptor.
46///
47/// Owns the smoltcp UDP socket handle and channels to the async resolver
48/// task. The poll loop calls [`process()`] each iteration to:
49///
50/// 1. Read pending queries from the smoltcp socket → send to resolver task.
51/// 2. Read resolved responses from the channel → write to smoltcp socket.
52///
53/// [`process()`]: DnsInterceptor::process
54pub struct DnsInterceptor {
55    /// Handle to the smoltcp UDP socket bound to gateway:53.
56    socket_handle: smoltcp::iface::SocketHandle,
57    /// Sends queries to the background resolver task.
58    query_tx: mpsc::Sender<DnsQuery>,
59    /// Receives responses from the background resolver task.
60    response_rx: mpsc::Receiver<DnsResponse>,
61}
62
63/// Pre-processed DNS config with lowercased block lists (avoids per-query allocations).
64struct NormalizedDnsConfig {
65    /// O(1) exact-match lookup for blocked domains.
66    blocked_domains: HashSet<String>,
67    /// Lowercased suffixes WITHOUT leading dot (for exact match against the suffix itself).
68    blocked_suffixes: Vec<String>,
69    /// Dot-prefixed lowercased suffixes (for `ends_with` matching without per-query `format!`).
70    blocked_suffixes_dotted: Vec<String>,
71    rebind_protection: bool,
72}
73
74/// A DNS query extracted from the smoltcp socket.
75struct DnsQuery {
76    /// Raw DNS message bytes.
77    data: Bytes,
78    /// Source endpoint (guest IP:port) for routing the response back.
79    source: IpEndpoint,
80}
81
82/// A resolved DNS response ready to send back to the guest.
83struct DnsResponse {
84    /// Raw DNS response bytes.
85    data: Bytes,
86    /// Destination endpoint (guest IP:port).
87    dest: IpEndpoint,
88}
89
90//--------------------------------------------------------------------------------------------------
91// Methods
92//--------------------------------------------------------------------------------------------------
93
94impl DnsInterceptor {
95    /// Create the DNS interceptor.
96    ///
97    /// Binds a smoltcp UDP socket to port 53, creates the channel pair, and
98    /// spawns the background resolver task.
99    pub fn new(
100        sockets: &mut SocketSet<'_>,
101        dns_config: DnsConfig,
102        shared: Arc<SharedState>,
103        tokio_handle: &tokio::runtime::Handle,
104    ) -> Self {
105        // Create and bind the smoltcp UDP socket.
106        let rx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
107        let rx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
108        let tx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
109        let tx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
110
111        let mut socket = udp::Socket::new(
112            udp::PacketBuffer::new(rx_meta, rx_payload),
113            udp::PacketBuffer::new(tx_meta, tx_payload),
114        );
115        socket
116            .bind(IpListenEndpoint {
117                addr: None,
118                port: DNS_PORT,
119            })
120            .expect("failed to bind DNS socket to port 53");
121
122        let socket_handle = sockets.add(socket);
123
124        // Create channels.
125        let (query_tx, query_rx) = mpsc::channel(CHANNEL_CAPACITY);
126        let (response_tx, response_rx) = mpsc::channel(CHANNEL_CAPACITY);
127
128        // Pre-lowercase block lists once to avoid per-query allocations.
129        let suffixes: Vec<String> = dns_config
130            .blocked_suffixes
131            .iter()
132            .map(|s| s.to_lowercase().trim_start_matches('.').to_string())
133            .collect();
134        let suffixes_dotted: Vec<String> = suffixes.iter().map(|s| format!(".{s}")).collect();
135        let normalized = Arc::new(NormalizedDnsConfig {
136            blocked_domains: dns_config
137                .blocked_domains
138                .iter()
139                .map(|d| d.to_lowercase())
140                .collect(),
141            blocked_suffixes: suffixes,
142            blocked_suffixes_dotted: suffixes_dotted,
143            rebind_protection: dns_config.rebind_protection,
144        });
145
146        // Spawn background resolver task.
147        tokio_handle.spawn(dns_resolver_task(query_rx, response_tx, normalized, shared));
148
149        Self {
150            socket_handle,
151            query_tx,
152            response_rx,
153        }
154    }
155
156    /// Process DNS queries and responses.
157    ///
158    /// Called by the poll loop each iteration:
159    /// 1. Reads queries from the smoltcp socket → sends to resolver task.
160    /// 2. Reads responses from the resolver → writes to smoltcp socket.
161    pub fn process(&mut self, sockets: &mut SocketSet<'_>) {
162        let socket = sockets.get_mut::<udp::Socket>(self.socket_handle);
163
164        // Read queries from the smoltcp socket.
165        let mut buf = [0u8; DNS_MAX_SIZE];
166        while socket.can_recv() {
167            match socket.recv_slice(&mut buf) {
168                Ok((n, meta)) => {
169                    let query = DnsQuery {
170                        data: Bytes::copy_from_slice(&buf[..n]),
171                        source: meta.endpoint,
172                    };
173                    if self.query_tx.try_send(query).is_err() {
174                        // Channel full — drop query. Guest will retry.
175                        tracing::debug!("DNS query channel full, dropping query");
176                    }
177                }
178                Err(_) => break,
179            }
180        }
181
182        // Write responses to the smoltcp socket.
183        // Check can_send() BEFORE consuming from the channel so
184        // undeliverable responses remain for the next poll iteration.
185        while socket.can_send() {
186            match self.response_rx.try_recv() {
187                Ok(response) => {
188                    let _ = socket.send_slice(&response.data, response.dest);
189                }
190                Err(_) => break,
191            }
192        }
193    }
194}
195
196//--------------------------------------------------------------------------------------------------
197// Functions
198//--------------------------------------------------------------------------------------------------
199
200/// Background task that resolves DNS queries using the host's resolvers.
201///
202/// Reads queries from the channel, applies domain filtering, resolves via
203/// hickory-resolver, and sends responses back.
204async fn dns_resolver_task(
205    mut query_rx: mpsc::Receiver<DnsQuery>,
206    response_tx: mpsc::Sender<DnsResponse>,
207    dns_config: Arc<NormalizedDnsConfig>,
208    shared: Arc<SharedState>,
209) {
210    // Create a system resolver that uses the host's /etc/resolv.conf.
211    let resolver = match hickory_resolver::Resolver::builder_tokio().map(|b| b.build()) {
212        Ok(r) => r,
213        Err(e) => {
214            tracing::error!(error = %e, "failed to create DNS resolver");
215            return;
216        }
217    };
218
219    while let Some(query) = query_rx.recv().await {
220        let response_tx = response_tx.clone();
221        let dns_config = dns_config.clone();
222        let shared = shared.clone();
223        let resolver = resolver.clone();
224
225        // Spawn a task per query for concurrency.
226        tokio::spawn(async move {
227            let result = resolve_query(&query.data, &dns_config, &resolver).await;
228            match result {
229                Some(response_data) => {
230                    let response = DnsResponse {
231                        data: response_data,
232                        dest: query.source,
233                    };
234                    if response_tx.send(response).await.is_ok() {
235                        shared.proxy_wake.wake();
236                    }
237                }
238                None => {
239                    // Query was blocked or failed — send REFUSED.
240                    if let Some(servfail) = build_refused(&query.data) {
241                        let response = DnsResponse {
242                            data: servfail,
243                            dest: query.source,
244                        };
245                        if response_tx.send(response).await.is_ok() {
246                            shared.proxy_wake.wake();
247                        }
248                    }
249                }
250            }
251        });
252    }
253}
254
255/// Resolve a single DNS query. Returns `None` if the domain is blocked
256/// or contains rebind-protected addresses.
257async fn resolve_query(
258    raw_query: &[u8],
259    dns_config: &NormalizedDnsConfig,
260    resolver: &hickory_resolver::TokioResolver,
261) -> Option<Bytes> {
262    use hickory_proto::op::Message;
263    use hickory_proto::rr::RData;
264    use hickory_proto::serialize::binary::BinDecodable;
265
266    // Parse the DNS query.
267    let query_msg = Message::from_bytes(raw_query).ok()?;
268    let query_id = query_msg.id();
269
270    // Extract the queried domain name.
271    let question = query_msg.queries().first()?;
272    let domain = question.name().to_string();
273    let domain = domain.trim_end_matches('.');
274
275    // Check domain block lists.
276    if is_domain_blocked(domain, dns_config) {
277        tracing::debug!(domain = %domain, "DNS query blocked");
278        return None;
279    }
280
281    // Forward the raw query to the host resolver by performing a lookup.
282    // We use the parsed question to do a proper lookup via hickory-resolver.
283    let record_type = question.query_type();
284
285    let lookup = resolver
286        .lookup(question.name().clone(), record_type)
287        .await
288        .ok()?;
289
290    // DNS rebind protection: reject responses containing private/reserved IPs.
291    if dns_config.rebind_protection {
292        for record in lookup.records() {
293            let is_private = match record.data() {
294                RData::A(a) => is_private_ipv4((*a).into()),
295                RData::AAAA(aaaa) => is_private_ipv6((*aaaa).into()),
296                _ => false,
297            };
298            if is_private {
299                tracing::debug!(
300                    domain = %domain,
301                    "DNS rebind protection: response contains private IP"
302                );
303                return None;
304            }
305        }
306    }
307
308    // Build a fresh DNS response (avoids cloning the entire query message).
309    let mut response_msg = Message::new();
310    response_msg.set_id(query_id);
311    response_msg.set_message_type(hickory_proto::op::MessageType::Response);
312    response_msg.set_op_code(query_msg.op_code());
313    response_msg.set_response_code(hickory_proto::op::ResponseCode::NoError);
314    response_msg.set_recursion_desired(query_msg.recursion_desired());
315    response_msg.set_recursion_available(true);
316    response_msg.add_query(question.clone());
317
318    // Add answer records.
319    let answers: Vec<_> = lookup.records().to_vec();
320    response_msg.insert_answers(answers);
321
322    // Serialize the response.
323    use hickory_proto::serialize::binary::BinEncodable;
324    let response_bytes = response_msg.to_bytes().ok()?;
325
326    Some(Bytes::from(response_bytes))
327}
328
329/// Check if an IPv4 address is in a private/reserved range (for rebind protection).
330fn is_private_ipv4(addr: std::net::Ipv4Addr) -> bool {
331    let octets = addr.octets();
332    addr.is_loopback()                                        // 127.0.0.0/8
333        || octets[0] == 10                                    // 10.0.0.0/8
334        || (octets[0] == 172 && (octets[1] & 0xf0) == 16)    // 172.16.0.0/12
335        || (octets[0] == 192 && octets[1] == 168)             // 192.168.0.0/16
336        || (octets[0] == 100 && (octets[1] & 0xc0) == 64)    // 100.64.0.0/10 (CGNAT)
337        || (octets[0] == 169 && octets[1] == 254)             // 169.254.0.0/16 (link-local)
338        || addr.is_unspecified() // 0.0.0.0
339}
340
341/// Check if an IPv6 address is in a private/reserved range (for rebind protection).
342fn is_private_ipv6(addr: std::net::Ipv6Addr) -> bool {
343    let segments = addr.segments();
344    addr.is_loopback()                       // ::1
345        || (segments[0] & 0xfe00) == 0xfc00  // fc00::/7 (ULA)
346        || (segments[0] & 0xffc0) == 0xfe80  // fe80::/10 (link-local)
347        || addr.is_unspecified() // ::
348}
349
350/// Check if a domain is blocked by the DNS config.
351///
352/// Block lists are pre-lowercased in [`NormalizedDnsConfig`], so only the
353/// queried domain needs lowercasing (once per query instead of per entry).
354fn is_domain_blocked(domain: &str, config: &NormalizedDnsConfig) -> bool {
355    let domain_lower = domain.to_lowercase();
356
357    // Check exact domain matches — O(1) via HashSet.
358    if config.blocked_domains.contains(&domain_lower) {
359        return true;
360    }
361
362    // Check suffix matches (already lowercased with pre-computed dot-prefixed forms).
363    for (suffix, dotted) in config
364        .blocked_suffixes
365        .iter()
366        .zip(config.blocked_suffixes_dotted.iter())
367    {
368        if domain_lower == *suffix || domain_lower.ends_with(dotted.as_str()) {
369            return true;
370        }
371    }
372
373    false
374}
375
376/// Build a REFUSED response for a query that was blocked by policy.
377///
378/// Uses REFUSED (RCODE 5) instead of SERVFAIL (RCODE 2) because the
379/// refusal is a policy decision, not a server failure. Most stub resolvers
380/// do not retry REFUSED, avoiding unnecessary latency.
381fn build_refused(raw_query: &[u8]) -> Option<Bytes> {
382    use hickory_proto::op::Message;
383    use hickory_proto::serialize::binary::{BinDecodable, BinEncodable};
384
385    let query_msg = Message::from_bytes(raw_query).ok()?;
386    let mut response = Message::new();
387    response.set_id(query_msg.id());
388    for q in query_msg.queries() {
389        response.add_query(q.clone());
390    }
391    response.set_message_type(hickory_proto::op::MessageType::Response);
392    response.set_response_code(hickory_proto::op::ResponseCode::Refused);
393    response.set_recursion_available(true);
394
395    let bytes = response.to_bytes().ok()?;
396    Some(Bytes::from(bytes))
397}
398
399//--------------------------------------------------------------------------------------------------
400// Tests
401//--------------------------------------------------------------------------------------------------
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    fn normalized(domains: Vec<&str>, suffixes: Vec<&str>) -> NormalizedDnsConfig {
408        let blocked_suffixes: Vec<String> = suffixes
409            .iter()
410            .map(|s| s.to_lowercase().trim_start_matches('.').to_string())
411            .collect();
412        let blocked_suffixes_dotted = blocked_suffixes.iter().map(|s| format!(".{s}")).collect();
413        NormalizedDnsConfig {
414            blocked_domains: domains
415                .iter()
416                .map(|d| d.to_lowercase())
417                .collect::<HashSet<_>>(),
418            blocked_suffixes,
419            blocked_suffixes_dotted,
420            rebind_protection: false,
421        }
422    }
423
424    #[test]
425    fn test_exact_domain_blocked() {
426        let config = normalized(vec!["evil.com"], vec![]);
427        assert!(is_domain_blocked("evil.com", &config));
428        assert!(is_domain_blocked("Evil.COM", &config));
429        assert!(!is_domain_blocked("not-evil.com", &config));
430        assert!(!is_domain_blocked("sub.evil.com", &config));
431    }
432
433    #[test]
434    fn test_suffix_domain_blocked() {
435        let config = normalized(vec![], vec![".evil.com"]);
436        assert!(is_domain_blocked("sub.evil.com", &config));
437        assert!(is_domain_blocked("deep.sub.evil.com", &config));
438        assert!(is_domain_blocked("evil.com", &config));
439        assert!(!is_domain_blocked("notevil.com", &config));
440    }
441
442    #[test]
443    fn test_no_blocks_nothing_blocked() {
444        let config = normalized(vec![], vec![]);
445        assert!(!is_domain_blocked("anything.com", &config));
446    }
447}