1use core::fmt;
2use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use core::time::Duration;
4
5use edge_nal::{UdpBind, UdpReceive, UdpSend};
6
7use super::*;
8
9pub const DEFAULT_SOCKET: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT);
10
11const PORT: u16 = 53;
12
13#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
14pub enum DnsIoError<E> {
15 DnsError(DnsError),
16 IoError(E),
17}
18
19pub type DnsIoErrorKind = DnsIoError<edge_nal::io::ErrorKind>;
20
21impl<E> DnsIoError<E>
22where
23 E: edge_nal::io::Error,
24{
25 pub fn erase(&self) -> DnsIoError<edge_nal::io::ErrorKind> {
26 match self {
27 Self::DnsError(e) => DnsIoError::DnsError(*e),
28 Self::IoError(e) => DnsIoError::IoError(e.kind()),
29 }
30 }
31}
32
33impl<E> From<DnsError> for DnsIoError<E> {
34 fn from(err: DnsError) -> Self {
35 Self::DnsError(err)
36 }
37}
38
39impl<E> fmt::Display for DnsIoError<E>
40where
41 E: fmt::Display,
42{
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 Self::DnsError(err) => write!(f, "DNS error: {}", err),
46 Self::IoError(err) => write!(f, "IO error: {}", err),
47 }
48 }
49}
50
51#[cfg(feature = "defmt")]
52impl<E> defmt::Format for DnsIoError<E>
53where
54 E: defmt::Format,
55{
56 fn format(&self, f: defmt::Formatter<'_>) {
57 match self {
58 Self::DnsError(err) => defmt::write!(f, "DNS error: {}", err),
59 Self::IoError(err) => defmt::write!(f, "IO error: {}", err),
60 }
61 }
62}
63
64#[cfg(feature = "std")]
65impl<E> std::error::Error for DnsIoError<E> where E: std::error::Error {}
66
67pub async fn run<S>(
68 stack: &S,
69 local_addr: SocketAddr,
70 tx_buf: &mut [u8],
71 rx_buf: &mut [u8],
72 ip: Ipv4Addr,
73 ttl: Duration,
74) -> Result<(), DnsIoError<S::Error>>
75where
76 S: UdpBind,
77{
78 let mut udp = stack.bind(local_addr).await.map_err(DnsIoError::IoError)?;
79
80 loop {
81 debug!("Waiting for data");
82
83 let (len, remote) = udp.receive(rx_buf).await.map_err(DnsIoError::IoError)?;
84
85 let request = &rx_buf[..len];
86
87 debug!("Received {} bytes from {}", request.len(), remote);
88
89 let len = match crate::reply(request, &ip.octets(), ttl, tx_buf) {
90 Ok(len) => len,
91 Err(err) => match err {
92 DnsError::InvalidMessage => {
93 warn!("Got invalid message from {}, skipping", remote);
94 continue;
95 }
96 other => Err(other)?,
97 },
98 };
99
100 udp.send(remote, &tx_buf[..len])
101 .await
102 .map_err(DnsIoError::IoError)?;
103
104 debug!("Sent {} bytes to {}", len, remote);
105 }
106}