1#![allow(unused)]
2use std::{
3 net::SocketAddr,
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use simple_dns::{Packet, SimpleDnsError};
9use tokio::{net::UdpSocket, sync::oneshot};
10use tracing::Level;
11
12use crate::{
13 custom_handler::{CustomHandlerError, HandlerHolder},
14 pending_request::{PendingRequest, PendingRequestStore},
15 query_id_manager::QueryIdManager,
16};
17
18#[non_exhaustive]
19#[derive(thiserror::Error, Debug)]
20pub enum RequestError {
21 #[error("Dns packet parse error: {0}")]
22 Parse(#[from] SimpleDnsError),
23
24 #[error(transparent)]
25 IO(#[from] tokio::io::Error),
26
27 #[error("Timeout. No answer received from forward server.")]
28 Timeout(#[from] tokio::time::error::Elapsed),
29}
30
31#[derive(Debug, Clone)]
35pub struct DnsSocket {
36 socket: Arc<UdpSocket>,
37 pending: PendingRequestStore,
38 handler: HandlerHolder,
39 icann_fallback: SocketAddr,
40 id_manager: QueryIdManager,
41}
42
43impl DnsSocket {
44 pub async fn new(
48 listening: SocketAddr,
49 icann_fallback: SocketAddr,
50 handler: HandlerHolder,
51 ) -> tokio::io::Result<Self> {
52 let socket = UdpSocket::bind(listening).await?;
53 Ok(Self {
54 socket: Arc::new(socket),
55 pending: PendingRequestStore::new(),
56 handler,
57 icann_fallback,
58 id_manager: QueryIdManager::new(),
59 })
60 }
61
62 pub async fn send_to(&self, buffer: &[u8], target: &SocketAddr) -> tokio::io::Result<usize> {
66 self.socket.send_to(buffer, target).await
67 }
68
69 pub async fn receive_loop(&mut self) {
73 loop {
74 if let Err(err) = self.receive_datagram().await {
75 tracing::error!("Error while trying to receive {err}");
76 }
77 }
78 }
79
80 async fn receive_datagram(&mut self) -> Result<(), RequestError> {
81 let mut buffer = [0; 1024];
82 let (size, from) = self.socket.recv_from(&mut buffer).await?;
83 let mut data = buffer.to_vec();
84 if data.len() > size {
85 data.drain((size + 1)..data.len());
86 }
87 let packet = Packet::parse(&data)?;
88
89 let pending = self.pending.remove_by_forward_id(&packet.id(), &from);
90 if pending.is_some() {
91 tracing::trace!("Received response from forward server. Send back to client.");
92 let query = pending.unwrap();
93 query.tx.send(data).unwrap();
94 return Ok(());
95 };
96
97 let is_reply = packet.questions.len() == 0;
98 if is_reply {
99 let span = tracing::span!(Level::DEBUG, "", forward_id = packet.id());
100 let guard = span.enter();
101 tracing::debug!(
102 "Received reply without an associated query {:?}. Ignore.",
103 packet
104 );
105 return Ok(());
106 };
107
108 let mut socket = self.clone();
110 tokio::spawn(async move {
111 let start = Instant::now();
112 let query_packet = Packet::parse(&data).unwrap();
113 let span = tracing::span!(Level::INFO, "", query_id = query_packet.id());
114 let guard = span.enter();
115
116 let question = query_packet.questions.first();
117 if question.is_none() {
118 tracing::debug!(
119 "Query with no associated a question {:?}. Ignore.",
120 query_packet
121 );
122 return;
123 };
124 let question = question.unwrap();
125 tracing::trace!(
126 "Received new query {} {:?}",
127 question.qname,
128 question.qtype
129 );
130 let query_result = socket.on_query(&data, &from).await;
131 match query_result {
132 Ok(_) => {
133 tracing::debug!(
134 "Processed query {} {:?} within {}ms",
135 question.qname,
136 question.qtype,
137 start.elapsed().as_millis()
138 );
139 }
140 Err(err) => {
141 tracing::error!(
142 "Failed to respond to query {} {:?}: {}",
143 question.qname,
144 question.qtype,
145 err
146 );
147 }
148 };
149 });
150
151 Ok(())
152 }
153
154 async fn on_query(&mut self, query: &Vec<u8>, from: &SocketAddr) -> Result<(), RequestError> {
158 match self.query(query).await {
159 Ok(reply) => {
160 self.send_to(&reply, from).await?;
161 Ok(())
162 }
163 Err(e) => Err(e),
164 }
165 }
166
167 pub async fn query(&mut self, query: &Vec<u8>) -> Result<Vec<u8>, RequestError> {
171 tracing::trace!("Try to resolve the query with the custom handler.");
172 let result = self.handler.call(query, self.clone()).await;
173 if let Ok(reply) = result {
174 tracing::trace!("Custom handler resolved the query.");
175 return Ok(reply);
177 };
178
179 match result.unwrap_err() {
180 CustomHandlerError::Unhandled => {
181 tracing::trace!("Custom handler rejected the query.");
183 let reply = self.forward_to_icann(query, Duration::from_secs(5)).await?;
184 Ok(reply)
185 }
186 CustomHandlerError::IO(e) => Err(e),
187 }
188 }
189
190 fn replace_packet_id(&self, packet: &mut Vec<u8>, new_id: u16) {
194 let id_bytes = new_id.to_be_bytes();
195 std::mem::replace(&mut packet[0], id_bytes[0]);
196 std::mem::replace(&mut packet[1], id_bytes[1]);
197 }
198
199 pub async fn forward(
203 &mut self,
204 query: &Vec<u8>,
205 to: &SocketAddr,
206 timeout: Duration,
207 ) -> Result<Vec<u8>, RequestError> {
208 let packet = Packet::parse(&query)?;
209 let (tx, rx) = oneshot::channel::<Vec<u8>>();
210 let forward_id = self.id_manager.get_next(to);
211 let original_id = packet.id();
212 let span = tracing::span!(Level::DEBUG, "", forward_id = forward_id);
213 let guard = span.enter();
214 tracing::trace!("Fallback to forward server {to:?}.");
215 let request = PendingRequest {
216 original_query_id: original_id,
217 forward_query_id: forward_id,
218 sent_at: Instant::now(),
219 to: to.clone(),
220 tx,
221 };
222
223 let mut query = packet.build_bytes_vec_compressed()?;
224 self.replace_packet_id(&mut query, forward_id);
225
226 self.pending.insert(request);
227 self.send_to(&query, to).await?;
228
229 let reply = tokio::time::timeout(timeout, rx).await;
231 if reply.is_err() {
232 tracing::trace!(
234 "Forwarded query original_id={original_id} forward_id={forward_id} timed out."
235 );
236 self.pending.remove_by_forward_id(&forward_id, &to);
237 };
238 let mut reply = reply?.unwrap();
239 self.replace_packet_id(&mut reply, original_id);
240 Ok(reply)
241 }
242
243 pub async fn forward_to_icann(
247 &mut self,
248 query: &Vec<u8>,
249 timeout: Duration,
250 ) -> Result<Vec<u8>, RequestError> {
251 self.forward(query, &self.icann_fallback.clone(), timeout)
252 .await
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use simple_dns::{Name, Packet, Question};
259 use std::{net::SocketAddr, time::Duration};
260
261 use crate::custom_handler::{EmptyHandler, HandlerHolder};
262
263 use super::DnsSocket;
264
265 #[tokio::test]
266 async fn run_processor() {
267 let listening: SocketAddr = "0.0.0.0:34254".parse().unwrap();
268 let icann_fallback: SocketAddr = "8.8.8.8:53".parse().unwrap();
269 let handler = HandlerHolder::new(EmptyHandler::new());
270 let mut socket = DnsSocket::new(listening, icann_fallback, handler)
271 .await
272 .unwrap();
273
274 let mut run_socket = socket.clone();
275 tokio::spawn(async move {
276 run_socket.receive_loop().await;
277 });
278
279 let mut query = Packet::new_query(0);
280 let qname = Name::new("google.ch").unwrap();
281 let qtype = simple_dns::QTYPE::TYPE(simple_dns::TYPE::A);
282 let qclass = simple_dns::QCLASS::CLASS(simple_dns::CLASS::IN);
283 let question = Question::new(qname, qtype, qclass, false);
284 query.questions = vec![question];
285
286 let query = query.build_bytes_vec_compressed().unwrap();
287 let to: SocketAddr = "8.8.8.8:53".parse().unwrap();
288 let result = socket
289 .forward(&query, &to, Duration::from_secs(5))
290 .await
291 .unwrap();
292 let reply = Packet::parse(&result).unwrap();
293 dbg!(reply);
294 }
295}