1use crate::error::LocalError::{self, InvalidAddress, PermissionDenied, Unknown};
2use crate::upstream::HttpsClient;
3use std::{io, net::SocketAddr, sync::Arc};
4use tokio::net::UdpSocket;
5use tracing::{info, info_span, warn, Instrument};
6use trust_dns_proto::op::message::Message;
7
8#[derive(Debug)]
9pub struct UdpListener {
10 udp_socket: Arc<UdpSocket>,
11 https_client: HttpsClient,
12}
13
14impl UdpListener {
15 pub async fn new(
16 host: String,
17 port: u16,
18 https_client: HttpsClient,
19 ) -> Result<Self, LocalError> {
20 let socket_addr: SocketAddr = match format!("{}:{}", host, port).parse() {
21 Ok(socket_addr) => socket_addr,
22 Err(_) => return Err(InvalidAddress(host, port)),
23 };
24
25 let udp_socket = match UdpSocket::bind(socket_addr).await {
26 Ok(udp_socket) => Arc::new(udp_socket),
27 Err(error) => match error.kind() {
28 io::ErrorKind::PermissionDenied => return Err(PermissionDenied(host, port)),
29 _ => return Err(Unknown(host, port)),
30 },
31 };
32 info!("listened on {}:{}", host, port);
33
34 Ok(UdpListener {
35 udp_socket,
36 https_client,
37 })
38 }
39
40 pub async fn listen(&self) {
41 loop {
42 let mut buffer = [0; 4096];
43 let mut https_client = self.https_client.clone();
44 let udp_socket = self.udp_socket.clone();
45
46 let (_, addr) = match udp_socket.recv_from(&mut buffer).await {
47 Ok(udp_recv_from_result) => udp_recv_from_result,
48 Err(_) => {
49 warn!("failed to receive the datagram message");
50 continue;
51 }
52 };
53
54 tokio::spawn(
55 async move {
56 let request_message = match Message::from_vec(&buffer) {
57 Ok(request_message) => request_message,
58 Err(_) => {
59 warn!("failed to parse the request");
60 return;
61 }
62 };
63
64 for request_record in request_message.queries().iter() {
65 info!(
66 phase = "request",
67 "{} {} {}",
68 request_record.name(),
69 request_record.query_class(),
70 request_record.query_type(),
71 );
72 }
73
74 let response_message = match https_client.process(request_message).await {
75 Ok(response_message) => response_message,
76 Err(error) => {
77 warn!("{}", error);
78 return;
79 }
80 };
81
82 for response_record in response_message.answers().iter() {
83 info!(phase = "response", "{}", response_record);
84 }
85
86 let raw_response_message = match response_message.to_vec() {
87 Ok(raw_response_message) => raw_response_message,
88 Err(_) => {
89 warn!("failed to parse the response");
90 return;
91 }
92 };
93
94 if udp_socket
95 .send_to(&raw_response_message, &addr)
96 .await
97 .is_err()
98 {
99 warn!("failed to send the inbound response to the client");
100 };
101 }
102 .instrument(info_span!("listen", ?addr)),
103 );
104 }
105 }
106}