flex_dns/
additional.rs

1use crate::{Buffer, DnsMessage, DnsMessageError, MutBuffer};
2use crate::answer::DnsAClass;
3use crate::name::DnsName;
4use crate::parse::{Parse, ParseBytes};
5use crate::question::DnsQType;
6use crate::rdata::{DnsAType, RData};
7use crate::write::WriteBytes;
8
9/// A DNS message additionals section.
10pub struct DnsAdditionals<
11    const PTR_STORAGE: usize,
12    B: Buffer,
13> {
14    message: DnsMessage<PTR_STORAGE, 3, B>,
15    remaining: usize,
16}
17
18impl<
19    const PTR_STORAGE: usize,
20    B: Buffer,
21> DnsAdditionals<PTR_STORAGE, B> {
22    #[inline(always)]
23    pub(crate) fn new(message: DnsMessage<PTR_STORAGE, 3, B>) -> Self {
24        let remaining = message.header().unwrap().answer_count() as usize;
25        Self {
26            message,
27            remaining,
28        }
29    }
30
31    /// Return an iterator over the additionals section.
32    #[inline(always)]
33    pub fn iter(&mut self) -> Result<DnsAdditionalsIterator, DnsMessageError> {
34        let (bytes, position) = self.message.bytes_and_position();
35
36        Ok(DnsAdditionalsIterator {
37            buffer: bytes,
38            current_position: position,
39            remaining: &mut self.remaining,
40        })
41    }
42
43    /// Complete writing to the additionals section and return the message.
44    #[inline(always)]
45    pub fn complete(mut self) -> Result<DnsMessage<PTR_STORAGE, 3, B>, DnsMessageError> {
46        if self.remaining != 0 {
47            for x in self.iter()? { x?; }
48        }
49
50        Ok(DnsMessage {
51            buffer: self.message.buffer,
52            position: self.message.position,
53            ptr_storage: self.message.ptr_storage,
54            ptr_len: self.message.ptr_len,
55        })
56    }
57}
58
59impl<
60    const PTR_STORAGE: usize,
61    B: MutBuffer + Buffer,
62> DnsAdditionals<PTR_STORAGE, B> {
63    /// Append an additional to the message. This will overwrite the next
64    /// additional or further sections, if any.
65    pub fn append(&mut self, answer: DnsAdditional<DnsAType>) -> Result<(), DnsMessageError> {
66        // Truncate the buffer to the current position.
67        self.message.truncate()?;
68        answer.write(&mut self.message)?;
69        // Set answer_count in the header to the current answer count + 1.
70        let answer_count = self.message.header().unwrap().answer_count();
71        let answer_count = answer_count + 1 - self.remaining as u16;
72        self.message.header_mut()?.set_answer_count(answer_count);
73        self.message.header_mut()?.set_name_server_count(0);
74        self.message.header_mut()?.set_additional_records_count(0);
75        self.remaining = 0;
76
77        Ok(())
78    }
79}
80
81/// An iterator over the additionals section of a DNS message.
82pub struct DnsAdditionalsIterator<'a> {
83    buffer: &'a [u8],
84    current_position: &'a mut usize,
85    remaining: &'a mut usize,
86}
87
88impl<'a> Iterator for DnsAdditionalsIterator<'a> {
89    type Item = Result<DnsAdditional<'a, RData<'a>>, DnsMessageError>;
90
91    #[inline]
92    fn next(&mut self) -> Option<Self::Item> {
93        if *self.remaining == 0 {
94            return None;
95        }
96
97        let additional = DnsAdditional::parse(
98            self.buffer, self.current_position
99        );
100        *self.remaining -= 1;
101
102        Some(additional)
103    }
104}
105
106/// A DNS message additional.
107#[derive(Debug, PartialEq)]
108pub struct DnsAdditional<'a, D> {
109    /// The name of the additional.
110    pub name: DnsName<'a>,
111    /// The data of the additional.
112    pub rdata: D,
113    /// Whether the additional should be cached.
114    pub cache_flush: bool,
115    /// The class of the additional.
116    pub aclass: DnsAClass,
117    /// The time to live of the additional.
118    pub ttl: u32,
119}
120
121impl<'a> DnsAdditional<'a, RData<'a>> {
122    /// Parse the rdata of the additional into a structured type.
123    #[inline(always)]
124    pub fn into_parsed(self) -> Result<DnsAdditional<'a, DnsAType<'a>>, DnsMessageError> {
125        Ok(DnsAdditional {
126            name: self.name,
127            rdata: self.rdata.into_parsed()?,
128            cache_flush: self.cache_flush,
129            aclass: self.aclass,
130            ttl: self.ttl,
131        })
132    }
133}
134
135impl<'a> ParseBytes<'a> for DnsAdditional<'a, RData<'a>> {
136    fn parse_bytes(bytes: &'a [u8], i: &mut usize) -> Result<Self, DnsMessageError> {
137        let name = DnsName::parse(bytes, i)?;
138        let atype_id = u16::parse(bytes, i)?;
139        let atype = DnsQType::from_id(atype_id);
140        let cache_flush = atype_id & 0b1000_0000_0000_0000 != 0;
141        let aclass = DnsAClass::from_id(u16::parse(bytes, i)?);
142        let ttl = u32::parse(bytes, i)?;
143        let rdata = RData::parse(bytes, i, atype)?;
144
145        Ok(Self {
146            name,
147            rdata,
148            cache_flush,
149            aclass,
150            ttl,
151        })
152    }
153}
154
155impl<'a> WriteBytes for DnsAdditional<'a, DnsAType<'a>> {
156    fn write<
157        const PTR_STORAGE: usize,
158        const DNS_SECTION: usize,
159        B: MutBuffer + Buffer,
160    >(&self, message: &mut DnsMessage<PTR_STORAGE, DNS_SECTION, B>) -> Result<usize, DnsMessageError> {
161        let mut bytes = 0;
162        // Write the name to the buffer using the pointer storage for compression.
163        bytes += self.name.write(message)?;
164        // Write the atype and aclass to the buffer.
165        bytes += self.rdata.id().write(message)?;
166        let mut aclass = self.aclass.id();
167        if self.cache_flush {
168            aclass |= 0b1000_0000;
169        }
170        bytes += aclass.write(message)?;
171        // Write the ttl to the buffer.
172        bytes += self.ttl.write(message)?;
173        let rdata_len_placeholder = message.write_placeholder::<2>()?;
174        // Write the type specific data to the buffer.
175        let rdata_len = self.rdata.write(message)?;
176        bytes += rdata_len;
177        bytes += rdata_len_placeholder(message, (rdata_len as u16).to_be_bytes());
178
179        Ok(bytes)
180    }
181}