1use bstr::BString;
2use std::collections::HashMap;
3use std::net;
4use std::net::{IpAddr, SocketAddr};
5use unicase::UniCase;
6
7#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Clone, Default, Debug, PartialEq, Eq)]
10pub struct Response {
11 pub answers: Vec<Record>,
12 pub nameservers: Vec<Record>,
13 pub additional: Vec<Record>,
14}
15
16#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub struct Record {
20 pub name: String,
21 #[serde(with = "serde_helpers::dns_class")]
22 pub class: dns_parser::Class,
23 pub ttl: u32,
24 pub kind: RecordKind,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq)]
29#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum RecordKind {
31 A(net::Ipv4Addr),
32 AAAA(net::Ipv6Addr),
33 CNAME(String),
34 MX {
35 preference: u16,
36 exchange: String,
37 },
38 NS(String),
39 SRV {
40 priority: u16,
41 weight: u16,
42 port: u16,
43 target: String,
44 },
45 #[serde(with = "serde_helpers::txt_records")]
46 TXT(HashMap<UniCase<String>, TxtRecordValue>),
47 PTR(String),
48 Unimplemented(Vec<u8>),
50}
51
52#[derive(Clone, Debug, PartialEq, Eq)]
61#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
62pub enum TxtRecordValue {
63 None,
64 Empty,
65 #[serde(with = "serde_helpers::bstring")]
66 Value(BString),
67}
68
69#[cfg(feature = "with-serde")]
70pub(crate) mod serde_helpers {
71 pub(crate) mod dns_class {
72 pub fn serialize<S>(class: &dns_parser::Class, serializer: S) -> Result<S::Ok, S::Error>
73 where
74 S: serde::ser::Serializer,
75 {
76 serializer.serialize_u8(*class as u8)
77 }
78
79 pub fn deserialize<'de, D>(d: D) -> Result<dns_parser::Class, D::Error>
80 where
81 D: serde::de::Deserializer<'de>,
82 {
83 d.deserialize_u8(DnsClassVisitor)
84 }
85
86 struct DnsClassVisitor;
87
88 impl<'de> serde::de::Visitor<'de> for DnsClassVisitor {
89 type Value = dns_parser::Class;
90
91 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
92 formatter.write_str("DNS CLASS value according to RFC 1035")
93 }
94
95 fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
96 where
97 E: serde::de::Error,
98 {
99 use dns_parser::Class::*;
100 let class = match v {
101 1 => IN,
102 2 => CS,
103 3 => CH,
104 4 => HS,
105 _ => {
106 return Err(serde::de::Error::invalid_value(
107 serde::de::Unexpected::Signed(v as i64),
108 &self,
109 ))
110 }
111 };
112
113 Ok(class)
114 }
115
116 fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
117 where
118 E: serde::de::Error,
119 {
120 self.visit_i8(v as i8)
121 }
122 }
123 }
124
125 pub(crate) mod bstring {
126 use bstr::{BString, ByteSlice};
127
128 pub fn serialize<S>(bstring: &BString, serializer: S) -> Result<S::Ok, S::Error>
129 where
130 S: serde::ser::Serializer,
131 {
132 serializer.serialize_bytes(bstring.as_bytes())
133 }
134
135 pub fn deserialize<'de, D>(d: D) -> Result<BString, D::Error>
136 where
137 D: serde::de::Deserializer<'de>,
138 {
139 d.deserialize_bytes(BStringVisitor)
140 }
141
142 struct BStringVisitor;
143
144 impl<'de> serde::de::Visitor<'de> for BStringVisitor {
145 type Value = BString;
146
147 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
148 formatter.write_str("BString")
149 }
150
151 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
152 where
153 E: serde::de::Error,
154 {
155 Ok(BString::from(v))
156 }
157 }
158 }
159
160 pub(crate) mod txt_records {
161 use crate::TxtRecordValue;
162 use serde::{de::MapAccess, ser::SerializeMap};
163 use std::collections::HashMap;
164 use unicase::UniCase;
165
166 pub fn serialize<S>(
167 records: &HashMap<UniCase<String>, TxtRecordValue>,
168 serializer: S,
169 ) -> Result<S::Ok, S::Error>
170 where
171 S: serde::ser::Serializer,
172 {
173 let mut map = serializer.serialize_map(Some(records.len()))?;
174 for (k, v) in records {
175 map.serialize_entry(&k.as_ref(), v)?;
176 }
177 map.end()
178 }
179
180 pub fn deserialize<'de, D>(
181 d: D,
182 ) -> Result<HashMap<UniCase<String>, TxtRecordValue>, D::Error>
183 where
184 D: serde::de::Deserializer<'de>,
185 {
186 d.deserialize_map(TxtRecordVisitor)
187 }
188
189 struct TxtRecordVisitor;
190
191 impl<'de> serde::de::Visitor<'de> for TxtRecordVisitor {
192 type Value = HashMap<UniCase<String>, TxtRecordValue>;
193
194 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
195 formatter.write_str(
196 "TXT Records map containing case-insensitive Key String and Value BString",
197 )
198 }
199
200 fn visit_map<A>(self, mut access: A) -> Result<Self::Value, A::Error>
201 where
202 A: MapAccess<'de>,
203 {
204 let mut map = HashMap::<UniCase<String>, TxtRecordValue>::with_capacity(
205 access.size_hint().unwrap_or(0),
206 );
207 while let Some((key, value)) = access.next_entry::<String, _>()? {
208 map.insert(UniCase::new(key), value);
209 }
210 Ok(map)
211 }
212 }
213 }
214}
215
216impl Response {
217 pub fn from_packet(packet: &dns_parser::Packet) -> Self {
218 Response {
219 answers: packet
220 .answers
221 .iter()
222 .map(Record::from_resource_record)
223 .collect(),
224 nameservers: packet
225 .nameservers
226 .iter()
227 .map(Record::from_resource_record)
228 .collect(),
229 additional: packet
230 .additional
231 .iter()
232 .map(Record::from_resource_record)
233 .collect(),
234 }
235 }
236
237 pub fn records(&self) -> impl Iterator<Item = &Record> {
238 self.answers
239 .iter()
240 .chain(self.nameservers.iter())
241 .chain(self.additional.iter())
242 }
243
244 pub fn is_empty(&self) -> bool {
245 self.answers.is_empty() && self.nameservers.is_empty() && self.additional.is_empty()
246 }
247
248 pub fn ip_addr(&self) -> Option<IpAddr> {
249 self.records().find_map(|record| match record.kind {
250 RecordKind::A(addr) => Some(addr.into()),
251 RecordKind::AAAA(addr) => Some(addr.into()),
252 _ => None,
253 })
254 }
255
256 pub fn hostname(&self) -> Option<&str> {
257 self.records().find_map(|record| match record.kind {
258 RecordKind::PTR(ref host) => Some(host.as_str()),
259 _ => None,
260 })
261 }
262
263 pub fn port(&self) -> Option<u16> {
264 self.records().find_map(|record| match record.kind {
265 RecordKind::SRV { port, .. } => Some(port),
266 _ => None,
267 })
268 }
269
270 pub fn socket_address(&self) -> Option<SocketAddr> {
271 Some((self.ip_addr()?, self.port()?).into())
272 }
273
274 pub fn txt_records(&self) -> impl Iterator<Item = (&str, &TxtRecordValue)> {
275 self.records()
276 .filter_map(|record| match record.kind {
277 RecordKind::TXT(ref txt) => Some(txt),
278 _ => None,
279 })
280 .flat_map(|txt| txt.iter())
281 .map(|(key, value)| (key.as_str(), value))
282 }
283}
284
285impl Record {
286 fn from_resource_record(rr: &dns_parser::ResourceRecord) -> Self {
287 Record {
288 name: rr.name.to_string(),
289 class: rr.cls,
290 ttl: rr.ttl,
291 kind: RecordKind::from_rr_data(&rr.data),
292 }
293 }
294}
295
296impl RecordKind {
297 fn from_rr_data(data: &dns_parser::RData) -> Self {
298 use dns_parser::RData;
299
300 match *data {
301 RData::A(dns_parser::rdata::a::Record(addr)) => RecordKind::A(addr),
302 RData::AAAA(dns_parser::rdata::aaaa::Record(addr)) => RecordKind::AAAA(addr),
303 RData::CNAME(ref name) => RecordKind::CNAME(name.to_string()),
304 RData::MX(dns_parser::rdata::mx::Record {
305 preference,
306 ref exchange,
307 }) => RecordKind::MX {
308 preference,
309 exchange: exchange.to_string(),
310 },
311 RData::NS(ref name) => RecordKind::NS(name.to_string()),
312 RData::PTR(ref name) => RecordKind::PTR(name.to_string()),
313 RData::SRV(dns_parser::rdata::srv::Record {
314 priority,
315 weight,
316 port,
317 ref target,
318 }) => RecordKind::SRV {
319 priority,
320 weight,
321 port,
322 target: target.to_string(),
323 },
324 RData::TXT(ref txt) => {
325 let mut txt_records: HashMap<UniCase<String>, TxtRecordValue> = HashMap::new();
326 for txt_record in txt.iter() {
327 let mut kv_split = txt_record.split(|c| c == &b'=');
328 if let Some(key_bytes) = kv_split.next() {
329 let key = UniCase::new(String::from_utf8_lossy(key_bytes).into_owned());
330 if txt_records.contains_key(&key) {
331 continue;
335 }
336 let value = if let Some(value_bytes) = kv_split.next() {
337 if value_bytes.is_empty() {
338 TxtRecordValue::Empty
339 } else {
340 TxtRecordValue::Value(BString::from(value_bytes))
341 }
342 } else {
343 TxtRecordValue::None
344 };
345 txt_records.insert(key, value);
346 }
347 }
348 RecordKind::TXT(txt_records)
349 }
350 RData::SOA(..) => {
351 RecordKind::Unimplemented("SOA record handling is not implemented".into())
352 }
353 RData::Unknown(data) => RecordKind::Unimplemented(data.to_owned()),
354 }
355 }
356}