1use crate::{Buffer, DnsMessage, DnsMessageError, MutBuffer};
2use crate::answer::DnsAClass;
3use crate::name::DnsName;
4use crate::parse::{Parse, ParseBytes};
5use crate::question::DnsQType;
6use crate::rdata::{DnsAType, RData};
7use crate::write::WriteBytes;
8
9pub struct DnsAdditionals<
11 const PTR_STORAGE: usize,
12 B: Buffer,
13> {
14 message: DnsMessage<PTR_STORAGE, 3, B>,
15 remaining: usize,
16}
17
18impl<
19 const PTR_STORAGE: usize,
20 B: Buffer,
21> DnsAdditionals<PTR_STORAGE, B> {
22 #[inline(always)]
23 pub(crate) fn new(message: DnsMessage<PTR_STORAGE, 3, B>) -> Self {
24 let remaining = message.header().unwrap().answer_count() as usize;
25 Self {
26 message,
27 remaining,
28 }
29 }
30
31 #[inline(always)]
33 pub fn iter(&mut self) -> Result<DnsAdditionalsIterator, DnsMessageError> {
34 let (bytes, position) = self.message.bytes_and_position();
35
36 Ok(DnsAdditionalsIterator {
37 buffer: bytes,
38 current_position: position,
39 remaining: &mut self.remaining,
40 })
41 }
42
43 #[inline(always)]
45 pub fn complete(mut self) -> Result<DnsMessage<PTR_STORAGE, 3, B>, DnsMessageError> {
46 if self.remaining != 0 {
47 for x in self.iter()? { x?; }
48 }
49
50 Ok(DnsMessage {
51 buffer: self.message.buffer,
52 position: self.message.position,
53 ptr_storage: self.message.ptr_storage,
54 ptr_len: self.message.ptr_len,
55 })
56 }
57}
58
59impl<
60 const PTR_STORAGE: usize,
61 B: MutBuffer + Buffer,
62> DnsAdditionals<PTR_STORAGE, B> {
63 pub fn append(&mut self, answer: DnsAdditional<DnsAType>) -> Result<(), DnsMessageError> {
66 self.message.truncate()?;
68 answer.write(&mut self.message)?;
69 let answer_count = self.message.header().unwrap().answer_count();
71 let answer_count = answer_count + 1 - self.remaining as u16;
72 self.message.header_mut()?.set_answer_count(answer_count);
73 self.message.header_mut()?.set_name_server_count(0);
74 self.message.header_mut()?.set_additional_records_count(0);
75 self.remaining = 0;
76
77 Ok(())
78 }
79}
80
81pub struct DnsAdditionalsIterator<'a> {
83 buffer: &'a [u8],
84 current_position: &'a mut usize,
85 remaining: &'a mut usize,
86}
87
88impl<'a> Iterator for DnsAdditionalsIterator<'a> {
89 type Item = Result<DnsAdditional<'a, RData<'a>>, DnsMessageError>;
90
91 #[inline]
92 fn next(&mut self) -> Option<Self::Item> {
93 if *self.remaining == 0 {
94 return None;
95 }
96
97 let additional = DnsAdditional::parse(
98 self.buffer, self.current_position
99 );
100 *self.remaining -= 1;
101
102 Some(additional)
103 }
104}
105
106#[derive(Debug, PartialEq)]
108pub struct DnsAdditional<'a, D> {
109 pub name: DnsName<'a>,
111 pub rdata: D,
113 pub cache_flush: bool,
115 pub aclass: DnsAClass,
117 pub ttl: u32,
119}
120
121impl<'a> DnsAdditional<'a, RData<'a>> {
122 #[inline(always)]
124 pub fn into_parsed(self) -> Result<DnsAdditional<'a, DnsAType<'a>>, DnsMessageError> {
125 Ok(DnsAdditional {
126 name: self.name,
127 rdata: self.rdata.into_parsed()?,
128 cache_flush: self.cache_flush,
129 aclass: self.aclass,
130 ttl: self.ttl,
131 })
132 }
133}
134
135impl<'a> ParseBytes<'a> for DnsAdditional<'a, RData<'a>> {
136 fn parse_bytes(bytes: &'a [u8], i: &mut usize) -> Result<Self, DnsMessageError> {
137 let name = DnsName::parse(bytes, i)?;
138 let atype_id = u16::parse(bytes, i)?;
139 let atype = DnsQType::from_id(atype_id);
140 let cache_flush = atype_id & 0b1000_0000_0000_0000 != 0;
141 let aclass = DnsAClass::from_id(u16::parse(bytes, i)?);
142 let ttl = u32::parse(bytes, i)?;
143 let rdata = RData::parse(bytes, i, atype)?;
144
145 Ok(Self {
146 name,
147 rdata,
148 cache_flush,
149 aclass,
150 ttl,
151 })
152 }
153}
154
155impl<'a> WriteBytes for DnsAdditional<'a, DnsAType<'a>> {
156 fn write<
157 const PTR_STORAGE: usize,
158 const DNS_SECTION: usize,
159 B: MutBuffer + Buffer,
160 >(&self, message: &mut DnsMessage<PTR_STORAGE, DNS_SECTION, B>) -> Result<usize, DnsMessageError> {
161 let mut bytes = 0;
162 bytes += self.name.write(message)?;
164 bytes += self.rdata.id().write(message)?;
166 let mut aclass = self.aclass.id();
167 if self.cache_flush {
168 aclass |= 0b1000_0000;
169 }
170 bytes += aclass.write(message)?;
171 bytes += self.ttl.write(message)?;
173 let rdata_len_placeholder = message.write_placeholder::<2>()?;
174 let rdata_len = self.rdata.write(message)?;
176 bytes += rdata_len;
177 bytes += rdata_len_placeholder(message, (rdata_len as u16).to_be_bytes());
178
179 Ok(bytes)
180 }
181}