1use std::{
9 collections::{BTreeMap, HashMap},
10 hash::{Hash, Hasher},
11 net::IpAddr,
12 sync::{Arc, Mutex, RwLock},
13 time::{Duration, Instant},
14};
15
16use etherparse::TransportSlice;
17use hickory_proto::{
18 op::{Message, MessageType, ResponseCode},
19 rr::{RData, RecordType},
20 serialize::binary::BinDecodable,
21};
22use hickory_resolver::{TokioResolver, name_server::TokioConnectionProvider};
23
24use crate::{packet::ParsedFrame, policy::DnsPinSet};
25
26use super::DnsFilter;
27
28const TCP_STREAM_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
33
34const MAX_TCP_REQUEST_BUFFER: usize = 65537;
36
37const MAX_TCP_STREAMS: usize = 64;
39
40const MAX_PENDING_SEGMENTS: usize = 16;
42
43pub struct DnsInterceptor {
49 resolver: TokioResolver,
51
52 filter: DnsFilter,
54
55 pin_set: Arc<RwLock<DnsPinSet>>,
57
58 gateway_ips: Vec<IpAddr>,
60
61 tcp_streams: Mutex<HashMap<TcpFlowKey, TcpStreamState>>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub enum DnsInterceptResult {
67 Intercepted,
68 NotIntercepted,
69 Responses(Vec<DnsInterceptResponse>),
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct DnsInterceptResponse {
74 pub payload: Vec<u8>,
75 pub tcp_sequence_number: Option<u32>,
76 pub tcp_acknowledgment_number: Option<u32>,
77 pub tcp_flags: Option<TcpResponseFlags>,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct TcpResponseFlags {
82 pub syn: bool,
83 pub ack: bool,
84 pub fin: bool,
85 pub rst: bool,
86 pub psh: bool,
87}
88
89#[derive(Clone, Copy)]
90enum DnsTransport {
91 Tcp,
92 Udp,
93}
94
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96struct TcpFlowKey {
97 src_ip: IpAddr,
98 dst_ip: IpAddr,
99 src_port: u16,
100 dst_port: u16,
101}
102
103#[derive(Debug)]
104struct TcpStreamState {
105 request_buffer: Vec<u8>,
106 pending_segments: BTreeMap<u32, BufferedTcpSegment>,
107 next_client_sequence: u32,
108 next_server_sequence: u32,
109 last_activity: Instant,
110}
111
112#[derive(Debug)]
113struct BufferedTcpSegment {
114 payload: Vec<u8>,
115 fin: bool,
116}
117
118struct NormalizedTcpSegment<'a> {
119 sequence_number: u32,
120 payload: &'a [u8],
121 fin: bool,
122}
123
124struct TcpFrameInfo<'a> {
125 flow: TcpFlowKey,
126 payload: &'a [u8],
127 sequence_number: u32,
128 syn: bool,
129 ack: bool,
130 fin: bool,
131 rst: bool,
132}
133
134enum ParsedDnsPayload<'a> {
135 Tcp(TcpFrameInfo<'a>),
136 Udp(&'a [u8]),
137}
138
139struct TcpInterceptWork {
140 flow: TcpFlowKey,
141 queries: Vec<Vec<u8>>,
142 acknowledgment_number: u32,
143 fin: bool,
144}
145
146impl DnsInterceptor {
151 pub fn new(
155 filter: DnsFilter,
156 pin_set: Arc<RwLock<DnsPinSet>>,
157 gateway_ips: Vec<IpAddr>,
158 ) -> std::io::Result<Self> {
159 let resolver = TokioResolver::builder(TokioConnectionProvider::default())
160 .map_err(|e| std::io::Error::other(format!("failed to read system DNS config: {e}")))?
161 .build();
162
163 Ok(Self {
164 resolver,
165 filter,
166 pin_set,
167 gateway_ips,
168 tcp_streams: Mutex::new(HashMap::new()),
169 })
170 }
171
172 pub async fn maybe_intercept(&self, frame: &ParsedFrame<'_>) -> DnsInterceptResult {
178 self.prune_expired_tcp_streams();
179
180 if frame.dst_port() != Some(53) {
181 return DnsInterceptResult::NotIntercepted;
182 }
183
184 let Some(dst_ip) = frame.dst_ip() else {
185 return DnsInterceptResult::Intercepted;
186 };
187 if !self.gateway_ips.contains(&dst_ip) {
188 return DnsInterceptResult::NotIntercepted;
189 }
190
191 match dns_query_payload(frame) {
192 Some(ParsedDnsPayload::Udp(payload)) => match self.resolve_query(payload).await {
193 Some(payload) => {
194 let payload = match encode_dns_response(DnsTransport::Udp, payload) {
195 Some(payload) => payload,
196 None => return DnsInterceptResult::Intercepted,
197 };
198
199 DnsInterceptResult::Responses(vec![DnsInterceptResponse {
200 payload,
201 tcp_sequence_number: None,
202 tcp_acknowledgment_number: None,
203 tcp_flags: None,
204 }])
205 }
206 None => DnsInterceptResult::Intercepted,
207 },
208 Some(ParsedDnsPayload::Tcp(info)) => self.handle_tcp_frame(info).await,
209 None => DnsInterceptResult::Intercepted,
210 }
211 }
212
213 async fn resolve_and_filter(
215 &self,
216 query: &Message,
217 domain: &str,
218 record_type: RecordType,
219 ) -> Vec<u8> {
220 if record_type != RecordType::A && record_type != RecordType::AAAA {
223 return match self.resolver.lookup(domain, record_type).await {
224 Ok(lookup) => {
225 let mut response = Message::new();
226 response.set_id(query.id());
227 response.set_message_type(MessageType::Response);
228 response.set_response_code(ResponseCode::NoError);
229 response.add_queries(query.queries().to_vec());
230 for record in lookup.record_iter() {
231 response.add_answer(record.clone());
232 }
233 response.to_vec().unwrap_or_default()
234 }
235 Err(_) => build_nxdomain_response(query),
236 };
237 }
238
239 let lookup_result = match record_type {
240 RecordType::A => self
241 .resolver
242 .ipv4_lookup(domain)
243 .await
244 .map(|l| l.iter().map(|ip| IpAddr::V4(ip.0)).collect::<Vec<_>>()),
245 RecordType::AAAA => self
246 .resolver
247 .ipv6_lookup(domain)
248 .await
249 .map(|l| l.iter().map(|ip| IpAddr::V6(ip.0)).collect::<Vec<_>>()),
250 _ => unreachable!(),
251 };
252
253 match lookup_result {
254 Ok(ips) => {
255 let ips: Vec<IpAddr> = ips
256 .into_iter()
257 .filter(|ip| !self.filter.is_rebind_blocked(*ip))
258 .collect();
259
260 if let Ok(mut pin_set) = self.pin_set.write() {
261 for ip in &ips {
262 pin_set.pin(domain, *ip);
263 }
264 }
265
266 build_success_response(query, &ips, record_type)
267 }
268 Err(_) => build_nxdomain_response(query),
269 }
270 }
271
272 async fn handle_tcp_frame(&self, info: TcpFrameInfo<'_>) -> DnsInterceptResult {
273 if info.rst {
274 self.remove_tcp_state(&info.flow);
275 return DnsInterceptResult::Intercepted;
276 }
277
278 if info.syn && !info.ack {
279 return self.handle_tcp_syn(info.flow, info.sequence_number);
280 }
281
282 let Some(work) = self.collect_tcp_queries(&info) else {
283 return DnsInterceptResult::Intercepted;
284 };
285
286 if work.queries.is_empty() {
287 if work.fin {
288 let sequence_number = match self.remove_tcp_state(&work.flow) {
289 Some(state) => state.next_server_sequence,
290 None => return DnsInterceptResult::Intercepted,
291 };
292
293 return DnsInterceptResult::Responses(vec![DnsInterceptResponse {
294 payload: Vec::new(),
295 tcp_sequence_number: Some(sequence_number),
296 tcp_acknowledgment_number: Some(work.acknowledgment_number),
297 tcp_flags: Some(TcpResponseFlags {
298 syn: false,
299 ack: true,
300 fin: true,
301 rst: false,
302 psh: false,
303 }),
304 }]);
305 }
306
307 return DnsInterceptResult::Intercepted;
308 }
309
310 let mut sequence_number = match self.current_server_sequence(&work.flow) {
311 Some(sequence_number) => sequence_number,
312 None => return DnsInterceptResult::Intercepted,
313 };
314
315 let mut responses = Vec::new();
316 for query in &work.queries {
317 let Some(payload) = self.resolve_query(query).await else {
318 continue;
319 };
320 let response_len = payload.len();
321 let payload = match encode_dns_response(DnsTransport::Tcp, payload) {
322 Some(payload) => payload,
323 None => continue,
324 };
325
326 responses.push(DnsInterceptResponse {
327 payload,
328 tcp_sequence_number: Some(sequence_number),
329 tcp_acknowledgment_number: Some(work.acknowledgment_number),
330 tcp_flags: Some(TcpResponseFlags {
331 syn: false,
332 ack: true,
333 fin: false,
334 rst: false,
335 psh: response_len > 0,
336 }),
337 });
338
339 sequence_number =
340 sequence_number.wrapping_add(u32::try_from(response_len + 2).unwrap_or(0));
341 }
342
343 self.update_server_sequence(&work.flow, sequence_number);
344
345 if work.fin {
346 self.remove_tcp_state(&work.flow);
347 responses.push(DnsInterceptResponse {
348 payload: Vec::new(),
349 tcp_sequence_number: Some(sequence_number),
350 tcp_acknowledgment_number: Some(work.acknowledgment_number),
351 tcp_flags: Some(TcpResponseFlags {
352 syn: false,
353 ack: true,
354 fin: true,
355 rst: false,
356 psh: false,
357 }),
358 });
359 }
360
361 if responses.is_empty() {
362 DnsInterceptResult::Intercepted
363 } else {
364 DnsInterceptResult::Responses(responses)
365 }
366 }
367
368 fn handle_tcp_syn(&self, flow: TcpFlowKey, client_sequence_number: u32) -> DnsInterceptResult {
369 let server_initial_sequence = initial_server_sequence(&flow, client_sequence_number);
370 let next_client_sequence = client_sequence_number.wrapping_add(1);
371 let next_server_sequence = server_initial_sequence.wrapping_add(1);
372
373 {
374 let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
375
376 if streams.len() >= MAX_TCP_STREAMS {
378 return DnsInterceptResult::Intercepted;
379 }
380
381 streams.insert(
382 flow,
383 TcpStreamState {
384 request_buffer: Vec::new(),
385 pending_segments: BTreeMap::new(),
386 next_client_sequence,
387 next_server_sequence,
388 last_activity: Instant::now(),
389 },
390 );
391 }
392
393 DnsInterceptResult::Responses(vec![DnsInterceptResponse {
394 payload: Vec::new(),
395 tcp_sequence_number: Some(server_initial_sequence),
396 tcp_acknowledgment_number: Some(next_client_sequence),
397 tcp_flags: Some(TcpResponseFlags {
398 syn: true,
399 ack: true,
400 fin: false,
401 rst: false,
402 psh: false,
403 }),
404 }])
405 }
406
407 async fn resolve_query(&self, payload: &[u8]) -> Option<Vec<u8>> {
408 let query = Message::from_bytes(payload).ok()?;
409 if query.message_type() != MessageType::Query {
410 return None;
411 }
412
413 let question = query.queries().first()?;
414 let domain = question.name().to_string();
415 let record_type = question.query_type();
416
417 let response_payload = if self.filter.is_domain_blocked(&domain) {
418 build_refused_response(&query)
419 } else {
420 self.resolve_and_filter(&query, &domain, record_type).await
421 };
422
423 Some(response_payload)
424 }
425
426 fn collect_tcp_queries(&self, info: &TcpFrameInfo<'_>) -> Option<TcpInterceptWork> {
427 let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
428 let state = streams.get_mut(&info.flow)?;
429 state.last_activity = Instant::now();
430
431 let Some(segment) = normalize_tcp_segment(
432 info.sequence_number,
433 info.payload,
434 info.fin,
435 state.next_client_sequence,
436 ) else {
437 return Some(TcpInterceptWork {
438 flow: info.flow.clone(),
439 queries: Vec::new(),
440 acknowledgment_number: state.next_client_sequence,
441 fin: false,
442 });
443 };
444
445 if segment.sequence_number != state.next_client_sequence {
446 store_pending_segment(state, segment.sequence_number, segment.payload, segment.fin)?;
447 return Some(TcpInterceptWork {
448 flow: info.flow.clone(),
449 queries: Vec::new(),
450 acknowledgment_number: state.next_client_sequence,
451 fin: false,
452 });
453 }
454
455 let mut close_after_queries = append_tcp_segment(state, segment.payload, segment.fin)?;
456 close_after_queries |= drain_pending_segments(state)?;
457
458 let mut queries = Vec::new();
459 while state.request_buffer.len() >= 2 {
460 let dns_len = usize::from(u16::from_be_bytes([
461 state.request_buffer[0],
462 state.request_buffer[1],
463 ]));
464 if dns_len == 0 {
465 state.request_buffer.clear();
466 break;
467 }
468
469 if state.request_buffer.len() < dns_len + 2 {
470 break;
471 }
472
473 queries.push(state.request_buffer[2..dns_len + 2].to_vec());
474 state.request_buffer.drain(..dns_len + 2);
475 }
476
477 Some(TcpInterceptWork {
478 flow: info.flow.clone(),
479 queries,
480 acknowledgment_number: state.next_client_sequence,
481 fin: close_after_queries,
482 })
483 }
484
485 fn current_server_sequence(&self, flow: &TcpFlowKey) -> Option<u32> {
486 let streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
487 streams.get(flow).map(|state| state.next_server_sequence)
488 }
489
490 fn update_server_sequence(&self, flow: &TcpFlowKey, next_server_sequence: u32) {
491 let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
492 if let Some(state) = streams.get_mut(flow) {
493 state.next_server_sequence = next_server_sequence;
494 }
495 }
496
497 fn remove_tcp_state(&self, flow: &TcpFlowKey) -> Option<TcpStreamState> {
498 let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
499 streams.remove(flow)
500 }
501
502 fn prune_expired_tcp_streams(&self) {
503 let mut streams = self.tcp_streams.lock().unwrap_or_else(|e| e.into_inner());
504 let now = Instant::now();
505 streams
506 .retain(|_, state| now.duration_since(state.last_activity) <= TCP_STREAM_IDLE_TIMEOUT);
507 }
508}
509
510fn dns_query_payload<'a>(frame: &'a ParsedFrame<'a>) -> Option<ParsedDnsPayload<'a>> {
515 match &frame.sliced().transport {
516 Some(TransportSlice::Udp(_)) => {
517 let payload = frame.payload();
518 if payload.is_empty() {
519 None
520 } else {
521 Some(ParsedDnsPayload::Udp(payload))
522 }
523 }
524 Some(TransportSlice::Tcp(tcp)) => Some(ParsedDnsPayload::Tcp(TcpFrameInfo {
525 flow: TcpFlowKey {
526 src_ip: frame.src_ip()?,
527 dst_ip: frame.dst_ip()?,
528 src_port: frame.src_port()?,
529 dst_port: frame.dst_port()?,
530 },
531 payload: frame.payload(),
532 sequence_number: tcp.sequence_number(),
533 syn: tcp.syn(),
534 ack: tcp.ack(),
535 fin: tcp.fin(),
536 rst: tcp.rst(),
537 })),
538 _ => None,
539 }
540}
541
542fn initial_server_sequence(flow: &TcpFlowKey, client_sequence_number: u32) -> u32 {
543 let mut hasher = std::collections::hash_map::DefaultHasher::new();
544 flow.hash(&mut hasher);
545 client_sequence_number.hash(&mut hasher);
546 hasher.finish() as u32
547}
548
549fn append_tcp_segment(state: &mut TcpStreamState, payload: &[u8], fin: bool) -> Option<bool> {
550 let advance = tcp_sequence_advance(payload.len(), fin, false)?;
551 state.next_client_sequence = state.next_client_sequence.wrapping_add(advance);
552
553 if !payload.is_empty() {
554 if state.request_buffer.len() + payload.len() > MAX_TCP_REQUEST_BUFFER {
556 return None;
557 }
558 state.request_buffer.extend_from_slice(payload);
559 }
560
561 Some(fin)
562}
563
564fn drain_pending_segments(state: &mut TcpStreamState) -> Option<bool> {
565 let mut close_after_queries = false;
566
567 while let Some(segment) = state.pending_segments.remove(&state.next_client_sequence) {
570 close_after_queries |= append_tcp_segment(state, &segment.payload, segment.fin)?;
571 }
572
573 Some(close_after_queries)
574}
575
576fn normalize_tcp_segment<'a>(
577 sequence_number: u32,
578 payload: &'a [u8],
579 fin: bool,
580 next_client_sequence: u32,
581) -> Option<NormalizedTcpSegment<'a>> {
582 let diff = sequence_number.wrapping_sub(next_client_sequence) as i32;
585 if diff >= 0 {
586 return Some(NormalizedTcpSegment {
587 sequence_number,
588 payload,
589 fin,
590 });
591 }
592
593 let segment_advance = tcp_sequence_advance(payload.len(), fin, false)?;
594 let duplicate_advance = next_client_sequence.wrapping_sub(sequence_number);
595 if duplicate_advance >= segment_advance {
596 return None;
597 }
598
599 let duplicate_bytes = usize::try_from(duplicate_advance).ok()?.min(payload.len());
600 let fin_consumed = fin && duplicate_advance > payload.len() as u32;
601
602 Some(NormalizedTcpSegment {
603 sequence_number: next_client_sequence,
604 payload: &payload[duplicate_bytes..],
605 fin: fin && !fin_consumed,
606 })
607}
608
609fn store_pending_segment(
610 state: &mut TcpStreamState,
611 sequence_number: u32,
612 payload: &[u8],
613 fin: bool,
614) -> Option<()> {
615 if state.pending_segments.len() >= MAX_PENDING_SEGMENTS {
617 return None;
618 }
619
620 let new_advance = tcp_sequence_advance(payload.len(), fin, false)?;
621
622 match state.pending_segments.entry(sequence_number) {
623 std::collections::btree_map::Entry::Vacant(entry) => {
624 entry.insert(BufferedTcpSegment {
625 payload: payload.to_vec(),
626 fin,
627 });
628 }
629 std::collections::btree_map::Entry::Occupied(mut entry) => {
630 let existing = entry.get();
631 let existing_advance =
632 tcp_sequence_advance(existing.payload.len(), existing.fin, false)?;
633 if new_advance > existing_advance {
634 entry.insert(BufferedTcpSegment {
635 payload: payload.to_vec(),
636 fin,
637 });
638 }
639 }
640 }
641
642 Some(())
643}
644
645fn tcp_sequence_advance(payload_len: usize, fin: bool, syn: bool) -> Option<u32> {
646 u32::try_from(payload_len).ok().map(|advance| {
647 advance
648 .wrapping_add(u32::from(fin))
649 .wrapping_add(u32::from(syn))
650 })
651}
652
653fn encode_dns_response(transport: DnsTransport, payload: Vec<u8>) -> Option<Vec<u8>> {
654 match transport {
655 DnsTransport::Udp => Some(payload),
656 DnsTransport::Tcp => {
657 let payload_len = u16::try_from(payload.len()).ok()?;
658 let mut framed = Vec::with_capacity(payload.len() + 2);
659 framed.extend_from_slice(&payload_len.to_be_bytes());
660 framed.extend_from_slice(&payload);
661 Some(framed)
662 }
663 }
664}
665
666fn build_refused_response(query: &Message) -> Vec<u8> {
668 let mut response = Message::new();
669 response.set_id(query.id());
670 response.set_message_type(MessageType::Response);
671 response.set_response_code(ResponseCode::Refused);
672 response.add_queries(query.queries().to_vec());
673
674 response.to_vec().unwrap_or_default()
675}
676
677fn build_nxdomain_response(query: &Message) -> Vec<u8> {
679 let mut response = Message::new();
680 response.set_id(query.id());
681 response.set_message_type(MessageType::Response);
682 response.set_response_code(ResponseCode::NXDomain);
683 response.add_queries(query.queries().to_vec());
684
685 response.to_vec().unwrap_or_default()
686}
687
688fn build_success_response(query: &Message, ips: &[IpAddr], record_type: RecordType) -> Vec<u8> {
690 use hickory_proto::rr::{Name, Record};
691 use std::str::FromStr;
692
693 let mut response = Message::new();
694 response.set_id(query.id());
695 response.set_message_type(MessageType::Response);
696 response.set_response_code(ResponseCode::NoError);
697 response.add_queries(query.queries().to_vec());
698
699 let name = query
700 .queries()
701 .first()
702 .map(|q| q.name().clone())
703 .unwrap_or_else(|| Name::from_str(".").unwrap());
704
705 for ip in ips {
706 let rdata = match (ip, record_type) {
707 (IpAddr::V4(v4), RecordType::A) => RData::A((*v4).into()),
708 (IpAddr::V6(v6), RecordType::AAAA) => RData::AAAA((*v6).into()),
709 _ => continue,
710 };
711
712 let record = Record::from_rdata(name.clone(), 60, rdata);
713 response.add_answer(record);
714 }
715
716 response.to_vec().unwrap_or_default()
717}
718
719#[cfg(test)]
724mod tests {
725 use std::{
726 net::Ipv4Addr,
727 time::{Duration, Instant},
728 };
729
730 use etherparse::PacketBuilder;
731 use hickory_proto::{op::Query, rr::Name};
732
733 use super::*;
734
735 fn build_dns_query(domain: &str) -> Vec<u8> {
736 let mut message = Message::new();
737 message.set_id(7);
738 message.set_message_type(MessageType::Query);
739 message.add_query(Query::query(
740 Name::from_ascii(domain).unwrap(),
741 RecordType::A,
742 ));
743 message.to_vec().unwrap()
744 }
745
746 fn build_udp_frame(payload: &[u8]) -> Vec<u8> {
747 let mut frame = Vec::new();
748 PacketBuilder::ethernet2(
749 [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
750 [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
751 )
752 .ipv4([100, 96, 0, 2], [100, 96, 0, 1], 64)
753 .udp(51000, 53)
754 .write(&mut frame, payload)
755 .unwrap();
756 frame
757 }
758
759 fn build_tcp_frame(
760 payload: &[u8],
761 sequence_number: u32,
762 acknowledgment_number: Option<u32>,
763 syn: bool,
764 fin: bool,
765 ) -> Vec<u8> {
766 let mut builder = PacketBuilder::ethernet2(
767 [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
768 [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
769 )
770 .ipv4([100, 96, 0, 2], [100, 96, 0, 1], 64)
771 .tcp(51000, 53, sequence_number, 200);
772
773 if let Some(acknowledgment_number) = acknowledgment_number {
774 builder = builder.ack(acknowledgment_number);
775 }
776 if syn {
777 builder = builder.syn();
778 }
779 if fin {
780 builder = builder.fin();
781 }
782 if !payload.is_empty() {
783 builder = builder.psh();
784 }
785
786 let mut frame = Vec::new();
787 builder.write(&mut frame, payload).unwrap();
788 frame
789 }
790
791 fn build_interceptor() -> DnsInterceptor {
792 DnsInterceptor::new(
793 DnsFilter::new(vec!["blocked.example.".to_string()], vec![], false),
794 Arc::new(RwLock::new(DnsPinSet::new())),
795 vec![IpAddr::V4(Ipv4Addr::new(100, 96, 0, 1))],
796 )
797 .unwrap()
798 }
799
800 async fn establish_tcp_session(
801 interceptor: &DnsInterceptor,
802 client_sequence_number: u32,
803 ) -> u32 {
804 let syn = build_tcp_frame(&[], client_sequence_number, None, true, false);
805 let parsed = ParsedFrame::parse(&syn).unwrap();
806
807 let responses = match interceptor.maybe_intercept(&parsed).await {
808 DnsInterceptResult::Responses(responses) => responses,
809 other => panic!("expected SYN-ACK, got {other:?}"),
810 };
811
812 assert_eq!(responses.len(), 1);
813 let response = &responses[0];
814 let tcp_flags = response.tcp_flags.unwrap();
815 assert!(tcp_flags.syn);
816 assert!(tcp_flags.ack);
817 assert_eq!(
818 response.tcp_acknowledgment_number,
819 Some(client_sequence_number.wrapping_add(1)),
820 );
821 response.tcp_sequence_number.unwrap().wrapping_add(1)
822 }
823
824 #[test]
825 fn test_dns_query_payload_extracts_udp_message() {
826 let query = build_dns_query("blocked.example.");
827 let frame = build_udp_frame(&query);
828 let parsed = ParsedFrame::parse(&frame).unwrap();
829
830 match dns_query_payload(&parsed).unwrap() {
831 ParsedDnsPayload::Udp(payload) => assert_eq!(payload, query.as_slice()),
832 ParsedDnsPayload::Tcp(_) => panic!("expected UDP payload"),
833 }
834 }
835
836 #[test]
837 fn test_dns_query_payload_extracts_tcp_syn() {
838 let frame = build_tcp_frame(&[], 10, None, true, false);
839 let parsed = ParsedFrame::parse(&frame).unwrap();
840
841 match dns_query_payload(&parsed).unwrap() {
842 ParsedDnsPayload::Tcp(info) => {
843 assert!(info.syn);
844 assert!(!info.ack);
845 assert_eq!(info.sequence_number, 10);
846 }
847 ParsedDnsPayload::Udp(_) => panic!("expected TCP payload"),
848 }
849 }
850
851 #[tokio::test]
852 async fn test_maybe_intercept_establishes_tcp_session_with_syn_ack() {
853 let interceptor = build_interceptor();
854 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
855 assert_ne!(server_next_sequence, 0);
856 }
857
858 #[tokio::test]
859 async fn test_maybe_intercept_returns_tcp_framed_response() {
860 let query = build_dns_query("blocked.example.");
861 let mut payload = Vec::with_capacity(query.len() + 2);
862 payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
863 payload.extend_from_slice(&query);
864
865 let interceptor = build_interceptor();
866 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
867
868 let frame = build_tcp_frame(&payload, 11, Some(server_next_sequence), false, false);
869 let parsed = ParsedFrame::parse(&frame).unwrap();
870
871 let responses = match interceptor.maybe_intercept(&parsed).await {
872 DnsInterceptResult::Responses(responses) => responses,
873 other => panic!("expected TCP response, got {other:?}"),
874 };
875
876 assert_eq!(responses.len(), 1);
877 let response = &responses[0];
878 let response_len = usize::from(u16::from_be_bytes([
879 response.payload[0],
880 response.payload[1],
881 ]));
882 let message = Message::from_bytes(&response.payload[2..]).unwrap();
883
884 assert_eq!(response_len, response.payload.len() - 2);
885 assert_eq!(message.response_code(), ResponseCode::Refused);
886 assert_eq!(message.message_type(), MessageType::Response);
887 assert_eq!(response.tcp_sequence_number, Some(server_next_sequence));
888 assert_eq!(
889 response.tcp_acknowledgment_number,
890 Some(11u32.wrapping_add(payload.len() as u32)),
891 );
892 }
893
894 #[tokio::test]
895 async fn test_maybe_intercept_buffers_split_tcp_query() {
896 let query = build_dns_query("blocked.example.");
897 let mut full_payload = Vec::with_capacity(query.len() + 2);
898 full_payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
899 full_payload.extend_from_slice(&query);
900
901 let split_at = 5;
902 let interceptor = build_interceptor();
903 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
904
905 let first = build_tcp_frame(
906 &full_payload[..split_at],
907 11,
908 Some(server_next_sequence),
909 false,
910 false,
911 );
912 let second = build_tcp_frame(
913 &full_payload[split_at..],
914 11 + split_at as u32,
915 Some(server_next_sequence),
916 false,
917 false,
918 );
919
920 let first_parsed = ParsedFrame::parse(&first).unwrap();
921 assert_eq!(
922 interceptor.maybe_intercept(&first_parsed).await,
923 DnsInterceptResult::Intercepted,
924 );
925
926 let second_parsed = ParsedFrame::parse(&second).unwrap();
927 let responses = match interceptor.maybe_intercept(&second_parsed).await {
928 DnsInterceptResult::Responses(responses) => responses,
929 other => panic!("expected buffered TCP response, got {other:?}"),
930 };
931
932 assert_eq!(responses.len(), 1);
933 assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
934 assert_eq!(
935 responses[0].tcp_acknowledgment_number,
936 Some(11u32.wrapping_add(full_payload.len() as u32)),
937 );
938 }
939
940 #[tokio::test]
941 async fn test_maybe_intercept_recovers_after_out_of_order_tcp_segment() {
942 let query = build_dns_query("blocked.example.");
943 let mut full_payload = Vec::with_capacity(query.len() + 2);
944 full_payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
945 full_payload.extend_from_slice(&query);
946
947 let first_end = 5;
948 let second_end = 10;
949
950 let interceptor = build_interceptor();
951 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
952
953 let first = build_tcp_frame(
954 &full_payload[..first_end],
955 11,
956 Some(server_next_sequence),
957 false,
958 false,
959 );
960 let third = build_tcp_frame(
961 &full_payload[second_end..],
962 11 + second_end as u32,
963 Some(server_next_sequence),
964 false,
965 false,
966 );
967 let second = build_tcp_frame(
968 &full_payload[first_end..second_end],
969 11 + first_end as u32,
970 Some(server_next_sequence),
971 false,
972 false,
973 );
974
975 let first_parsed = ParsedFrame::parse(&first).unwrap();
976 assert_eq!(
977 interceptor.maybe_intercept(&first_parsed).await,
978 DnsInterceptResult::Intercepted,
979 );
980
981 let third_parsed = ParsedFrame::parse(&third).unwrap();
982 assert_eq!(
983 interceptor.maybe_intercept(&third_parsed).await,
984 DnsInterceptResult::Intercepted,
985 );
986
987 let second_parsed = ParsedFrame::parse(&second).unwrap();
988 let responses = match interceptor.maybe_intercept(&second_parsed).await {
989 DnsInterceptResult::Responses(responses) => responses,
990 other => panic!("expected reordered TCP response, got {other:?}"),
991 };
992
993 assert_eq!(responses.len(), 1);
994 assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
995 assert_eq!(
996 responses[0].tcp_acknowledgment_number,
997 Some(11u32.wrapping_add(full_payload.len() as u32)),
998 );
999 }
1000
1001 #[tokio::test]
1002 async fn test_maybe_intercept_recovers_after_overlapping_tcp_retransmit() {
1003 let query = build_dns_query("blocked.example.");
1004 let mut full_payload = Vec::with_capacity(query.len() + 2);
1005 full_payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
1006 full_payload.extend_from_slice(&query);
1007
1008 let first_end = 8;
1009 let overlap_start = 5;
1010
1011 let interceptor = build_interceptor();
1012 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1013
1014 let first = build_tcp_frame(
1015 &full_payload[..first_end],
1016 11,
1017 Some(server_next_sequence),
1018 false,
1019 false,
1020 );
1021 let overlapping = build_tcp_frame(
1022 &full_payload[overlap_start..],
1023 11 + overlap_start as u32,
1024 Some(server_next_sequence),
1025 false,
1026 false,
1027 );
1028
1029 let first_parsed = ParsedFrame::parse(&first).unwrap();
1030 assert_eq!(
1031 interceptor.maybe_intercept(&first_parsed).await,
1032 DnsInterceptResult::Intercepted,
1033 );
1034
1035 let overlapping_parsed = ParsedFrame::parse(&overlapping).unwrap();
1036 let responses = match interceptor.maybe_intercept(&overlapping_parsed).await {
1037 DnsInterceptResult::Responses(responses) => responses,
1038 other => panic!("expected overlapping TCP response, got {other:?}"),
1039 };
1040
1041 assert_eq!(responses.len(), 1);
1042 assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
1043 assert_eq!(
1044 responses[0].tcp_acknowledgment_number,
1045 Some(11u32.wrapping_add(full_payload.len() as u32)),
1046 );
1047 }
1048
1049 #[tokio::test]
1050 async fn test_maybe_intercept_handles_two_queries_in_one_tcp_segment() {
1051 let query = build_dns_query("blocked.example.");
1052 let framed_query = {
1053 let mut payload = Vec::with_capacity(query.len() + 2);
1054 payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
1055 payload.extend_from_slice(&query);
1056 payload
1057 };
1058
1059 let mut combined = Vec::with_capacity(framed_query.len() * 2);
1060 combined.extend_from_slice(&framed_query);
1061 combined.extend_from_slice(&framed_query);
1062
1063 let interceptor = build_interceptor();
1064 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1065
1066 let frame = build_tcp_frame(&combined, 11, Some(server_next_sequence), false, false);
1067 let parsed = ParsedFrame::parse(&frame).unwrap();
1068
1069 let responses = match interceptor.maybe_intercept(&parsed).await {
1070 DnsInterceptResult::Responses(responses) => responses,
1071 other => panic!("expected pipelined TCP responses, got {other:?}"),
1072 };
1073
1074 assert_eq!(responses.len(), 2);
1075 assert_eq!(responses[0].tcp_sequence_number, Some(server_next_sequence));
1076 assert!(
1077 responses[1].tcp_sequence_number.unwrap() > responses[0].tcp_sequence_number.unwrap()
1078 );
1079 }
1080
1081 #[tokio::test]
1082 async fn test_maybe_intercept_closes_tcp_session_on_fin() {
1083 let interceptor = build_interceptor();
1084 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1085
1086 let fin = build_tcp_frame(&[], 11, Some(server_next_sequence), false, true);
1087 let parsed = ParsedFrame::parse(&fin).unwrap();
1088
1089 let responses = match interceptor.maybe_intercept(&parsed).await {
1090 DnsInterceptResult::Responses(responses) => responses,
1091 other => panic!("expected FIN-ACK, got {other:?}"),
1092 };
1093
1094 assert_eq!(responses.len(), 1);
1095 let response = &responses[0];
1096 let tcp_flags = response.tcp_flags.unwrap();
1097 assert!(tcp_flags.fin);
1098 assert!(tcp_flags.ack);
1099 }
1100
1101 #[tokio::test]
1102 async fn test_maybe_intercept_answers_fin_with_query_before_closing() {
1103 let query = build_dns_query("blocked.example.");
1104 let mut payload = Vec::with_capacity(query.len() + 2);
1105 payload.extend_from_slice(&(query.len() as u16).to_be_bytes());
1106 payload.extend_from_slice(&query);
1107
1108 let interceptor = build_interceptor();
1109 let server_next_sequence = establish_tcp_session(&interceptor, 10).await;
1110
1111 let fin_with_query = build_tcp_frame(&payload, 11, Some(server_next_sequence), false, true);
1112 let parsed = ParsedFrame::parse(&fin_with_query).unwrap();
1113
1114 let responses = match interceptor.maybe_intercept(&parsed).await {
1115 DnsInterceptResult::Responses(responses) => responses,
1116 other => panic!("expected DNS response and FIN-ACK, got {other:?}"),
1117 };
1118
1119 assert_eq!(responses.len(), 2);
1120
1121 let dns_response = &responses[0];
1122 let dns_message = Message::from_bytes(&dns_response.payload[2..]).unwrap();
1123 assert_eq!(dns_message.response_code(), ResponseCode::Refused);
1124
1125 let fin_response = &responses[1];
1126 let fin_flags = fin_response.tcp_flags.unwrap();
1127 assert!(fin_flags.fin);
1128 assert!(fin_flags.ack);
1129 assert_eq!(
1130 fin_response.tcp_sequence_number,
1131 Some(
1132 dns_response
1133 .tcp_sequence_number
1134 .unwrap()
1135 .wrapping_add(dns_response.payload.len() as u32)
1136 )
1137 );
1138 }
1139
1140 #[tokio::test]
1141 async fn test_maybe_intercept_prunes_idle_tcp_sessions() {
1142 let interceptor = build_interceptor();
1143 establish_tcp_session(&interceptor, 10).await;
1144
1145 {
1146 let mut streams = interceptor.tcp_streams.lock().unwrap();
1147 let stale_age = TCP_STREAM_IDLE_TIMEOUT + Duration::from_secs(1);
1148 for state in streams.values_mut() {
1149 state.last_activity = Instant::now() - stale_age;
1150 }
1151 }
1152
1153 let query = build_dns_query("blocked.example.");
1154 let frame = build_udp_frame(&query);
1155 let parsed = ParsedFrame::parse(&frame).unwrap();
1156 let _ = interceptor.maybe_intercept(&parsed).await;
1157
1158 let streams = interceptor.tcp_streams.lock().unwrap();
1159 assert!(streams.is_empty());
1160 }
1161
1162 #[test]
1163 fn test_encode_dns_response_tcp_prefixes_length() {
1164 let encoded = encode_dns_response(DnsTransport::Tcp, vec![1, 2, 3]).unwrap();
1165 assert_eq!(encoded, vec![0, 3, 1, 2, 3]);
1166 }
1167
1168 #[test]
1169 fn test_engine_tcp_flags_shape_is_serializable() {
1170 let flags = TcpResponseFlags {
1171 syn: true,
1172 ack: true,
1173 fin: false,
1174 rst: false,
1175 psh: false,
1176 };
1177 assert!(flags.syn && flags.ack);
1178 }
1179}