flex_dns/
name_servers.rs

1use crate::{Buffer, DnsMessage, DnsMessageError, MutBuffer};
2use crate::answer::DnsAClass;
3use crate::name::DnsName;
4use crate::parse::{Parse, ParseBytes};
5use crate::question::DnsQType;
6use crate::rdata::{DnsAType, RData};
7use crate::write::WriteBytes;
8
9/// A DNS message name servers section.
10pub struct DnsNameServers<
11    const PTR_STORAGE: usize,
12    B: Buffer,
13> {
14    message: DnsMessage<PTR_STORAGE, 2, B>,
15    remaining: usize,
16}
17
18impl<
19    const PTR_STORAGE: usize,
20    B: Buffer,
21> DnsNameServers<PTR_STORAGE, B> {
22    #[inline(always)]
23    pub(crate) fn new(message: DnsMessage<PTR_STORAGE, 2, B>) -> Self {
24        let remaining = message.header().unwrap().answer_count() as usize;
25        Self {
26            message,
27            remaining,
28        }
29    }
30
31    /// Return an iterator over the answer section.
32    #[inline(always)]
33    pub fn iter(&mut self) -> Result<DnsNameServersIterator, DnsMessageError> {
34        let (bytes, position) = self.message.bytes_and_position();
35
36        Ok(DnsNameServersIterator {
37            buffer: bytes,
38            current_position: position,
39            remaining: &mut self.remaining,
40        })
41    }
42
43    /// Complete the message. This will read and check the remaining answers
44    /// and return the message if successful.
45    #[inline(always)]
46    pub fn complete(mut self) -> Result<DnsMessage<PTR_STORAGE, 3, B>, DnsMessageError> {
47        if self.remaining != 0 {
48            for x in self.iter()? { x?; }
49        }
50
51        Ok(DnsMessage {
52            buffer: self.message.buffer,
53            position: self.message.position,
54            ptr_storage: self.message.ptr_storage,
55            ptr_len: self.message.ptr_len,
56        })
57    }
58}
59
60impl<
61    const PTR_STORAGE: usize,
62    B: MutBuffer + Buffer,
63> DnsNameServers<PTR_STORAGE, B> {
64    /// Append a name server to the message. This will override the next
65    /// name server or further sections, if any.
66    pub fn append(&mut self, answer: NameServer<DnsAType>) -> Result<(), DnsMessageError> {
67        // Truncate the buffer to the current position.
68        self.message.truncate()?;
69        answer.write(&mut self.message)?;
70        // Set answer_count in the header to the current answer count + 1.
71        let answer_count = self.message.header().unwrap().answer_count();
72        let answer_count = answer_count + 1 - self.remaining as u16;
73        self.message.header_mut()?.set_answer_count(answer_count);
74        self.message.header_mut()?.set_name_server_count(0);
75        self.message.header_mut()?.set_additional_records_count(0);
76        self.remaining = 0;
77
78        Ok(())
79    }
80}
81
82/// An iterator over the name servers section of a DNS message.
83pub struct DnsNameServersIterator<'a> {
84    buffer: &'a [u8],
85    current_position: &'a mut usize,
86    remaining: &'a mut usize,
87}
88
89impl<'a> Iterator for DnsNameServersIterator<'a> {
90    type Item = Result<NameServer<'a, RData<'a>>, DnsMessageError>;
91
92    #[inline]
93    fn next(&mut self) -> Option<Self::Item> {
94        if *self.remaining == 0 {
95            return None;
96        }
97
98        let name_server = NameServer::parse(
99            self.buffer, self.current_position
100        );
101        *self.remaining -= 1;
102
103        Some(name_server)
104    }
105}
106
107/// A DNS message name server.
108#[derive(Debug, PartialEq)]
109pub struct NameServer<'a, D> {
110    /// The name of the name server.
111    pub name: DnsName<'a>,
112    /// The data of the name server.
113    pub rdata: D,
114    /// Whether the name server is authoritative.
115    pub cache_flush: bool,
116    /// The class of the name server.
117    pub aclass: DnsAClass,
118    /// The time to live of the name server.
119    pub ttl: u32,
120}
121
122impl<'a> NameServer<'a, RData<'a>> {
123    /// Parse the rdata of the additional into a structured type.
124    #[inline(always)]
125    pub fn into_parsed(self) -> Result<NameServer<'a, DnsAType<'a>>, DnsMessageError> {
126        Ok(NameServer {
127            name: self.name,
128            rdata: self.rdata.into_parsed()?,
129            cache_flush: self.cache_flush,
130            aclass: self.aclass,
131            ttl: self.ttl,
132        })
133    }
134}
135
136impl<'a> ParseBytes<'a> for NameServer<'a, RData<'a>> {
137    fn parse_bytes(bytes: &'a [u8], i: &mut usize) -> Result<Self, DnsMessageError> {
138        let name = DnsName::parse(bytes, i)?;
139        let atype_id = u16::parse(bytes, i)?;
140        let atype = DnsQType::from_id(atype_id);
141        let cache_flush = atype_id & 0b1000_0000_0000_0000 != 0;
142        let aclass = DnsAClass::from_id(u16::parse(bytes, i)?);
143        let ttl = u32::parse(bytes, i)?;
144        let rdata = RData::parse(bytes, i, atype)?;
145
146        Ok(Self {
147            name,
148            rdata,
149            cache_flush,
150            aclass,
151            ttl,
152        })
153    }
154}
155
156impl<'a> WriteBytes for NameServer<'a, DnsAType<'a>> {
157    fn write<
158        const PTR_STORAGE: usize,
159        const DNS_SECTION: usize,
160        B: MutBuffer + Buffer,
161    >(&self, message: &mut DnsMessage<PTR_STORAGE, DNS_SECTION, B>) -> Result<usize, DnsMessageError> {
162        let mut bytes = 0;
163        // Write the name to the buffer using the pointer storage for compression.
164        bytes += self.name.write(message)?;
165        // Write the atype and aclass to the buffer.
166        bytes += self.rdata.id().write(message)?;
167        let mut aclass = self.aclass.id();
168        if self.cache_flush {
169            aclass |= 0b1000_0000;
170        }
171        bytes += aclass.write(message)?;
172        // Write the ttl to the buffer.
173        bytes += self.ttl.write(message)?;
174        let rdata_len_placeholder = message.write_placeholder::<2>()?;
175        // Write the type specific data to the buffer.
176        let rdata_len = self.rdata.write(message)?;
177        bytes += rdata_len;
178        bytes += rdata_len_placeholder(message, (rdata_len as u16).to_be_bytes());
179
180        Ok(bytes)
181    }
182}