adns_proto/
packet.rs

1use thiserror::Error;
2
3use crate::{
4    context::{DeserializeContext, SerializeContext},
5    Header, Name, Question, Record, TsigData, Type, TypeData,
6};
7
8#[derive(Default, Clone, Debug)]
9pub struct Packet {
10    pub header: Header,
11    pub questions: Vec<Question>,
12    pub answers: Vec<Record>,
13    pub nameservers: Vec<Record>,
14    pub additional_records: Vec<Record>,
15}
16
17#[derive(Error, Debug)]
18pub enum PacketParseError {
19    #[error("the packet header was truncated")]
20    HeaderTruncated,
21    #[error("the packet was truncated")]
22    Truncated,
23    #[error("the header was invalid")]
24    InvalidHeader,
25    #[error("unexpected EOF")]
26    UnexpectedEOF,
27    #[error("corrupt name, invalid label tag, length, or ptr")]
28    CorruptName,
29    #[error("invalid UTF8 in name: {0}")]
30    UTF8Error(#[from] std::str::Utf8Error),
31    #[error("invalid record bytes")]
32    CorruptRecord,
33}
34
35pub struct ValidatableTsig<'a> {
36    pub name: Name,
37    pub data: TsigData,
38    pub hmac_slice: &'a [u8],
39}
40
41impl Packet {
42    pub fn parse(bytes: &[u8]) -> Result<(Packet, Option<ValidatableTsig<'_>>), PacketParseError> {
43        if bytes.len() < Header::LENGTH {
44            return Err(PacketParseError::HeaderTruncated);
45        }
46        let header = Header::parse(bytes[..Header::LENGTH].try_into().unwrap());
47        if !header.validate() {
48            return Err(PacketParseError::InvalidHeader);
49        }
50        if header.is_truncated {
51            return Err(PacketParseError::Truncated);
52        }
53        let mut packet = Packet {
54            questions: Vec::with_capacity(header.question_count as usize),
55            answers: Vec::with_capacity(header.answer_count as usize),
56            nameservers: Vec::with_capacity(header.nameserver_count as usize),
57            additional_records: Vec::with_capacity(header.additional_record_count as usize),
58            header,
59        };
60        let mut context = DeserializeContext::new_post_header(bytes);
61        for _ in 0..packet.header.question_count {
62            packet.questions.push(Question::parse(&mut context)?);
63        }
64        for _ in 0..packet.header.answer_count {
65            packet.answers.push(Record::parse(&mut context)?);
66        }
67        for _ in 0..packet.header.nameserver_count {
68            packet.nameservers.push(Record::parse(&mut context)?);
69        }
70        let mut tsig = None;
71        for i in 0..packet.header.additional_record_count {
72            let index = context.index();
73            let record = Record::parse(&mut context)?;
74            if i == packet.header.additional_record_count - 1 && record.type_ == Type::TSIG {
75                let data = match record.data {
76                    TypeData::TSIG(data) => data,
77                    _ => unreachable!(),
78                };
79                tsig = Some(ValidatableTsig {
80                    name: record.name,
81                    data,
82                    hmac_slice: &bytes[..index],
83                });
84                continue;
85            }
86            packet.additional_records.push(record);
87        }
88
89        Ok((packet, tsig))
90    }
91
92    pub(crate) fn serialize_open(&self) -> (Header, SerializeContext) {
93        let mut context = SerializeContext::default();
94
95        let mut header = self.header.clone();
96        header.question_count = self.questions.len().try_into().unwrap();
97        header.answer_count = self.answers.len().try_into().unwrap();
98        header.nameserver_count = self.nameservers.len().try_into().unwrap();
99        header.additional_record_count = self.additional_records.len().try_into().unwrap();
100        context.write_blob(header.to_bytes());
101
102        for question in &self.questions {
103            question.serialize(&mut context);
104        }
105        for record in &self.answers {
106            record.serialize(&mut context);
107        }
108        for record in &self.nameservers {
109            record.serialize(&mut context);
110        }
111        for record in &self.additional_records {
112            record.serialize(&mut context);
113        }
114
115        (header, context)
116    }
117
118    pub fn serialize(&self, max_size: usize) -> Vec<u8> {
119        let (mut header, context) = self.serialize_open();
120
121        let mut out = context.finalize();
122        if out.len() > max_size {
123            out.truncate(max_size);
124            header.is_truncated = true;
125            out[..Header::LENGTH].copy_from_slice(&header.to_bytes());
126        }
127        out
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::{test_data::*, Class, Type, TypeData};
135
136    #[test]
137    fn test_packet_parse() {
138        let packet = Packet::parse(&DNS_QUERY).unwrap().0;
139        assert_eq!(packet.questions.len(), 1);
140
141        let question = packet.questions.first().unwrap();
142        assert_eq!(question.name.as_ref(), "google.com");
143        assert_eq!(question.type_, Type::A);
144        assert_eq!(question.class, Class::IN);
145
146        assert_eq!(&DNS_QUERY[..], &packet.serialize(512));
147
148        let packet = Packet::parse(&DNS_RESPONSE).unwrap().0;
149        assert_eq!(packet.questions.len(), 1);
150        assert_eq!(packet.answers.len(), 1);
151
152        let question = packet.questions.first().unwrap();
153        assert_eq!(question.name.as_ref(), "google.com");
154        assert_eq!(question.type_, Type::A);
155        assert_eq!(question.class, Class::IN);
156
157        let answer = packet.answers.first().unwrap();
158        assert_eq!(answer.name.as_ref(), "google.com");
159        assert_eq!(answer.type_, Type::A);
160        assert_eq!(answer.class, Class::IN);
161        assert_eq!(answer.data, TypeData::A("142.250.189.174".parse().unwrap()));
162
163        assert_eq!(&DNS_RESPONSE[..], &packet.serialize(512));
164    }
165}