oko_mdns/
response.rs

1use bstr::BString;
2use std::collections::HashMap;
3use std::net;
4use std::net::{IpAddr, SocketAddr};
5use unicase::UniCase;
6
7/// A DNS response.
8#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Clone, Default, Debug, PartialEq, Eq)]
10pub struct Response {
11    pub answers: Vec<Record>,
12    pub nameservers: Vec<Record>,
13    pub additional: Vec<Record>,
14}
15
16/// Any type of DNS record.
17#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub struct Record {
20    pub name: String,
21    #[serde(with = "serde_helpers::dns_class")]
22    pub class: dns_parser::Class,
23    pub ttl: u32,
24    pub kind: RecordKind,
25}
26
27/// A specific DNS record variant.
28#[derive(Clone, Debug, PartialEq, Eq)]
29#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum RecordKind {
31    A(net::Ipv4Addr),
32    AAAA(net::Ipv6Addr),
33    CNAME(String),
34    MX {
35        preference: u16,
36        exchange: String,
37    },
38    NS(String),
39    SRV {
40        priority: u16,
41        weight: u16,
42        port: u16,
43        target: String,
44    },
45    #[serde(with = "serde_helpers::txt_records")]
46    TXT(HashMap<UniCase<String>, TxtRecordValue>),
47    PTR(String),
48    /// A record kind that hasn't been implemented by this library yet.
49    Unimplemented(Vec<u8>),
50}
51
52/// A TXT Record's Value for a present Attribute with following variants:
53/// - None:   Attribute present, with no value
54///           (e.g., "passreq" -- password required for this service)
55/// - Empty:  Attribute present, with empty value
56//            (e.g., "PlugIns=" -- the server supports plugins, but none are presently installed)
57/// - Value(BString): Attribute present, with non-empty value
58//                    (e.g., "PlugIns=JPEG,MPEG2,MPEG4")
59/// RFC ref: https://datatracker.ietf.org/doc/html/rfc6763#section-6.4
60#[derive(Clone, Debug, PartialEq, Eq)]
61#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
62pub enum TxtRecordValue {
63    None,
64    Empty,
65    #[serde(with = "serde_helpers::bstring")]
66    Value(BString),
67}
68
69#[cfg(feature = "with-serde")]
70pub(crate) mod serde_helpers {
71    pub(crate) mod dns_class {
72        pub fn serialize<S>(class: &dns_parser::Class, serializer: S) -> Result<S::Ok, S::Error>
73        where
74            S: serde::ser::Serializer,
75        {
76            serializer.serialize_u8(*class as u8)
77        }
78
79        pub fn deserialize<'de, D>(d: D) -> Result<dns_parser::Class, D::Error>
80        where
81            D: serde::de::Deserializer<'de>,
82        {
83            d.deserialize_u8(DnsClassVisitor)
84        }
85
86        struct DnsClassVisitor;
87
88        impl<'de> serde::de::Visitor<'de> for DnsClassVisitor {
89            type Value = dns_parser::Class;
90
91            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
92                formatter.write_str("DNS CLASS value according to RFC 1035")
93            }
94
95            fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
96            where
97                E: serde::de::Error,
98            {
99                use dns_parser::Class::*;
100                let class = match v {
101                    1 => IN,
102                    2 => CS,
103                    3 => CH,
104                    4 => HS,
105                    _ => {
106                        return Err(serde::de::Error::invalid_value(
107                            serde::de::Unexpected::Signed(v as i64),
108                            &self,
109                        ))
110                    }
111                };
112
113                Ok(class)
114            }
115
116            fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
117            where
118                E: serde::de::Error,
119            {
120                self.visit_i8(v as i8)
121            }
122        }
123    }
124
125    pub(crate) mod bstring {
126        use bstr::{BString, ByteSlice};
127
128        pub fn serialize<S>(bstring: &BString, serializer: S) -> Result<S::Ok, S::Error>
129        where
130            S: serde::ser::Serializer,
131        {
132            serializer.serialize_bytes(bstring.as_bytes())
133        }
134
135        pub fn deserialize<'de, D>(d: D) -> Result<BString, D::Error>
136        where
137            D: serde::de::Deserializer<'de>,
138        {
139            d.deserialize_bytes(BStringVisitor)
140        }
141
142        struct BStringVisitor;
143
144        impl<'de> serde::de::Visitor<'de> for BStringVisitor {
145            type Value = BString;
146
147            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
148                formatter.write_str("BString")
149            }
150
151            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
152            where
153                E: serde::de::Error,
154            {
155                Ok(BString::from(v))
156            }
157        }
158    }
159
160    pub(crate) mod txt_records {
161        use crate::TxtRecordValue;
162        use serde::{de::MapAccess, ser::SerializeMap};
163        use std::collections::HashMap;
164        use unicase::UniCase;
165
166        pub fn serialize<S>(
167            records: &HashMap<UniCase<String>, TxtRecordValue>,
168            serializer: S,
169        ) -> Result<S::Ok, S::Error>
170        where
171            S: serde::ser::Serializer,
172        {
173            let mut map = serializer.serialize_map(Some(records.len()))?;
174            for (k, v) in records {
175                map.serialize_entry(&k.as_ref(), v)?;
176            }
177            map.end()
178        }
179
180        pub fn deserialize<'de, D>(
181            d: D,
182        ) -> Result<HashMap<UniCase<String>, TxtRecordValue>, D::Error>
183        where
184            D: serde::de::Deserializer<'de>,
185        {
186            d.deserialize_map(TxtRecordVisitor)
187        }
188
189        struct TxtRecordVisitor;
190
191        impl<'de> serde::de::Visitor<'de> for TxtRecordVisitor {
192            type Value = HashMap<UniCase<String>, TxtRecordValue>;
193
194            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
195                formatter.write_str(
196                    "TXT Records map containing case-insensitive Key String and Value BString",
197                )
198            }
199
200            fn visit_map<A>(self, mut access: A) -> Result<Self::Value, A::Error>
201            where
202                A: MapAccess<'de>,
203            {
204                let mut map = HashMap::<UniCase<String>, TxtRecordValue>::with_capacity(
205                    access.size_hint().unwrap_or(0),
206                );
207                while let Some((key, value)) = access.next_entry::<String, _>()? {
208                    map.insert(UniCase::new(key), value);
209                }
210                Ok(map)
211            }
212        }
213    }
214}
215
216impl Response {
217    pub fn from_packet(packet: &dns_parser::Packet) -> Self {
218        Response {
219            answers: packet
220                .answers
221                .iter()
222                .map(Record::from_resource_record)
223                .collect(),
224            nameservers: packet
225                .nameservers
226                .iter()
227                .map(Record::from_resource_record)
228                .collect(),
229            additional: packet
230                .additional
231                .iter()
232                .map(Record::from_resource_record)
233                .collect(),
234        }
235    }
236
237    pub fn records(&self) -> impl Iterator<Item = &Record> {
238        self.answers
239            .iter()
240            .chain(self.nameservers.iter())
241            .chain(self.additional.iter())
242    }
243
244    pub fn is_empty(&self) -> bool {
245        self.answers.is_empty() && self.nameservers.is_empty() && self.additional.is_empty()
246    }
247
248    pub fn ip_addr(&self) -> Option<IpAddr> {
249        self.records().find_map(|record| match record.kind {
250            RecordKind::A(addr) => Some(addr.into()),
251            RecordKind::AAAA(addr) => Some(addr.into()),
252            _ => None,
253        })
254    }
255
256    pub fn hostname(&self) -> Option<&str> {
257        self.records().find_map(|record| match record.kind {
258            RecordKind::PTR(ref host) => Some(host.as_str()),
259            _ => None,
260        })
261    }
262
263    pub fn port(&self) -> Option<u16> {
264        self.records().find_map(|record| match record.kind {
265            RecordKind::SRV { port, .. } => Some(port),
266            _ => None,
267        })
268    }
269
270    pub fn socket_address(&self) -> Option<SocketAddr> {
271        Some((self.ip_addr()?, self.port()?).into())
272    }
273
274    pub fn txt_records(&self) -> impl Iterator<Item = (&str, &TxtRecordValue)> {
275        self.records()
276            .filter_map(|record| match record.kind {
277                RecordKind::TXT(ref txt) => Some(txt),
278                _ => None,
279            })
280            .flat_map(|txt| txt.iter())
281            .map(|(key, value)| (key.as_str(), value))
282    }
283}
284
285impl Record {
286    fn from_resource_record(rr: &dns_parser::ResourceRecord) -> Self {
287        Record {
288            name: rr.name.to_string(),
289            class: rr.cls,
290            ttl: rr.ttl,
291            kind: RecordKind::from_rr_data(&rr.data),
292        }
293    }
294}
295
296impl RecordKind {
297    fn from_rr_data(data: &dns_parser::RData) -> Self {
298        use dns_parser::RData;
299
300        match *data {
301            RData::A(dns_parser::rdata::a::Record(addr)) => RecordKind::A(addr),
302            RData::AAAA(dns_parser::rdata::aaaa::Record(addr)) => RecordKind::AAAA(addr),
303            RData::CNAME(ref name) => RecordKind::CNAME(name.to_string()),
304            RData::MX(dns_parser::rdata::mx::Record {
305                preference,
306                ref exchange,
307            }) => RecordKind::MX {
308                preference,
309                exchange: exchange.to_string(),
310            },
311            RData::NS(ref name) => RecordKind::NS(name.to_string()),
312            RData::PTR(ref name) => RecordKind::PTR(name.to_string()),
313            RData::SRV(dns_parser::rdata::srv::Record {
314                priority,
315                weight,
316                port,
317                ref target,
318            }) => RecordKind::SRV {
319                priority,
320                weight,
321                port,
322                target: target.to_string(),
323            },
324            RData::TXT(ref txt) => {
325                let mut txt_records: HashMap<UniCase<String>, TxtRecordValue> = HashMap::new();
326                for txt_record in txt.iter() {
327                    let mut kv_split = txt_record.split(|c| c == &b'=');
328                    if let Some(key_bytes) = kv_split.next() {
329                        let key = UniCase::new(String::from_utf8_lossy(key_bytes).into_owned());
330                        if txt_records.contains_key(&key) {
331                            // RFC 6763 Section 6.4: If a client receives a TXT record containing
332                            // the same key more than once, then the client MUST silently ignore
333                            // all but the first occurrence of that attribute.
334                            continue;
335                        }
336                        let value = if let Some(value_bytes) = kv_split.next() {
337                            if value_bytes.is_empty() {
338                                TxtRecordValue::Empty
339                            } else {
340                                TxtRecordValue::Value(BString::from(value_bytes))
341                            }
342                        } else {
343                            TxtRecordValue::None
344                        };
345                        txt_records.insert(key, value);
346                    }
347                }
348                RecordKind::TXT(txt_records)
349            }
350            RData::SOA(..) => {
351                RecordKind::Unimplemented("SOA record handling is not implemented".into())
352            }
353            RData::Unknown(data) => RecordKind::Unimplemented(data.to_owned()),
354        }
355    }
356}