flex_dns/
question.rs

1use crate::{Buffer, DnsMessage, DnsMessageError, MutBuffer};
2use crate::name::DnsName;
3use crate::parse::{Parse, ParseBytes};
4use crate::write::WriteBytes;
5
6/// A DNS message questions section.
7pub struct DnsQuestions<
8    const PTR_STORAGE: usize,
9    B: Buffer,
10> {
11    message: DnsMessage<PTR_STORAGE, 0, B>,
12    remaining: usize,
13}
14
15impl<
16    const PTR_STORAGE: usize,
17    B: Buffer,
18> DnsQuestions<PTR_STORAGE, B> {
19    #[inline(always)]
20    pub(crate) fn new(message: DnsMessage<PTR_STORAGE, 0, B>) -> Self {
21        let remaining = message.header().unwrap().question_count() as usize;
22        Self {
23            message,
24            remaining,
25        }
26    }
27
28    /// Return an iterator over the question section.
29    #[inline(always)]
30    pub fn iter(&mut self) -> Result<DnsQuestionIterator, DnsMessageError> {
31        let (bytes, position) = self.message.bytes_and_position();
32
33        Ok(DnsQuestionIterator {
34            buffer: bytes,
35            current_position: position,
36            remaining: &mut self.remaining,
37        })
38    }
39
40    /// Complete the message. This will check the remaining questions and
41    /// return the message if successful.
42    #[inline(always)]
43    pub fn complete(mut self) -> Result<DnsMessage<PTR_STORAGE, 1, B>, DnsMessageError> {
44        if self.remaining != 0 {
45            for x in self.iter()? { x?; }
46        }
47
48        Ok(DnsMessage {
49            buffer: self.message.buffer,
50            position: self.message.position,
51            ptr_storage: self.message.ptr_storage,
52            ptr_len: self.message.ptr_len,
53        })
54    }
55}
56
57impl<
58    const PTR_STORAGE: usize,
59    B: MutBuffer + Buffer,
60> DnsQuestions<PTR_STORAGE, B> {
61    /// Append a question to the message. This will override the next
62    /// question or further sections, if any.
63    pub fn append(&mut self, question: DnsQuestion) -> Result<(), DnsMessageError> {
64        // Truncate the buffer to the current position.
65        self.message.truncate()?;
66        question.write(&mut self.message)?;
67        // Set question_count in the header to the current question count + 1.
68        let question_count = self.message.header().unwrap().question_count();
69        let question_count = question_count + 1 - self.remaining as u16;
70        self.message.header_mut()?.set_question_count(question_count);
71        self.message.header_mut()?.set_answer_count(0);
72        self.message.header_mut()?.set_name_server_count(0);
73        self.message.header_mut()?.set_additional_records_count(0);
74        self.remaining = 0;
75
76        Ok(())
77    }
78}
79
80/// An iterator over the questions section of a DNS message.
81pub struct DnsQuestionIterator<'a> {
82    buffer: &'a [u8],
83    current_position: &'a mut usize,
84    remaining: &'a mut usize,
85}
86
87impl<'a> Iterator for DnsQuestionIterator<'a> {
88    type Item = Result<DnsQuestion<'a>, DnsMessageError>;
89
90    #[inline]
91    fn next(&mut self) -> Option<Self::Item> {
92        if *self.remaining == 0 {
93            return None;
94        }
95
96        let question = DnsQuestion::parse(
97            self.buffer, self.current_position
98        );
99        *self.remaining -= 1;
100
101        Some(question)
102    }
103}
104
105/// A DNS message question.
106#[derive(Debug, PartialEq)]
107pub struct DnsQuestion<'a> {
108    /// The domain name being queried.
109    pub name: DnsName<'a>,
110    /// The type of the query.
111    pub qtype: DnsQType,
112    /// The class of the query.
113    pub qclass: DnsQClass,
114}
115
116impl<'a> ParseBytes<'a> for DnsQuestion<'a> {
117    fn parse_bytes(bytes: &'a [u8], i: &mut usize) -> Result<Self, DnsMessageError> {
118        let name = DnsName::parse(bytes, i)?;
119        let qtype = u16::parse(bytes, i)?.into();
120        let qclass = u16::parse(bytes, i)?.into();
121
122        Ok(Self {
123            name,
124            qtype,
125            qclass,
126        })
127    }
128}
129
130impl<'a> WriteBytes for DnsQuestion<'a> {
131    fn write<
132        const PTR_STORAGE: usize,
133        const DNS_SECTION: usize,
134        B: MutBuffer + Buffer,
135    >(&self, message: &mut DnsMessage<PTR_STORAGE, DNS_SECTION, B>) -> Result<usize, DnsMessageError> {
136        let mut bytes = 0;
137
138        bytes += self.name.write(message)?;
139        bytes += self.qtype.id().write(message)?;
140        bytes += self.qclass.id().write(message)?;
141
142        Ok(bytes)
143    }
144}
145
146/// The kind of a DNS query.
147///
148/// According to [RFC 1035 Section 3.2.2](https://tools.ietf.org/rfc/rfc1035#section-3.2.2)
149/// and [RFC 1035 Section 3.2.3](https://tools.ietf.org/rfc/rfc1035#section-3.2.3).
150#[derive(Copy, Clone, Debug, PartialEq)]
151#[repr(u16)]
152pub enum DnsQType {
153    A = 1,
154    NS = 2,
155    CNAME = 5,
156    SOA = 6,
157    PTR = 12,
158    HINFO = 13,
159    MX = 15,
160    TXT = 16,
161    RP = 17,
162    AFSDB = 18,
163    SIG = 24,
164    KEY = 25,
165    AAAA = 28,
166    LOC = 29,
167    SRV = 33,
168    NAPTR = 35,
169    KX = 36,
170    CERT = 37,
171    DNAME = 39,
172    OPT = 41,
173    APL = 42,
174    DS = 43,
175    SSHFP = 44,
176    IPSECKEY = 45,
177    RRSIG = 46,
178    NSEC = 47,
179    DNSKEY = 48,
180    DHCID = 49,
181    NSEC3 = 50,
182    NSEC3PARAM = 51,
183    TLSA = 52,
184    SMIMEA = 53,
185    HIP = 55,
186    CDS = 59,
187    CDNSKEY = 60,
188    OPENPGPKEY = 61,
189    CSYNC = 62,
190    ZONEMD = 63,
191    SVCB = 64,
192    HTTPS = 65,
193    EUI48 = 108,
194    EUI64 = 109,
195    TKEY = 249,
196    TSIG = 250,
197    IXFR = 251,
198    AXFR = 252,
199    ALL = 255,
200    URI = 256,
201    CAA = 257,
202    TA = 32768,
203    DLV = 32769,
204    Reserved,
205}
206
207impl DnsQType {
208    /// Create a new QType from an ID.
209    #[inline(always)]
210    pub fn from_id(id: u16) -> Self {
211        match id {
212            1..=2 | 5..=6 | 12..=13 | 15..=18 | 24..=25 | 28..=29
213            | 33 | 35..=37 | 39 | 41..=53 | 55 | 59..=65
214            | 108..=109 | 249..=252 | 255 | 256 | 257
215            | 32768..=32769 => unsafe {
216                core::mem::transmute(id)
217            },
218            _ => DnsQType::Reserved,
219        }
220    }
221
222    /// Get the ID of the QType.
223    #[inline(always)]
224    pub fn id(&self) -> u16 {
225        match self {
226            DnsQType::Reserved => panic!("Reserved QType"),
227            _ => *self as u16,
228        }
229    }
230}
231
232impl From<DnsQType> for u16 {
233    #[inline(always)]
234    fn from(q: DnsQType) -> Self {
235        DnsQType::id(&q)
236    }
237}
238
239impl From<u16> for DnsQType {
240    #[inline(always)]
241    fn from(n: u16) -> Self {
242        DnsQType::from_id(n)
243    }
244}
245
246/// The class of a DNS query.
247///
248/// According to [RFC 1035 Section 3.2.4](https://tools.ietf.org/rfc/rfc1035#section-3.2.4).
249#[derive(Copy, Clone, Debug, PartialEq)]
250#[repr(u16)]
251pub enum DnsQClass {
252    /// Internet
253    IN = 1,
254    /// CSNET
255    CS = 2,
256    /// CHAOS
257    CH = 3,
258    /// Hesiod
259    HS = 4,
260    /// Any
261    ANY = 255,
262    Reserved,
263}
264
265impl DnsQClass {
266    /// Create a new QClass from an ID.
267    #[inline(always)]
268    pub fn from_id(id: u16) -> Self {
269        match id {
270            1..=4 | 255 => unsafe { core::mem::transmute(id) },
271            _ => DnsQClass::Reserved,
272        }
273    }
274
275    /// Get the ID of the QClass.
276    #[inline(always)]
277    pub fn id(&self) -> u16 {
278        match self {
279            DnsQClass::Reserved => panic!("Reserved QClass"),
280            _ => *self as u16,
281        }
282    }
283}
284
285impl From<DnsQClass> for u16 {
286    #[inline(always)]
287    fn from(q: DnsQClass) -> Self {
288        DnsQClass::id(&q)
289    }
290}
291
292impl From<u16> for DnsQClass {
293    #[inline(always)]
294    fn from(n: u16) -> Self {
295        DnsQClass::from_id(n)
296    }
297}