1use crate::error::{Error, Result};
7use globset::GlobBuilder;
8
9pub fn extract_domain_from_query(packet: &[u8]) -> Result<Option<String>> {
11 if packet.len() < 12 {
13 return Err(Error::DnsParseError(
14 "Packet too short for DNS header".into(),
15 ));
16 }
17
18 let qr = (packet[2] >> 7) & 1;
20 if qr != 0 {
21 return Ok(None);
23 }
24
25 let qdcount = u16::from_be_bytes([packet[4], packet[5]]) as usize;
27 if qdcount == 0 {
28 return Ok(None);
29 }
30
31 let mut offset = 12;
33 let mut labels = Vec::new();
34
35 loop {
36 if offset >= packet.len() {
37 return Err(Error::DnsParseError("Unexpected end of packet".into()));
38 }
39
40 let len = packet[offset] as usize;
41 if len == 0 {
42 break;
43 }
44
45 if len & 0xC0 == 0xC0 {
47 return Err(Error::DnsParseError(
48 "Compression pointers not supported in queries".into(),
49 ));
50 }
51
52 if len > 63 {
53 return Err(Error::DnsParseError("Label too long".into()));
54 }
55
56 offset += 1;
57 if offset + len > packet.len() {
58 return Err(Error::DnsParseError("Label extends beyond packet".into()));
59 }
60
61 let label = std::str::from_utf8(&packet[offset..offset + len])
62 .map_err(|_| Error::DnsParseError("Invalid UTF-8 in label".into()))?;
63 labels.push(label.to_string());
64 offset += len;
65 }
66
67 if labels.is_empty() {
68 return Ok(None);
69 }
70
71 Ok(Some(labels.join(".")))
72}
73
74pub fn domain_matches(domain: &str, pattern: &str) -> bool {
76 let Ok(glob) = GlobBuilder::new(pattern)
77 .case_insensitive(true)
78 .literal_separator(true)
79 .build()
80 else {
81 return false;
82 };
83
84 glob.compile_matcher().is_match(domain)
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_domain_matches_exact() {
93 assert!(domain_matches("example.com", "example.com"));
94 assert!(domain_matches("Example.COM", "example.com"));
95 assert!(!domain_matches("other.com", "example.com"));
96 }
97
98 #[test]
99 fn test_domain_matches_wildcard() {
100 assert!(domain_matches("sub.example.com", "*.example.com"));
101 assert!(domain_matches("deep.sub.example.com", "*.example.com"));
102 assert!(!domain_matches("example.com", "*.example.com"));
103 assert!(!domain_matches("example.org", "*.example.com"));
104 }
105
106 #[test]
107 fn test_domain_matches_glob_features() {
108 assert!(domain_matches("api-1.example.com", "api-?.example.com"));
109 assert!(domain_matches("api-a.example.com", "api-[a-z].example.com"));
110 assert!(!domain_matches("api-aa.example.com", "api-?.example.com"));
111 }
112
113 #[test]
114 fn test_extract_domain_short_packet() {
115 let result = extract_domain_from_query(&[0; 5]);
116 assert!(result.is_err());
117 }
118}