1use core::cell::RefCell;
2use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
3
4use core::pin::pin;
5
6use buf::BufferAccess;
7
8use embassy_futures::select::{select, Either};
9use embassy_sync::blocking_mutex;
10use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex};
11use embassy_sync::mutex::Mutex;
12use embassy_sync::signal::Signal;
13
14use edge_nal::{MulticastV4, MulticastV6, Readable, UdpBind, UdpReceive, UdpSend};
15
16use embassy_time::{Duration, Timer};
17
18use super::*;
19
20pub const IPV4_DEFAULT_SOCKET: SocketAddr =
22 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), PORT);
23
24pub const IPV6_DEFAULT_SOCKET: SocketAddr =
27 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT);
28
29pub const DEFAULT_SOCKET: SocketAddr = IPV6_DEFAULT_SOCKET;
33
34pub const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
36pub const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb);
38
39pub const PORT: u16 = 5353;
41
42#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
44pub enum MdnsIoError<E> {
45 MdnsError(MdnsError),
46 NoRecvBufError,
47 NoSendBufError,
48 IoError(E),
49}
50
51pub type MdnsIoErrorKind = MdnsIoError<edge_nal::io::ErrorKind>;
52
53impl<E> MdnsIoError<E>
54where
55 E: edge_nal::io::Error,
56{
57 pub fn erase(&self) -> MdnsIoError<edge_nal::io::ErrorKind> {
58 match self {
59 Self::MdnsError(e) => MdnsIoError::MdnsError(*e),
60 Self::NoRecvBufError => MdnsIoError::NoRecvBufError,
61 Self::NoSendBufError => MdnsIoError::NoSendBufError,
62 Self::IoError(e) => MdnsIoError::IoError(e.kind()),
63 }
64 }
65}
66
67impl<E> From<MdnsError> for MdnsIoError<E> {
68 fn from(err: MdnsError) -> Self {
69 Self::MdnsError(err)
70 }
71}
72
73impl<E> core::fmt::Display for MdnsIoError<E>
74where
75 E: core::fmt::Display,
76{
77 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78 match self {
79 Self::MdnsError(err) => write!(f, "mDNS error: {}", err),
80 Self::NoRecvBufError => write!(f, "No recv buf available"),
81 Self::NoSendBufError => write!(f, "No send buf available"),
82 Self::IoError(err) => write!(f, "IO error: {}", err),
83 }
84 }
85}
86
87#[cfg(feature = "defmt")]
88impl<E> defmt::Format for MdnsIoError<E>
89where
90 E: defmt::Format,
91{
92 fn format(&self, f: defmt::Formatter<'_>) {
93 match self {
94 Self::MdnsError(err) => defmt::write!(f, "mDNS error: {}", err),
95 Self::NoRecvBufError => defmt::write!(f, "No recv buf available"),
96 Self::NoSendBufError => defmt::write!(f, "No send buf available"),
97 Self::IoError(err) => defmt::write!(f, "IO error: {}", err),
98 }
99 }
100}
101
102impl<E> core::error::Error for MdnsIoError<E> where E: core::error::Error {}
103
104pub async fn bind<S>(
110 stack: &S,
111 addr: SocketAddr,
112 ipv4_interface: Option<Ipv4Addr>,
113 ipv6_interface: Option<u32>,
114) -> Result<S::Socket<'_>, MdnsIoError<S::Error>>
115where
116 S: UdpBind,
117{
118 let mut socket = stack.bind(addr).await.map_err(MdnsIoError::IoError)?;
119
120 if let Some(v4) = ipv4_interface {
121 socket
122 .join_v4(IP_BROADCAST_ADDR, v4)
123 .await
124 .map_err(MdnsIoError::IoError)?;
125 }
126
127 if let Some(v6) = ipv6_interface {
128 socket
129 .join_v6(IPV6_BROADCAST_ADDR, v6)
130 .await
131 .map_err(MdnsIoError::IoError)?;
132 }
133
134 Ok(socket)
135}
136
137pub struct Mdns<'a, R, S, RB, SB, C, M = NoopRawMutex>
146where
147 M: RawMutex,
148{
149 ipv4_interface: Option<Ipv4Addr>,
150 ipv6_interface: Option<u32>,
151 recv: Mutex<M, R>,
152 send: Mutex<M, S>,
153 recv_buf: RB,
154 send_buf: SB,
155 rand: blocking_mutex::Mutex<M, RefCell<C>>,
156 broadcast_signal: &'a Signal<M, ()>,
157 wait_readable: bool,
158}
159
160impl<'a, R, S, RB, SB, C, M> Mdns<'a, R, S, RB, SB, C, M>
161where
162 R: UdpReceive + Readable,
163 S: UdpSend<Error = R::Error>,
164 RB: BufferAccess<[u8]>,
165 SB: BufferAccess<[u8]>,
166 C: rand_core::RngCore,
167 M: RawMutex,
168{
169 #[allow(clippy::too_many_arguments)]
171 pub fn new(
172 ipv4_interface: Option<Ipv4Addr>,
173 ipv6_interface: Option<u32>,
174 recv: R,
175 send: S,
176 recv_buf: RB,
177 send_buf: SB,
178 rand: C,
179 broadcast_signal: &'a Signal<M, ()>,
180 ) -> Self {
181 Self {
182 ipv4_interface,
183 ipv6_interface,
184 recv: Mutex::new(recv),
185 send: Mutex::new(send),
186 recv_buf,
187 send_buf,
188 rand: blocking_mutex::Mutex::new(RefCell::new(rand)),
189 broadcast_signal,
190 wait_readable: false,
191 }
192 }
193
194 pub fn wait_readable(&mut self, wait_readable: bool) {
198 self.wait_readable = wait_readable;
199 }
200
201 pub async fn run<T>(&self, handler: T) -> Result<(), MdnsIoError<S::Error>>
210 where
211 T: MdnsHandler,
212 {
213 let handler = blocking_mutex::Mutex::<M, _>::new(RefCell::new(handler));
214
215 let mut broadcast = pin!(self.broadcast(&handler));
216 let mut respond = pin!(self.respond(&handler));
217
218 let result = select(&mut broadcast, &mut respond).await;
219
220 match result {
221 Either::First(result) => result,
222 Either::Second(result) => result,
223 }
224 }
225
226 pub async fn query<Q>(&self, q: Q) -> Result<(), MdnsIoError<S::Error>>
235 where
236 Q: FnOnce(&mut [u8]) -> Result<usize, MdnsError>,
237 {
238 let mut send_buf = self
239 .send_buf
240 .get()
241 .await
242 .ok_or(MdnsIoError::NoSendBufError)?;
243
244 let mut send_guard = self.send.lock().await;
245 let send = &mut *send_guard;
246
247 let len = q(send_buf.as_mut())?;
248
249 if len > 0 {
250 self.broadcast_once(send, &send_buf.as_mut()[..len]).await?;
251 }
252
253 Ok(())
254 }
255
256 async fn broadcast<T>(
257 &self,
258 handler: &blocking_mutex::Mutex<M, RefCell<T>>,
259 ) -> Result<(), MdnsIoError<S::Error>>
260 where
261 T: MdnsHandler,
262 {
263 loop {
264 {
265 let mut send_buf = self
266 .send_buf
267 .get()
268 .await
269 .ok_or(MdnsIoError::NoSendBufError)?;
270
271 let mut send_guard = self.send.lock().await;
272 let send = &mut *send_guard;
273
274 let response = handler.lock(|handler| {
275 handler
276 .borrow_mut()
277 .handle(MdnsRequest::None, send_buf.as_mut())
278 })?;
279
280 if let MdnsResponse::Reply { data, delay } = response {
281 if delay {
282 self.delay().await;
284 }
285
286 self.broadcast_once(send, data).await?;
287 }
288 }
289
290 self.broadcast_signal.wait().await;
291 }
292 }
293
294 async fn respond<T>(
295 &self,
296 handler: &blocking_mutex::Mutex<M, RefCell<T>>,
297 ) -> Result<(), MdnsIoError<S::Error>>
298 where
299 T: MdnsHandler,
300 {
301 let mut recv = self.recv.lock().await;
302
303 loop {
304 if self.wait_readable {
305 recv.readable().await.map_err(MdnsIoError::IoError)?;
306 }
307
308 {
309 let mut recv_buf = self
310 .recv_buf
311 .get()
312 .await
313 .ok_or(MdnsIoError::NoRecvBufError)?;
314
315 let (len, remote) = recv
316 .receive(recv_buf.as_mut())
317 .await
318 .map_err(MdnsIoError::IoError)?;
319
320 debug!("Got mDNS query from {}", remote);
321
322 {
323 let mut send_buf = self
324 .send_buf
325 .get()
326 .await
327 .ok_or(MdnsIoError::NoSendBufError)?;
328
329 let mut send_guard = self.send.lock().await;
330 let send = &mut *send_guard;
331
332 let response = match handler.lock(|handler| {
333 handler.borrow_mut().handle(
334 MdnsRequest::Request {
335 data: &recv_buf.as_mut()[..len],
336 legacy: remote.port() != PORT,
337 multicast: true, },
339 send_buf.as_mut(),
340 )
341 }) {
342 Ok(len) => len,
343 Err(err) => match err {
344 MdnsError::InvalidMessage => {
345 warn!("Got invalid message from {}, skipping", remote);
346 continue;
347 }
348 other => Err(other)?,
349 },
350 };
351
352 if let MdnsResponse::Reply { data, delay } = response {
353 if remote.port() != PORT {
354 debug!(
358 "Replying privately to a one-shot mDNS query from {}",
359 remote
360 );
361
362 if let Err(err) = send.send(remote, data).await {
363 warn!(
364 "Failed to reply privately to {}: {:?}",
365 remote,
366 debug2format!(err)
367 );
368 }
369 } else {
370 if delay {
373 self.delay().await;
374 }
375
376 debug!("Re-broadcasting due to mDNS query from {}", remote);
377
378 self.broadcast_once(send, data).await?;
379 }
380 }
381 }
382 }
383 }
384 }
385
386 async fn broadcast_once(&self, send: &mut S, data: &[u8]) -> Result<(), MdnsIoError<S::Error>> {
387 for remote_addr in
388 core::iter::once(SocketAddr::V4(SocketAddrV4::new(IP_BROADCAST_ADDR, PORT)))
389 .filter(|_| self.ipv4_interface.is_some())
390 .chain(
391 self.ipv6_interface
392 .map(|interface| {
393 SocketAddr::V6(SocketAddrV6::new(
394 IPV6_BROADCAST_ADDR,
395 PORT,
396 0,
397 interface,
398 ))
399 })
400 .into_iter(),
401 )
402 {
403 if !data.is_empty() {
404 debug!("Broadcasting mDNS entry to {}", remote_addr);
405
406 let fut = pin!(send.send(remote_addr, data));
407
408 fut.await.map_err(MdnsIoError::IoError)?;
409 }
410 }
411
412 Ok(())
413 }
414
415 async fn delay(&self) {
416 let mut b = [0];
417 self.rand.lock(|rand| rand.borrow_mut().fill_bytes(&mut b));
418
419 let delay_ms = 20 + (b[0] as u32 * 100 / 256);
421
422 Timer::after(Duration::from_millis(delay_ms as _)).await;
423 }
424}