any_dns/
dns_socket.rs

1#![allow(unused)]
2use std::{
3    net::SocketAddr,
4    sync::Arc,
5    time::{Duration, Instant},
6};
7
8use simple_dns::{Packet, SimpleDnsError};
9use tokio::{net::UdpSocket, sync::oneshot};
10use tracing::Level;
11
12use crate::{
13    custom_handler::{CustomHandlerError, HandlerHolder},
14    pending_request::{PendingRequest, PendingRequestStore},
15    query_id_manager::QueryIdManager,
16};
17
18#[non_exhaustive]
19#[derive(thiserror::Error, Debug)]
20pub enum RequestError {
21    #[error("Dns packet parse error: {0}")]
22    Parse(#[from] SimpleDnsError),
23
24    #[error(transparent)]
25    IO(#[from] tokio::io::Error),
26
27    #[error("Timeout. No answer received from forward server.")]
28    Timeout(#[from] tokio::time::error::Elapsed),
29}
30
31/**
32 * DNS UDP socket
33 */
34#[derive(Debug, Clone)]
35pub struct DnsSocket {
36    socket: Arc<UdpSocket>,
37    pending: PendingRequestStore,
38    handler: HandlerHolder,
39    icann_fallback: SocketAddr,
40    id_manager: QueryIdManager,
41}
42
43impl DnsSocket {
44    /**
45     * Creates a new DNS socket
46     */
47    pub async fn new(
48        listening: SocketAddr,
49        icann_fallback: SocketAddr,
50        handler: HandlerHolder,
51    ) -> tokio::io::Result<Self> {
52        let socket = UdpSocket::bind(listening).await?;
53        Ok(Self {
54            socket: Arc::new(socket),
55            pending: PendingRequestStore::new(),
56            handler,
57            icann_fallback,
58            id_manager: QueryIdManager::new(),
59        })
60    }
61
62    /**
63     * Send message to address
64     */
65    pub async fn send_to(&self, buffer: &[u8], target: &SocketAddr) -> tokio::io::Result<usize> {
66        self.socket.send_to(buffer, target).await
67    }
68
69    /**
70     * Run receive loop
71     */
72    pub async fn receive_loop(&mut self) {
73        loop {
74            if let Err(err) = self.receive_datagram().await {
75                tracing::error!("Error while trying to receive {err}");
76            }
77        }
78    }
79
80    async fn receive_datagram(&mut self) -> Result<(), RequestError> {
81        let mut buffer = [0; 1024];
82        let (size, from) = self.socket.recv_from(&mut buffer).await?;
83        let mut data = buffer.to_vec();
84        if data.len() > size {
85            data.drain((size + 1)..data.len());
86        }
87        let packet = Packet::parse(&data)?;
88
89        let pending = self.pending.remove_by_forward_id(&packet.id(), &from);
90        if pending.is_some() {
91            tracing::trace!("Received response from forward server. Send back to client.");
92            let query = pending.unwrap();
93            query.tx.send(data).unwrap();
94            return Ok(());
95        };
96
97        let is_reply = packet.questions.len() == 0;
98        if is_reply {
99            let span = tracing::span!(Level::DEBUG, "", forward_id = packet.id());
100            let guard = span.enter();
101            tracing::debug!(
102                "Received reply without an associated query {:?}. Ignore.",
103                packet
104            );
105            return Ok(());
106        };
107
108        // New query
109        let mut socket = self.clone();
110        tokio::spawn(async move {
111            let start = Instant::now();
112            let query_packet = Packet::parse(&data).unwrap();
113            let span = tracing::span!(Level::INFO, "", query_id = query_packet.id());
114            let guard = span.enter();
115
116            let question = query_packet.questions.first();
117            if question.is_none() {
118                tracing::debug!(
119                    "Query with no associated a question {:?}. Ignore.",
120                    query_packet
121                );
122                return;
123            };
124            let question = question.unwrap();
125            tracing::trace!(
126                "Received new query {} {:?}",
127                question.qname,
128                question.qtype
129            );
130            let query_result = socket.on_query(&data, &from).await;
131                match query_result {
132                    Ok(_) => {
133                        tracing::debug!(
134                            "Processed query {} {:?} within {}ms",
135                            question.qname,
136                            question.qtype,
137                            start.elapsed().as_millis()
138                        );
139                    }
140                    Err(err) => {
141                        tracing::error!(
142                            "Failed to respond to query {} {:?}: {}",
143                            question.qname,
144                            question.qtype,
145                            err
146                        );
147                    }
148                };
149        });
150
151        Ok(())
152    }
153
154    /**
155     * New query received.
156     */
157    async fn on_query(&mut self, query: &Vec<u8>, from: &SocketAddr) -> Result<(), RequestError> {
158        match self.query(query).await {
159            Ok(reply) => {
160                self.send_to(&reply, from).await?;
161                Ok(())
162            }
163            Err(e) => Err(e),
164        }
165    }
166
167    /**
168     * Query this dns for data
169     */
170    pub async fn query(&mut self, query: &Vec<u8>) -> Result<Vec<u8>, RequestError> {
171        tracing::trace!("Try to resolve the query with the custom handler.");
172        let result = self.handler.call(query, self.clone()).await;
173        if let Ok(reply) = result {
174            tracing::trace!("Custom handler resolved the query.");
175            // All good. Handler handled the query
176            return Ok(reply);
177        };
178
179        match result.unwrap_err() {
180            CustomHandlerError::Unhandled => {
181                // Fallback to ICANN
182                tracing::trace!("Custom handler rejected the query.");
183                let reply = self.forward_to_icann(query, Duration::from_secs(5)).await?;
184                Ok(reply)
185            }
186            CustomHandlerError::IO(e) => Err(e),
187        }
188    }
189
190    /**
191     * Replaces the id of the dns packet.
192     */
193    fn replace_packet_id(&self, packet: &mut Vec<u8>, new_id: u16) {
194        let id_bytes = new_id.to_be_bytes();
195        std::mem::replace(&mut packet[0], id_bytes[0]);
196        std::mem::replace(&mut packet[1], id_bytes[1]);
197    }
198
199    /**
200     * Send dns request
201     */
202    pub async fn forward(
203        &mut self,
204        query: &Vec<u8>,
205        to: &SocketAddr,
206        timeout: Duration,
207    ) -> Result<Vec<u8>, RequestError> {
208        let packet = Packet::parse(&query)?;
209        let (tx, rx) = oneshot::channel::<Vec<u8>>();
210        let forward_id = self.id_manager.get_next(to);
211        let original_id = packet.id();
212        let span = tracing::span!(Level::DEBUG, "", forward_id = forward_id);
213        let guard = span.enter();
214        tracing::trace!("Fallback to forward server {to:?}.");
215        let request = PendingRequest {
216            original_query_id: original_id,
217            forward_query_id: forward_id,
218            sent_at: Instant::now(),
219            to: to.clone(),
220            tx,
221        };
222
223        let mut query = packet.build_bytes_vec_compressed()?;
224        self.replace_packet_id(&mut query, forward_id);
225
226        self.pending.insert(request);
227        self.send_to(&query, to).await?;
228
229        // Wait on response
230        let reply = tokio::time::timeout(timeout, rx).await;
231        if reply.is_err() {
232            // Timeout, remove pending again
233            tracing::trace!(
234                "Forwarded query original_id={original_id} forward_id={forward_id} timed out."
235            );
236            self.pending.remove_by_forward_id(&forward_id, &to);
237        };
238        let mut reply = reply?.unwrap();
239        self.replace_packet_id(&mut reply, original_id);
240        Ok(reply)
241    }
242
243    /**
244     * Forward query to icann
245     */
246    pub async fn forward_to_icann(
247        &mut self,
248        query: &Vec<u8>,
249        timeout: Duration,
250    ) -> Result<Vec<u8>, RequestError> {
251        self.forward(query, &self.icann_fallback.clone(), timeout)
252            .await
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use simple_dns::{Name, Packet, Question};
259    use std::{net::SocketAddr, time::Duration};
260
261    use crate::custom_handler::{EmptyHandler, HandlerHolder};
262
263    use super::DnsSocket;
264
265    #[tokio::test]
266    async fn run_processor() {
267        let listening: SocketAddr = "0.0.0.0:34254".parse().unwrap();
268        let icann_fallback: SocketAddr = "8.8.8.8:53".parse().unwrap();
269        let handler = HandlerHolder::new(EmptyHandler::new());
270        let mut socket = DnsSocket::new(listening, icann_fallback, handler)
271            .await
272            .unwrap();
273
274        let mut run_socket = socket.clone();
275        tokio::spawn(async move {
276            run_socket.receive_loop().await;
277        });
278
279        let mut query = Packet::new_query(0);
280        let qname = Name::new("google.ch").unwrap();
281        let qtype = simple_dns::QTYPE::TYPE(simple_dns::TYPE::A);
282        let qclass = simple_dns::QCLASS::CLASS(simple_dns::CLASS::IN);
283        let question = Question::new(qname, qtype, qclass, false);
284        query.questions = vec![question];
285
286        let query = query.build_bytes_vec_compressed().unwrap();
287        let to: SocketAddr = "8.8.8.8:53".parse().unwrap();
288        let result = socket
289            .forward(&query, &to, Duration::from_secs(5))
290            .await
291            .unwrap();
292        let reply = Packet::parse(&result).unwrap();
293        dbg!(reply);
294    }
295}