Skip to main content

hush_proxy/
dns.rs

1//! DNS packet parsing and domain extraction
2//!
3//! Provides utilities for extracting domain names from DNS queries
4//! for egress filtering.
5
6use crate::error::{Error, Result};
7use globset::GlobBuilder;
8
9/// Extract domain name from a DNS query packet
10pub fn extract_domain_from_query(packet: &[u8]) -> Result<Option<String>> {
11    // DNS header is 12 bytes minimum
12    if packet.len() < 12 {
13        return Err(Error::DnsParseError(
14            "Packet too short for DNS header".into(),
15        ));
16    }
17
18    // Check if it's a query (QR bit = 0)
19    let qr = (packet[2] >> 7) & 1;
20    if qr != 0 {
21        // This is a response, not a query
22        return Ok(None);
23    }
24
25    // Get question count
26    let qdcount = u16::from_be_bytes([packet[4], packet[5]]) as usize;
27    if qdcount == 0 {
28        return Ok(None);
29    }
30
31    // Parse the first question
32    let mut offset = 12;
33    let mut labels = Vec::new();
34
35    loop {
36        if offset >= packet.len() {
37            return Err(Error::DnsParseError("Unexpected end of packet".into()));
38        }
39
40        let len = packet[offset] as usize;
41        if len == 0 {
42            break;
43        }
44
45        // Check for compression pointer (starts with 0b11)
46        if len & 0xC0 == 0xC0 {
47            return Err(Error::DnsParseError(
48                "Compression pointers not supported in queries".into(),
49            ));
50        }
51
52        if len > 63 {
53            return Err(Error::DnsParseError("Label too long".into()));
54        }
55
56        offset += 1;
57        if offset + len > packet.len() {
58            return Err(Error::DnsParseError("Label extends beyond packet".into()));
59        }
60
61        let label = std::str::from_utf8(&packet[offset..offset + len])
62            .map_err(|_| Error::DnsParseError("Invalid UTF-8 in label".into()))?;
63        labels.push(label.to_string());
64        offset += len;
65    }
66
67    if labels.is_empty() {
68        return Ok(None);
69    }
70
71    Ok(Some(labels.join(".")))
72}
73
74/// Check if a domain matches a pattern (supports wildcards)
75pub fn domain_matches(domain: &str, pattern: &str) -> bool {
76    let Ok(glob) = GlobBuilder::new(pattern)
77        .case_insensitive(true)
78        .literal_separator(true)
79        .build()
80    else {
81        return false;
82    };
83
84    glob.compile_matcher().is_match(domain)
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_domain_matches_exact() {
93        assert!(domain_matches("example.com", "example.com"));
94        assert!(domain_matches("Example.COM", "example.com"));
95        assert!(!domain_matches("other.com", "example.com"));
96    }
97
98    #[test]
99    fn test_domain_matches_wildcard() {
100        assert!(domain_matches("sub.example.com", "*.example.com"));
101        assert!(domain_matches("deep.sub.example.com", "*.example.com"));
102        assert!(!domain_matches("example.com", "*.example.com"));
103        assert!(!domain_matches("example.org", "*.example.com"));
104    }
105
106    #[test]
107    fn test_domain_matches_glob_features() {
108        assert!(domain_matches("api-1.example.com", "api-?.example.com"));
109        assert!(domain_matches("api-a.example.com", "api-[a-z].example.com"));
110        assert!(!domain_matches("api-aa.example.com", "api-?.example.com"));
111    }
112
113    #[test]
114    fn test_extract_domain_short_packet() {
115        let result = extract_domain_from_query(&[0; 5]);
116        assert!(result.is_err());
117    }
118}