flex_dns/
answer.rs

1use crate::{Buffer, DnsMessage, DnsMessageError, MutBuffer};
2use crate::name::DnsName;
3use crate::parse::{Parse, ParseBytes};
4use crate::question::DnsQType;
5use crate::rdata::{DnsAType, RData};
6use crate::write::WriteBytes;
7
8/// A DNS message answers section.
9pub struct DnsAnswers<
10    const PTR_STORAGE: usize,
11    B: Buffer,
12> {
13    message: DnsMessage<PTR_STORAGE, 1, B>,
14    remaining: usize,
15}
16
17impl<
18    const PTR_STORAGE: usize,
19    B: Buffer,
20> DnsAnswers<PTR_STORAGE, B> {
21    #[inline(always)]
22    pub(crate) fn new(message: DnsMessage<PTR_STORAGE, 1, B>) -> Self {
23        let remaining = message.header().unwrap().answer_count() as usize;
24        Self {
25            message,
26            remaining,
27        }
28    }
29
30    /// Return an iterator over the answers section.
31    #[inline(always)]
32    pub fn iter(&mut self) -> Result<DnsAnswerIterator, DnsMessageError> {
33        let (bytes, position) = self.message.bytes_and_position();
34
35        Ok(DnsAnswerIterator {
36            buffer: bytes,
37            current_position: position,
38            remaining: &mut self.remaining,
39        })
40    }
41
42    /// Complete the message. This will overwrite the next answer or further
43    /// sections, if any.
44    #[inline(always)]
45    pub fn complete(mut self) -> Result<DnsMessage<PTR_STORAGE, 2, B>, DnsMessageError> {
46        if self.remaining != 0 {
47            for x in self.iter()? { x?; }
48        }
49
50        Ok(DnsMessage {
51            buffer: self.message.buffer,
52            position: self.message.position,
53            ptr_storage: self.message.ptr_storage,
54            ptr_len: self.message.ptr_len,
55        })
56    }
57}
58
59impl<
60    const PTR_STORAGE: usize,
61    B: MutBuffer + Buffer,
62> DnsAnswers<PTR_STORAGE, B> {
63    /// Append an answer to the message. This will overwrite the next
64    /// answer or further sections, if any.
65    pub fn append(&mut self, answer: DnsAnswer<DnsAType>) -> Result<(), DnsMessageError> {
66        // Truncate the buffer to the current position.
67        self.message.truncate()?;
68        answer.write(&mut self.message)?;
69        // Set answer_count in the header to the current answer count + 1.
70        let answer_count = self.message.header().unwrap().answer_count();
71        let answer_count = answer_count + 1 - self.remaining as u16;
72        self.message.header_mut()?.set_answer_count(answer_count);
73        self.message.header_mut()?.set_name_server_count(0);
74        self.message.header_mut()?.set_additional_records_count(0);
75        self.remaining = 0;
76
77        Ok(())
78    }
79}
80
81/// An iterator over the answers section of a DNS message.
82pub struct DnsAnswerIterator<'a> {
83    buffer: &'a [u8],
84    current_position: &'a mut usize,
85    remaining: &'a mut usize,
86}
87
88impl<'a> Iterator for DnsAnswerIterator<'a> {
89    type Item = Result<DnsAnswer<'a, RData<'a>>, DnsMessageError>;
90
91    #[inline]
92    fn next(&mut self) -> Option<Self::Item> {
93        if *self.remaining == 0 {
94            return None;
95        }
96
97        let answer = DnsAnswer::parse(
98            self.buffer, self.current_position
99        );
100        *self.remaining -= 1;
101
102        Some(answer)
103    }
104}
105
106/// A DNS message answer.
107#[derive(Debug, PartialEq)]
108pub struct DnsAnswer<'a, D> {
109    /// The name of the answer.
110    pub name: DnsName<'a>,
111    /// The answer data.
112    pub rdata: D,
113    /// Whether the answer should be cached.
114    pub cache_flush: bool,
115    /// The class of the answer.
116    pub aclass: DnsAClass,
117    /// The time to live of the answer.
118    pub ttl: u32,
119}
120
121impl<'a> DnsAnswer<'a, RData<'a>> {
122    /// Parse the rdata of the additional into a structured type.
123    #[inline(always)]
124    pub fn into_parsed(self) -> Result<DnsAnswer<'a, DnsAType<'a>>, DnsMessageError> {
125        Ok(DnsAnswer {
126            name: self.name,
127            rdata: self.rdata.into_parsed()?,
128            cache_flush: self.cache_flush,
129            aclass: self.aclass,
130            ttl: self.ttl,
131        })
132    }
133}
134
135impl<'a> ParseBytes<'a> for DnsAnswer<'a, RData<'a>> {
136    fn parse_bytes(bytes: &'a [u8], i: &mut usize) -> Result<Self, DnsMessageError> {
137        let name = DnsName::parse(bytes, i)?;
138        let atype_id = u16::parse(bytes, i)?;
139        let atype = DnsQType::from_id(atype_id);
140        let cache_flush = atype_id & 0b1000_0000_0000_0000 != 0;
141        let aclass = DnsAClass::from_id(u16::parse(bytes, i)?);
142        let ttl = u32::parse(bytes, i)?;
143        let rdata = RData::parse(bytes, i, atype)?;
144
145        Ok(Self {
146            name,
147            rdata,
148            cache_flush,
149            aclass,
150            ttl,
151        })
152    }
153}
154
155impl<'a> WriteBytes for DnsAnswer<'a, DnsAType<'a>> {
156    fn write<
157        const PTR_STORAGE: usize,
158        const DNS_SECTION: usize,
159        B: MutBuffer + Buffer,
160    >(&self, message: &mut DnsMessage<PTR_STORAGE, DNS_SECTION, B>) -> Result<usize, DnsMessageError> {
161        let mut bytes = 0;
162        // Write the name to the buffer using the pointer storage for compression.
163        bytes += self.name.write(message)?;
164        // Write the atype and aclass to the buffer.
165        bytes += self.rdata.id().write(message)?;
166        let mut aclass = self.aclass.id();
167        if self.cache_flush {
168            aclass |= 0b1000_0000;
169        }
170        bytes += aclass.write(message)?;
171        // Write the ttl to the buffer.
172        bytes += self.ttl.write(message)?;
173        let rdata_len_placeholder = message.write_placeholder::<2>()?;
174        // Write the type specific data to the buffer.
175        let rdata_len = self.rdata.write(message)?;
176        bytes += rdata_len;
177        bytes += rdata_len_placeholder(message, (rdata_len as u16).to_be_bytes());
178
179        Ok(bytes)
180    }
181}
182
183/// A DNS message answer class.
184#[derive(Copy, Clone, Debug, PartialEq)]
185#[repr(u16)]
186pub enum DnsAClass {
187    IN = 1,
188    CS = 2,
189    CH = 3,
190    HS = 4,
191    Reserved,
192}
193
194impl DnsAClass {
195    /// Create a new QClass from an id.
196    #[inline(always)]
197    pub fn from_id(id: u16) -> Self {
198        match id {
199            1..=4 => unsafe { core::mem::transmute(id) },
200            _ => DnsAClass::Reserved,
201        }
202    }
203
204    /// Get the id of the QClass.
205    #[inline(always)]
206    pub fn id(&self) -> u16 {
207        match self {
208            DnsAClass::Reserved => panic!("Reserved QClass"),
209            _ => *self as u16,
210        }
211    }
212}
213
214impl From<DnsAClass> for u16 {
215    #[inline(always)]
216    fn from(q: DnsAClass) -> Self {
217        DnsAClass::id(&q)
218    }
219}
220
221impl From<u16> for DnsAClass {
222    #[inline(always)]
223    fn from(n: u16) -> Self {
224        DnsAClass::from_id(n)
225    }
226}