Skip to main content

microsandbox_network/dns/
interceptor.rs

1//! DNS interception engine.
2//!
3//! Intercepts DNS queries destined for the sandbox gateway, resolves
4//! them via host nameservers using `hickory-resolver`, applies domain
5//! and rebind filters, records A/AAAA answers in the pin set, and
6//! synthesizes DNS response frames.
7
8use std::{
9    collections::{BTreeMap, HashMap},
10    hash::{Hash, Hasher},
11    net::IpAddr,
12    sync::{Arc, Mutex, RwLock},
13    time::{Duration, Instant},
14};
15
16use etherparse::TransportSlice;
17use hickory_proto::{
18    op::{Message, MessageType, ResponseCode},
19    rr::{RData, RecordType},
20    serialize::binary::BinDecodable,
21};
22use hickory_resolver::{TokioResolver, name_server::TokioConnectionProvider};
23
24use crate::{packet::ParsedFrame, policy::DnsPinSet};
25
26use super::DnsFilter;
27
28//--------------------------------------------------------------------------------------------------
29// Constants
30//--------------------------------------------------------------------------------------------------
31
32const TCP_STREAM_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
33
34/// Maximum TCP DNS reassembly buffer size (max DNS message = 65535 + 2-byte length prefix).
35const MAX_TCP_REQUEST_BUFFER: usize = 65537;
36
37/// Maximum concurrent TCP DNS streams to prevent SYN-flood memory exhaustion.
38const MAX_TCP_STREAMS: usize = 64;
39
40/// Maximum number of out-of-order pending TCP segments per stream.
41const MAX_PENDING_SEGMENTS: usize = 16;
42
43//--------------------------------------------------------------------------------------------------
44// Types
45//--------------------------------------------------------------------------------------------------
46
47/// DNS interceptor that resolves queries via host nameservers.
48pub struct DnsInterceptor {
49    /// Host DNS resolver.
50    resolver: TokioResolver,
51
52    /// Domain and rebind filter.
53    filter: DnsFilter,
54
55    /// Shared pin set for recording resolved IPs.
56    pin_set: Arc<RwLock<DnsPinSet>>,
57
58    /// Gateway IP addresses — DNS queries to these IPs are intercepted.
59    gateway_ips: Vec<IpAddr>,
60
61    /// Reassembly and TCP session state for DNS-over-TCP interception.
62    tcp_streams: Mutex<HashMap<TcpFlowKey, TcpStreamState>>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub enum DnsInterceptResult {
67    Intercepted,
68    NotIntercepted,
69    Responses(Vec<DnsInterceptResponse>),
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct DnsInterceptResponse {
74    pub payload: Vec<u8>,
75    pub tcp_sequence_number: Option<u32>,
76    pub tcp_acknowledgment_number: Option<u32>,
77    pub tcp_flags: Option<TcpResponseFlags>,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct TcpResponseFlags {
82    pub syn: bool,
83    pub ack: bool,
84    pub fin: bool,
85    pub rst: bool,
86    pub psh: bool,
87}
88
89#[derive(Clone, Copy)]
90enum DnsTransport {
91    Tcp,
92    Udp,
93}
94
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96struct TcpFlowKey {
97    src_ip: IpAddr,
98    dst_ip: IpAddr,
99    src_port: u16,
100    dst_port: u16,
101}
102
103#[derive(Debug)]
104struct TcpStreamState {
105    request_buffer: Vec<u8>,
106    pending_segments: BTreeMap<u32, BufferedTcpSegment>,
107    next_client_sequence: u32,
108    next_server_sequence: u32,
109    last_activity: Instant,
110}
111
112#[derive(Debug)]
113struct BufferedTcpSegment {
114    payload: Vec<u8>,
115    fin: bool,
116}
117
118struct NormalizedTcpSegment<'a> {
119    sequence_number: u32,
120    payload: &'a [u8],
121    fin: bool,
122}
123
124struct TcpFrameInfo<'a> {
125    flow: TcpFlowKey,
126    payload: &'a [u8],
127    sequence_number: u32,
128    syn: bool,
129    ack: bool,
130    fin: bool,
131    rst: bool,
132}
133
134enum ParsedDnsPayload<'a> {
135    Tcp(TcpFrameInfo<'a>),
136    Udp(&'a [u8]),
137}
138
139struct TcpInterceptWork {
140    flow: TcpFlowKey,
141    queries: Vec<Vec<u8>>,
142    acknowledgment_number: u32,
143    fin: bool,
144}
145
146//--------------------------------------------------------------------------------------------------
147// Methods
148//--------------------------------------------------------------------------------------------------
149
150impl DnsInterceptor {
151    /// Creates a new DNS interceptor.
152    ///
153    /// Returns an error if the system DNS resolver configuration cannot be read.
154    pub fn new(
155        filter: DnsFilter,
156        pin_set: Arc<RwLock<DnsPinSet>>,
157        gateway_ips: Vec<IpAddr>,
158    ) -> std::io::Result<Self> {
159        let resolver = TokioResolver::builder(TokioConnectionProvider::default())
160            .map_err(|e| std::io::Error::other(format!("failed to read system DNS config: {e}")))?
161            .build();
162
163        Ok(Self {
164            resolver,
165            filter,
166            pin_set,
167            gateway_ips,
168            tcp_streams: Mutex::new(HashMap::new()),
169        })
170    }
171
172    /// Checks if a frame is a DNS query to the gateway and resolves it locally.
173    ///
174    /// Returns `NotIntercepted` if the frame should continue to the backend,
175    /// `Intercepted` if the frame was consumed locally without a response yet,
176    /// or one or more synthesized DNS responses.
177    pub async fn maybe_intercept(&self, frame: &ParsedFrame<'_>) -> DnsInterceptResult {
178        self.prune_expired_tcp_streams();
179
180        if frame.dst_port() != Some(53) {
181            return DnsInterceptResult::NotIntercepted;
182        }
183
184        let Some(dst_ip) = frame.dst_ip() else {
185            return DnsInterceptResult::Intercepted;
186        };
187        if !self.gateway_ips.contains(&dst_ip) {
188            return DnsInterceptResult::NotIntercepted;
189        }
190
191        match dns_query_payload(frame) {
192            Some(ParsedDnsPayload::Udp(payload)) => match self.resolve_query(payload).await {
193                Some(payload) => {
194                    let payload = match encode_dns_response(DnsTransport::Udp, payload) {
195                        Some(payload) => payload,
196                        None => return DnsInterceptResult::Intercepted,
197                    };
198
199                    DnsInterceptResult::Responses(vec![DnsInterceptResponse {
200                        payload,
201                        tcp_sequence_number: None,
202                        tcp_acknowledgment_number: None,
203                        tcp_flags: None,
204                    }])
205                }
206                None => DnsInterceptResult::Intercepted,
207            },
208            Some(ParsedDnsPayload::Tcp(info)) => self.handle_tcp_frame(info).await,
209            None => DnsInterceptResult::Intercepted,
210        }
211    }
212
213    /// Resolves a DNS query and applies rebind filtering.
214    async fn resolve_and_filter(
215        &self,
216        query: &Message,
217        domain: &str,
218        record_type: RecordType,
219    ) -> Vec<u8> {
220        // For non-A/AAAA queries (MX, TXT, SRV, etc.), forward via the
221        // general lookup API rather than synthesizing NXDOMAIN.
222        if record_type != RecordType::A && record_type != RecordType::AAAA {
223            return match self.resolver.lookup(domain, record_type).await {
224                Ok(lookup) => {
225                    let mut response = Message::new();
226                    response.set_id(query.id());
227                    response.set_message_type(MessageType::Response);
228                    response.set_response_code(ResponseCode::NoError);
229                    response.add_queries(query.queries().to_vec());
230                    for record in lookup.record_iter() {
231                        response.add_answer(record.clone());
232                    }
233                    response.to_vec().unwrap_or_default()
234                }
235                Err(_) => build_nxdomain_response(query),
236            };
237        }
238
239        let lookup_result = match record_type {
240            RecordType::A => self
241                .resolver
242                .ipv4_lookup(domain)
243                .await
244                .map(|l| l.iter().map(|ip| IpAddr::V4(ip.0)).collect::<Vec<_>>()),
245            RecordType::AAAA => self
246                .resolver
247                .ipv6_lookup(domain)
248                .await
249                .map(|l| l.iter().map(|ip| IpAddr::V6(ip.0)).collect::<Vec<_>>()),
250            _ => unreachable!(),
251        };
252
253        match lookup_result {
254            Ok(ips) => {
255                let ips: Vec<IpAddr> = ips
256                    .into_iter()
257                    .filter(|ip| !self.filter.is_rebind_blocked(*ip))
258                    .collect();
259
260                if let Ok(mut pin_set) = self.pin_set.write() {
261                    for ip in &ips {
262                        pin_set.pin(domain, *ip);
263                    }
264                }
265
266                build_success_response(query, &ips, record_type)
267            }
268            Err(_) => build_nxdomain_response(query),
269        }
270    }
271
272    async fn handle_tcp_frame(&self, info: TcpFrameInfo<'_>) -> DnsInterceptResult {
273        if info.rst {
274            self.remove_tcp_state(&info.flow);
275            return DnsInterceptResult::Intercepted;
276        }
277
278        if info.syn && !info.ack {
279            return self.handle_tcp_syn(info.flow, info.sequence_number);
280        }
281
282        let Some(work) = self.collect_tcp_queries(&info) else {
283            return DnsInterceptResult::Intercepted;
284        };
285
286        if work.queries.is_empty() {
287            if work.fin {
288                let sequence_number = match self.remove_tcp_state(&work.flow) {
289                    Some(state) => state.next_server_sequence,
290                    None => return DnsInterceptResult::Intercepted,
291                };
292
293                return DnsInterceptResult::Responses(vec![DnsInterceptResponse {
294                    payload: Vec::new(),
295                    tcp_sequence_number: Some(sequence_number),
296                    tcp_acknowledgment_number: Some(work.acknowledgment_number),
297                    tcp_flags: Some(TcpResponseFlags {
298                        syn: false,
299                        ack: true,
300                        fin: true,
301                        rst: false,
302                        psh: false,
303                    }),
304                }]);
305            }
306
307            return DnsInterceptResult::Intercepted;
308        }
309
310        let mut sequence_number = match self.current_server_sequence(&work.flow) {
311            Some(sequence_number) => sequence_number,
312            None => return DnsInterceptResult::Intercepted,
313        };
314
315        let mut responses = Vec::new();
316        for query in &work.queries {
317            let Some(payload) = self.resolve_query(query).await else {
318                continue;
319            };
320            let response_len = payload.len();
321            let payload = match encode_dns_response(DnsTransport::Tcp, payload) {
322                Some(payload) => payload,
323                None => continue,
324            };
325
326            responses.push(DnsInterceptResponse {
327                payload,
328                tcp_sequence_number: Some(sequence_number),
329                tcp_acknowledgment_number: Some(work.acknowledgment_number),
330                tcp_flags: Some(TcpResponseFlags {
331                    syn: false,
332                    ack: true,
333                    fin: false,
334                    rst: false,
335                    psh: response_len > 0,
336                }),
337            });
338
339            sequence_number =
340                sequence_number.wrapping_add(u32::try_from(response_len + 2).unwrap_or(0));
341        }
342
343        self.update_server_sequence(&work.flow, sequence_number);
344
345        if work.fin {
346            self.remove_tcp_state(&work.flow);
347            responses.push(DnsInterceptResponse {
348                payload: Vec::new(),
349                tcp_sequence_number: Some(sequence_number),
350                tcp_acknowledgment_number: Some(work.acknowledgment_number),
351                tcp_flags: Some(TcpResponseFlags {
352                    syn: false,
353                    ack: true,
354                    fin: true,
355                    rst: false,
356                    psh: false,
357                }),
358            });
359        }
360
361        if responses.is_empty() {
362            DnsInterceptResult::Intercepted
363        } else {
364            DnsInterceptResult::Responses(responses)
365        }
366    }
367
368    fn handle_tcp_syn(&self, flow: TcpFlowKey, client_sequence_number: u32) -> DnsInterceptResult {
369        let server_initial_sequence = initial_server_sequence(&flow, client_sequence_number);
370        let next_client_sequence = client_sequence_number.wrapping_add(1);
371        let next_server_sequence = server_initial_sequence.wrapping_add(1);
372
373        {
374            let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
375
376            // Enforce stream count cap to prevent SYN-flood memory exhaustion.
377            if streams.len() >= MAX_TCP_STREAMS {
378                return DnsInterceptResult::Intercepted;
379            }
380
381            streams.insert(
382                flow,
383                TcpStreamState {
384                    request_buffer: Vec::new(),
385                    pending_segments: BTreeMap::new(),
386                    next_client_sequence,
387                    next_server_sequence,
388                    last_activity: Instant::now(),
389                },
390            );
391        }
392
393        DnsInterceptResult::Responses(vec![DnsInterceptResponse {
394            payload: Vec::new(),
395            tcp_sequence_number: Some(server_initial_sequence),
396            tcp_acknowledgment_number: Some(next_client_sequence),
397            tcp_flags: Some(TcpResponseFlags {
398                syn: true,
399                ack: true,
400                fin: false,
401                rst: false,
402                psh: false,
403            }),
404        }])
405    }
406
407    async fn resolve_query(&self, payload: &[u8]) -> Option<Vec<u8>> {
408        let query = Message::from_bytes(payload).ok()?;
409        if query.message_type() != MessageType::Query {
410            return None;
411        }
412
413        let question = query.queries().first()?;
414        let domain = question.name().to_string();
415        let record_type = question.query_type();
416
417        let response_payload = if self.filter.is_domain_blocked(&domain) {
418            build_refused_response(&query)
419        } else {
420            self.resolve_and_filter(&query, &domain, record_type).await
421        };
422
423        Some(response_payload)
424    }
425
426    fn collect_tcp_queries(&self, info: &TcpFrameInfo<'_>) -> Option<TcpInterceptWork> {
427        let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
428        let state = streams.get_mut(&info.flow)?;
429        state.last_activity = Instant::now();
430
431        let Some(segment) = normalize_tcp_segment(
432            info.sequence_number,
433            info.payload,
434            info.fin,
435            state.next_client_sequence,
436        ) else {
437            return Some(TcpInterceptWork {
438                flow: info.flow.clone(),
439                queries: Vec::new(),
440                acknowledgment_number: state.next_client_sequence,
441                fin: false,
442            });
443        };
444
445        if segment.sequence_number != state.next_client_sequence {
446            store_pending_segment(state, segment.sequence_number, segment.payload, segment.fin)?;
447            return Some(TcpInterceptWork {
448                flow: info.flow.clone(),
449                queries: Vec::new(),
450                acknowledgment_number: state.next_client_sequence,
451                fin: false,
452            });
453        }
454
455        let mut close_after_queries = append_tcp_segment(state, segment.payload, segment.fin)?;
456        close_after_queries |= drain_pending_segments(state)?;
457
458        let mut queries = Vec::new();
459        while state.request_buffer.len() >= 2 {
460            let dns_len = usize::from(u16::from_be_bytes([
461                state.request_buffer[0],
462                state.request_buffer[1],
463            ]));
464            if dns_len == 0 {
465                state.request_buffer.clear();
466                break;
467            }
468
469            if state.request_buffer.len() < dns_len + 2 {
470                break;
471            }
472
473            queries.push(state.request_buffer[2..dns_len + 2].to_vec());
474            state.request_buffer.drain(..dns_len + 2);
475        }
476
477        Some(TcpInterceptWork {
478            flow: info.flow.clone(),
479            queries,
480            acknowledgment_number: state.next_client_sequence,
481            fin: close_after_queries,
482        })
483    }
484
485    fn current_server_sequence(&self, flow: &TcpFlowKey) -> Option<u32> {
486        let streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
487        streams.get(flow).map(|state| state.next_server_sequence)
488    }
489
490    fn update_server_sequence(&self, flow: &TcpFlowKey, next_server_sequence: u32) {
491        let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
492        if let Some(state) = streams.get_mut(flow) {
493            state.next_server_sequence = next_server_sequence;
494        }
495    }
496
497    fn remove_tcp_state(&self, flow: &TcpFlowKey) -> Option<TcpStreamState> {
498        let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
499        streams.remove(flow)
500    }
501
502    fn prune_expired_tcp_streams(&self) {
503        let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
504        let now = Instant::now();
505        streams
506            .retain(|_, state| now.duration_since(state.last_activity) <= TCP_STREAM_IDLE_TIMEOUT);
507    }
508}
509
510//--------------------------------------------------------------------------------------------------
511// Functions
512//--------------------------------------------------------------------------------------------------
513
514fn dns_query_payload<'a>(frame: &'a ParsedFrame<'a>) -> Option<ParsedDnsPayload<'a>> {
515    match &frame.sliced().transport {
516        Some(TransportSlice::Udp(_)) => {
517            let payload = frame.payload();
518            if payload.is_empty() {
519                None
520            } else {
521                Some(ParsedDnsPayload::Udp(payload))
522            }
523        }
524        Some(TransportSlice::Tcp(tcp)) => Some(ParsedDnsPayload::Tcp(TcpFrameInfo {
525            flow: TcpFlowKey {
526                src_ip: frame.src_ip()?,
527                dst_ip: frame.dst_ip()?,
528                src_port: frame.src_port()?,
529                dst_port: frame.dst_port()?,
530            },
531            payload: frame.payload(),
532            sequence_number: tcp.sequence_number(),
533            syn: tcp.syn(),
534            ack: tcp.ack(),
535            fin: tcp.fin(),
536            rst: tcp.rst(),
537        })),
538        _ => None,
539    }
540}
541
542fn initial_server_sequence(flow: &TcpFlowKey, client_sequence_number: u32) -> u32 {
543    let mut hasher = std::collections::hash_map::DefaultHasher::new();
544    flow.hash(&mut hasher);
545    client_sequence_number.hash(&mut hasher);
546    hasher.finish() as u32
547}
548
549fn append_tcp_segment(state: &mut TcpStreamState, payload: &[u8], fin: bool) -> Option<bool> {
550    let advance = tcp_sequence_advance(payload.len(), fin, false)?;
551    state.next_client_sequence = state.next_client_sequence.wrapping_add(advance);
552
553    if !payload.is_empty() {
554        // Enforce buffer cap to prevent guest-triggered OOM.
555        if state.request_buffer.len() + payload.len() > MAX_TCP_REQUEST_BUFFER {
556            return None;
557        }
558        state.request_buffer.extend_from_slice(payload);
559    }
560
561    Some(fin)
562}
563
564fn drain_pending_segments(state: &mut TcpStreamState) -> Option<bool> {
565    let mut close_after_queries = false;
566
567    // Use direct key lookup instead of iter().next() to avoid BTreeMap
568    // natural ordering issues at TCP sequence number wraparound.
569    while let Some(segment) = state.pending_segments.remove(&state.next_client_sequence) {
570        close_after_queries |= append_tcp_segment(state, &segment.payload, segment.fin)?;
571    }
572
573    Some(close_after_queries)
574}
575
576fn normalize_tcp_segment<'a>(
577    sequence_number: u32,
578    payload: &'a [u8],
579    fin: bool,
580    next_client_sequence: u32,
581) -> Option<NormalizedTcpSegment<'a>> {
582    // Use wrapping subtraction + signed comparison to handle sequence
583    // number wraparound at u32::MAX correctly (standard TCP practice).
584    let diff = sequence_number.wrapping_sub(next_client_sequence) as i32;
585    if diff >= 0 {
586        return Some(NormalizedTcpSegment {
587            sequence_number,
588            payload,
589            fin,
590        });
591    }
592
593    let segment_advance = tcp_sequence_advance(payload.len(), fin, false)?;
594    let duplicate_advance = next_client_sequence.wrapping_sub(sequence_number);
595    if duplicate_advance >= segment_advance {
596        return None;
597    }
598
599    let duplicate_bytes = usize::try_from(duplicate_advance).ok()?.min(payload.len());
600    let fin_consumed = fin && duplicate_advance > payload.len() as u32;
601
602    Some(NormalizedTcpSegment {
603        sequence_number: next_client_sequence,
604        payload: &payload[duplicate_bytes..],
605        fin: fin && !fin_consumed,
606    })
607}
608
609fn store_pending_segment(
610    state: &mut TcpStreamState,
611    sequence_number: u32,
612    payload: &[u8],
613    fin: bool,
614) -> Option<()> {
615    // Enforce pending segment cap to prevent guest-triggered OOM.
616    if state.pending_segments.len() >= MAX_PENDING_SEGMENTS {
617        return None;
618    }
619
620    let new_advance = tcp_sequence_advance(payload.len(), fin, false)?;
621
622    match state.pending_segments.entry(sequence_number) {
623        std::collections::btree_map::Entry::Vacant(entry) => {
624            entry.insert(BufferedTcpSegment {
625                payload: payload.to_vec(),
626                fin,
627            });
628        }
629        std::collections::btree_map::Entry::Occupied(mut entry) => {
630            let existing = entry.get();
631            let existing_advance =
632                tcp_sequence_advance(existing.payload.len(), existing.fin, false)?;
633            if new_advance > existing_advance {
634                entry.insert(BufferedTcpSegment {
635                    payload: payload.to_vec(),
636                    fin,
637                });
638            }
639        }
640    }
641
642    Some(())
643}
644
645fn tcp_sequence_advance(payload_len: usize, fin: bool, syn: bool) -> Option<u32> {
646    u32::try_from(payload_len).ok().map(|advance| {
647        advance
648            .wrapping_add(u32::from(fin))
649            .wrapping_add(u32::from(syn))
650    })
651}
652
653fn encode_dns_response(transport: DnsTransport, payload: Vec<u8>) -> Option<Vec<u8>> {
654    match transport {
655        DnsTransport::Udp => Some(payload),
656        DnsTransport::Tcp => {
657            let payload_len = u16::try_from(payload.len()).ok()?;
658            let mut framed = Vec::with_capacity(payload.len() + 2);
659            framed.extend_from_slice(&payload_len.to_be_bytes());
660            framed.extend_from_slice(&payload);
661            Some(framed)
662        }
663    }
664}
665
666/// Builds a DNS REFUSED response.
667fn build_refused_response(query: &Message) -> Vec<u8> {
668    let mut response = Message::new();
669    response.set_id(query.id());
670    response.set_message_type(MessageType::Response);
671    response.set_response_code(ResponseCode::Refused);
672    response.add_queries(query.queries().to_vec());
673
674    response.to_vec().unwrap_or_default()
675}
676
677/// Builds a DNS NXDOMAIN response.
678fn build_nxdomain_response(query: &Message) -> Vec<u8> {
679    let mut response = Message::new();
680    response.set_id(query.id());
681    response.set_message_type(MessageType::Response);
682    response.set_response_code(ResponseCode::NXDomain);
683    response.add_queries(query.queries().to_vec());
684
685    response.to_vec().unwrap_or_default()
686}
687
688/// Builds a DNS success response with the given IP addresses.
689fn build_success_response(query: &Message, ips: &[IpAddr], record_type: RecordType) -> Vec<u8> {
690    use hickory_proto::rr::{Name, Record};
691    use std::str::FromStr;
692
693    let mut response = Message::new();
694    response.set_id(query.id());
695    response.set_message_type(MessageType::Response);
696    response.set_response_code(ResponseCode::NoError);
697    response.add_queries(query.queries().to_vec());
698
699    let name = query
700        .queries()
701        .first()
702        .map(|q| q.name().clone())
703        .unwrap_or_else(|| Name::from_str(".").unwrap());
704
705    for ip in ips {
706        let rdata = match (ip, record_type) {
707            (IpAddr::V4(v4), RecordType::A) => RData::A((*v4).into()),
708            (IpAddr::V6(v6), RecordType::AAAA) => RData::AAAA((*v6).into()),
709            _ => continue,
710        };
711
712        let record = Record::from_rdata(name.clone(), 60, rdata);
713        response.add_answer(record);
714    }
715
716    response.to_vec().unwrap_or_default()
717}
718
719//--------------------------------------------------------------------------------------------------
720// Tests
721//--------------------------------------------------------------------------------------------------
722
723#[cfg(test)]
724mod tests {
725    use std::{
726        net::Ipv4Addr,
727        time::{Duration, Instant},
728    };
729
730    use etherparse::PacketBuilder;
731    use hickory_proto::{op::Query, rr::Name};
732
733    use super::*;
734
735    fn build_dns_query(domain: &str) -> Vec<u8> {
736        let mut message = Message::new();
737        message.set_id(7);
738        message.set_message_type(MessageType::Query);
739        message.add_query(Query::query(
740            Name::from_ascii(domain).unwrap(),
741            RecordType::A,
742        ));
743        message.to_vec().unwrap()
744    }
745
746    fn build_udp_frame(payload: &[u8]) -> Vec<u8> {
747        let mut frame = Vec::new();
748        PacketBuilder::ethernet2(
749            [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
750            [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
751        )
752        .ipv4([100, 96, 0, 2], [100, 96, 0, 1], 64)
753        .udp(51000, 53)
754        .write(&mut frame, payload)
755        .unwrap();
756        frame
757    }
758
759    fn build_tcp_frame(
760        payload: &[u8],
761        sequence_number: u32,
762        acknowledgment_number: Option<u32>,
763        syn: bool,
764        fin: bool,
765    ) -> Vec<u8> {
766        let mut builder = PacketBuilder::ethernet2(
767            [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
768            [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
769        )
770        .ipv4([100, 96, 0, 2], [100, 96, 0, 1], 64)
771        .tcp(51000, 53, sequence_number, 200);
772
773        if let Some(acknowledgment_number) = acknowledgment_number {
774            builder = builder.ack(acknowledgment_number);
775        }
776        if syn {
777            builder = builder.syn();
778        }
779        if fin {
780            builder = builder.fin();
781        }
782        if !payload.is_empty() {
783            builder = builder.psh();
784        }
785
786        let mut frame = Vec::new();
787        builder.write(&mut frame, payload).unwrap();
788        frame
789    }
790
791    fn build_interceptor() -> DnsInterceptor {
792        DnsInterceptor::new(
793            DnsFilter::new(vec!["blocked.example.".to_string()], vec![], false),
794            Arc::new(RwLock::new(DnsPinSet::new())),
795            vec![IpAddr::V4(Ipv4Addr::new(100, 96, 0, 1))],
796        )
797        .unwrap()
798    }
799
800    async fn establish_tcp_session(
801        interceptor: &DnsInterceptor,
802        client_sequence_number: u32,
803    ) -> u32 {
804        let syn = build_tcp_frame(&[], client_sequence_number, None, true, false);
805        let parsed = ParsedFrame::parse(&syn).unwrap();
806
807        let responses = match interceptor.maybe_intercept(&parsed).await {
808            DnsInterceptResult::Responses(responses) => responses,
809            other => panic!("expected SYN-ACK, got {other:?}"),
810        };
811
812        assert_eq!(responses.len(), 1);
813        let response = &responses[0];
814        let tcp_flags = response.tcp_flags.unwrap();
815        assert!(tcp_flags.syn);
816        assert!(tcp_flags.ack);
817        assert_eq!(
818            response.tcp_acknowledgment_number,
819            Some(client_sequence_number.wrapping_add(1)),
820        );
821        response.tcp_sequence_number.unwrap().wrapping_add(1)
822    }
823
824    #[test]
825    fn test_dns_query_payload_extracts_udp_message() {
826        let query = build_dns_query("blocked.example.");
827        let frame = build_udp_frame(&query);
828        let parsed = ParsedFrame::parse(&frame).unwrap();
829
830        match dns_query_payload(&parsed).unwrap() {
831            ParsedDnsPayload::Udp(payload) => assert_eq!(payload, query.as_slice()),
832            ParsedDnsPayload::Tcp(_) => panic!("expected UDP payload"),
833        }
834    }
835
836    #[test]
837    fn test_dns_query_payload_extracts_tcp_syn() {
838        let frame = build_tcp_frame(&[], 10, None, true, false);
839        let parsed = ParsedFrame::parse(&frame).unwrap();
840
841        match dns_query_payload(&parsed).unwrap() {
842            ParsedDnsPayload::Tcp(info) => {
843                assert!(info.syn);
844                assert!(!info.ack);
845                assert_eq!(info.sequence_number, 10);
846            }
847            ParsedDnsPayload::Udp(_) => panic!("expected TCP payload"),
848        }
849    }
850
851    #[tokio::test]
852    async fn test_maybe_intercept_establishes_tcp_session_with_syn_ack() {
853        let interceptor = build_interceptor();
854        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
855        assert_ne!(server_next_sequence, 0);
856    }
857
858    #[tokio::test]
859    async fn test_maybe_intercept_returns_tcp_framed_response() {
860        let query = build_dns_query("blocked.example.");
861        let mut payload = Vec::with_capacity(query.len() + 2);
862        payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
863        payload.extend_from_slice(&query);
864
865        let interceptor = build_interceptor();
866        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
867
868        let frame = build_tcp_frame(&payload, 11, Some(server_next_sequence), false, false);
869        let parsed = ParsedFrame::parse(&frame).unwrap();
870
871        let responses = match interceptor.maybe_intercept(&parsed).await {
872            DnsInterceptResult::Responses(responses) => responses,
873            other => panic!("expected TCP response, got {other:?}"),
874        };
875
876        assert_eq!(responses.len(), 1);
877        let response = &responses[0];
878        let response_len = usize::from(u16::from_be_bytes([
879            response.payload[0],
880            response.payload[1],
881        ]));
882        let message = Message::from_bytes(&response.payload[2..]).unwrap();
883
884        assert_eq!(response_len, response.payload.len() - 2);
885        assert_eq!(message.response_code(), ResponseCode::Refused);
886        assert_eq!(message.message_type(), MessageType::Response);
887        assert_eq!(response.tcp_sequence_number, Some(server_next_sequence));
888        assert_eq!(
889            response.tcp_acknowledgment_number,
890            Some(11u32.wrapping_add(payload.len() as u32)),
891        );
892    }
893
894    #[tokio::test]
895    async fn test_maybe_intercept_buffers_split_tcp_query() {
896        let query = build_dns_query("blocked.example.");
897        let mut full_payload = Vec::with_capacity(query.len() + 2);
898        full_payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
899        full_payload.extend_from_slice(&query);
900
901        let split_at = 5;
902        let interceptor = build_interceptor();
903        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
904
905        let first = build_tcp_frame(
906            &full_payload[..split_at],
907            11,
908            Some(server_next_sequence),
909            false,
910            false,
911        );
912        let second = build_tcp_frame(
913            &full_payload[split_at..],
914            11 + split_at as u32,
915            Some(server_next_sequence),
916            false,
917            false,
918        );
919
920        let first_parsed = ParsedFrame::parse(&first).unwrap();
921        assert_eq!(
922            interceptor.maybe_intercept(&first_parsed).await,
923            DnsInterceptResult::Intercepted,
924        );
925
926        let second_parsed = ParsedFrame::parse(&second).unwrap();
927        let responses = match interceptor.maybe_intercept(&second_parsed).await {
928            DnsInterceptResult::Responses(responses) => responses,
929            other => panic!("expected buffered TCP response, got {other:?}"),
930        };
931
932        assert_eq!(responses.len(), 1);
933        assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
934        assert_eq!(
935            responses[0].tcp_acknowledgment_number,
936            Some(11u32.wrapping_add(full_payload.len() as u32)),
937        );
938    }
939
940    #[tokio::test]
941    async fn test_maybe_intercept_recovers_after_out_of_order_tcp_segment() {
942        let query = build_dns_query("blocked.example.");
943        let mut full_payload = Vec::with_capacity(query.len() + 2);
944        full_payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
945        full_payload.extend_from_slice(&query);
946
947        let first_end = 5;
948        let second_end = 10;
949
950        let interceptor = build_interceptor();
951        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
952
953        let first = build_tcp_frame(
954            &full_payload[..first_end],
955            11,
956            Some(server_next_sequence),
957            false,
958            false,
959        );
960        let third = build_tcp_frame(
961            &full_payload[second_end..],
962            11 + second_end as u32,
963            Some(server_next_sequence),
964            false,
965            false,
966        );
967        let second = build_tcp_frame(
968            &full_payload[first_end..second_end],
969            11 + first_end as u32,
970            Some(server_next_sequence),
971            false,
972            false,
973        );
974
975        let first_parsed = ParsedFrame::parse(&first).unwrap();
976        assert_eq!(
977            interceptor.maybe_intercept(&first_parsed).await,
978            DnsInterceptResult::Intercepted,
979        );
980
981        let third_parsed = ParsedFrame::parse(&third).unwrap();
982        assert_eq!(
983            interceptor.maybe_intercept(&third_parsed).await,
984            DnsInterceptResult::Intercepted,
985        );
986
987        let second_parsed = ParsedFrame::parse(&second).unwrap();
988        let responses = match interceptor.maybe_intercept(&second_parsed).await {
989            DnsInterceptResult::Responses(responses) => responses,
990            other => panic!("expected reordered TCP response, got {other:?}"),
991        };
992
993        assert_eq!(responses.len(), 1);
994        assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
995        assert_eq!(
996            responses[0].tcp_acknowledgment_number,
997            Some(11u32.wrapping_add(full_payload.len() as u32)),
998        );
999    }
1000
1001    #[tokio::test]
1002    async fn test_maybe_intercept_recovers_after_overlapping_tcp_retransmit() {
1003        let query = build_dns_query("blocked.example.");
1004        let mut full_payload = Vec::with_capacity(query.len() + 2);
1005        full_payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
1006        full_payload.extend_from_slice(&query);
1007
1008        let first_end = 8;
1009        let overlap_start = 5;
1010
1011        let interceptor = build_interceptor();
1012        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1013
1014        let first = build_tcp_frame(
1015            &full_payload[..first_end],
1016            11,
1017            Some(server_next_sequence),
1018            false,
1019            false,
1020        );
1021        let overlapping = build_tcp_frame(
1022            &full_payload[overlap_start..],
1023            11 + overlap_start as u32,
1024            Some(server_next_sequence),
1025            false,
1026            false,
1027        );
1028
1029        let first_parsed = ParsedFrame::parse(&first).unwrap();
1030        assert_eq!(
1031            interceptor.maybe_intercept(&first_parsed).await,
1032            DnsInterceptResult::Intercepted,
1033        );
1034
1035        let overlapping_parsed = ParsedFrame::parse(&overlapping).unwrap();
1036        let responses = match interceptor.maybe_intercept(&overlapping_parsed).await {
1037            DnsInterceptResult::Responses(responses) => responses,
1038            other => panic!("expected overlapping TCP response, got {other:?}"),
1039        };
1040
1041        assert_eq!(responses.len(), 1);
1042        assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
1043        assert_eq!(
1044            responses[0].tcp_acknowledgment_number,
1045            Some(11u32.wrapping_add(full_payload.len() as u32)),
1046        );
1047    }
1048
1049    #[tokio::test]
1050    async fn test_maybe_intercept_handles_two_queries_in_one_tcp_segment() {
1051        let query = build_dns_query("blocked.example.");
1052        let framed_query = {
1053            let mut payload = Vec::with_capacity(query.len() + 2);
1054            payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
1055            payload.extend_from_slice(&query);
1056            payload
1057        };
1058
1059        let mut combined = Vec::with_capacity(framed_query.len() * 2);
1060        combined.extend_from_slice(&framed_query);
1061        combined.extend_from_slice(&framed_query);
1062
1063        let interceptor = build_interceptor();
1064        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1065
1066        let frame = build_tcp_frame(&combined, 11, Some(server_next_sequence), false, false);
1067        let parsed = ParsedFrame::parse(&frame).unwrap();
1068
1069        let responses = match interceptor.maybe_intercept(&parsed).await {
1070            DnsInterceptResult::Responses(responses) => responses,
1071            other => panic!("expected pipelined TCP responses, got {other:?}"),
1072        };
1073
1074        assert_eq!(responses.len(), 2);
1075        assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
1076        assert!(
1077            responses[1].tcp_sequence_number.unwrap() > responses[0].tcp_sequence_number.unwrap()
1078        );
1079    }
1080
1081    #[tokio::test]
1082    async fn test_maybe_intercept_closes_tcp_session_on_fin() {
1083        let interceptor = build_interceptor();
1084        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1085
1086        let fin = build_tcp_frame(&[], 11, Some(server_next_sequence), false, true);
1087        let parsed = ParsedFrame::parse(&fin).unwrap();
1088
1089        let responses = match interceptor.maybe_intercept(&parsed).await {
1090            DnsInterceptResult::Responses(responses) => responses,
1091            other => panic!("expected FIN-ACK, got {other:?}"),
1092        };
1093
1094        assert_eq!(responses.len(), 1);
1095        let response = &responses[0];
1096        let tcp_flags = response.tcp_flags.unwrap();
1097        assert!(tcp_flags.fin);
1098        assert!(tcp_flags.ack);
1099    }
1100
1101    #[tokio::test]
1102    async fn test_maybe_intercept_answers_fin_with_query_before_closing() {
1103        let query = build_dns_query("blocked.example.");
1104        let mut payload = Vec::with_capacity(query.len() + 2);
1105        payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
1106        payload.extend_from_slice(&query);
1107
1108        let interceptor = build_interceptor();
1109        let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1110
1111        let fin_with_query = build_tcp_frame(&payload, 11, Some(server_next_sequence), false, true);
1112        let parsed = ParsedFrame::parse(&fin_with_query).unwrap();
1113
1114        let responses = match interceptor.maybe_intercept(&parsed).await {
1115            DnsInterceptResult::Responses(responses) => responses,
1116            other => panic!("expected DNS response and FIN-ACK, got {other:?}"),
1117        };
1118
1119        assert_eq!(responses.len(), 2);
1120
1121        let dns_response = &responses[0];
1122        let dns_message = Message::from_bytes(&dns_response.payload[2..]).unwrap();
1123        assert_eq!(dns_message.response_code(), ResponseCode::Refused);
1124
1125        let fin_response = &responses[1];
1126        let fin_flags = fin_response.tcp_flags.unwrap();
1127        assert!(fin_flags.fin);
1128        assert!(fin_flags.ack);
1129        assert_eq!(
1130            fin_response.tcp_sequence_number,
1131            Some(
1132                dns_response
1133                    .tcp_sequence_number
1134                    .unwrap()
1135                    .wrapping_add(dns_response.payload.len() as u32)
1136            )
1137        );
1138    }
1139
1140    #[tokio::test]
1141    async fn test_maybe_intercept_prunes_idle_tcp_sessions() {
1142        let interceptor = build_interceptor();
1143        establish_tcp_session(&interceptor, 10).await;
1144
1145        {
1146            let mut streams = interceptor.tcp_streams.lock().unwrap();
1147            let stale_age = TCP_STREAM_IDLE_TIMEOUT + Duration::from_secs(1);
1148            for state in streams.values_mut() {
1149                state.last_activity = Instant::now() - stale_age;
1150            }
1151        }
1152
1153        let query = build_dns_query("blocked.example.");
1154        let frame = build_udp_frame(&query);
1155        let parsed = ParsedFrame::parse(&frame).unwrap();
1156        let _ = interceptor.maybe_intercept(&parsed).await;
1157
1158        let streams = interceptor.tcp_streams.lock().unwrap();
1159        assert!(streams.is_empty());
1160    }
1161
1162    #[test]
1163    fn test_encode_dns_response_tcp_prefixes_length() {
1164        let encoded = encode_dns_response(DnsTransport::Tcp, vec![1, 2, 3]).unwrap();
1165        assert_eq!(encoded, vec![0, 3, 1, 2, 3]);
1166    }
1167
1168    #[test]
1169    fn test_engine_tcp_flags_shape_is_serializable() {
1170        let flags = TcpResponseFlags {
1171            syn: true,
1172            ack: true,
1173            fin: false,
1174            rst: false,
1175            psh: false,
1176        };
1177        assert!(flags.syn && flags.ack);
1178    }
1179}