1use std::net::IpAddr;
34use std::collections::HashMap;
35use std::time::{Duration, Instant};
36use tracing::{debug, info, warn};
37
38pub const CLOUDFLARE_DNS: &str = "1.1.1.1:53";
40pub const CLOUDFLARE_DNS2: &str = "1.0.0.1:53";
41pub const GOOGLE_DNS: &str = "8.8.8.8:53";
42pub const GOOGLE_DNS2: &str = "8.8.4.4:53";
43pub const QUAD9_DNS: &str = "9.9.9.9:53";
44
45#[derive(Debug, Clone, PartialEq)]
47pub enum DnsQueryType {
48 A, AAAA, CNAME, MX, TXT, PTR, NS, Other(u16),
56}
57
58impl DnsQueryType {
59 pub fn from_u16(v: u16) -> Self {
60 match v {
61 1 => DnsQueryType::A,
62 28 => DnsQueryType::AAAA,
63 5 => DnsQueryType::CNAME,
64 15 => DnsQueryType::MX,
65 16 => DnsQueryType::TXT,
66 12 => DnsQueryType::PTR,
67 2 => DnsQueryType::NS,
68 o => DnsQueryType::Other(o),
69 }
70 }
71
72 pub fn to_u16(&self) -> u16 {
73 match self {
74 DnsQueryType::A => 1,
75 DnsQueryType::AAAA => 28,
76 DnsQueryType::CNAME => 5,
77 DnsQueryType::MX => 15,
78 DnsQueryType::TXT => 16,
79 DnsQueryType::PTR => 12,
80 DnsQueryType::NS => 2,
81 DnsQueryType::Other(o) => *o,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct DnsPacket {
89 pub id: u16,
91 pub is_query: bool,
93 pub domain: String,
95 pub query_type: DnsQueryType,
97 pub raw: Vec<u8>,
99}
100
101impl DnsPacket {
102 pub fn parse(raw: Vec<u8>) -> Option<Self> {
106 if raw.len() < 12 {
107 return None;
108 }
109
110 let id = u16::from_be_bytes([raw[0], raw[1]]);
111 let flags = u16::from_be_bytes([raw[2], raw[3]]);
112 let is_query = (flags >> 15) == 0;
113 let qdcount = u16::from_be_bytes([raw[4], raw[5]]);
114
115 if qdcount == 0 {
116 return Some(DnsPacket {
117 id,
118 is_query,
119 domain: String::new(),
120 query_type: DnsQueryType::A,
121 raw,
122 });
123 }
124
125 let (domain, offset) = parse_dns_name(&raw, 12)?;
127 if offset + 4 > raw.len() {
128 return None;
129 }
130 let qtype = u16::from_be_bytes([raw[offset], raw[offset + 1]]);
131
132 debug!(id, domain = %domain, is_query, "DNS packet parsed");
133
134 Some(DnsPacket {
135 id,
136 is_query,
137 domain,
138 query_type: DnsQueryType::from_u16(qtype),
139 raw,
140 })
141 }
142
143 pub fn is_query(&self) -> bool {
145 self.is_query
146 }
147}
148
149fn parse_dns_name(data: &[u8], mut offset: usize) -> Option<(String, usize)> {
152 let mut labels = Vec::new();
153 let mut iterations = 0;
154
155 loop {
156 if offset >= data.len() || iterations > 128 {
157 return None;
158 }
159 iterations += 1;
160
161 let len = data[offset] as usize;
162 if len == 0 {
163 offset += 1;
164 break;
165 }
166 if len & 0xC0 == 0xC0 {
168 offset += 2;
169 break;
170 }
171 offset += 1;
172 if offset + len > data.len() {
173 return None;
174 }
175 let label = std::str::from_utf8(&data[offset..offset + len]).ok()?;
176 labels.push(label.to_string());
177 offset += len;
178 }
179
180 Some((labels.join("."), offset))
181}
182
183#[derive(Debug, Clone, PartialEq)]
185pub enum DnsAction {
186 ForwardThroughTunnel,
188 Block,
190 ReturnCached(IpAddr),
192 AllowDirect,
194}
195
196#[derive(Debug, Clone)]
198struct CacheEntry {
199 addr: IpAddr,
200 expires_at: Instant,
201}
202
203impl CacheEntry {
204 fn is_expired(&self) -> bool {
205 Instant::now() > self.expires_at
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct DnsConfig {
212 pub upstream_servers: Vec<String>,
214 pub split_dns_domains: Vec<String>,
216 pub blocked_domains: Vec<String>,
218 pub enable_cache: bool,
220 pub cache_ttl: Duration,
222 pub max_cache_size: usize,
224}
225
226impl Default for DnsConfig {
227 fn default() -> Self {
228 DnsConfig {
229 upstream_servers: vec![
230 CLOUDFLARE_DNS.to_string(),
231 CLOUDFLARE_DNS2.to_string(),
232 ],
233 split_dns_domains: Vec::new(),
234 blocked_domains: Vec::new(),
235 enable_cache: true,
236 cache_ttl: Duration::from_secs(300),
237 max_cache_size: 1024,
238 }
239 }
240}
241
242impl DnsConfig {
243 pub fn cloudflare() -> Self {
245 DnsConfig::default()
246 }
247
248 pub fn google() -> Self {
250 DnsConfig {
251 upstream_servers: vec![
252 GOOGLE_DNS.to_string(),
253 GOOGLE_DNS2.to_string(),
254 ],
255 ..Default::default()
256 }
257 }
258
259 pub fn quad9() -> Self {
261 DnsConfig {
262 upstream_servers: vec![QUAD9_DNS.to_string()],
263 ..Default::default()
264 }
265 }
266
267 pub fn with_split_domain(mut self, domain: &str) -> Self {
269 self.split_dns_domains.push(domain.to_string());
270 self
271 }
272
273 pub fn with_blocked_domain(mut self, domain: &str) -> Self {
275 self.blocked_domains.push(domain.to_string());
276 self
277 }
278}
279
280pub struct DnsFilter {
285 config: DnsConfig,
286 cache: HashMap<String, CacheEntry>,
287 total_intercepted: u64,
289 total_blocked: u64,
291 total_cache_hits: u64,
293 total_forwarded: u64,
295}
296
297impl DnsFilter {
298 pub fn new(config: DnsConfig) -> Self {
300 info!(
301 upstream = ?config.upstream_servers,
302 blocked_count = config.blocked_domains.len(),
303 "DnsFilter created"
304 );
305 DnsFilter {
306 config,
307 cache: HashMap::new(),
308 total_intercepted: 0,
309 total_blocked: 0,
310 total_cache_hits: 0,
311 total_forwarded: 0,
312 }
313 }
314
315 pub fn is_dns_packet(data: &[u8]) -> bool {
319 if data.len() < 12 {
320 return false;
321 }
322 let opcode = (data[2] >> 3) & 0x0F;
324 opcode <= 2
325 }
326
327 pub fn decide(&mut self, domain: &str, query_type: &DnsQueryType) -> DnsAction {
331 self.total_intercepted += 1;
332
333 if self.total_intercepted % 100 == 0 {
335 self.evict_expired();
336 }
337
338 if self.config.enable_cache {
340 if let Some(entry) = self.cache.get(domain) {
341 if !entry.is_expired() {
342 self.total_cache_hits += 1;
343 debug!(domain, "DNS cache hit");
344 return DnsAction::ReturnCached(entry.addr);
345 }
346 }
347 }
348
349 if self.is_blocked(domain) {
351 self.total_blocked += 1;
352 warn!(domain, "DNS query blocked");
353 return DnsAction::Block;
354 }
355
356 if self.is_split_dns(domain) {
358 debug!(domain, "DNS split — allowing direct");
359 return DnsAction::AllowDirect;
360 }
361
362 self.total_forwarded += 1;
364 debug!(domain, query_type = ?query_type, "DNS forwarding through tunnel");
365 DnsAction::ForwardThroughTunnel
366 }
367
368 pub fn cache_response(&mut self, domain: &str, addr: IpAddr) {
370 if !self.config.enable_cache {
371 return;
372 }
373 if self.cache.len() >= self.config.max_cache_size {
374 self.evict_expired();
375 if self.cache.len() >= self.config.max_cache_size {
377 if let Some(key) = self.cache.keys().next().cloned() {
378 self.cache.remove(&key);
379 }
380 }
381 }
382 self.cache.insert(domain.to_string(), CacheEntry {
383 addr,
384 expires_at: Instant::now() + self.config.cache_ttl,
385 });
386 debug!(domain, addr = %addr, "DNS response cached");
387 }
388
389 pub fn is_blocked(&self, domain: &str) -> bool {
393 let domain_lower = domain.to_lowercase();
394 self.config.blocked_domains.iter().any(|blocked| {
395 let b = blocked.to_lowercase();
396 domain_lower == b || domain_lower.ends_with(&format!(".{}", b))
397 })
398 }
399
400 pub fn is_split_dns(&self, domain: &str) -> bool {
402 let domain_lower = domain.to_lowercase();
403 self.config.split_dns_domains.iter().any(|split| {
404 let s = split.to_lowercase();
405 domain_lower == s || domain_lower.ends_with(&format!(".{}", s))
406 })
407 }
408
409 pub fn block_domain(&mut self, domain: &str) {
411 info!(domain, "DNS domain blocked");
412 self.config.blocked_domains.push(domain.to_string());
413 }
414
415 pub fn add_split_domain(&mut self, domain: &str) {
417 info!(domain, "DNS split domain added");
418 self.config.split_dns_domains.push(domain.to_string());
419 }
420
421 pub fn primary_upstream(&self) -> Option<&str> {
423 self.config.upstream_servers.first().map(|s| s.as_str())
424 }
425
426 pub fn evict_expired(&mut self) {
428 let before = self.cache.len();
429 self.cache.retain(|_, v| !v.is_expired());
430 let removed = before - self.cache.len();
431 if removed > 0 {
432 debug!(removed, "DNS cache eviction");
433 }
434 }
435
436 pub fn clear_cache(&mut self) {
438 self.cache.clear();
439 debug!("DNS cache cleared");
440 }
441
442 pub fn cache_size(&self) -> usize {
444 self.cache.len()
445 }
446
447 pub fn total_intercepted(&self) -> u64 {
449 self.total_intercepted
450 }
451
452 pub fn total_blocked(&self) -> u64 {
454 self.total_blocked
455 }
456
457 pub fn total_cache_hits(&self) -> u64 {
459 self.total_cache_hits
460 }
461
462 pub fn total_forwarded(&self) -> u64 {
464 self.total_forwarded
465 }
466
467 pub fn config(&self) -> &DnsConfig {
469 &self.config
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 fn minimal_dns_query(domain: &str) -> Vec<u8> {
478 let mut pkt = vec![
480 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
487 for label in domain.split('.') {
489 pkt.push(label.len() as u8);
490 pkt.extend_from_slice(label.as_bytes());
491 }
492 pkt.push(0x00); pkt.extend_from_slice(&[0x00, 0x01]); pkt.extend_from_slice(&[0x00, 0x01]); pkt
496 }
497
498 #[test]
499 fn test_is_dns_packet_valid() {
500 let pkt = minimal_dns_query("example.com");
501 assert!(DnsFilter::is_dns_packet(&pkt));
502 }
503
504 #[test]
505 fn test_is_dns_packet_too_short() {
506 assert!(!DnsFilter::is_dns_packet(&[0u8; 5]));
507 assert!(!DnsFilter::is_dns_packet(&[]));
508 }
509
510 #[test]
511 fn test_dns_packet_parse() {
512 let raw = minimal_dns_query("example.com");
513 let pkt = DnsPacket::parse(raw).unwrap();
514 assert_eq!(pkt.id, 1);
515 assert!(pkt.is_query());
516 assert_eq!(pkt.domain, "example.com");
517 assert_eq!(pkt.query_type, DnsQueryType::A);
518 }
519
520 #[test]
521 fn test_dns_packet_parse_too_short() {
522 assert!(DnsPacket::parse(vec![0u8; 5]).is_none());
523 }
524
525 #[test]
526 fn test_dns_config_default() {
527 let c = DnsConfig::default();
528 assert!(c.upstream_servers.contains(&CLOUDFLARE_DNS.to_string()));
529 assert!(c.enable_cache);
530 }
531
532 #[test]
533 fn test_dns_config_google() {
534 let c = DnsConfig::google();
535 assert!(c.upstream_servers.contains(&GOOGLE_DNS.to_string()));
536 }
537
538 #[test]
539 fn test_dns_config_quad9() {
540 let c = DnsConfig::quad9();
541 assert!(c.upstream_servers.contains(&QUAD9_DNS.to_string()));
542 }
543
544 #[test]
545 fn test_dns_config_with_blocked() {
546 let c = DnsConfig::default().with_blocked_domain("ads.com");
547 assert!(c.blocked_domains.contains(&"ads.com".to_string()));
548 }
549
550 #[test]
551 fn test_dns_config_with_split() {
552 let c = DnsConfig::default().with_split_domain("corp.internal");
553 assert!(c.split_dns_domains.contains(&"corp.internal".to_string()));
554 }
555
556 #[test]
557 fn test_filter_forward() {
558 let mut f = DnsFilter::new(DnsConfig::default());
559 let action = f.decide("example.com", &DnsQueryType::A);
560 assert_eq!(action, DnsAction::ForwardThroughTunnel);
561 assert_eq!(f.total_forwarded(), 1);
562 }
563
564 #[test]
565 fn test_filter_block() {
566 let config = DnsConfig::default().with_blocked_domain("ads.com");
567 let mut f = DnsFilter::new(config);
568 let action = f.decide("ads.com", &DnsQueryType::A);
569 assert_eq!(action, DnsAction::Block);
570 assert_eq!(f.total_blocked(), 1);
571 }
572
573 #[test]
574 fn test_filter_block_subdomain() {
575 let config = DnsConfig::default().with_blocked_domain("ads.com");
576 let mut f = DnsFilter::new(config);
577 let action = f.decide("tracker.ads.com", &DnsQueryType::A);
578 assert_eq!(action, DnsAction::Block);
579 }
580
581 #[test]
582 fn test_filter_split_dns() {
583 let config = DnsConfig::default().with_split_domain("corp.internal");
584 let mut f = DnsFilter::new(config);
585 let action = f.decide("server.corp.internal", &DnsQueryType::A);
586 assert_eq!(action, DnsAction::AllowDirect);
587 }
588
589 #[test]
590 fn test_filter_cache_hit() {
591 let mut f = DnsFilter::new(DnsConfig::default());
592 let addr: IpAddr = "1.2.3.4".parse().unwrap();
593 f.cache_response("example.com", addr);
594 let action = f.decide("example.com", &DnsQueryType::A);
595 assert_eq!(action, DnsAction::ReturnCached(addr));
596 assert_eq!(f.total_cache_hits(), 1);
597 }
598
599 #[test]
600 fn test_filter_cache_size() {
601 let mut f = DnsFilter::new(DnsConfig::default());
602 f.cache_response("a.com", "1.1.1.1".parse().unwrap());
603 f.cache_response("b.com", "2.2.2.2".parse().unwrap());
604 assert_eq!(f.cache_size(), 2);
605 }
606
607 #[test]
608 fn test_filter_clear_cache() {
609 let mut f = DnsFilter::new(DnsConfig::default());
610 f.cache_response("a.com", "1.1.1.1".parse().unwrap());
611 f.clear_cache();
612 assert_eq!(f.cache_size(), 0);
613 }
614
615 #[test]
616 fn test_filter_block_runtime() {
617 let mut f = DnsFilter::new(DnsConfig::default());
618 f.block_domain("evil.com");
619 assert_eq!(f.decide("evil.com", &DnsQueryType::A), DnsAction::Block);
620 }
621
622 #[test]
623 fn test_filter_split_runtime() {
624 let mut f = DnsFilter::new(DnsConfig::default());
625 f.add_split_domain("local.net");
626 assert_eq!(f.decide("host.local.net", &DnsQueryType::A), DnsAction::AllowDirect);
627 }
628
629 #[test]
630 fn test_is_blocked_exact() {
631 let config = DnsConfig::default().with_blocked_domain("bad.com");
632 let f = DnsFilter::new(config);
633 assert!(f.is_blocked("bad.com"));
634 assert!(!f.is_blocked("good.com"));
635 }
636
637 #[test]
638 fn test_is_blocked_subdomain() {
639 let config = DnsConfig::default().with_blocked_domain("bad.com");
640 let f = DnsFilter::new(config);
641 assert!(f.is_blocked("sub.bad.com"));
642 assert!(f.is_blocked("deep.sub.bad.com"));
643 }
644
645 #[test]
646 fn test_is_split_dns() {
647 let config = DnsConfig::default().with_split_domain("internal");
648 let f = DnsFilter::new(config);
649 assert!(f.is_split_dns("host.internal"));
650 assert!(!f.is_split_dns("external.com"));
651 }
652
653 #[test]
654 fn test_primary_upstream() {
655 let f = DnsFilter::new(DnsConfig::cloudflare());
656 assert_eq!(f.primary_upstream(), Some(CLOUDFLARE_DNS));
657 }
658
659 #[test]
660 fn test_stats() {
661 let mut f = DnsFilter::new(
662 DnsConfig::default().with_blocked_domain("bad.com")
663 );
664 f.decide("example.com", &DnsQueryType::A);
665 f.decide("bad.com", &DnsQueryType::A);
666 assert_eq!(f.total_intercepted(), 2);
667 assert_eq!(f.total_blocked(), 1);
668 assert_eq!(f.total_forwarded(), 1);
669 }
670
671 #[test]
672 fn test_query_type_from_u16() {
673 assert_eq!(DnsQueryType::from_u16(1), DnsQueryType::A);
674 assert_eq!(DnsQueryType::from_u16(28), DnsQueryType::AAAA);
675 assert_eq!(DnsQueryType::from_u16(99), DnsQueryType::Other(99));
676 }
677
678 #[test]
679 fn test_query_type_to_u16() {
680 assert_eq!(DnsQueryType::A.to_u16(), 1);
681 assert_eq!(DnsQueryType::AAAA.to_u16(), 28);
682 assert_eq!(DnsQueryType::Other(99).to_u16(), 99);
683 }
684
685 #[test]
686 fn test_evict_expired() {
687 let config = DnsConfig {
688 cache_ttl: Duration::from_millis(1),
689 ..DnsConfig::default()
690 };
691 let mut f = DnsFilter::new(config);
692 f.cache_response("a.com", "1.1.1.1".parse().unwrap());
693 std::thread::sleep(Duration::from_millis(5));
694 f.evict_expired();
695 assert_eq!(f.cache_size(), 0);
696 }
697}