1use thiserror::Error;
2
3use crate::{
4 context::{DeserializeContext, SerializeContext},
5 Header, Name, Question, Record, TsigData, Type, TypeData,
6};
7
8#[derive(Default, Clone, Debug)]
9pub struct Packet {
10 pub header: Header,
11 pub questions: Vec<Question>,
12 pub answers: Vec<Record>,
13 pub nameservers: Vec<Record>,
14 pub additional_records: Vec<Record>,
15}
16
17#[derive(Error, Debug)]
18pub enum PacketParseError {
19 #[error("the packet header was truncated")]
20 HeaderTruncated,
21 #[error("the packet was truncated")]
22 Truncated,
23 #[error("the header was invalid")]
24 InvalidHeader,
25 #[error("unexpected EOF")]
26 UnexpectedEOF,
27 #[error("corrupt name, invalid label tag, length, or ptr")]
28 CorruptName,
29 #[error("invalid UTF8 in name: {0}")]
30 UTF8Error(#[from] std::str::Utf8Error),
31 #[error("invalid record bytes")]
32 CorruptRecord,
33}
34
35pub struct ValidatableTsig<'a> {
36 pub name: Name,
37 pub data: TsigData,
38 pub hmac_slice: &'a [u8],
39}
40
41impl Packet {
42 pub fn parse(bytes: &[u8]) -> Result<(Packet, Option<ValidatableTsig<'_>>), PacketParseError> {
43 if bytes.len() < Header::LENGTH {
44 return Err(PacketParseError::HeaderTruncated);
45 }
46 let header = Header::parse(bytes[..Header::LENGTH].try_into().unwrap());
47 if !header.validate() {
48 return Err(PacketParseError::InvalidHeader);
49 }
50 if header.is_truncated {
51 return Err(PacketParseError::Truncated);
52 }
53 let mut packet = Packet {
54 questions: Vec::with_capacity(header.question_count as usize),
55 answers: Vec::with_capacity(header.answer_count as usize),
56 nameservers: Vec::with_capacity(header.nameserver_count as usize),
57 additional_records: Vec::with_capacity(header.additional_record_count as usize),
58 header,
59 };
60 let mut context = DeserializeContext::new_post_header(bytes);
61 for _ in 0..packet.header.question_count {
62 packet.questions.push(Question::parse(&mut context)?);
63 }
64 for _ in 0..packet.header.answer_count {
65 packet.answers.push(Record::parse(&mut context)?);
66 }
67 for _ in 0..packet.header.nameserver_count {
68 packet.nameservers.push(Record::parse(&mut context)?);
69 }
70 let mut tsig = None;
71 for i in 0..packet.header.additional_record_count {
72 let index = context.index();
73 let record = Record::parse(&mut context)?;
74 if i == packet.header.additional_record_count - 1 && record.type_ == Type::TSIG {
75 let data = match record.data {
76 TypeData::TSIG(data) => data,
77 _ => unreachable!(),
78 };
79 tsig = Some(ValidatableTsig {
80 name: record.name,
81 data,
82 hmac_slice: &bytes[..index],
83 });
84 continue;
85 }
86 packet.additional_records.push(record);
87 }
88
89 Ok((packet, tsig))
90 }
91
92 pub(crate) fn serialize_open(&self) -> (Header, SerializeContext) {
93 let mut context = SerializeContext::default();
94
95 let mut header = self.header.clone();
96 header.question_count = self.questions.len().try_into().unwrap();
97 header.answer_count = self.answers.len().try_into().unwrap();
98 header.nameserver_count = self.nameservers.len().try_into().unwrap();
99 header.additional_record_count = self.additional_records.len().try_into().unwrap();
100 context.write_blob(header.to_bytes());
101
102 for question in &self.questions {
103 question.serialize(&mut context);
104 }
105 for record in &self.answers {
106 record.serialize(&mut context);
107 }
108 for record in &self.nameservers {
109 record.serialize(&mut context);
110 }
111 for record in &self.additional_records {
112 record.serialize(&mut context);
113 }
114
115 (header, context)
116 }
117
118 pub fn serialize(&self, max_size: usize) -> Vec<u8> {
119 let (mut header, context) = self.serialize_open();
120
121 let mut out = context.finalize();
122 if out.len() > max_size {
123 out.truncate(max_size);
124 header.is_truncated = true;
125 out[..Header::LENGTH].copy_from_slice(&header.to_bytes());
126 }
127 out
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::{test_data::*, Class, Type, TypeData};
135
136 #[test]
137 fn test_packet_parse() {
138 let packet = Packet::parse(&DNS_QUERY).unwrap().0;
139 assert_eq!(packet.questions.len(), 1);
140
141 let question = packet.questions.first().unwrap();
142 assert_eq!(question.name.as_ref(), "google.com");
143 assert_eq!(question.type_, Type::A);
144 assert_eq!(question.class, Class::IN);
145
146 assert_eq!(&DNS_QUERY[..], &packet.serialize(512));
147
148 let packet = Packet::parse(&DNS_RESPONSE).unwrap().0;
149 assert_eq!(packet.questions.len(), 1);
150 assert_eq!(packet.answers.len(), 1);
151
152 let question = packet.questions.first().unwrap();
153 assert_eq!(question.name.as_ref(), "google.com");
154 assert_eq!(question.type_, Type::A);
155 assert_eq!(question.class, Class::IN);
156
157 let answer = packet.answers.first().unwrap();
158 assert_eq!(answer.name.as_ref(), "google.com");
159 assert_eq!(answer.type_, Type::A);
160 assert_eq!(answer.class, Class::IN);
161 assert_eq!(answer.data, TypeData::A("142.250.189.174".parse().unwrap()));
162
163 assert_eq!(&DNS_RESPONSE[..], &packet.serialize(512));
164 }
165}