wez_mdns/
lib.rs

1pub mod dns_parser;
2pub mod mac_addr;
3pub mod net_utils;
4
5pub use dns_parser::QueryType;
6use dns_parser::{Builder, Packet, QueryClass, RData, ResourceRecord};
7use smol::channel::{bounded, Receiver};
8use smol::net::UdpSocket;
9use smol::prelude::*;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::time::{Duration, Instant};
12use thiserror::*;
13
14const MULTICAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
15const MULTICAST_PORT: u16 = 5353;
16
17/// Errors that may occur during resolution/discovery
18#[derive(Debug, Error)]
19pub enum Error {
20    #[error(transparent)]
21    Io(#[from] std::io::Error),
22    #[error(transparent)]
23    ChanRecv(#[from] smol::channel::RecvError),
24    #[error(transparent)]
25    ChanSend(#[from] smol::channel::SendError<Response>),
26    #[error("failed to build DNS packet")]
27    DnsPacketBuildError,
28    #[error("Timed out")]
29    Timeout,
30    #[error("The QueryParameters were invalid")]
31    InvalidQueryParams,
32    #[error("Unable to determine local interface")]
33    LocalInterfaceUnknown,
34}
35
36pub type Result<T> = std::result::Result<T, Error>;
37
38fn sockaddr(ip: Ipv4Addr, port: u16) -> SocketAddr {
39    let addr = std::net::SocketAddrV4::new(ip, port);
40    addr.into()
41}
42
43async fn create_socket() -> Result<UdpSocket> {
44    let socket = socket2::Socket::new(
45        socket2::Domain::IPV4,
46        socket2::Type::DGRAM,
47        Some(socket2::Protocol::UDP),
48    )?;
49    socket.set_reuse_address(true)?;
50    let _ = socket.set_reuse_port(true);
51
52    let addr = sockaddr(Ipv4Addr::UNSPECIFIED, MULTICAST_PORT);
53    socket.bind(&addr.into())?;
54
55    let socket = UdpSocket::from(smol::Async::new(socket.into())?);
56    socket.set_multicast_loop_v4(false)?;
57    socket.join_multicast_v4(MULTICAST_ADDR, Ipv4Addr::UNSPECIFIED)?;
58    Ok(socket)
59}
60
61/// An mDNS query response
62#[derive(Debug)]
63pub struct Response {
64    pub answers: Vec<Record>,
65    pub nameservers: Vec<Record>,
66    pub additional: Vec<Record>,
67}
68
69/// The resolved information about a host (or rather, a service)
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct Host {
72    /// A friendly name for this instance.
73    pub name: String,
74    /// The mDNS `A` (or `AAAA`) resolvable hostname for this host.
75    /// May be different from `name`.
76    pub host_name: Option<String>,
77    /// The set of addresses
78    pub ip_address: Vec<IpAddr>,
79    /// The set of addresses with port numbers.
80    /// May be empty if no SRV record was resolved
81    pub socket_address: Vec<SocketAddr>,
82    /// The instant at which this information is no longer valid
83    pub expires: Instant,
84}
85
86impl Host {
87    /// Returns true if the information is still valid (within the
88    /// TTL specified by the mDNS response).
89    pub fn valid(&self) -> bool {
90        Instant::now() < self.expires
91    }
92}
93
94impl Response {
95    fn new(packet: &Packet) -> Self {
96        Self {
97            answers: packet.answers.iter().map(Record::new).collect(),
98            nameservers: packet.nameservers.iter().map(Record::new).collect(),
99            additional: packet.additional.iter().map(Record::new).collect(),
100        }
101    }
102
103    fn all_records(&self) -> impl Iterator<Item = &Record> {
104        self.answers
105            .iter()
106            .chain(self.additional.iter())
107            .chain(self.nameservers.iter())
108    }
109
110    /// Compose the response as an array of Host structs
111    pub fn hosts(&self) -> Vec<Host> {
112        let mut result = vec![];
113
114        for ans in &self.answers {
115            match &ans.kind {
116                RecordKind::A(addr) => {
117                    result.push(Host {
118                        name: ans.name.clone(),
119                        host_name: Some(ans.name.clone()),
120                        ip_address: vec![(*addr).into()],
121                        socket_address: vec![],
122                        expires: Instant::now() + Duration::from_secs(ans.ttl.into()),
123                    });
124                }
125                RecordKind::AAAA(addr) => {
126                    result.push(Host {
127                        name: ans.name.clone(),
128                        host_name: Some(ans.name.clone()),
129                        ip_address: vec![(*addr).into()],
130                        socket_address: vec![],
131                        expires: Instant::now() + Duration::from_secs(ans.ttl.into()),
132                    });
133                }
134                RecordKind::PTR(name) => {
135                    let name = name.clone();
136                    let mut found_port = None;
137                    let mut host_name = None;
138                    let mut ip_address = vec![];
139                    let mut socket_address = vec![];
140
141                    for r in self.all_records() {
142                        if r.name != name {
143                            continue;
144                        }
145
146                        match &r.kind {
147                            RecordKind::SRV { port, target, .. } => {
148                                found_port.replace(*port);
149                                host_name.replace(target.clone());
150                            }
151                            _ => {}
152                        }
153                    }
154
155                    if let Some(host_name) = host_name.as_ref() {
156                        for r in self.all_records() {
157                            if &r.name != host_name {
158                                continue;
159                            }
160
161                            match &r.kind {
162                                RecordKind::A(addr) => {
163                                    ip_address.push(addr.clone().into());
164                                }
165                                RecordKind::AAAA(addr) => {
166                                    ip_address.push(addr.clone().into());
167                                }
168                                _ => {}
169                            }
170                        }
171                    }
172
173                    if let Some(port) = found_port {
174                        for addr in &ip_address {
175                            socket_address.push(SocketAddr::new(*addr, port));
176                        }
177                    }
178
179                    result.push(Host {
180                        name,
181                        host_name,
182                        ip_address,
183                        socket_address,
184                        expires: Instant::now() + Duration::from_secs(ans.ttl.into()),
185                    });
186                }
187                _ => {}
188            }
189        }
190        result
191    }
192}
193
194/// mDNS Records compose into a [Response](struct.Record.html)
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct Record {
197    pub name: String,
198    pub class: dns_parser::Class,
199    pub ttl: u32,
200    pub kind: RecordKind,
201}
202
203impl Record {
204    fn new(rr: &ResourceRecord) -> Self {
205        Self {
206            name: rr.name.to_string(),
207            class: rr.cls,
208            ttl: rr.ttl,
209            kind: RecordKind::new(&rr.data),
210        }
211    }
212}
213
214/// mDNS record data of various kinds
215#[derive(Debug, Clone, PartialEq, Eq)]
216pub enum RecordKind {
217    A(Ipv4Addr),
218    AAAA(Ipv6Addr),
219    CNAME(String),
220    PTR(String),
221    NS(String),
222    MX {
223        preference: u16,
224        exchange: String,
225    },
226    SRV {
227        priority: u16,
228        weight: u16,
229        port: u16,
230        target: String,
231    },
232    SOA {
233        primary_ns: String,
234        mailbox: String,
235        serial: u32,
236        refresh: u32,
237        retry: u32,
238        expire: u32,
239        minimum_ttl: u32,
240    },
241    TXT(Vec<String>),
242    Unimplemented {
243        kind: dns_parser::Type,
244        data: Vec<u8>,
245    },
246}
247
248impl RecordKind {
249    fn new(data: &RData) -> Self {
250        match data {
251            RData::A(dns_parser::rdata::a::Record(addr)) => Self::A(*addr),
252            RData::AAAA(dns_parser::rdata::aaaa::Record(addr)) => Self::AAAA(*addr),
253            RData::CNAME(name) => Self::CNAME(name.to_string()),
254            RData::NS(name) => Self::NS(name.to_string()),
255            RData::PTR(name) => Self::PTR(name.to_string()),
256            RData::MX(dns_parser::rdata::mx::Record {
257                preference,
258                exchange,
259            }) => Self::MX {
260                preference: *preference,
261                exchange: exchange.to_string(),
262            },
263            RData::SRV(dns_parser::rdata::srv::Record {
264                priority,
265                weight,
266                port,
267                target,
268            }) => Self::SRV {
269                priority: *priority,
270                weight: *weight,
271                port: *port,
272                target: target.to_string(),
273            },
274            RData::TXT(txt) => Self::TXT(
275                txt.iter()
276                    .map(|b| String::from_utf8_lossy(b).into_owned())
277                    .collect(),
278            ),
279            RData::SOA(dns_parser::rdata::soa::Record {
280                primary_ns,
281                mailbox,
282                serial,
283                refresh,
284                retry,
285                expire,
286                minimum_ttl,
287            }) => Self::SOA {
288                primary_ns: primary_ns.to_string(),
289                mailbox: mailbox.to_string(),
290                serial: *serial,
291                refresh: *refresh,
292                retry: *retry,
293                expire: *expire,
294                minimum_ttl: *minimum_ttl,
295            },
296            RData::Unknown(kind, data) => Self::Unimplemented {
297                kind: *kind,
298                data: data.to_vec(),
299            },
300        }
301    }
302}
303
304/// Resolve a single host using an mDNS request.
305/// Returns a `Response` if found within the specified timeout,
306/// otherwise yields an Error.
307pub async fn resolve_one<S: AsRef<str>>(
308    service_name: S,
309    params: QueryParameters,
310) -> Result<Response> {
311    let responses = resolve(service_name, params).await?;
312    let response = responses.recv().await?;
313    Ok(response)
314}
315
316/// Controls how to perform the query.
317/// You will typically use one of the associated constants
318/// [DISCOVERY](#associatedconstant.DISCOVERY),
319/// [SERVICE_LOOKUP](#associatedconstant.SERVICE_LOOKUP),
320/// [HOST_LOOKUP](#associatedconstant.HOST_LOOKUP)
321#[derive(Clone, Copy, Debug, PartialEq, Eq)]
322pub struct QueryParameters {
323    pub query_type: QueryType,
324    /// If specified, the query will be re-issued after this duration
325    pub base_repeat_interval: Option<Duration>,
326    /// The maximum interval between retries
327    pub max_repeat_interval: Option<Duration>,
328    /// If true, repeat interval will be doubled on each iteration
329    /// until it reaches the max_repeat_interval.
330    /// If false, it will increment by base_repeat_interval on each
331    /// iteration until it reaches the max_repeat_interval.
332    pub exponential_backoff: bool,
333    /// If set, specifies the upper bound on total time spent
334    /// processing the request.
335    /// Otherwise, the request will keep going forever, subject to
336    /// the repeat interval.
337    pub timeout_after: Option<Duration>,
338}
339
340impl QueryParameters {
341    /// Parameters suitable for performing long-running discovery.
342    /// Repeatedly performs a PTR lookup with exponential backoff
343    /// ranging from 2 seconds up to 5 minutes.
344    pub const DISCOVERY: QueryParameters = QueryParameters {
345        query_type: QueryType::PTR,
346        base_repeat_interval: Some(Duration::from_secs(2)),
347        exponential_backoff: true,
348        max_repeat_interval: Some(Duration::from_secs(300)),
349        timeout_after: None,
350    };
351
352    /// Parameters suitable for performing short-lived discovery.
353    /// Repeatedly performs a PTR lookup with exponential backoff
354    /// ranging from 2 seconds up to 5 minutes.
355    pub const SERVICE_LOOKUP: QueryParameters = QueryParameters {
356        query_type: QueryType::PTR,
357        base_repeat_interval: Some(Duration::from_secs(2)),
358        exponential_backoff: true,
359        max_repeat_interval: None,
360        timeout_after: Some(Duration::from_secs(60)),
361    };
362
363    /// Parameters suitable for resolving a single host.
364    /// Performs an A lookup with exponential backoff ranging from
365    /// 1 second.  The overall lookup will timeout after 1 minutes.
366    pub const HOST_LOOKUP: QueryParameters = QueryParameters {
367        query_type: QueryType::A,
368        base_repeat_interval: Some(Duration::from_secs(1)),
369        exponential_backoff: true,
370        max_repeat_interval: None,
371        timeout_after: Some(Duration::from_secs(60)),
372    };
373
374    pub fn with_timeout(mut self, timeout: Duration) -> Self {
375        self.timeout_after.replace(timeout);
376        self
377    }
378}
379
380fn make_query(service_name: &str, query_type: QueryType) -> Result<Vec<u8>> {
381    let mut builder = Builder::new_query(rand::random(), false);
382    let prefer_unicast = false;
383    builder.add_question(&service_name, prefer_unicast, query_type, QueryClass::IN);
384    Ok(builder.build().map_err(|_| Error::DnsPacketBuildError)?)
385}
386
387/// The source UDP port in all Multicast DNS responses MUST
388/// be 5353 (the well-known port assigned to mDNS).
389/// Multicast DNS implementations MUST silently ignore any
390/// Multicast DNS responses they receive where the source
391/// UDP port is not 5353.
392///
393/// Also applies the Source Address Check from section 11 of
394/// <https://tools.ietf.org/html/rfc6762>
395fn valid_source_address(addr: SocketAddr) -> bool {
396    if addr.port() != MULTICAST_PORT {
397        false
398    } else {
399        /// Computes the masked address bits.
400        fn masked(addr: &[u8], mask: &[u8]) -> Vec<u8> {
401            assert_eq!(addr.len(), mask.len());
402            addr.iter().zip(mask.iter()).map(|(a, m)| a & m).collect()
403        }
404
405        let ifaces = match crate::net_utils::get_if_addrs() {
406            Ok(i) => i,
407            Err(err) => {
408                log::error!("error while listing local interfaces: {}", err);
409                return false;
410            }
411        };
412
413        for iface in ifaces {
414            let matches_iface = match (&iface.addr, addr.ip()) {
415                (crate::net_utils::IfAddr::V4(a), IpAddr::V4(source)) => {
416                    masked(&a.ip.octets(), &a.netmask.octets())
417                        == masked(&source.octets(), &a.netmask.octets())
418                }
419
420                (crate::net_utils::IfAddr::V6(a), IpAddr::V6(source)) => {
421                    masked(&a.ip.octets(), &a.netmask.octets())
422                        == masked(&source.octets(), &a.netmask.octets())
423                }
424                _ => false,
425            };
426
427            if matches_iface {
428                return true;
429            }
430        }
431
432        false
433    }
434}
435
436/// Resolve records matching the requested service name.
437/// Returns a Receiver that will yield successive responses.
438/// Once `timeout` passes, the Sender side of the receiver
439/// will disconnect and the channel will yield a RecvError.
440pub async fn resolve<S: AsRef<str>>(
441    service_name: S,
442    params: QueryParameters,
443) -> Result<Receiver<Response>> {
444    if params.base_repeat_interval.is_none() && params.timeout_after.is_none() {
445        return Err(Error::InvalidQueryParams);
446    }
447
448    let service_name = service_name.as_ref().to_string();
449    let deadline = params.timeout_after.map(|d| Instant::now() + d);
450
451    let data = make_query(&service_name, params.query_type)?;
452
453    let socket = create_socket().await?;
454    let addr = sockaddr(MULTICAST_ADDR, MULTICAST_PORT);
455
456    socket.send_to(&data, addr).await?;
457
458    let (tx, rx) = bounded(8);
459
460    smol::spawn(async move {
461        let mut retry_interval = params.base_repeat_interval;
462        let mut last_send = Instant::now();
463
464        loop {
465            let now = Instant::now();
466
467            if let Some(deadline) = deadline {
468                if now >= deadline {
469                    log::trace!("resolve loop completing because {now:?} >= {deadline:?}");
470                    break;
471                }
472            }
473
474            let recv_deadline = match retry_interval {
475                Some(retry) => match deadline {
476                    Some(overall) => (last_send + retry).min(overall),
477                    None => last_send + retry,
478                },
479                None => match deadline {
480                    Some(overall) => overall,
481                    None => {
482                        // Shouldn't be possible and we should
483                        // have caught this in the params validation
484                        // at entry to the function.
485                        log::error!("resolve loop aborting because params are invalid");
486                        return Err(Error::InvalidQueryParams);
487                    }
488                },
489            };
490
491            let mut buf = [0u8; 4096];
492
493            let recv = async {
494                let (len, addr) = socket.recv_from(&mut buf).await?;
495                Result::Ok(Some((len, addr)))
496            };
497
498            let timer = async {
499                let timer = smol::Timer::at(recv_deadline);
500                timer.await;
501                Result::Ok(None)
502            };
503
504            if let Some((len, addr)) = recv.or(timer).await? {
505                match Packet::parse(&buf[..len]) {
506                    Ok(dns) => {
507                        let response = Response::new(&dns);
508                        if !valid_source_address(addr) {
509                            log::trace!(
510                                "ignoring response {response:?} from {addr:?} which is not local",
511                            );
512                        } else {
513                            let matched = response
514                                .answers
515                                .iter()
516                                .any(|answer| answer.name == service_name);
517                            if matched {
518                                tx.send(response).await?;
519                            }
520                        }
521                    }
522                    Err(e) => {
523                        log::trace!("failed to parse packet: {e:?} received from {addr:?}");
524                    }
525                }
526            } else {
527                log::trace!("resolve loop read timeout; send another query");
528                // retry_interval exceeded, so send another query
529                let data = make_query(&service_name, params.query_type)?;
530                socket.send_to(&data, addr).await?;
531                last_send = Instant::now();
532
533                // And compute next interval
534                match retry_interval.take() {
535                    None => {
536                        // No retries; we're done!
537                        break;
538                    }
539                    Some(retry) => {
540                        let base = params.base_repeat_interval.unwrap();
541
542                        let retry = if params.exponential_backoff {
543                            retry + retry
544                        } else {
545                            retry + base
546                        };
547
548                        let retry = params
549                            .max_repeat_interval
550                            .map(|max| retry.min(max))
551                            .unwrap_or(retry);
552
553                        retry_interval.replace(retry);
554                    }
555                }
556                log::trace!("updated retry_interval is now {retry_interval:?}");
557            }
558        }
559
560        log::trace!("resolve loop completing OK");
561        Result::Ok(())
562    })
563    .detach();
564
565    Ok(rx)
566}