1use core::fmt;
2use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use core::time::Duration;
4
5use edge_nal::{UdpBind, UdpReceive, UdpSend};
6
7use super::*;
8
9pub const DEFAULT_SOCKET: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT);
10
11const PORT: u16 = 53;
12
13#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
14pub enum DnsIoError<E> {
15 DnsError(DnsError),
16 IoError(E),
17}
18
19pub type DnsIoErrorKind = DnsIoError<edge_nal::io::ErrorKind>;
20
21impl<E> DnsIoError<E>
22where
23 E: edge_nal::io::Error,
24{
25 pub fn erase(&self) -> DnsIoError<edge_nal::io::ErrorKind> {
26 match self {
27 Self::DnsError(e) => DnsIoError::DnsError(*e),
28 Self::IoError(e) => DnsIoError::IoError(e.kind()),
29 }
30 }
31}
32
33impl<E> From<DnsError> for DnsIoError<E> {
34 fn from(err: DnsError) -> Self {
35 Self::DnsError(err)
36 }
37}
38
39impl<E> fmt::Display for DnsIoError<E>
40where
41 E: fmt::Display,
42{
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 Self::DnsError(err) => write!(f, "DNS error: {}", err),
46 Self::IoError(err) => write!(f, "IO error: {}", err),
47 }
48 }
49}
50
51#[cfg(feature = "defmt")]
52impl<E> defmt::Format for DnsIoError<E>
53where
54 E: defmt::Format,
55{
56 fn format(&self, f: defmt::Formatter<'_>) {
57 match self {
58 Self::DnsError(err) => defmt::write!(f, "DNS error: {}", err),
59 Self::IoError(err) => defmt::write!(f, "IO error: {}", err),
60 }
61 }
62}
63
64impl<E> core::error::Error for DnsIoError<E> where E: core::error::Error {}
65
66pub async fn run<S>(
67 stack: &S,
68 local_addr: SocketAddr,
69 tx_buf: &mut [u8],
70 rx_buf: &mut [u8],
71 ip: Ipv4Addr,
72 ttl: Duration,
73) -> Result<(), DnsIoError<S::Error>>
74where
75 S: UdpBind,
76{
77 let mut udp = stack.bind(local_addr).await.map_err(DnsIoError::IoError)?;
78
79 loop {
80 debug!("Waiting for data");
81
82 let (len, remote) = udp.receive(rx_buf).await.map_err(DnsIoError::IoError)?;
83
84 let request = &rx_buf[..len];
85
86 debug!("Received {} bytes from {}", request.len(), remote);
87
88 let len = match crate::reply(request, &ip.octets(), ttl, tx_buf) {
89 Ok(len) => len,
90 Err(err) => match err {
91 DnsError::InvalidMessage => {
92 warn!("Got invalid message from {}, skipping", remote);
93 continue;
94 }
95 other => Err(other)?,
96 },
97 };
98
99 udp.send(remote, &tx_buf[..len])
100 .await
101 .map_err(DnsIoError::IoError)?;
102
103 debug!("Sent {} bytes to {}", len, remote);
104 }
105}