edge_captive/
io.rs

1use core::fmt;
2use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use core::time::Duration;
4
5use edge_nal::{UdpBind, UdpReceive, UdpSend};
6
7use super::*;
8
9pub const DEFAULT_SOCKET: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT);
10
11const PORT: u16 = 53;
12
13#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
14pub enum DnsIoError<E> {
15    DnsError(DnsError),
16    IoError(E),
17}
18
19pub type DnsIoErrorKind = DnsIoError<edge_nal::io::ErrorKind>;
20
21impl<E> DnsIoError<E>
22where
23    E: edge_nal::io::Error,
24{
25    pub fn erase(&self) -> DnsIoError<edge_nal::io::ErrorKind> {
26        match self {
27            Self::DnsError(e) => DnsIoError::DnsError(*e),
28            Self::IoError(e) => DnsIoError::IoError(e.kind()),
29        }
30    }
31}
32
33impl<E> From<DnsError> for DnsIoError<E> {
34    fn from(err: DnsError) -> Self {
35        Self::DnsError(err)
36    }
37}
38
39impl<E> fmt::Display for DnsIoError<E>
40where
41    E: fmt::Display,
42{
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            Self::DnsError(err) => write!(f, "DNS error: {}", err),
46            Self::IoError(err) => write!(f, "IO error: {}", err),
47        }
48    }
49}
50
51#[cfg(feature = "defmt")]
52impl<E> defmt::Format for DnsIoError<E>
53where
54    E: defmt::Format,
55{
56    fn format(&self, f: defmt::Formatter<'_>) {
57        match self {
58            Self::DnsError(err) => defmt::write!(f, "DNS error: {}", err),
59            Self::IoError(err) => defmt::write!(f, "IO error: {}", err),
60        }
61    }
62}
63
64impl<E> core::error::Error for DnsIoError<E> where E: core::error::Error {}
65
66pub async fn run<S>(
67    stack: &S,
68    local_addr: SocketAddr,
69    tx_buf: &mut [u8],
70    rx_buf: &mut [u8],
71    ip: Ipv4Addr,
72    ttl: Duration,
73) -> Result<(), DnsIoError<S::Error>>
74where
75    S: UdpBind,
76{
77    let mut udp = stack.bind(local_addr).await.map_err(DnsIoError::IoError)?;
78
79    loop {
80        debug!("Waiting for data");
81
82        let (len, remote) = udp.receive(rx_buf).await.map_err(DnsIoError::IoError)?;
83
84        let request = &rx_buf[..len];
85
86        debug!("Received {} bytes from {}", request.len(), remote);
87
88        let len = match crate::reply(request, &ip.octets(), ttl, tx_buf) {
89            Ok(len) => len,
90            Err(err) => match err {
91                DnsError::InvalidMessage => {
92                    warn!("Got invalid message from {}, skipping", remote);
93                    continue;
94                }
95                other => Err(other)?,
96            },
97        };
98
99        udp.send(remote, &tx_buf[..len])
100            .await
101            .map_err(DnsIoError::IoError)?;
102
103        debug!("Sent {} bytes to {}", len, remote);
104    }
105}