simple_dns/dns/
packet.rs

1use super::{Header, PacketFlag, Question, ResourceRecord, WireFormat, OPCODE};
2use crate::{
3    bytes_buffer::BytesBuffer,
4    lib::{Seek, Vec, Write},
5    rdata::OPT,
6    RCODE,
7};
8
9/// Represents a DNS message packet
10///
11/// When working with EDNS packets, use [Packet::opt] and [Packet::opt_mut] to add or access [OPT] packet information
12#[derive(Debug, Clone)]
13pub struct Packet<'a> {
14    /// Packet header
15    header: Header<'a>,
16    /// Questions section
17    pub questions: Vec<Question<'a>>,
18    /// Answers section
19    pub answers: Vec<ResourceRecord<'a>>,
20    /// Name servers section
21    pub name_servers: Vec<ResourceRecord<'a>>,
22    /// Aditional records section.  
23    /// DO NOT use this field to add OPT record, use [`Packet::opt_mut`] instead
24    pub additional_records: Vec<ResourceRecord<'a>>,
25}
26
27impl<'a> Packet<'a> {
28    /// Creates a new empty packet with a query header
29    pub fn new_query(id: u16) -> Self {
30        Self {
31            header: Header::new_query(id),
32            questions: Vec::new(),
33            answers: Vec::new(),
34            name_servers: Vec::new(),
35            additional_records: Vec::new(),
36        }
37    }
38
39    /// Creates a new empty packet with a reply header
40    pub fn new_reply(id: u16) -> Self {
41        Self {
42            header: Header::new_reply(id, OPCODE::StandardQuery),
43            questions: Vec::new(),
44            answers: Vec::new(),
45            name_servers: Vec::new(),
46            additional_records: Vec::new(),
47        }
48    }
49
50    /// Get packet id
51    pub fn id(&self) -> u16 {
52        self.header.id
53    }
54
55    /// Set packet id
56    pub fn set_id(&mut self, id: u16) {
57        self.header.id = id;
58    }
59
60    /// Set flags in the packet
61    pub fn set_flags(&mut self, flags: PacketFlag) {
62        self.header.set_flags(flags);
63    }
64
65    /// Remove flags present in the packet
66    pub fn remove_flags(&mut self, flags: PacketFlag) {
67        self.header.remove_flags(flags)
68    }
69
70    /// Check if the packet has flags set
71    pub fn has_flags(&self, flags: PacketFlag) -> bool {
72        self.header.has_flags(flags)
73    }
74
75    /// Get this packet [RCODE] information
76    pub fn rcode(&self) -> RCODE {
77        self.header.response_code
78    }
79
80    /// Get a mutable reference for  this packet [RCODE] information
81    /// Warning, if the [RCODE] value is greater than 15 (4 bits), you MUST provide an [OPT]
82    /// resource record through the [Packet::opt_mut] function
83    pub fn rcode_mut(&mut self) -> &mut RCODE {
84        &mut self.header.response_code
85    }
86
87    /// Get this packet [OPCODE] information
88    pub fn opcode(&self) -> OPCODE {
89        self.header.opcode
90    }
91
92    /// Get a mutable reference for this packet [OPCODE] information
93    pub fn opcode_mut(&mut self) -> &mut OPCODE {
94        &mut self.header.opcode
95    }
96
97    /// Get the [OPT] resource record for this packet, if present
98    pub fn opt(&self) -> Option<&OPT<'a>> {
99        self.header.opt.as_ref()
100    }
101
102    /// Get a mutable reference for this packet [OPT] resource record.  
103    pub fn opt_mut(&mut self) -> &mut Option<OPT<'a>> {
104        &mut self.header.opt
105    }
106
107    /// Changes this packet into a reply packet by replacing its header
108    pub fn into_reply(mut self) -> Self {
109        self.header = Header::new_reply(self.header.id, self.header.opcode);
110        self
111    }
112
113    /// Parses a packet from a slice of bytes
114    pub fn parse(data: &'a [u8]) -> crate::Result<Self> {
115        let mut data = BytesBuffer::new(data);
116        let mut header = Header::parse(&mut data)?;
117
118        let questions = Self::parse_section(&mut data, header.questions)?;
119        let answers = Self::parse_section(&mut data, header.answers)?;
120        let name_servers = Self::parse_section(&mut data, header.name_servers)?;
121        let mut additional_records: Vec<ResourceRecord> =
122            Self::parse_section(&mut data, header.additional_records)?;
123
124        header.extract_info_from_opt_rr(
125            additional_records
126                .iter()
127                .position(|rr| rr.rdata.type_code() == crate::TYPE::OPT)
128                .map(|i| additional_records.remove(i)),
129        );
130
131        Ok(Self {
132            header,
133            questions,
134            answers,
135            name_servers,
136            additional_records,
137        })
138    }
139
140    fn parse_section<T: WireFormat<'a>>(
141        data: &mut BytesBuffer<'a>,
142        items_count: u16,
143    ) -> crate::Result<Vec<T>> {
144        let mut section_items = Vec::with_capacity(items_count as usize);
145
146        for _ in 0..items_count {
147            section_items.push(T::parse(data)?);
148        }
149
150        Ok(section_items)
151    }
152
153    /// Creates a new [Vec`<u8>`](`Vec<T>`) and write the contents of this package in wire format
154    ///
155    /// This call will allocate a `Vec<u8>` of 900 bytes, which is enough for a jumbo UDP packet
156    pub fn build_bytes_vec(&self) -> crate::Result<Vec<u8>> {
157        let mut out = Vec::with_capacity(900);
158        self.write_to(&mut out)?;
159
160        Ok(out)
161    }
162
163    /// Creates a new [Vec`<u8>`](`Vec<T>`) and write the contents of this package in wire format
164    /// with compression enabled
165    ///
166    /// This call will allocate a `Vec<u8>` of 900 bytes, which is enough for a jumbo UDP packet
167    pub fn build_bytes_vec_compressed(&self) -> crate::Result<Vec<u8>> {
168        let mut out = crate::lib::Cursor::new(Vec::with_capacity(900));
169        self.write_compressed_to(&mut out)?;
170
171        Ok(out.into_inner())
172    }
173
174    /// Write the contents of this package in wire format into the provided writer
175    pub fn write_to<T: Write>(&self, out: &mut T) -> crate::Result<()> {
176        self.write_header(out)?;
177
178        for e in &self.questions {
179            e.write_to(out)?;
180        }
181        for e in &self.answers {
182            e.write_to(out)?;
183        }
184        for e in &self.name_servers {
185            e.write_to(out)?;
186        }
187
188        if let Some(rr) = self.header.opt_rr() {
189            rr.write_to(out)?;
190        }
191
192        for e in &self.additional_records {
193            e.write_to(out)?;
194        }
195
196        out.flush()?;
197        Ok(())
198    }
199
200    /// Write the contents of this package in wire format with enabled compression into the provided writer
201    pub fn write_compressed_to<T: Write + Seek>(&self, out: &mut T) -> crate::Result<()> {
202        self.write_header(out)?;
203
204        let mut name_refs = Default::default();
205        for e in &self.questions {
206            e.write_compressed_to(out, &mut name_refs)?;
207        }
208        for e in &self.answers {
209            e.write_compressed_to(out, &mut name_refs)?;
210        }
211        for e in &self.name_servers {
212            e.write_compressed_to(out, &mut name_refs)?;
213        }
214
215        if let Some(rr) = self.header.opt_rr() {
216            rr.write_to(out)?;
217        }
218
219        for e in &self.additional_records {
220            e.write_compressed_to(out, &mut name_refs)?;
221        }
222        out.flush()?;
223
224        Ok(())
225    }
226
227    fn write_header<T: Write>(&self, out: &mut T) -> crate::Result<()> {
228        self.header.write_to(
229            out,
230            self.questions.len() as u16,
231            self.answers.len() as u16,
232            self.name_servers.len() as u16,
233            self.additional_records.len() as u16 + u16::from(self.header.opt.is_some()),
234        )
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::{dns::CLASS, dns::TYPE, lib::ToString, SimpleDnsError};
241
242    use super::*;
243
244    #[test]
245    fn parse_without_data_should_not_panic() {
246        assert!(matches!(
247            Packet::parse(&[]),
248            Err(SimpleDnsError::InsufficientData)
249        ));
250    }
251
252    #[test]
253    fn build_query_correct() {
254        let mut query = Packet::new_query(1);
255        query.questions.push(Question::new(
256            "_srv._udp.local".try_into().unwrap(),
257            TYPE::TXT.into(),
258            CLASS::IN.into(),
259            false,
260        ));
261        query.questions.push(Question::new(
262            "_srv2._udp.local".try_into().unwrap(),
263            TYPE::TXT.into(),
264            CLASS::IN.into(),
265            false,
266        ));
267
268        let query = query.build_bytes_vec().unwrap();
269
270        let parsed = Packet::parse(&query);
271        assert!(parsed.is_ok());
272
273        let parsed = parsed.unwrap();
274        assert_eq!(2, parsed.questions.len());
275        assert_eq!("_srv._udp.local", parsed.questions[0].qname.to_string());
276        assert_eq!("_srv2._udp.local", parsed.questions[1].qname.to_string());
277    }
278}