edge-net 0.2.0

no_std and no-alloc async implementations of various network protocols.
Documentation
use core::fmt;
use core::time::Duration;

use log::info;

use domain::{
    base::{
        iana::{Class, Opcode, Rcode},
        octets::*,
        Record, Rtype,
    },
    rdata::A,
};

#[cfg(feature = "std")]
pub use server::*;

#[derive(Debug)]
pub struct InnerError<T: fmt::Debug + fmt::Display>(T);

#[derive(Debug)]
pub enum DnsError {
    ShortBuf(InnerError<ShortBuf>),
    ParseError(InnerError<ParseError>),
}

impl fmt::Display for DnsError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            DnsError::ShortBuf(e) => e.0.fmt(f),
            DnsError::ParseError(e) => e.0.fmt(f),
        }
    }
}

impl From<ShortBuf> for DnsError {
    fn from(e: ShortBuf) -> Self {
        Self::ShortBuf(InnerError(e))
    }
}

impl From<ParseError> for DnsError {
    fn from(e: ParseError) -> Self {
        Self::ParseError(InnerError(e))
    }
}

#[cfg(feature = "std")]
impl std::error::Error for DnsError {}

pub fn process_dns_request(
    request: impl AsRef<[u8]>,
    ip: &[u8; 4],
    ttl: Duration,
) -> Result<impl AsRef<[u8]>, DnsError> {
    let request = request.as_ref();
    let response = Octets512::new();

    let message = domain::base::Message::from_octets(request)?;
    info!("Processing message with header: {:?}", message.header());

    let mut responseb = domain::base::MessageBuilder::from_target(response)?;

    let response = if matches!(message.header().opcode(), Opcode::Query) {
        info!("Message is of type Query, processing all questions");

        let mut answerb = responseb.start_answer(&message, Rcode::NoError)?;

        for question in message.question() {
            let question = question?;

            if matches!(question.qtype(), Rtype::A) {
                info!(
                    "Question {:?} is of type A, answering with IP {:?}, TTL {:?}",
                    question, ip, ttl
                );

                let record = Record::new(
                    question.qname(),
                    Class::In,
                    ttl.as_secs() as u32,
                    A::from_octets(ip[0], ip[1], ip[2], ip[3]),
                );
                info!("Answering question {:?} with {:?}", question, record);

                answerb.push(record)?;
            } else {
                info!("Question {:?} is not of type A, not answering", question);
            }
        }

        answerb.finish()
    } else {
        info!("Message is not of type Query, replying with NotImp");

        let headerb = responseb.header_mut();

        headerb.set_id(message.header().id());
        headerb.set_opcode(message.header().opcode());
        headerb.set_rd(message.header().rd());
        headerb.set_rcode(domain::base::iana::Rcode::NotImp);

        responseb.finish()
    };

    Ok(response)
}

#[cfg(feature = "std")]
mod server {
    use std::{
        io, mem,
        net::{Ipv4Addr, SocketAddrV4, UdpSocket},
        sync::{
            atomic::{AtomicBool, Ordering},
            Arc,
        },
        thread::{self, JoinHandle},
        time::Duration,
    };

    use log::*;

    #[derive(Clone, Debug)]
    pub struct DnsConf {
        pub bind_ip: Ipv4Addr,
        pub bind_port: u16,
        pub ip: Ipv4Addr,
        pub ttl: Duration,
    }

    impl DnsConf {
        pub fn new(ip: Ipv4Addr) -> Self {
            Self {
                bind_ip: Ipv4Addr::new(0, 0, 0, 0),
                bind_port: 53,
                ip,
                ttl: Duration::from_secs(60),
            }
        }
    }

    #[derive(Debug)]
    pub enum Status {
        Stopped,
        Started,
        Error(io::Error),
    }

    pub struct DnsServer {
        conf: DnsConf,
        status: Status,
        running: Arc<AtomicBool>,
        handle: Option<JoinHandle<Result<(), io::Error>>>,
    }

    impl DnsServer {
        pub fn new(conf: DnsConf) -> Self {
            Self {
                conf,
                status: Status::Stopped,
                running: Arc::new(AtomicBool::new(false)),
                handle: None,
            }
        }

        pub fn get_status(&mut self) -> &Status {
            self.cleanup();
            &self.status
        }

        pub fn start(&mut self) -> Result<(), io::Error> {
            if matches!(self.get_status(), Status::Started) {
                return Ok(());
            }

            let socket =
                UdpSocket::bind(SocketAddrV4::new(self.conf.bind_ip, self.conf.bind_port))?;

            socket.set_read_timeout(Some(Duration::from_secs(1)))?;

            let running = self.running.clone();
            let ip = self.conf.ip;
            let ttl = self.conf.ttl;

            self.running.store(true, Ordering::Relaxed);

            self.handle = Some(thread::spawn(move || {
                let result = Self::run(&running, ip, ttl, socket);

                running.store(false, Ordering::Relaxed);

                result
            }));

            Ok(())
        }

        pub fn stop(&mut self) -> Result<(), io::Error> {
            if matches!(self.get_status(), Status::Stopped) {
                return Ok(());
            }

            self.running.store(false, Ordering::Relaxed);
            self.cleanup();

            let mut status = Status::Stopped;
            mem::swap(&mut self.status, &mut status);

            match status {
                Status::Error(e) => Err(e),
                _ => Ok(()),
            }
        }

        fn cleanup(&mut self) {
            if !self.running.load(Ordering::Relaxed) && self.handle.is_some() {
                self.status = match mem::take(&mut self.handle).unwrap().join().unwrap() {
                    Ok(_) => Status::Stopped,
                    Err(e) => Status::Error(e),
                };
            }
        }

        fn run(
            running: &AtomicBool,
            ip: Ipv4Addr,
            ttl: Duration,
            socket: UdpSocket,
        ) -> Result<(), io::Error> {
            while running.load(Ordering::Relaxed) {
                info!("Waiting for data");

                let mut request_arr = [0_u8; 512];

                let (request_len, source_addr) = match socket.recv_from(&mut request_arr) {
                    Ok(value) => value,
                    Err(err) => match err.kind() {
                        std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => continue,
                        _ => return Err(err),
                    },
                };

                let request = &request_arr[..request_len];

                info!("Received {} bytes from {}", request.len(), source_addr);

                let response = super::process_dns_request(request, &ip.octets(), ttl)
                    .map_err(|_| io::ErrorKind::Other)?;

                socket.send_to(response.as_ref(), source_addr)?;

                info!("Sent {} bytes to {}", response.as_ref().len(), source_addr);
            }

            Ok(())
        }
    }
}