edge_captive/
io.rs

1use core::fmt;
2use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use core::time::Duration;
4
5use edge_nal::{UdpBind, UdpReceive, UdpSend};
6
7use super::*;
8
9pub const DEFAULT_SOCKET: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT);
10
11const PORT: u16 = 53;
12
13#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
14pub enum DnsIoError<E> {
15    DnsError(DnsError),
16    IoError(E),
17}
18
19pub type DnsIoErrorKind = DnsIoError<edge_nal::io::ErrorKind>;
20
21impl<E> DnsIoError<E>
22where
23    E: edge_nal::io::Error,
24{
25    pub fn erase(&self) -> DnsIoError<edge_nal::io::ErrorKind> {
26        match self {
27            Self::DnsError(e) => DnsIoError::DnsError(*e),
28            Self::IoError(e) => DnsIoError::IoError(e.kind()),
29        }
30    }
31}
32
33impl<E> From<DnsError> for DnsIoError<E> {
34    fn from(err: DnsError) -> Self {
35        Self::DnsError(err)
36    }
37}
38
39impl<E> fmt::Display for DnsIoError<E>
40where
41    E: fmt::Display,
42{
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            Self::DnsError(err) => write!(f, "DNS error: {}", err),
46            Self::IoError(err) => write!(f, "IO error: {}", err),
47        }
48    }
49}
50
51#[cfg(feature = "defmt")]
52impl<E> defmt::Format for DnsIoError<E>
53where
54    E: defmt::Format,
55{
56    fn format(&self, f: defmt::Formatter<'_>) {
57        match self {
58            Self::DnsError(err) => defmt::write!(f, "DNS error: {}", err),
59            Self::IoError(err) => defmt::write!(f, "IO error: {}", err),
60        }
61    }
62}
63
64#[cfg(feature = "std")]
65impl<E> std::error::Error for DnsIoError<E> where E: std::error::Error {}
66
67pub async fn run<S>(
68    stack: &S,
69    local_addr: SocketAddr,
70    tx_buf: &mut [u8],
71    rx_buf: &mut [u8],
72    ip: Ipv4Addr,
73    ttl: Duration,
74) -> Result<(), DnsIoError<S::Error>>
75where
76    S: UdpBind,
77{
78    let mut udp = stack.bind(local_addr).await.map_err(DnsIoError::IoError)?;
79
80    loop {
81        debug!("Waiting for data");
82
83        let (len, remote) = udp.receive(rx_buf).await.map_err(DnsIoError::IoError)?;
84
85        let request = &rx_buf[..len];
86
87        debug!("Received {} bytes from {}", request.len(), remote);
88
89        let len = match crate::reply(request, &ip.octets(), ttl, tx_buf) {
90            Ok(len) => len,
91            Err(err) => match err {
92                DnsError::InvalidMessage => {
93                    warn!("Got invalid message from {}, skipping", remote);
94                    continue;
95                }
96                other => Err(other)?,
97            },
98        };
99
100        udp.send(remote, &tx_buf[..len])
101            .await
102            .map_err(DnsIoError::IoError)?;
103
104        debug!("Sent {} bytes to {}", len, remote);
105    }
106}