1use std::collections::HashSet;
4
5use compact_str::CompactString;
6use simple_dns::{rdata::RData, Packet, PacketFlag, OPCODE, RCODE};
7use smallvec::SmallVec;
8
9use super::{FieldValue, ParseContext, ParseResult, Protocol};
10use crate::schema::{DataKind, FieldDescriptor};
11
12pub const DNS_PORT: u16 = 53;
14
15#[derive(Debug, Clone, Copy)]
17pub struct DnsProtocol;
18
19impl Protocol for DnsProtocol {
20 fn name(&self) -> &'static str {
21 "dns"
22 }
23
24 fn display_name(&self) -> &'static str {
25 "DNS"
26 }
27
28 fn can_parse(&self, context: &ParseContext) -> Option<u32> {
29 let src_port = context.hint("src_port");
31 let dst_port = context.hint("dst_port");
32
33 match (src_port, dst_port) {
34 (Some(p), _) | (_, Some(p)) if p == DNS_PORT as u64 => Some(100),
35 _ => None,
36 }
37 }
38
39 fn parse<'a>(&self, data: &'a [u8], _context: &ParseContext) -> ParseResult<'a> {
40 let packet = match Packet::parse(data) {
42 Ok(p) => p,
43 Err(e) => return ParseResult::error(format!("DNS parse error: {e}"), data),
44 };
45
46 let mut fields = SmallVec::new();
47
48 extract_header_fields(&packet, &mut fields);
50
51 extract_question_fields(&packet, &mut fields);
53
54 extract_answer_fields(&packet, &mut fields);
56
57 extract_edns_fields(&packet, &mut fields);
59
60 ParseResult::success(fields, &[], SmallVec::new())
61 }
62
63 fn schema_fields(&self) -> Vec<FieldDescriptor> {
64 vec![
65 FieldDescriptor::new("dns.transaction_id", DataKind::UInt16).set_nullable(true),
67 FieldDescriptor::new("dns.is_query", DataKind::Bool).set_nullable(true),
68 FieldDescriptor::new("dns.opcode", DataKind::UInt8).set_nullable(true),
69 FieldDescriptor::new("dns.is_authoritative", DataKind::Bool).set_nullable(true),
70 FieldDescriptor::new("dns.is_truncated", DataKind::Bool).set_nullable(true),
71 FieldDescriptor::new("dns.recursion_desired", DataKind::Bool).set_nullable(true),
72 FieldDescriptor::new("dns.recursion_available", DataKind::Bool).set_nullable(true),
73 FieldDescriptor::new("dns.response_code", DataKind::UInt8).set_nullable(true),
74 FieldDescriptor::new("dns.query_count", DataKind::UInt16).set_nullable(true),
75 FieldDescriptor::new("dns.answer_count", DataKind::UInt16).set_nullable(true),
76 FieldDescriptor::new("dns.authority_count", DataKind::UInt16).set_nullable(true),
77 FieldDescriptor::new("dns.additional_count", DataKind::UInt16).set_nullable(true),
78 FieldDescriptor::new("dns.query_name", DataKind::String).set_nullable(true),
80 FieldDescriptor::new("dns.query_type", DataKind::UInt16).set_nullable(true),
81 FieldDescriptor::new("dns.query_class", DataKind::UInt16).set_nullable(true),
82 FieldDescriptor::new(
84 "dns.answer_ip4s",
85 DataKind::List(Box::new(DataKind::UInt32)),
86 )
87 .set_nullable(true),
88 FieldDescriptor::new(
89 "dns.answer_ip6s",
90 DataKind::List(Box::new(DataKind::FixedBinary(16))),
91 )
92 .set_nullable(true),
93 FieldDescriptor::new(
94 "dns.answer_cnames",
95 DataKind::List(Box::new(DataKind::String)),
96 )
97 .set_nullable(true),
98 FieldDescriptor::new(
99 "dns.answer_types",
100 DataKind::List(Box::new(DataKind::UInt16)),
101 )
102 .set_nullable(true),
103 FieldDescriptor::new(
104 "dns.answer_ttls",
105 DataKind::List(Box::new(DataKind::UInt32)),
106 )
107 .set_nullable(true),
108 FieldDescriptor::new("dns.has_edns", DataKind::Bool).set_nullable(true),
110 FieldDescriptor::new("dns.edns_udp_size", DataKind::UInt16).set_nullable(true),
111 ]
112 }
113
114 fn child_protocols(&self) -> &[&'static str] {
115 &[]
116 }
117
118 fn dependencies(&self) -> &'static [&'static str] {
119 &["udp", "tcp"] }
121
122 fn parse_projected<'a>(
123 &self,
124 data: &'a [u8],
125 _context: &ParseContext,
126 requested_fields: Option<&HashSet<String>>,
127 ) -> ParseResult<'a> {
128 let requested = match requested_fields {
130 None => return self.parse(data, _context),
131 Some(f) if f.is_empty() => return self.parse(data, _context),
132 Some(f) => f,
133 };
134
135 let packet = match Packet::parse(data) {
137 Ok(p) => p,
138 Err(e) => return ParseResult::error(format!("DNS parse error: {e}"), data),
139 };
140
141 let mut fields = SmallVec::new();
142
143 let need_header = requested.iter().any(|f| is_header_field(f));
145 let need_question = requested.iter().any(|f| is_question_field(f));
146 let need_answers = requested.iter().any(|f| is_answer_field(f));
147 let need_edns = requested.iter().any(|f| is_edns_field(f));
148
149 if need_header {
150 extract_header_fields_projected(&packet, &mut fields, requested);
151 }
152
153 if need_question {
154 extract_question_fields_projected(&packet, &mut fields, requested);
155 }
156
157 if need_answers {
158 extract_answer_fields_projected(&packet, &mut fields, requested);
159 }
160
161 if need_edns {
162 extract_edns_fields_projected(&packet, &mut fields, requested);
163 }
164
165 ParseResult::success(fields, &[], SmallVec::new())
166 }
167
168 fn cheap_fields(&self) -> &'static [&'static str] {
169 &[
171 "transaction_id",
172 "is_query",
173 "opcode",
174 "is_authoritative",
175 "is_truncated",
176 "recursion_desired",
177 "recursion_available",
178 "response_code",
179 "query_count",
180 "answer_count",
181 "authority_count",
182 "additional_count",
183 "has_edns",
184 "edns_udp_size",
185 ]
186 }
187
188 fn expensive_fields(&self) -> &'static [&'static str] {
189 &[
191 "query_name",
192 "query_type",
193 "query_class",
194 "answer_ip4s",
195 "answer_ip6s",
196 "answer_cnames",
197 "answer_types",
198 "answer_ttls",
199 ]
200 }
201}
202
203fn is_header_field(field: &str) -> bool {
205 matches!(
206 field,
207 "transaction_id"
208 | "is_query"
209 | "opcode"
210 | "is_authoritative"
211 | "is_truncated"
212 | "recursion_desired"
213 | "recursion_available"
214 | "response_code"
215 | "query_count"
216 | "answer_count"
217 | "authority_count"
218 | "additional_count"
219 )
220}
221
222fn is_question_field(field: &str) -> bool {
224 matches!(field, "query_name" | "query_type" | "query_class")
225}
226
227fn is_answer_field(field: &str) -> bool {
229 matches!(
230 field,
231 "answer_ip4s" | "answer_ip6s" | "answer_cnames" | "answer_types" | "answer_ttls"
232 )
233}
234
235fn is_edns_field(field: &str) -> bool {
237 matches!(field, "has_edns" | "edns_udp_size")
238}
239
240fn opcode_to_u8(opcode: OPCODE) -> u8 {
242 match opcode {
243 OPCODE::StandardQuery => 0,
244 OPCODE::InverseQuery => 1,
245 OPCODE::ServerStatusRequest => 2,
246 OPCODE::Notify => 4,
247 OPCODE::Update => 5,
248 OPCODE::Reserved => 15, }
250}
251
252fn rcode_to_u8(rcode: RCODE) -> u8 {
254 match rcode {
255 RCODE::NoError => 0,
256 RCODE::FormatError => 1,
257 RCODE::ServerFailure => 2,
258 RCODE::NameError => 3,
259 RCODE::NotImplemented => 4,
260 RCODE::Refused => 5,
261 RCODE::YXDOMAIN => 6,
262 RCODE::YXRRSET => 7,
263 RCODE::NXRRSET => 8,
264 RCODE::NOTAUTH => 9,
265 RCODE::NOTZONE => 10,
266 RCODE::BADVERS => 16,
267 RCODE::Reserved => 15,
268 }
269}
270
271fn extract_header_fields(packet: &Packet, fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
273 fields.push(("transaction_id", FieldValue::UInt16(packet.id())));
274 fields.push((
275 "is_query",
276 FieldValue::Bool(!packet.has_flags(PacketFlag::RESPONSE)),
277 ));
278 fields.push(("opcode", FieldValue::UInt8(opcode_to_u8(packet.opcode()))));
279 fields.push((
280 "is_authoritative",
281 FieldValue::Bool(packet.has_flags(PacketFlag::AUTHORITATIVE_ANSWER)),
282 ));
283 fields.push((
284 "is_truncated",
285 FieldValue::Bool(packet.has_flags(PacketFlag::TRUNCATION)),
286 ));
287 fields.push((
288 "recursion_desired",
289 FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_DESIRED)),
290 ));
291 fields.push((
292 "recursion_available",
293 FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_AVAILABLE)),
294 ));
295 fields.push((
296 "response_code",
297 FieldValue::UInt8(rcode_to_u8(packet.rcode())),
298 ));
299 fields.push((
300 "query_count",
301 FieldValue::UInt16(packet.questions.len() as u16),
302 ));
303 fields.push((
304 "answer_count",
305 FieldValue::UInt16(packet.answers.len() as u16),
306 ));
307 fields.push((
308 "authority_count",
309 FieldValue::UInt16(packet.name_servers.len() as u16),
310 ));
311 fields.push((
312 "additional_count",
313 FieldValue::UInt16(packet.additional_records.len() as u16),
314 ));
315}
316
317fn extract_header_fields_projected(
319 packet: &Packet,
320 fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
321 requested: &HashSet<String>,
322) {
323 if requested.contains("transaction_id") {
324 fields.push(("transaction_id", FieldValue::UInt16(packet.id())));
325 }
326 if requested.contains("is_query") {
327 fields.push((
328 "is_query",
329 FieldValue::Bool(!packet.has_flags(PacketFlag::RESPONSE)),
330 ));
331 }
332 if requested.contains("opcode") {
333 fields.push(("opcode", FieldValue::UInt8(opcode_to_u8(packet.opcode()))));
334 }
335 if requested.contains("is_authoritative") {
336 fields.push((
337 "is_authoritative",
338 FieldValue::Bool(packet.has_flags(PacketFlag::AUTHORITATIVE_ANSWER)),
339 ));
340 }
341 if requested.contains("is_truncated") {
342 fields.push((
343 "is_truncated",
344 FieldValue::Bool(packet.has_flags(PacketFlag::TRUNCATION)),
345 ));
346 }
347 if requested.contains("recursion_desired") {
348 fields.push((
349 "recursion_desired",
350 FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_DESIRED)),
351 ));
352 }
353 if requested.contains("recursion_available") {
354 fields.push((
355 "recursion_available",
356 FieldValue::Bool(packet.has_flags(PacketFlag::RECURSION_AVAILABLE)),
357 ));
358 }
359 if requested.contains("response_code") {
360 fields.push((
361 "response_code",
362 FieldValue::UInt8(rcode_to_u8(packet.rcode())),
363 ));
364 }
365 if requested.contains("query_count") {
366 fields.push((
367 "query_count",
368 FieldValue::UInt16(packet.questions.len() as u16),
369 ));
370 }
371 if requested.contains("answer_count") {
372 fields.push((
373 "answer_count",
374 FieldValue::UInt16(packet.answers.len() as u16),
375 ));
376 }
377 if requested.contains("authority_count") {
378 fields.push((
379 "authority_count",
380 FieldValue::UInt16(packet.name_servers.len() as u16),
381 ));
382 }
383 if requested.contains("additional_count") {
384 fields.push((
385 "additional_count",
386 FieldValue::UInt16(packet.additional_records.len() as u16),
387 ));
388 }
389}
390
391fn extract_question_fields(
393 packet: &Packet,
394 fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
395) {
396 if let Some(question) = packet.questions.first() {
397 fields.push((
398 "query_name",
399 FieldValue::OwnedString(CompactString::new(question.qname.to_string())),
400 ));
401 fields.push(("query_type", FieldValue::UInt16(question.qtype.into())));
402 fields.push(("query_class", FieldValue::UInt16(question.qclass.into())));
403 } else {
404 fields.push(("query_name", FieldValue::Null));
405 fields.push(("query_type", FieldValue::Null));
406 fields.push(("query_class", FieldValue::Null));
407 }
408}
409
410fn extract_question_fields_projected(
412 packet: &Packet,
413 fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
414 requested: &HashSet<String>,
415) {
416 if let Some(question) = packet.questions.first() {
417 if requested.contains("query_name") {
418 fields.push((
419 "query_name",
420 FieldValue::OwnedString(CompactString::new(question.qname.to_string())),
421 ));
422 }
423 if requested.contains("query_type") {
424 fields.push(("query_type", FieldValue::UInt16(question.qtype.into())));
425 }
426 if requested.contains("query_class") {
427 fields.push(("query_class", FieldValue::UInt16(question.qclass.into())));
428 }
429 } else {
430 if requested.contains("query_name") {
431 fields.push(("query_name", FieldValue::Null));
432 }
433 if requested.contains("query_type") {
434 fields.push(("query_type", FieldValue::Null));
435 }
436 if requested.contains("query_class") {
437 fields.push(("query_class", FieldValue::Null));
438 }
439 }
440}
441
442fn extract_answer_fields(packet: &Packet, fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
444 let mut ip4s: Vec<FieldValue> = Vec::new();
445 let mut ip6s: Vec<FieldValue> = Vec::new();
446 let mut cnames: Vec<FieldValue> = Vec::new();
447 let mut types: Vec<FieldValue> = Vec::new();
448 let mut ttls: Vec<FieldValue> = Vec::new();
449
450 for answer in &packet.answers {
451 let rtype: u16 = answer.rdata.type_code().into();
453 types.push(FieldValue::UInt16(rtype));
454
455 ttls.push(FieldValue::UInt32(answer.ttl));
457
458 match &answer.rdata {
460 RData::A(a) => {
461 ip4s.push(FieldValue::UInt32(a.address));
462 }
463 RData::AAAA(aaaa) => {
464 ip6s.push(FieldValue::IpAddr(std::net::IpAddr::V6(
466 std::net::Ipv6Addr::from(aaaa.address),
467 )));
468 }
469 RData::CNAME(cname) => {
470 cnames.push(FieldValue::OwnedString(CompactString::new(
471 cname.0.to_string(),
472 )));
473 }
474 _ => {}
475 }
476 }
477
478 fields.push(("answer_ip4s", FieldValue::List(ip4s)));
480 fields.push(("answer_ip6s", FieldValue::List(ip6s)));
481 fields.push(("answer_cnames", FieldValue::List(cnames)));
482 fields.push(("answer_types", FieldValue::List(types)));
483 fields.push(("answer_ttls", FieldValue::List(ttls)));
484}
485
486fn extract_answer_fields_projected(
488 packet: &Packet,
489 fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
490 requested: &HashSet<String>,
491) {
492 let need_ip4s = requested.contains("answer_ip4s");
493 let need_ip6s = requested.contains("answer_ip6s");
494 let need_cnames = requested.contains("answer_cnames");
495 let need_types = requested.contains("answer_types");
496 let need_ttls = requested.contains("answer_ttls");
497
498 let mut ip4s: Vec<FieldValue> = Vec::new();
499 let mut ip6s: Vec<FieldValue> = Vec::new();
500 let mut cnames: Vec<FieldValue> = Vec::new();
501 let mut types: Vec<FieldValue> = Vec::new();
502 let mut ttls: Vec<FieldValue> = Vec::new();
503
504 for answer in &packet.answers {
505 if need_types {
506 let rtype: u16 = answer.rdata.type_code().into();
507 types.push(FieldValue::UInt16(rtype));
508 }
509
510 if need_ttls {
511 ttls.push(FieldValue::UInt32(answer.ttl));
512 }
513
514 match &answer.rdata {
515 RData::A(a) if need_ip4s => {
516 ip4s.push(FieldValue::UInt32(a.address));
517 }
518 RData::AAAA(aaaa) if need_ip6s => {
519 ip6s.push(FieldValue::IpAddr(std::net::IpAddr::V6(
521 std::net::Ipv6Addr::from(aaaa.address),
522 )));
523 }
524 RData::CNAME(cname) if need_cnames => {
525 cnames.push(FieldValue::OwnedString(CompactString::new(
526 cname.0.to_string(),
527 )));
528 }
529 _ => {}
530 }
531 }
532
533 if need_ip4s {
534 fields.push(("answer_ip4s", FieldValue::List(ip4s)));
535 }
536 if need_ip6s {
537 fields.push(("answer_ip6s", FieldValue::List(ip6s)));
538 }
539 if need_cnames {
540 fields.push(("answer_cnames", FieldValue::List(cnames)));
541 }
542 if need_types {
543 fields.push(("answer_types", FieldValue::List(types)));
544 }
545 if need_ttls {
546 fields.push(("answer_ttls", FieldValue::List(ttls)));
547 }
548}
549
550fn extract_edns_fields(packet: &Packet, fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
552 if let Some(opt) = packet.opt() {
553 fields.push(("has_edns", FieldValue::Bool(true)));
554 fields.push(("edns_udp_size", FieldValue::UInt16(opt.udp_packet_size)));
555 } else {
556 fields.push(("has_edns", FieldValue::Bool(false)));
557 fields.push(("edns_udp_size", FieldValue::Null));
558 }
559}
560
561fn extract_edns_fields_projected(
563 packet: &Packet,
564 fields: &mut SmallVec<[(&'static str, FieldValue); 16]>,
565 requested: &HashSet<String>,
566) {
567 let has_edns = packet.opt().is_some();
568
569 if requested.contains("has_edns") {
570 fields.push(("has_edns", FieldValue::Bool(has_edns)));
571 }
572
573 if requested.contains("edns_udp_size") {
574 if let Some(opt) = packet.opt() {
575 fields.push(("edns_udp_size", FieldValue::UInt16(opt.udp_packet_size)));
576 } else {
577 fields.push(("edns_udp_size", FieldValue::Null));
578 }
579 }
580}
581
582#[allow(dead_code)]
584pub mod record_type {
585 pub const A: u16 = 1;
586 pub const NS: u16 = 2;
587 pub const CNAME: u16 = 5;
588 pub const SOA: u16 = 6;
589 pub const PTR: u16 = 12;
590 pub const MX: u16 = 15;
591 pub const TXT: u16 = 16;
592 pub const AAAA: u16 = 28;
593 pub const SRV: u16 = 33;
594 pub const OPT: u16 = 41;
595 pub const ANY: u16 = 255;
596}
597
598#[allow(dead_code)]
600pub mod rcode {
601 pub const NOERROR: u8 = 0;
603 pub const FORMERR: u8 = 1;
604 pub const SERVFAIL: u8 = 2;
605 pub const NXDOMAIN: u8 = 3;
606 pub const NOTIMP: u8 = 4;
607 pub const REFUSED: u8 = 5;
608
609 pub const YXDOMAIN: u8 = 6;
611 pub const YXRRSET: u8 = 7;
612 pub const NXRRSET: u8 = 8;
613 pub const NOTAUTH: u8 = 9;
614 pub const NOTZONE: u8 = 10;
615
616 pub const DSOTYPENI: u8 = 11;
618
619 pub const BADVERS: u8 = 16;
621 pub const BADKEY: u8 = 17;
622 pub const BADTIME: u8 = 18;
623 pub const BADMODE: u8 = 19;
624 pub const BADNAME: u8 = 20;
625 pub const BADALG: u8 = 21;
626 pub const BADTRUNC: u8 = 22;
627 pub const BADCOOKIE: u8 = 23;
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use std::net::Ipv4Addr;
634
635 fn encode_domain_name(name: &str) -> Vec<u8> {
637 let mut result = Vec::new();
638 for part in name.split('.') {
639 if !part.is_empty() {
640 result.push(part.len() as u8);
641 result.extend_from_slice(part.as_bytes());
642 }
643 }
644 result.push(0); result
646 }
647
648 fn create_dns_query(transaction_id: u16, query_name: &[u8]) -> Vec<u8> {
650 let mut packet = Vec::new();
651
652 packet.extend_from_slice(&transaction_id.to_be_bytes());
654
655 packet.extend_from_slice(&[0x01, 0x00]);
657
658 packet.extend_from_slice(&[0x00, 0x01]);
660
661 packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
663
664 packet.extend_from_slice(query_name);
666
667 packet.extend_from_slice(&[0x00, 0x01]);
669
670 packet.extend_from_slice(&[0x00, 0x01]);
672
673 packet
674 }
675
676 fn create_dns_response_with_answer(transaction_id: u16, ip: [u8; 4]) -> Vec<u8> {
678 let mut packet = Vec::new();
679
680 packet.extend_from_slice(&transaction_id.to_be_bytes());
682
683 packet.extend_from_slice(&[0x81, 0x80]);
685
686 packet.extend_from_slice(&[0x00, 0x01]);
688
689 packet.extend_from_slice(&[0x00, 0x01]);
691
692 packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
694
695 packet.extend_from_slice(&[
697 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, ]);
701
702 packet.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]);
704
705 packet.extend_from_slice(&[0xC0, 0x0C]);
708
709 packet.extend_from_slice(&[0x00, 0x01]);
711
712 packet.extend_from_slice(&[0x00, 0x01]);
714
715 packet.extend_from_slice(&[0x00, 0x00, 0x01, 0x2C]);
717
718 packet.extend_from_slice(&[0x00, 0x04]);
720
721 packet.extend_from_slice(&ip);
723
724 packet
725 }
726
727 #[test]
728 fn test_parse_dns_query() {
729 let query_name = encode_domain_name("example.com");
730 let packet = create_dns_query(0x1234, &query_name);
731
732 let parser = DnsProtocol;
733 let mut context = ParseContext::new(1);
734 context.insert_hint("dst_port", 53);
735 context.parent_protocol = Some("udp");
736
737 let result = parser.parse(&packet, &context);
738
739 assert!(result.is_ok());
740 assert_eq!(
741 result.get("transaction_id"),
742 Some(&FieldValue::UInt16(0x1234))
743 );
744 assert_eq!(result.get("is_query"), Some(&FieldValue::Bool(true)));
745 assert_eq!(result.get("query_count"), Some(&FieldValue::UInt16(1)));
746 assert_eq!(result.get("answer_count"), Some(&FieldValue::UInt16(0)));
747 assert_eq!(
748 result.get("recursion_desired"),
749 Some(&FieldValue::Bool(true))
750 );
751 assert_eq!(
752 result.get("query_name"),
753 Some(&FieldValue::OwnedString(CompactString::new("example.com")))
754 );
755 assert_eq!(result.get("query_type"), Some(&FieldValue::UInt16(1))); assert_eq!(result.get("query_class"), Some(&FieldValue::UInt16(1))); }
758
759 #[test]
760 fn test_parse_dns_response_with_answer() {
761 let packet = create_dns_response_with_answer(0xABCD, [93, 184, 216, 34]);
762
763 let parser = DnsProtocol;
764 let mut context = ParseContext::new(1);
765 context.insert_hint("src_port", 53);
766 context.parent_protocol = Some("udp");
767
768 let result = parser.parse(&packet, &context);
769
770 assert!(result.is_ok());
771 assert_eq!(
772 result.get("transaction_id"),
773 Some(&FieldValue::UInt16(0xABCD))
774 );
775 assert_eq!(result.get("is_query"), Some(&FieldValue::Bool(false)));
776 assert_eq!(result.get("answer_count"), Some(&FieldValue::UInt16(1)));
777 assert_eq!(result.get("response_code"), Some(&FieldValue::UInt8(0)));
778
779 if let Some(FieldValue::List(ip4s)) = result.get("answer_ip4s") {
781 assert_eq!(ip4s.len(), 1);
782 let expected_ip = u32::from(Ipv4Addr::new(93, 184, 216, 34));
784 assert_eq!(ip4s[0], FieldValue::UInt32(expected_ip));
785 } else {
786 panic!("Expected answer_ip4s to be a list");
787 }
788
789 if let Some(FieldValue::List(types)) = result.get("answer_types") {
791 assert_eq!(types.len(), 1);
792 assert_eq!(types[0], FieldValue::UInt16(1)); } else {
794 panic!("Expected answer_types to be a list");
795 }
796
797 if let Some(FieldValue::List(ttls)) = result.get("answer_ttls") {
799 assert_eq!(ttls.len(), 1);
800 assert_eq!(ttls[0], FieldValue::UInt32(300)); } else {
802 panic!("Expected answer_ttls to be a list");
803 }
804 }
805
806 #[test]
807 fn test_can_parse_dns() {
808 let parser = DnsProtocol;
809
810 let ctx1 = ParseContext::new(1);
812 assert!(parser.can_parse(&ctx1).is_none());
813
814 let mut ctx2 = ParseContext::new(1);
816 ctx2.insert_hint("dst_port", 53);
817 assert!(parser.can_parse(&ctx2).is_some());
818
819 let mut ctx3 = ParseContext::new(1);
821 ctx3.insert_hint("src_port", 53);
822 assert!(parser.can_parse(&ctx3).is_some());
823
824 let mut ctx4 = ParseContext::new(1);
826 ctx4.insert_hint("dst_port", 80);
827 assert!(parser.can_parse(&ctx4).is_none());
828 }
829
830 #[test]
831 fn test_parse_dns_too_short() {
832 let short_packet = [0x12, 0x34, 0x00, 0x00]; let parser = DnsProtocol;
835 let mut context = ParseContext::new(1);
836 context.insert_hint("dst_port", 53);
837
838 let result = parser.parse(&short_packet, &context);
839
840 assert!(!result.is_ok());
841 assert!(result.error.is_some());
842 }
843
844 #[test]
845 fn test_dns_schema_fields() {
846 let parser = DnsProtocol;
847 let fields = parser.schema_fields();
848
849 assert!(!fields.is_empty());
850
851 let field_names: Vec<&str> = fields.iter().map(|f| f.name).collect();
852 assert!(field_names.contains(&"dns.transaction_id"));
853 assert!(field_names.contains(&"dns.is_query"));
854 assert!(field_names.contains(&"dns.query_name"));
855 assert!(field_names.contains(&"dns.query_type"));
856 assert!(field_names.contains(&"dns.answer_ip4s"));
858 assert!(field_names.contains(&"dns.answer_ip6s"));
859 assert!(field_names.contains(&"dns.answer_cnames"));
860 assert!(field_names.contains(&"dns.has_edns"));
861 }
862
863 #[test]
864 fn test_dns_projected_header_only() {
865 let query_name = encode_domain_name("example.com");
866 let packet = create_dns_query(0x1234, &query_name);
867
868 let parser = DnsProtocol;
869 let mut context = ParseContext::new(1);
870 context.insert_hint("dst_port", 53);
871
872 let fields: HashSet<String> = ["transaction_id", "is_query", "response_code"]
874 .iter()
875 .map(|s| s.to_string())
876 .collect();
877 let result = parser.parse_projected(&packet, &context, Some(&fields));
878
879 assert!(result.is_ok());
880 assert_eq!(
882 result.get("transaction_id"),
883 Some(&FieldValue::UInt16(0x1234))
884 );
885 assert_eq!(result.get("is_query"), Some(&FieldValue::Bool(true)));
886 assert_eq!(result.get("response_code"), Some(&FieldValue::UInt8(0)));
887 assert!(result.get("query_name").is_none());
889 assert!(result.get("answer_ip4s").is_none());
890 }
891
892 #[test]
893 fn test_dns_aaaa_query() {
894 let mut packet = Vec::new();
895
896 packet.extend_from_slice(&[0x12, 0x34]);
898
899 packet.extend_from_slice(&[0x01, 0x00]);
901
902 packet.extend_from_slice(&[0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
904
905 packet.extend_from_slice(&[
907 0x04, b'i', b'p', b'v', b'6', 0x06, b'g', b'o', b'o', b'g', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, ]);
912
913 packet.extend_from_slice(&[0x00, 0x1C]);
915
916 packet.extend_from_slice(&[0x00, 0x01]);
918
919 let parser = DnsProtocol;
920 let mut context = ParseContext::new(1);
921 context.insert_hint("dst_port", 53);
922
923 let result = parser.parse(&packet, &context);
924
925 assert!(result.is_ok());
926 assert_eq!(
927 result.get("query_name"),
928 Some(&FieldValue::OwnedString(CompactString::new(
929 "ipv6.google.com"
930 )))
931 );
932 assert_eq!(
933 result.get("query_type"),
934 Some(&FieldValue::UInt16(record_type::AAAA))
935 );
936 }
937}