edge_captive/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(clippy::large_futures)]
3#![allow(clippy::uninlined_format_args)]
4#![allow(unknown_lints)]
5
6use core::fmt::Display;
7use core::time::Duration;
8
9use domain::base::wire::Composer;
10use domain::dep::octseq::{OctetsBuilder, Truncate};
11
12use domain::{
13    base::{
14        iana::{Class, Opcode, Rcode},
15        message::ShortMessage,
16        message_builder::PushError,
17        record::Ttl,
18        wire::ParseError,
19        Record, Rtype,
20    },
21    dep::octseq::ShortBuf,
22    rdata::A,
23};
24
25// This mod MUST go first, so that the others see its macros.
26pub(crate) mod fmt;
27
28#[cfg(feature = "io")]
29pub mod io;
30
31#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
32pub enum DnsError {
33    ShortBuf,
34    InvalidMessage,
35}
36
37impl Display for DnsError {
38    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39        match self {
40            Self::ShortBuf => write!(f, "ShortBuf"),
41            Self::InvalidMessage => write!(f, "InvalidMessage"),
42        }
43    }
44}
45
46#[cfg(feature = "defmt")]
47impl defmt::Format for DnsError {
48    fn format(&self, f: defmt::Formatter<'_>) {
49        match self {
50            Self::ShortBuf => defmt::write!(f, "ShortBuf"),
51            Self::InvalidMessage => defmt::write!(f, "InvalidMessage"),
52        }
53    }
54}
55
56#[cfg(feature = "std")]
57impl std::error::Error for DnsError {}
58
59impl From<ShortBuf> for DnsError {
60    fn from(_: ShortBuf) -> Self {
61        Self::ShortBuf
62    }
63}
64
65impl From<PushError> for DnsError {
66    fn from(_: PushError) -> Self {
67        Self::ShortBuf
68    }
69}
70
71impl From<ShortMessage> for DnsError {
72    fn from(_: ShortMessage) -> Self {
73        Self::InvalidMessage
74    }
75}
76
77impl From<ParseError> for DnsError {
78    fn from(_: ParseError) -> Self {
79        Self::InvalidMessage
80    }
81}
82
83pub fn reply(
84    request: &[u8],
85    ip: &[u8; 4],
86    ttl: Duration,
87    buf: &mut [u8],
88) -> Result<usize, DnsError> {
89    let buf = Buf(buf, 0);
90
91    let message = domain::base::Message::from_octets(request)?;
92    debug!(
93        "Processing message with header: {:?}",
94        debug2format!(message.header())
95    );
96
97    let mut responseb = domain::base::MessageBuilder::from_target(buf)?;
98
99    let buf = if matches!(message.header().opcode(), Opcode::QUERY) {
100        debug!("Message is of type Query, processing all questions");
101
102        let mut answerb = responseb.start_answer(&message, Rcode::NOERROR)?;
103
104        for question in message.question() {
105            let question = question?;
106
107            if matches!(question.qtype(), Rtype::A) && matches!(question.qclass(), Class::IN) {
108                let record = Record::new(
109                    question.qname(),
110                    Class::IN,
111                    Ttl::from_duration_lossy(ttl),
112                    A::from_octets(ip[0], ip[1], ip[2], ip[3]),
113                );
114                debug!(
115                    "Answering {:?} with {:?}",
116                    debug2format!(question),
117                    debug2format!(record)
118                );
119                answerb.push(record)?;
120            } else {
121                debug!(
122                    "Question {:?} is not of type A, not answering",
123                    debug2format!(question)
124                );
125            }
126        }
127
128        answerb.finish()
129    } else {
130        debug!("Message is not of type Query, replying with NotImp");
131
132        let headerb = responseb.header_mut();
133
134        headerb.set_id(message.header().id());
135        headerb.set_opcode(message.header().opcode());
136        headerb.set_rd(message.header().rd());
137        headerb.set_rcode(domain::base::iana::Rcode::NOTIMP);
138
139        responseb.finish()
140    };
141
142    Ok(buf.1)
143}
144
145struct Buf<'a>(pub &'a mut [u8], pub usize);
146
147impl Composer for Buf<'_> {}
148
149impl OctetsBuilder for Buf<'_> {
150    type AppendError = ShortBuf;
151
152    fn append_slice(&mut self, slice: &[u8]) -> Result<(), Self::AppendError> {
153        if self.1 + slice.len() <= self.0.len() {
154            let end = self.1 + slice.len();
155            self.0[self.1..end].copy_from_slice(slice);
156            self.1 = end;
157
158            Ok(())
159        } else {
160            Err(ShortBuf)
161        }
162    }
163}
164
165impl Truncate for Buf<'_> {
166    fn truncate(&mut self, len: usize) {
167        self.1 = len;
168    }
169}
170
171impl AsMut<[u8]> for Buf<'_> {
172    fn as_mut(&mut self) -> &mut [u8] {
173        &mut self.0[..self.1]
174    }
175}
176
177impl AsRef<[u8]> for Buf<'_> {
178    fn as_ref(&self) -> &[u8] {
179        &self.0[..self.1]
180    }
181}