edge_mdns/
io.rs

1use core::cell::RefCell;
2use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
3
4use core::pin::pin;
5
6use buf::BufferAccess;
7
8use embassy_futures::select::{select, Either};
9use embassy_sync::blocking_mutex;
10use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex};
11use embassy_sync::mutex::Mutex;
12use embassy_sync::signal::Signal;
13
14use edge_nal::{MulticastV4, MulticastV6, Readable, UdpBind, UdpReceive, UdpSend};
15
16use embassy_time::{Duration, Timer};
17
18use super::*;
19
20/// Socket address that binds to any IPv4-configured interface available
21pub const IPV4_DEFAULT_SOCKET: SocketAddr =
22    SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), PORT);
23
24/// Socket address that binds to any IPv6-configured interface available on single-stack
25/// implementations and to any configured interface available on dual-stack implementations.
26pub const IPV6_DEFAULT_SOCKET: SocketAddr =
27    SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT);
28
29/// A quick-and-dirty socket address that binds to any interface available on dual-stack
30/// implementations.
31/// Don't use in production code.
32pub const DEFAULT_SOCKET: SocketAddr = IPV6_DEFAULT_SOCKET;
33
34/// The IPv4 mDNS broadcast address, as per spec.
35pub const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
36/// The IPv6 mDNS broadcast address, as per spec.
37pub const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb);
38
39/// The mDNS port, as per spec.
40pub const PORT: u16 = 5353;
41
42/// A wrapper for mDNS and IO errors.
43#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
44pub enum MdnsIoError<E> {
45    MdnsError(MdnsError),
46    NoRecvBufError,
47    NoSendBufError,
48    IoError(E),
49}
50
51pub type MdnsIoErrorKind = MdnsIoError<edge_nal::io::ErrorKind>;
52
53impl<E> MdnsIoError<E>
54where
55    E: edge_nal::io::Error,
56{
57    pub fn erase(&self) -> MdnsIoError<edge_nal::io::ErrorKind> {
58        match self {
59            Self::MdnsError(e) => MdnsIoError::MdnsError(*e),
60            Self::NoRecvBufError => MdnsIoError::NoRecvBufError,
61            Self::NoSendBufError => MdnsIoError::NoSendBufError,
62            Self::IoError(e) => MdnsIoError::IoError(e.kind()),
63        }
64    }
65}
66
67impl<E> From<MdnsError> for MdnsIoError<E> {
68    fn from(err: MdnsError) -> Self {
69        Self::MdnsError(err)
70    }
71}
72
73impl<E> core::fmt::Display for MdnsIoError<E>
74where
75    E: core::fmt::Display,
76{
77    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78        match self {
79            Self::MdnsError(err) => write!(f, "mDNS error: {}", err),
80            Self::NoRecvBufError => write!(f, "No recv buf available"),
81            Self::NoSendBufError => write!(f, "No send buf available"),
82            Self::IoError(err) => write!(f, "IO error: {}", err),
83        }
84    }
85}
86
87#[cfg(feature = "defmt")]
88impl<E> defmt::Format for MdnsIoError<E>
89where
90    E: defmt::Format,
91{
92    fn format(&self, f: defmt::Formatter<'_>) {
93        match self {
94            Self::MdnsError(err) => defmt::write!(f, "mDNS error: {}", err),
95            Self::NoRecvBufError => defmt::write!(f, "No recv buf available"),
96            Self::NoSendBufError => defmt::write!(f, "No send buf available"),
97            Self::IoError(err) => defmt::write!(f, "IO error: {}", err),
98        }
99    }
100}
101
102impl<E> core::error::Error for MdnsIoError<E> where E: core::error::Error {}
103
104/// A utility method to bind a socket suitable for mDNS, by using the provided
105/// stack and address, and optionally joining the provided interfaces via multicast.
106///
107/// Note that mDNS is pointless without multicast, so at least one - or both - of the
108/// ipv4 and ipv6 interfaces need to be provided.
109pub async fn bind<S>(
110    stack: &S,
111    addr: SocketAddr,
112    ipv4_interface: Option<Ipv4Addr>,
113    ipv6_interface: Option<u32>,
114) -> Result<S::Socket<'_>, MdnsIoError<S::Error>>
115where
116    S: UdpBind,
117{
118    let mut socket = stack.bind(addr).await.map_err(MdnsIoError::IoError)?;
119
120    if let Some(v4) = ipv4_interface {
121        socket
122            .join_v4(IP_BROADCAST_ADDR, v4)
123            .await
124            .map_err(MdnsIoError::IoError)?;
125    }
126
127    if let Some(v6) = ipv6_interface {
128        socket
129            .join_v6(IPV6_BROADCAST_ADDR, v6)
130            .await
131            .map_err(MdnsIoError::IoError)?;
132    }
133
134    Ok(socket)
135}
136
137/// Represents an mDNS service that can respond to queries using the provided handler.
138///
139/// This structure is generic over the mDNS handler, the UDP receiver and sender, and the
140/// raw mutex type.
141///
142/// The handler is expected to be a type that implements the `MdnsHandler` trait, which
143/// allows it to handle mDNS queries and generate responses, as well as to handle mDNS
144/// responses to queries which we might have issues using the `query` method.
145pub struct Mdns<'a, R, S, RB, SB, C, M = NoopRawMutex>
146where
147    M: RawMutex,
148{
149    ipv4_interface: Option<Ipv4Addr>,
150    ipv6_interface: Option<u32>,
151    recv: Mutex<M, R>,
152    send: Mutex<M, S>,
153    recv_buf: RB,
154    send_buf: SB,
155    rand: blocking_mutex::Mutex<M, RefCell<C>>,
156    broadcast_signal: &'a Signal<M, ()>,
157    wait_readable: bool,
158}
159
160impl<'a, R, S, RB, SB, C, M> Mdns<'a, R, S, RB, SB, C, M>
161where
162    R: UdpReceive + Readable,
163    S: UdpSend<Error = R::Error>,
164    RB: BufferAccess<[u8]>,
165    SB: BufferAccess<[u8]>,
166    C: rand_core::RngCore,
167    M: RawMutex,
168{
169    /// Creates a new mDNS service with the provided handler, interfaces, and UDP receiver and sender.
170    #[allow(clippy::too_many_arguments)]
171    pub fn new(
172        ipv4_interface: Option<Ipv4Addr>,
173        ipv6_interface: Option<u32>,
174        recv: R,
175        send: S,
176        recv_buf: RB,
177        send_buf: SB,
178        rand: C,
179        broadcast_signal: &'a Signal<M, ()>,
180    ) -> Self {
181        Self {
182            ipv4_interface,
183            ipv6_interface,
184            recv: Mutex::new(recv),
185            send: Mutex::new(send),
186            recv_buf,
187            send_buf,
188            rand: blocking_mutex::Mutex::new(RefCell::new(rand)),
189            broadcast_signal,
190            wait_readable: false,
191        }
192    }
193
194    /// Sets whether the mDNS service should wait for the socket to be readable before reading.
195    ///
196    /// Setting this to `true` is only useful when the read buffer is shared with other tasks
197    pub fn wait_readable(&mut self, wait_readable: bool) {
198        self.wait_readable = wait_readable;
199    }
200
201    /// Runs the mDNS service, handling queries and responding to them, as well as broadcasting
202    /// mDNS answers and handling responses to our own queries.
203    ///
204    /// All of the handling logic is expected to be implemented by the provided handler:
205    /// - I.e. hanbdling responses to our own queries cannot happen, unless the supplied handler
206    ///   is capable of doing that (i.e. it is a `PeerMdnsHandler`, or a chain containing it, or similar).
207    /// - Ditto for handling queries coming from other peers - this can only happen if the handler
208    ///   is capable of doing that. I.e., it is a `HostMdnsHandler`, or a chain containing it, or similar.
209    pub async fn run<T>(&self, handler: T) -> Result<(), MdnsIoError<S::Error>>
210    where
211        T: MdnsHandler,
212    {
213        let handler = blocking_mutex::Mutex::<M, _>::new(RefCell::new(handler));
214
215        let mut broadcast = pin!(self.broadcast(&handler));
216        let mut respond = pin!(self.respond(&handler));
217
218        let result = select(&mut broadcast, &mut respond).await;
219
220        match result {
221            Either::First(result) => result,
222            Either::Second(result) => result,
223        }
224    }
225
226    /// Sends a multicast query with the provided payload.
227    /// It is assumed that the payload represents a valid mDNS query message.
228    ///
229    /// The payload is constructed via a closure, because this way we can provide to
230    /// the payload-constructing closure a ready-to-use `&mut [u8]` slice, where the
231    /// closure can arrange the mDNS query message (i.e. we avoid extra memory usage
232    /// by constructing the mDNS query directly in the `send_buf` buffer that was supplied
233    /// when the `Mdns` instance was constructed).
234    pub async fn query<Q>(&self, q: Q) -> Result<(), MdnsIoError<S::Error>>
235    where
236        Q: FnOnce(&mut [u8]) -> Result<usize, MdnsError>,
237    {
238        let mut send_buf = self
239            .send_buf
240            .get()
241            .await
242            .ok_or(MdnsIoError::NoSendBufError)?;
243
244        let mut send_guard = self.send.lock().await;
245        let send = &mut *send_guard;
246
247        let len = q(send_buf.as_mut())?;
248
249        if len > 0 {
250            self.broadcast_once(send, &send_buf.as_mut()[..len]).await?;
251        }
252
253        Ok(())
254    }
255
256    async fn broadcast<T>(
257        &self,
258        handler: &blocking_mutex::Mutex<M, RefCell<T>>,
259    ) -> Result<(), MdnsIoError<S::Error>>
260    where
261        T: MdnsHandler,
262    {
263        loop {
264            {
265                let mut send_buf = self
266                    .send_buf
267                    .get()
268                    .await
269                    .ok_or(MdnsIoError::NoSendBufError)?;
270
271                let mut send_guard = self.send.lock().await;
272                let send = &mut *send_guard;
273
274                let response = handler.lock(|handler| {
275                    handler
276                        .borrow_mut()
277                        .handle(MdnsRequest::None, send_buf.as_mut())
278                })?;
279
280                if let MdnsResponse::Reply { data, delay } = response {
281                    if delay {
282                        // TODO: Not ideal, as we hold the lock during the delay
283                        self.delay().await;
284                    }
285
286                    self.broadcast_once(send, data).await?;
287                }
288            }
289
290            self.broadcast_signal.wait().await;
291        }
292    }
293
294    async fn respond<T>(
295        &self,
296        handler: &blocking_mutex::Mutex<M, RefCell<T>>,
297    ) -> Result<(), MdnsIoError<S::Error>>
298    where
299        T: MdnsHandler,
300    {
301        let mut recv = self.recv.lock().await;
302
303        loop {
304            if self.wait_readable {
305                recv.readable().await.map_err(MdnsIoError::IoError)?;
306            }
307
308            {
309                let mut recv_buf = self
310                    .recv_buf
311                    .get()
312                    .await
313                    .ok_or(MdnsIoError::NoRecvBufError)?;
314
315                let (len, remote) = recv
316                    .receive(recv_buf.as_mut())
317                    .await
318                    .map_err(MdnsIoError::IoError)?;
319
320                debug!("Got mDNS query from {}", remote);
321
322                {
323                    let mut send_buf = self
324                        .send_buf
325                        .get()
326                        .await
327                        .ok_or(MdnsIoError::NoSendBufError)?;
328
329                    let mut send_guard = self.send.lock().await;
330                    let send = &mut *send_guard;
331
332                    let response = match handler.lock(|handler| {
333                        handler.borrow_mut().handle(
334                            MdnsRequest::Request {
335                                data: &recv_buf.as_mut()[..len],
336                                legacy: remote.port() != PORT,
337                                multicast: true, // TODO: Cannot determine this
338                            },
339                            send_buf.as_mut(),
340                        )
341                    }) {
342                        Ok(len) => len,
343                        Err(err) => match err {
344                            MdnsError::InvalidMessage => {
345                                warn!("Got invalid message from {}, skipping", remote);
346                                continue;
347                            }
348                            other => Err(other)?,
349                        },
350                    };
351
352                    if let MdnsResponse::Reply { data, delay } = response {
353                        if remote.port() != PORT {
354                            // Support one-shot legacy queries by replying privately
355                            // to the remote address, if the query was not sent from the mDNS port (as per the spec)
356
357                            debug!(
358                                "Replying privately to a one-shot mDNS query from {}",
359                                remote
360                            );
361
362                            if let Err(err) = send.send(remote, data).await {
363                                warn!(
364                                    "Failed to reply privately to {}: {:?}",
365                                    remote,
366                                    debug2format!(err)
367                                );
368                            }
369                        } else {
370                            // Otherwise, re-broadcast the response
371
372                            if delay {
373                                self.delay().await;
374                            }
375
376                            debug!("Re-broadcasting due to mDNS query from {}", remote);
377
378                            self.broadcast_once(send, data).await?;
379                        }
380                    }
381                }
382            }
383        }
384    }
385
386    async fn broadcast_once(&self, send: &mut S, data: &[u8]) -> Result<(), MdnsIoError<S::Error>> {
387        for remote_addr in
388            core::iter::once(SocketAddr::V4(SocketAddrV4::new(IP_BROADCAST_ADDR, PORT)))
389                .filter(|_| self.ipv4_interface.is_some())
390                .chain(
391                    self.ipv6_interface
392                        .map(|interface| {
393                            SocketAddr::V6(SocketAddrV6::new(
394                                IPV6_BROADCAST_ADDR,
395                                PORT,
396                                0,
397                                interface,
398                            ))
399                        })
400                        .into_iter(),
401                )
402        {
403            if !data.is_empty() {
404                debug!("Broadcasting mDNS entry to {}", remote_addr);
405
406                let fut = pin!(send.send(remote_addr, data));
407
408                fut.await.map_err(MdnsIoError::IoError)?;
409            }
410        }
411
412        Ok(())
413    }
414
415    async fn delay(&self) {
416        let mut b = [0];
417        self.rand.lock(|rand| rand.borrow_mut().fill_bytes(&mut b));
418
419        // Generate a delay between 20 and 120 ms, as per spec
420        let delay_ms = 20 + (b[0] as u32 * 100 / 256);
421
422        Timer::after(Duration::from_millis(delay_ms as _)).await;
423    }
424}