pcapsql_core/protocol/
dns.rs

1//! DNS protocol parser using simple-dns library.
2
3use std::collections::HashSet;
4
5use compact_str::CompactString;
6use simple_dns::{rdata::RData, Packet, PacketFlag, OPCODE, RCODE};
7use smallvec::SmallVec;
8
9use super::{FieldValue, ParseContext, ParseResult, Protocol};
10use crate::schema::{DataKind, FieldDescriptor};
11
12/// DNS well-known port.
13pub const DNS_PORT: u16 = 53;
14
15/// DNS protocol parser.
16#[derive(Debug, Clone, Copy)]
17pub struct DnsProtocol;
18
19impl Protocol for DnsProtocol {
20    fn name(&self) -> &'static str {
21        "dns"
22    }
23
24    fn display_name(&self) -> &'static str {
25        "DNS"
26    }
27
28    fn can_parse(&self, context: &ParseContext) -> Option<u32> {
29        // Check for DNS port in either src_port or dst_port
30        let src_port = context.hint("src_port");
31        let dst_port = context.hint("dst_port");
32
33        match (src_port, dst_port) {
34            (Some(p), _) | (_, Some(p)) if p == DNS_PORT as u64 => Some(100),
35            _ => None,
36        }
37    }
38
39    fn parse<'a>(&self, data: &'a [u8], _context: &ParseContext) -> ParseResult<'a> {
40        // Parse using simple-dns
41        let packet = match Packet::parse(data) {
42            Ok(p) => p,
43            Err(e) => return ParseResult::error(format!("DNS parse error: {e}"), data),
44        };
45
46        let mut fields = SmallVec::new();
47
48        // Extract header fields
49        extract_header_fields(&packet, &mut fields);
50
51        // Extract question fields (first question only)
52        extract_question_fields(&packet, &mut fields);
53
54        // Extract answer fields (as lists)
55        extract_answer_fields(&packet, &mut fields);
56
57        // Extract EDNS fields
58        extract_edns_fields(&packet, &mut fields);
59
60        ParseResult::success(fields, &[], SmallVec::new())
61    }
62
63    fn schema_fields(&self) -> Vec<FieldDescriptor> {
64        vec![
65            // Header fields (cheap)
66            FieldDescriptor::new("dns.transaction_id", DataKind::UInt16).set_nullable(true),
67            FieldDescriptor::new("dns.is_query", DataKind::Bool).set_nullable(true),
68            FieldDescriptor::new("dns.opcode", DataKind::UInt8).set_nullable(true),
69            FieldDescriptor::new("dns.is_authoritative", DataKind::Bool).set_nullable(true),
70            FieldDescriptor::new("dns.is_truncated", DataKind::Bool).set_nullable(true),
71            FieldDescriptor::new("dns.recursion_desired", DataKind::Bool).set_nullable(true),
72            FieldDescriptor::new("dns.recursion_available", DataKind::Bool).set_nullable(true),
73            FieldDescriptor::new("dns.response_code", DataKind::UInt8).set_nullable(true),
74            FieldDescriptor::new("dns.query_count", DataKind::UInt16).set_nullable(true),
75            FieldDescriptor::new("dns.answer_count", DataKind::UInt16).set_nullable(true),
76            FieldDescriptor::new("dns.authority_count", DataKind::UInt16).set_nullable(true),
77            FieldDescriptor::new("dns.additional_count", DataKind::UInt16).set_nullable(true),
78            // Question fields (expensive)
79            FieldDescriptor::new("dns.query_name", DataKind::String).set_nullable(true),
80            FieldDescriptor::new("dns.query_type", DataKind::UInt16).set_nullable(true),
81            FieldDescriptor::new("dns.query_class", DataKind::UInt16).set_nullable(true),
82            // Answer fields (lists) - NEW
83            FieldDescriptor::new(
84                "dns.answer_ip4s",
85                DataKind::List(Box::new(DataKind::UInt32)),
86            )
87            .set_nullable(true),
88            FieldDescriptor::new(
89                "dns.answer_ip6s",
90                DataKind::List(Box::new(DataKind::FixedBinary(16))),
91            )
92            .set_nullable(true),
93            FieldDescriptor::new(
94                "dns.answer_cnames",
95                DataKind::List(Box::new(DataKind::String)),
96            )
97            .set_nullable(true),
98            FieldDescriptor::new(
99                "dns.answer_types",
100                DataKind::List(Box::new(DataKind::UInt16)),
101            )
102            .set_nullable(true),
103            FieldDescriptor::new(
104                "dns.answer_ttls",
105                DataKind::List(Box::new(DataKind::UInt32)),
106            )
107            .set_nullable(true),
108            // EDNS fields - NEW
109            FieldDescriptor::new("dns.has_edns", DataKind::Bool).set_nullable(true),
110            FieldDescriptor::new("dns.edns_udp_size", DataKind::UInt16).set_nullable(true),
111        ]
112    }
113
114    fn child_protocols(&self) -> &[&'static str] {
115        &[]
116    }
117
118    fn dependencies(&self) -> &'static [&'static str] {
119        &["udp", "tcp"] // DNS runs over UDP (primarily) and TCP
120    }
121
122    fn parse_projected<'a>(
123        &self,
124        data: &'a [u8],
125        _context: &ParseContext,
126        requested_fields: Option<&HashSet<String>>,
127    ) -> ParseResult<'a> {
128        // If no projection, use full parse
129        let requested = match requested_fields {
130            None => return self.parse(data, _context),
131            Some(f) if f.is_empty() => return self.parse(data, _context),
132            Some(f) => f,
133        };
134
135        // Parse using simple-dns
136        let packet = match Packet::parse(data) {
137            Ok(p) => p,
138            Err(e) => return ParseResult::error(format!("DNS parse error: {e}"), data),
139        };
140
141        let mut fields = SmallVec::new();
142
143        // Check which field categories are needed
144        let need_header = requested.iter().any(|f| is_header_field(f));
145        let need_question = requested.iter().any(|f| is_question_field(f));
146        let need_answers = requested.iter().any(|f| is_answer_field(f));
147        let need_edns = requested.iter().any(|f| is_edns_field(f));
148
149        if need_header {
150            extract_header_fields_projected(&packet, &mut fields, requested);
151        }
152
153        if need_question {
154            extract_question_fields_projected(&packet, &mut fields, requested);
155        }
156
157        if need_answers {
158            extract_answer_fields_projected(&packet, &mut fields, requested);
159        }
160
161        if need_edns {
162            extract_edns_fields_projected(&packet, &mut fields, requested);
163        }
164
165        ParseResult::success(fields, &[], SmallVec::new())
166    }
167
168    fn cheap_fields(&self) -> &'static [&'static str] {
169        // Header fields are all cheap
170        &[
171            "transaction_id",
172            "is_query",
173            "opcode",
174            "is_authoritative",
175            "is_truncated",
176            "recursion_desired",
177            "recursion_available",
178            "response_code",
179            "query_count",
180            "answer_count",
181            "authority_count",
182            "additional_count",
183            "has_edns",
184            "edns_udp_size",
185        ]
186    }
187
188    fn expensive_fields(&self) -> &'static [&'static str] {
189        // Fields requiring parsing variable-length sections
190        &[
191            "query_name",
192            "query_type",
193            "query_class",
194            "answer_ip4s",
195            "answer_ip6s",
196            "answer_cnames",
197            "answer_types",
198            "answer_ttls",
199        ]
200    }
201}
202
203/// Check if field is a header field.
204fn is_header_field(field: &str) -> bool {
205    matches!(
206        field,
207        "transaction_id"
208            | "is_query"
209            | "opcode"
210            | "is_authoritative"
211            | "is_truncated"
212            | "recursion_desired"
213            | "recursion_available"
214            | "response_code"
215            | "query_count"
216            | "answer_count"
217            | "authority_count"
218            | "additional_count"
219    )
220}
221
222/// Check if field is a question field.
223fn is_question_field(field: &str) -> bool {
224    matches!(field, "query_name" | "query_type" | "query_class")
225}
226
227/// Check if field is an answer field.
228fn is_answer_field(field: &str) -> bool {
229    matches!(
230        field,
231        "answer_ip4s" | "answer_ip6s" | "answer_cnames" | "answer_types" | "answer_ttls"
232    )
233}
234
235/// Check if field is an EDNS field.
236fn is_edns_field(field: &str) -> bool {
237    matches!(field, "has_edns" | "edns_udp_size")
238}
239
240/// Convert OPCODE to u8.
241fn opcode_to_u8(opcode: OPCODE) -> u8 {
242    match opcode {
243        OPCODE::StandardQuery => 0,
244        OPCODE::InverseQuery => 1,
245        OPCODE::ServerStatusRequest => 2,
246        OPCODE::Notify => 4,
247        OPCODE::Update => 5,
248        OPCODE::Reserved => 15, // Reserved
249    }
250}
251
252/// Convert RCODE to u8.
253fn rcode_to_u8(rcode: RCODE) -> u8 {
254    match rcode {
255        RCODE::NoError => 0,
256        RCODE::FormatError => 1,
257        RCODE::ServerFailure => 2,
258        RCODE::NameError => 3,
259        RCODE::NotImplemented => 4,
260        RCODE::Refused => 5,
261        RCODE::YXDOMAIN => 6,
262        RCODE::YXRRSET => 7,
263        RCODE::NXRRSET => 8,
264        RCODE::NOTAUTH => 9,
265        RCODE::NOTZONE => 10,
266        RCODE::BADVERS => 16,
267        RCODE::Reserved => 15,
268    }
269}
270
271/// Extract header fields from a DNS packet.
272fn extract_header_fields(packet: &Packet, fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
273    fields.push(("transaction_id", FieldValue::UInt16(packet.id())));
274    fields.push((
275        "is_query",
276        FieldValue::Bool(!packet.has_flags(PacketFlag::RESPONSE)),
277    ));
278    fields.push(("opcode", FieldValue::UInt8(opcode_to_u8(packet.opcode()))));
279    fields.push((
280        "is_authoritative",
281        FieldValue::Bool(packet.has_flags(PacketFlag::AUTHORITATIVE_ANSWER)),
282    ));
283    fields.push((
284        "is_truncated",
285        FieldValue::Bool(packet.has_flags(PacketFlag::TRUNCATION)),
286    ));
287    fields.push((
288        "recursion_desired",
289        FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_DESIRED)),
290    ));
291    fields.push((
292        "recursion_available",
293        FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_AVAILABLE)),
294    ));
295    fields.push((
296        "response_code",
297        FieldValue::UInt8(rcode_to_u8(packet.rcode())),
298    ));
299    fields.push((
300        "query_count",
301        FieldValue::UInt16(packet.questions.len() as u16),
302    ));
303    fields.push((
304        "answer_count",
305        FieldValue::UInt16(packet.answers.len() as u16),
306    ));
307    fields.push((
308        "authority_count",
309        FieldValue::UInt16(packet.name_servers.len() as u16),
310    ));
311    fields.push((
312        "additional_count",
313        FieldValue::UInt16(packet.additional_records.len() as u16),
314    ));
315}
316
317/// Extract header fields from a DNS packet (projected).
318fn extract_header_fields_projected(
319    packet: &Packet,
320    fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
321    requested: &HashSet<String>,
322) {
323    if requested.contains("transaction_id") {
324        fields.push(("transaction_id", FieldValue::UInt16(packet.id())));
325    }
326    if requested.contains("is_query") {
327        fields.push((
328            "is_query",
329            FieldValue::Bool(!packet.has_flags(PacketFlag::RESPONSE)),
330        ));
331    }
332    if requested.contains("opcode") {
333        fields.push(("opcode", FieldValue::UInt8(opcode_to_u8(packet.opcode()))));
334    }
335    if requested.contains("is_authoritative") {
336        fields.push((
337            "is_authoritative",
338            FieldValue::Bool(packet.has_flags(PacketFlag::AUTHORITATIVE_ANSWER)),
339        ));
340    }
341    if requested.contains("is_truncated") {
342        fields.push((
343            "is_truncated",
344            FieldValue::Bool(packet.has_flags(PacketFlag::TRUNCATION)),
345        ));
346    }
347    if requested.contains("recursion_desired") {
348        fields.push((
349            "recursion_desired",
350            FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_DESIRED)),
351        ));
352    }
353    if requested.contains("recursion_available") {
354        fields.push((
355            "recursion_available",
356            FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_AVAILABLE)),
357        ));
358    }
359    if requested.contains("response_code") {
360        fields.push((
361            "response_code",
362            FieldValue::UInt8(rcode_to_u8(packet.rcode())),
363        ));
364    }
365    if requested.contains("query_count") {
366        fields.push((
367            "query_count",
368            FieldValue::UInt16(packet.questions.len() as u16),
369        ));
370    }
371    if requested.contains("answer_count") {
372        fields.push((
373            "answer_count",
374            FieldValue::UInt16(packet.answers.len() as u16),
375        ));
376    }
377    if requested.contains("authority_count") {
378        fields.push((
379            "authority_count",
380            FieldValue::UInt16(packet.name_servers.len() as u16),
381        ));
382    }
383    if requested.contains("additional_count") {
384        fields.push((
385            "additional_count",
386            FieldValue::UInt16(packet.additional_records.len() as u16),
387        ));
388    }
389}
390
391/// Extract question fields from a DNS packet.
392fn extract_question_fields(
393    packet: &Packet,
394    fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
395) {
396    if let Some(question) = packet.questions.first() {
397        fields.push((
398            "query_name",
399            FieldValue::OwnedString(CompactString::new(question.qname.to_string())),
400        ));
401        fields.push(("query_type", FieldValue::UInt16(question.qtype.into())));
402        fields.push(("query_class", FieldValue::UInt16(question.qclass.into())));
403    } else {
404        fields.push(("query_name", FieldValue::Null));
405        fields.push(("query_type", FieldValue::Null));
406        fields.push(("query_class", FieldValue::Null));
407    }
408}
409
410/// Extract question fields from a DNS packet (projected).
411fn extract_question_fields_projected(
412    packet: &Packet,
413    fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
414    requested: &HashSet<String>,
415) {
416    if let Some(question) = packet.questions.first() {
417        if requested.contains("query_name") {
418            fields.push((
419                "query_name",
420                FieldValue::OwnedString(CompactString::new(question.qname.to_string())),
421            ));
422        }
423        if requested.contains("query_type") {
424            fields.push(("query_type", FieldValue::UInt16(question.qtype.into())));
425        }
426        if requested.contains("query_class") {
427            fields.push(("query_class", FieldValue::UInt16(question.qclass.into())));
428        }
429    } else {
430        if requested.contains("query_name") {
431            fields.push(("query_name", FieldValue::Null));
432        }
433        if requested.contains("query_type") {
434            fields.push(("query_type", FieldValue::Null));
435        }
436        if requested.contains("query_class") {
437            fields.push(("query_class", FieldValue::Null));
438        }
439    }
440}
441
442/// Extract answer fields from a DNS packet as lists.
443fn extract_answer_fields(packet: &Packet, fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
444    let mut ip4s: Vec<FieldValue> = Vec::new();
445    let mut ip6s: Vec<FieldValue> = Vec::new();
446    let mut cnames: Vec<FieldValue> = Vec::new();
447    let mut types: Vec<FieldValue> = Vec::new();
448    let mut ttls: Vec<FieldValue> = Vec::new();
449
450    for answer in &packet.answers {
451        // Record type
452        let rtype: u16 = answer.rdata.type_code().into();
453        types.push(FieldValue::UInt16(rtype));
454
455        // TTL
456        ttls.push(FieldValue::UInt32(answer.ttl));
457
458        // Extract type-specific data
459        match &answer.rdata {
460            RData::A(a) => {
461                ip4s.push(FieldValue::UInt32(a.address));
462            }
463            RData::AAAA(aaaa) => {
464                // Use IpAddr to avoid heap allocation for IPv6 address
465                ip6s.push(FieldValue::IpAddr(std::net::IpAddr::V6(
466                    std::net::Ipv6Addr::from(aaaa.address),
467                )));
468            }
469            RData::CNAME(cname) => {
470                cnames.push(FieldValue::OwnedString(CompactString::new(
471                    cname.0.to_string(),
472                )));
473            }
474            _ => {}
475        }
476    }
477
478    // Push list fields
479    fields.push(("answer_ip4s", FieldValue::List(ip4s)));
480    fields.push(("answer_ip6s", FieldValue::List(ip6s)));
481    fields.push(("answer_cnames", FieldValue::List(cnames)));
482    fields.push(("answer_types", FieldValue::List(types)));
483    fields.push(("answer_ttls", FieldValue::List(ttls)));
484}
485
486/// Extract answer fields from a DNS packet (projected).
487fn extract_answer_fields_projected(
488    packet: &Packet,
489    fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
490    requested: &HashSet<String>,
491) {
492    let need_ip4s = requested.contains("answer_ip4s");
493    let need_ip6s = requested.contains("answer_ip6s");
494    let need_cnames = requested.contains("answer_cnames");
495    let need_types = requested.contains("answer_types");
496    let need_ttls = requested.contains("answer_ttls");
497
498    let mut ip4s: Vec<FieldValue> = Vec::new();
499    let mut ip6s: Vec<FieldValue> = Vec::new();
500    let mut cnames: Vec<FieldValue> = Vec::new();
501    let mut types: Vec<FieldValue> = Vec::new();
502    let mut ttls: Vec<FieldValue> = Vec::new();
503
504    for answer in &packet.answers {
505        if need_types {
506            let rtype: u16 = answer.rdata.type_code().into();
507            types.push(FieldValue::UInt16(rtype));
508        }
509
510        if need_ttls {
511            ttls.push(FieldValue::UInt32(answer.ttl));
512        }
513
514        match &answer.rdata {
515            RData::A(a) if need_ip4s => {
516                ip4s.push(FieldValue::UInt32(a.address));
517            }
518            RData::AAAA(aaaa) if need_ip6s => {
519                // Use IpAddr to avoid heap allocation for IPv6 address
520                ip6s.push(FieldValue::IpAddr(std::net::IpAddr::V6(
521                    std::net::Ipv6Addr::from(aaaa.address),
522                )));
523            }
524            RData::CNAME(cname) if need_cnames => {
525                cnames.push(FieldValue::OwnedString(CompactString::new(
526                    cname.0.to_string(),
527                )));
528            }
529            _ => {}
530        }
531    }
532
533    if need_ip4s {
534        fields.push(("answer_ip4s", FieldValue::List(ip4s)));
535    }
536    if need_ip6s {
537        fields.push(("answer_ip6s", FieldValue::List(ip6s)));
538    }
539    if need_cnames {
540        fields.push(("answer_cnames", FieldValue::List(cnames)));
541    }
542    if need_types {
543        fields.push(("answer_types", FieldValue::List(types)));
544    }
545    if need_ttls {
546        fields.push(("answer_ttls", FieldValue::List(ttls)));
547    }
548}
549
550/// Extract EDNS fields from a DNS packet.
551fn extract_edns_fields(packet: &Packet, fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
552    if let Some(opt) = packet.opt() {
553        fields.push(("has_edns", FieldValue::Bool(true)));
554        fields.push(("edns_udp_size", FieldValue::UInt16(opt.udp_packet_size)));
555    } else {
556        fields.push(("has_edns", FieldValue::Bool(false)));
557        fields.push(("edns_udp_size", FieldValue::Null));
558    }
559}
560
561/// Extract EDNS fields from a DNS packet (projected).
562fn extract_edns_fields_projected(
563    packet: &Packet,
564    fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
565    requested: &HashSet<String>,
566) {
567    let has_edns = packet.opt().is_some();
568
569    if requested.contains("has_edns") {
570        fields.push(("has_edns", FieldValue::Bool(has_edns)));
571    }
572
573    if requested.contains("edns_udp_size") {
574        if let Some(opt) = packet.opt() {
575            fields.push(("edns_udp_size", FieldValue::UInt16(opt.udp_packet_size)));
576        } else {
577            fields.push(("edns_udp_size", FieldValue::Null));
578        }
579    }
580}
581
582/// DNS record types.
583#[allow(dead_code)]
584pub mod record_type {
585    pub const A: u16 = 1;
586    pub const NS: u16 = 2;
587    pub const CNAME: u16 = 5;
588    pub const SOA: u16 = 6;
589    pub const PTR: u16 = 12;
590    pub const MX: u16 = 15;
591    pub const TXT: u16 = 16;
592    pub const AAAA: u16 = 28;
593    pub const SRV: u16 = 33;
594    pub const OPT: u16 = 41;
595    pub const ANY: u16 = 255;
596}
597
598/// DNS response codes (RFC 1035, RFC 2136, RFC 2845, RFC 6895).
599#[allow(dead_code)]
600pub mod rcode {
601    // Standard RCODEs (RFC 1035)
602    pub const NOERROR: u8 = 0;
603    pub const FORMERR: u8 = 1;
604    pub const SERVFAIL: u8 = 2;
605    pub const NXDOMAIN: u8 = 3;
606    pub const NOTIMP: u8 = 4;
607    pub const REFUSED: u8 = 5;
608
609    // Extended RCODEs (RFC 2136)
610    pub const YXDOMAIN: u8 = 6;
611    pub const YXRRSET: u8 = 7;
612    pub const NXRRSET: u8 = 8;
613    pub const NOTAUTH: u8 = 9;
614    pub const NOTZONE: u8 = 10;
615
616    // EDNS extended RCODEs (RFC 6891)
617    pub const DSOTYPENI: u8 = 11;
618
619    // TSIG/TKEY RCODEs (RFC 2845, RFC 2930)
620    pub const BADVERS: u8 = 16;
621    pub const BADKEY: u8 = 17;
622    pub const BADTIME: u8 = 18;
623    pub const BADMODE: u8 = 19;
624    pub const BADNAME: u8 = 20;
625    pub const BADALG: u8 = 21;
626    pub const BADTRUNC: u8 = 22;
627    pub const BADCOOKIE: u8 = 23;
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633    use std::net::Ipv4Addr;
634
635    /// Encode a domain name in DNS format.
636    fn encode_domain_name(name: &str) -> Vec<u8> {
637        let mut result = Vec::new();
638        for part in name.split('.') {
639            if !part.is_empty() {
640                result.push(part.len() as u8);
641                result.extend_from_slice(part.as_bytes());
642            }
643        }
644        result.push(0); // Null terminator
645        result
646    }
647
648    /// Create a minimal DNS query header.
649    fn create_dns_query(transaction_id: u16, query_name: &[u8]) -> Vec<u8> {
650        let mut packet = Vec::new();
651
652        // Transaction ID
653        packet.extend_from_slice(&transaction_id.to_be_bytes());
654
655        // Flags: Standard query (0x0100 = RD set)
656        packet.extend_from_slice(&[0x01, 0x00]);
657
658        // Question count: 1
659        packet.extend_from_slice(&[0x00, 0x01]);
660
661        // Answer, Authority, Additional counts: 0
662        packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
663
664        // Query name
665        packet.extend_from_slice(query_name);
666
667        // QTYPE: A (1)
668        packet.extend_from_slice(&[0x00, 0x01]);
669
670        // QCLASS: IN (1)
671        packet.extend_from_slice(&[0x00, 0x01]);
672
673        packet
674    }
675
676    /// Create a DNS response header with an A record answer.
677    fn create_dns_response_with_answer(transaction_id: u16, ip: [u8; 4]) -> Vec<u8> {
678        let mut packet = Vec::new();
679
680        // Transaction ID
681        packet.extend_from_slice(&transaction_id.to_be_bytes());
682
683        // Flags: Response (0x8180 = QR set, RD set, RA set)
684        packet.extend_from_slice(&[0x81, 0x80]);
685
686        // Question count: 1
687        packet.extend_from_slice(&[0x00, 0x01]);
688
689        // Answer count: 1
690        packet.extend_from_slice(&[0x00, 0x01]);
691
692        // Authority, Additional counts: 0
693        packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
694
695        // Question section (example.com)
696        packet.extend_from_slice(&[
697            0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', // "example"
698            0x03, b'c', b'o', b'm', // "com"
699            0x00, // null terminator
700        ]);
701
702        // QTYPE: A, QCLASS: IN
703        packet.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]);
704
705        // Answer section
706        // Name: compression pointer to question name
707        packet.extend_from_slice(&[0xC0, 0x0C]);
708
709        // TYPE: A (1)
710        packet.extend_from_slice(&[0x00, 0x01]);
711
712        // CLASS: IN (1)
713        packet.extend_from_slice(&[0x00, 0x01]);
714
715        // TTL: 300 seconds
716        packet.extend_from_slice(&[0x00, 0x00, 0x01, 0x2C]);
717
718        // RDLENGTH: 4
719        packet.extend_from_slice(&[0x00, 0x04]);
720
721        // RDATA: IP address
722        packet.extend_from_slice(&ip);
723
724        packet
725    }
726
727    #[test]
728    fn test_parse_dns_query() {
729        let query_name = encode_domain_name("example.com");
730        let packet = create_dns_query(0x1234, &query_name);
731
732        let parser = DnsProtocol;
733        let mut context = ParseContext::new(1);
734        context.insert_hint("dst_port", 53);
735        context.parent_protocol = Some("udp");
736
737        let result = parser.parse(&packet, &context);
738
739        assert!(result.is_ok());
740        assert_eq!(
741            result.get("transaction_id"),
742            Some(&FieldValue::UInt16(0x1234))
743        );
744        assert_eq!(result.get("is_query"), Some(&FieldValue::Bool(true)));
745        assert_eq!(result.get("query_count"), Some(&FieldValue::UInt16(1)));
746        assert_eq!(result.get("answer_count"), Some(&FieldValue::UInt16(0)));
747        assert_eq!(
748            result.get("recursion_desired"),
749            Some(&FieldValue::Bool(true))
750        );
751        assert_eq!(
752            result.get("query_name"),
753            Some(&FieldValue::OwnedString(CompactString::new("example.com")))
754        );
755        assert_eq!(result.get("query_type"), Some(&FieldValue::UInt16(1))); // A record
756        assert_eq!(result.get("query_class"), Some(&FieldValue::UInt16(1))); // IN class
757    }
758
759    #[test]
760    fn test_parse_dns_response_with_answer() {
761        let packet = create_dns_response_with_answer(0xABCD, [93, 184, 216, 34]);
762
763        let parser = DnsProtocol;
764        let mut context = ParseContext::new(1);
765        context.insert_hint("src_port", 53);
766        context.parent_protocol = Some("udp");
767
768        let result = parser.parse(&packet, &context);
769
770        assert!(result.is_ok());
771        assert_eq!(
772            result.get("transaction_id"),
773            Some(&FieldValue::UInt16(0xABCD))
774        );
775        assert_eq!(result.get("is_query"), Some(&FieldValue::Bool(false)));
776        assert_eq!(result.get("answer_count"), Some(&FieldValue::UInt16(1)));
777        assert_eq!(result.get("response_code"), Some(&FieldValue::UInt8(0)));
778
779        // Check answer_ip4s list
780        if let Some(FieldValue::List(ip4s)) = result.get("answer_ip4s") {
781            assert_eq!(ip4s.len(), 1);
782            // 93.184.216.34 as u32: (93 << 24) | (184 << 16) | (216 << 8) | 34
783            let expected_ip = u32::from(Ipv4Addr::new(93, 184, 216, 34));
784            assert_eq!(ip4s[0], FieldValue::UInt32(expected_ip));
785        } else {
786            panic!("Expected answer_ip4s to be a list");
787        }
788
789        // Check answer_types list
790        if let Some(FieldValue::List(types)) = result.get("answer_types") {
791            assert_eq!(types.len(), 1);
792            assert_eq!(types[0], FieldValue::UInt16(1)); // A record
793        } else {
794            panic!("Expected answer_types to be a list");
795        }
796
797        // Check answer_ttls list
798        if let Some(FieldValue::List(ttls)) = result.get("answer_ttls") {
799            assert_eq!(ttls.len(), 1);
800            assert_eq!(ttls[0], FieldValue::UInt32(300)); // 300 seconds TTL
801        } else {
802            panic!("Expected answer_ttls to be a list");
803        }
804    }
805
806    #[test]
807    fn test_can_parse_dns() {
808        let parser = DnsProtocol;
809
810        // Without hint
811        let ctx1 = ParseContext::new(1);
812        assert!(parser.can_parse(&ctx1).is_none());
813
814        // With dst_port 53
815        let mut ctx2 = ParseContext::new(1);
816        ctx2.insert_hint("dst_port", 53);
817        assert!(parser.can_parse(&ctx2).is_some());
818
819        // With src_port 53
820        let mut ctx3 = ParseContext::new(1);
821        ctx3.insert_hint("src_port", 53);
822        assert!(parser.can_parse(&ctx3).is_some());
823
824        // With different port
825        let mut ctx4 = ParseContext::new(1);
826        ctx4.insert_hint("dst_port", 80);
827        assert!(parser.can_parse(&ctx4).is_none());
828    }
829
830    #[test]
831    fn test_parse_dns_too_short() {
832        let short_packet = [0x12, 0x34, 0x00, 0x00]; // Only 4 bytes
833
834        let parser = DnsProtocol;
835        let mut context = ParseContext::new(1);
836        context.insert_hint("dst_port", 53);
837
838        let result = parser.parse(&short_packet, &context);
839
840        assert!(!result.is_ok());
841        assert!(result.error.is_some());
842    }
843
844    #[test]
845    fn test_dns_schema_fields() {
846        let parser = DnsProtocol;
847        let fields = parser.schema_fields();
848
849        assert!(!fields.is_empty());
850
851        let field_names: Vec<&str> = fields.iter().map(|f| f.name).collect();
852        assert!(field_names.contains(&"dns.transaction_id"));
853        assert!(field_names.contains(&"dns.is_query"));
854        assert!(field_names.contains(&"dns.query_name"));
855        assert!(field_names.contains(&"dns.query_type"));
856        // New fields
857        assert!(field_names.contains(&"dns.answer_ip4s"));
858        assert!(field_names.contains(&"dns.answer_ip6s"));
859        assert!(field_names.contains(&"dns.answer_cnames"));
860        assert!(field_names.contains(&"dns.has_edns"));
861    }
862
863    #[test]
864    fn test_dns_projected_header_only() {
865        let query_name = encode_domain_name("example.com");
866        let packet = create_dns_query(0x1234, &query_name);
867
868        let parser = DnsProtocol;
869        let mut context = ParseContext::new(1);
870        context.insert_hint("dst_port", 53);
871
872        // Only request header fields - skip expensive parsing
873        let fields: HashSet<String> = ["transaction_id", "is_query", "response_code"]
874            .iter()
875            .map(|s| s.to_string())
876            .collect();
877        let result = parser.parse_projected(&packet, &context, Some(&fields));
878
879        assert!(result.is_ok());
880        // Requested fields are present
881        assert_eq!(
882            result.get("transaction_id"),
883            Some(&FieldValue::UInt16(0x1234))
884        );
885        assert_eq!(result.get("is_query"), Some(&FieldValue::Bool(true)));
886        assert_eq!(result.get("response_code"), Some(&FieldValue::UInt8(0)));
887        // Expensive fields are NOT present
888        assert!(result.get("query_name").is_none());
889        assert!(result.get("answer_ip4s").is_none());
890    }
891
892    #[test]
893    fn test_dns_aaaa_query() {
894        let mut packet = Vec::new();
895
896        // Transaction ID
897        packet.extend_from_slice(&[0x12, 0x34]);
898
899        // Flags: Standard query with RD
900        packet.extend_from_slice(&[0x01, 0x00]);
901
902        // Counts: 1 question
903        packet.extend_from_slice(&[0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
904
905        // Query name: ipv6.google.com
906        packet.extend_from_slice(&[
907            0x04, b'i', b'p', b'v', b'6', // "ipv6"
908            0x06, b'g', b'o', b'o', b'g', b'l', b'e', // "google"
909            0x03, b'c', b'o', b'm', // "com"
910            0x00, // null terminator
911        ]);
912
913        // QTYPE: AAAA (28)
914        packet.extend_from_slice(&[0x00, 0x1C]);
915
916        // QCLASS: IN (1)
917        packet.extend_from_slice(&[0x00, 0x01]);
918
919        let parser = DnsProtocol;
920        let mut context = ParseContext::new(1);
921        context.insert_hint("dst_port", 53);
922
923        let result = parser.parse(&packet, &context);
924
925        assert!(result.is_ok());
926        assert_eq!(
927            result.get("query_name"),
928            Some(&FieldValue::OwnedString(CompactString::new(
929                "ipv6.google.com"
930            )))
931        );
932        assert_eq!(
933            result.get("query_type"),
934            Some(&FieldValue::UInt16(record_type::AAAA))
935        );
936    }
937}