edge_captive/
io.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
use core::fmt;
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use core::time::Duration;

use edge_nal::{UdpBind, UdpReceive, UdpSend};

use log::*;

use super::*;

pub const DEFAULT_SOCKET: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT);

const PORT: u16 = 53;

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum DnsIoError<E> {
    DnsError(DnsError),
    IoError(E),
}

pub type DnsIoErrorKind = DnsIoError<edge_nal::io::ErrorKind>;

impl<E> DnsIoError<E>
where
    E: edge_nal::io::Error,
{
    pub fn erase(&self) -> DnsIoError<edge_nal::io::ErrorKind> {
        match self {
            Self::DnsError(e) => DnsIoError::DnsError(*e),
            Self::IoError(e) => DnsIoError::IoError(e.kind()),
        }
    }
}

impl<E> From<DnsError> for DnsIoError<E> {
    fn from(err: DnsError) -> Self {
        Self::DnsError(err)
    }
}

impl<E> fmt::Display for DnsIoError<E>
where
    E: fmt::Display,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::DnsError(err) => write!(f, "DNS error: {}", err),
            Self::IoError(err) => write!(f, "IO error: {}", err),
        }
    }
}

#[cfg(feature = "std")]
impl<E> std::error::Error for DnsIoError<E> where E: std::error::Error {}

pub async fn run<S>(
    stack: &S,
    local_addr: SocketAddr,
    tx_buf: &mut [u8],
    rx_buf: &mut [u8],
    ip: Ipv4Addr,
    ttl: Duration,
) -> Result<(), DnsIoError<S::Error>>
where
    S: UdpBind,
{
    let mut udp = stack.bind(local_addr).await.map_err(DnsIoError::IoError)?;

    loop {
        debug!("Waiting for data");

        let (len, remote) = udp.receive(rx_buf).await.map_err(DnsIoError::IoError)?;

        let request = &rx_buf[..len];

        debug!("Received {} bytes from {remote}", request.len());

        let len = match crate::reply(request, &ip.octets(), ttl, tx_buf) {
            Ok(len) => len,
            Err(err) => match err {
                DnsError::InvalidMessage => {
                    warn!("Got invalid message from {remote}, skipping");
                    continue;
                }
                other => Err(other)?,
            },
        };

        udp.send(remote, &tx_buf[..len])
            .await
            .map_err(DnsIoError::IoError)?;

        debug!("Sent {len} bytes to {remote}");
    }
}