microsandbox_network/dns/
interceptor.rs1use std::collections::HashSet;
13use std::sync::Arc;
14
15use bytes::Bytes;
16use smoltcp::iface::SocketSet;
17use smoltcp::socket::udp;
18use smoltcp::storage::PacketMetadata;
19use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
20use tokio::sync::mpsc;
21
22use crate::config::DnsConfig;
23use crate::shared::SharedState;
24
25const DNS_PORT: u16 = 53;
31
32const DNS_MAX_SIZE: usize = 4096;
34
35const DNS_SOCKET_PACKET_SLOTS: usize = 16;
37
38const CHANNEL_CAPACITY: usize = 64;
40
41pub struct DnsInterceptor {
55 socket_handle: smoltcp::iface::SocketHandle,
57 query_tx: mpsc::Sender<DnsQuery>,
59 response_rx: mpsc::Receiver<DnsResponse>,
61}
62
63struct NormalizedDnsConfig {
65 blocked_domains: HashSet<String>,
67 blocked_suffixes: Vec<String>,
69 blocked_suffixes_dotted: Vec<String>,
71 rebind_protection: bool,
72}
73
74struct DnsQuery {
76 data: Bytes,
78 source: IpEndpoint,
80}
81
82struct DnsResponse {
84 data: Bytes,
86 dest: IpEndpoint,
88}
89
90impl DnsInterceptor {
95 pub fn new(
100 sockets: &mut SocketSet<'_>,
101 dns_config: DnsConfig,
102 shared: Arc<SharedState>,
103 tokio_handle: &tokio::runtime::Handle,
104 ) -> Self {
105 let rx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
107 let rx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
108 let tx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
109 let tx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
110
111 let mut socket = udp::Socket::new(
112 udp::PacketBuffer::new(rx_meta, rx_payload),
113 udp::PacketBuffer::new(tx_meta, tx_payload),
114 );
115 socket
116 .bind(IpListenEndpoint {
117 addr: None,
118 port: DNS_PORT,
119 })
120 .expect("failed to bind DNS socket to port 53");
121
122 let socket_handle = sockets.add(socket);
123
124 let (query_tx, query_rx) = mpsc::channel(CHANNEL_CAPACITY);
126 let (response_tx, response_rx) = mpsc::channel(CHANNEL_CAPACITY);
127
128 let suffixes: Vec<String> = dns_config
130 .blocked_suffixes
131 .iter()
132 .map(|s| s.to_lowercase().trim_start_matches('.').to_string())
133 .collect();
134 let suffixes_dotted: Vec<String> = suffixes.iter().map(|s| format!(".{s}")).collect();
135 let normalized = Arc::new(NormalizedDnsConfig {
136 blocked_domains: dns_config
137 .blocked_domains
138 .iter()
139 .map(|d| d.to_lowercase())
140 .collect(),
141 blocked_suffixes: suffixes,
142 blocked_suffixes_dotted: suffixes_dotted,
143 rebind_protection: dns_config.rebind_protection,
144 });
145
146 tokio_handle.spawn(dns_resolver_task(query_rx, response_tx, normalized, shared));
148
149 Self {
150 socket_handle,
151 query_tx,
152 response_rx,
153 }
154 }
155
156 pub fn process(&mut self, sockets: &mut SocketSet<'_>) {
162 let socket = sockets.get_mut::<udp::Socket>(self.socket_handle);
163
164 let mut buf = [0u8; DNS_MAX_SIZE];
166 while socket.can_recv() {
167 match socket.recv_slice(&mut buf) {
168 Ok((n, meta)) => {
169 let query = DnsQuery {
170 data: Bytes::copy_from_slice(&buf[..n]),
171 source: meta.endpoint,
172 };
173 if self.query_tx.try_send(query).is_err() {
174 tracing::debug!("DNS query channel full, dropping query");
176 }
177 }
178 Err(_) => break,
179 }
180 }
181
182 while socket.can_send() {
186 match self.response_rx.try_recv() {
187 Ok(response) => {
188 let _ = socket.send_slice(&response.data, response.dest);
189 }
190 Err(_) => break,
191 }
192 }
193 }
194}
195
196async fn dns_resolver_task(
205 mut query_rx: mpsc::Receiver<DnsQuery>,
206 response_tx: mpsc::Sender<DnsResponse>,
207 dns_config: Arc<NormalizedDnsConfig>,
208 shared: Arc<SharedState>,
209) {
210 let resolver = match hickory_resolver::Resolver::builder_tokio().map(|b| b.build()) {
212 Ok(r) => r,
213 Err(e) => {
214 tracing::error!(error = %e, "failed to create DNS resolver");
215 return;
216 }
217 };
218
219 while let Some(query) = query_rx.recv().await {
220 let response_tx = response_tx.clone();
221 let dns_config = dns_config.clone();
222 let shared = shared.clone();
223 let resolver = resolver.clone();
224
225 tokio::spawn(async move {
227 let result = resolve_query(&query.data, &dns_config, &resolver).await;
228 match result {
229 Some(response_data) => {
230 let response = DnsResponse {
231 data: response_data,
232 dest: query.source,
233 };
234 if response_tx.send(response).await.is_ok() {
235 shared.proxy_wake.wake();
236 }
237 }
238 None => {
239 if let Some(servfail) = build_refused(&query.data) {
241 let response = DnsResponse {
242 data: servfail,
243 dest: query.source,
244 };
245 if response_tx.send(response).await.is_ok() {
246 shared.proxy_wake.wake();
247 }
248 }
249 }
250 }
251 });
252 }
253}
254
255async fn resolve_query(
258 raw_query: &[u8],
259 dns_config: &NormalizedDnsConfig,
260 resolver: &hickory_resolver::TokioResolver,
261) -> Option<Bytes> {
262 use hickory_proto::op::Message;
263 use hickory_proto::rr::RData;
264 use hickory_proto::serialize::binary::BinDecodable;
265
266 let query_msg = Message::from_bytes(raw_query).ok()?;
268 let query_id = query_msg.id();
269
270 let question = query_msg.queries().first()?;
272 let domain = question.name().to_string();
273 let domain = domain.trim_end_matches('.');
274
275 if is_domain_blocked(domain, dns_config) {
277 tracing::debug!(domain = %domain, "DNS query blocked");
278 return None;
279 }
280
281 let record_type = question.query_type();
284
285 let lookup = resolver
286 .lookup(question.name().clone(), record_type)
287 .await
288 .ok()?;
289
290 if dns_config.rebind_protection {
292 for record in lookup.records() {
293 let is_private = match record.data() {
294 RData::A(a) => is_private_ipv4((*a).into()),
295 RData::AAAA(aaaa) => is_private_ipv6((*aaaa).into()),
296 _ => false,
297 };
298 if is_private {
299 tracing::debug!(
300 domain = %domain,
301 "DNS rebind protection: response contains private IP"
302 );
303 return None;
304 }
305 }
306 }
307
308 let mut response_msg = Message::new();
310 response_msg.set_id(query_id);
311 response_msg.set_message_type(hickory_proto::op::MessageType::Response);
312 response_msg.set_op_code(query_msg.op_code());
313 response_msg.set_response_code(hickory_proto::op::ResponseCode::NoError);
314 response_msg.set_recursion_desired(query_msg.recursion_desired());
315 response_msg.set_recursion_available(true);
316 response_msg.add_query(question.clone());
317
318 let answers: Vec<_> = lookup.records().to_vec();
320 response_msg.insert_answers(answers);
321
322 use hickory_proto::serialize::binary::BinEncodable;
324 let response_bytes = response_msg.to_bytes().ok()?;
325
326 Some(Bytes::from(response_bytes))
327}
328
329fn is_private_ipv4(addr: std::net::Ipv4Addr) -> bool {
331 let octets = addr.octets();
332 addr.is_loopback() || octets[0] == 10 || (octets[0] == 172 && (octets[1] & 0xf0) == 16) || (octets[0] == 192 && octets[1] == 168) || (octets[0] == 100 && (octets[1] & 0xc0) == 64) || (octets[0] == 169 && octets[1] == 254) || addr.is_unspecified() }
340
341fn is_private_ipv6(addr: std::net::Ipv6Addr) -> bool {
343 let segments = addr.segments();
344 addr.is_loopback() || (segments[0] & 0xfe00) == 0xfc00 || (segments[0] & 0xffc0) == 0xfe80 || addr.is_unspecified() }
349
350fn is_domain_blocked(domain: &str, config: &NormalizedDnsConfig) -> bool {
355 let domain_lower = domain.to_lowercase();
356
357 if config.blocked_domains.contains(&domain_lower) {
359 return true;
360 }
361
362 for (suffix, dotted) in config
364 .blocked_suffixes
365 .iter()
366 .zip(config.blocked_suffixes_dotted.iter())
367 {
368 if domain_lower == *suffix || domain_lower.ends_with(dotted.as_str()) {
369 return true;
370 }
371 }
372
373 false
374}
375
376fn build_refused(raw_query: &[u8]) -> Option<Bytes> {
382 use hickory_proto::op::Message;
383 use hickory_proto::serialize::binary::{BinDecodable, BinEncodable};
384
385 let query_msg = Message::from_bytes(raw_query).ok()?;
386 let mut response = Message::new();
387 response.set_id(query_msg.id());
388 for q in query_msg.queries() {
389 response.add_query(q.clone());
390 }
391 response.set_message_type(hickory_proto::op::MessageType::Response);
392 response.set_response_code(hickory_proto::op::ResponseCode::Refused);
393 response.set_recursion_available(true);
394
395 let bytes = response.to_bytes().ok()?;
396 Some(Bytes::from(bytes))
397}
398
399#[cfg(test)]
404mod tests {
405 use super::*;
406
407 fn normalized(domains: Vec<&str>, suffixes: Vec<&str>) -> NormalizedDnsConfig {
408 let blocked_suffixes: Vec<String> = suffixes
409 .iter()
410 .map(|s| s.to_lowercase().trim_start_matches('.').to_string())
411 .collect();
412 let blocked_suffixes_dotted = blocked_suffixes.iter().map(|s| format!(".{s}")).collect();
413 NormalizedDnsConfig {
414 blocked_domains: domains
415 .iter()
416 .map(|d| d.to_lowercase())
417 .collect::<HashSet<_>>(),
418 blocked_suffixes,
419 blocked_suffixes_dotted,
420 rebind_protection: false,
421 }
422 }
423
424 #[test]
425 fn test_exact_domain_blocked() {
426 let config = normalized(vec!["evil.com"], vec![]);
427 assert!(is_domain_blocked("evil.com", &config));
428 assert!(is_domain_blocked("Evil.COM", &config));
429 assert!(!is_domain_blocked("not-evil.com", &config));
430 assert!(!is_domain_blocked("sub.evil.com", &config));
431 }
432
433 #[test]
434 fn test_suffix_domain_blocked() {
435 let config = normalized(vec![], vec![".evil.com"]);
436 assert!(is_domain_blocked("sub.evil.com", &config));
437 assert!(is_domain_blocked("deep.sub.evil.com", &config));
438 assert!(is_domain_blocked("evil.com", &config));
439 assert!(!is_domain_blocked("notevil.com", &config));
440 }
441
442 #[test]
443 fn test_no_blocks_nothing_blocked() {
444 let config = normalized(vec![], vec![]);
445 assert!(!is_domain_blocked("anything.com", &config));
446 }
447}