1use std::io;
2use std::time::Instant;
3use std::{net::SocketAddr, sync::Arc};
4
5use mpsc::{UnboundedReceiver, UnboundedSender};
6
7use parking_lot::Mutex;
8
9use resolver::ResolveError;
10
11use snafu::{ResultExt, Snafu};
12
13use tokio::net::UdpSocket;
14use tokio::sync::mpsc;
15use tokio::sync::Semaphore;
16use tokio::task::JoinHandle;
17
18use crate::{
19 packet::PacketError, packet_buffer::BufferError, query_type::QueryType, record::DnsRecord,
20 resolver, BytePacketBuffer, Config, DnsPacket, ResultCode,
21};
22
23#[derive(Debug, Snafu)]
24pub enum ServerError {
25 InvalidBuffer { source: BufferError },
26 InvalidPacket { source: PacketError },
27
28 SocketRecvError { source: io::Error },
29 SocketSendError { source: io::Error },
30
31 ResolutionError { source: ResolveError },
32
33 JoinError,
34}
35
36type Result<T> = std::result::Result<T, ServerError>;
37
38type ResponseSender = UnboundedSender<(SocketAddr, Vec<u8>)>;
39type ResponseReceiver = UnboundedReceiver<(SocketAddr, Vec<u8>)>;
40
41struct ServerProcess {
42 pub join_handle: JoinHandle<()>,
43 pub tx_stop: mpsc::Sender<()>,
44}
45
46pub struct Server {
47 txt_challenge: Arc<Mutex<String>>,
48 handle: ServerProcess,
49}
50
51impl Server {
52 pub fn start(cfg: Config) -> Server {
53 let (tx_stop, rx) = mpsc::channel(1);
54
55 tracing::info!("starting DNS layer");
56
57 let txt_challenge = Arc::from(Mutex::from(String::default()));
58 let join_handle: JoinHandle<()> = {
59 let challenge_cloned = txt_challenge.clone();
60 tokio::task::spawn(async move {
61 Server::run(cfg, rx, challenge_cloned).await;
62 })
63 };
64
65 tracing::info!("DNS layer started");
66
67 Server {
68 handle: ServerProcess {
69 join_handle,
70 tx_stop,
71 },
72 txt_challenge,
73 }
74 }
75
76 pub async fn set_dns_challenge(&self, challenge: &str) -> Result<()> {
77 let mut guard = self.txt_challenge.lock();
78 *guard = challenge.to_string();
79 tracing::info!("set acme challenge: {}", &*guard);
80 Ok(())
81 }
82
83 async fn handle_query(
84 cfg: &Config,
85 req_buffer: &mut BytePacketBuffer,
86 challenge: Arc<Mutex<String>>,
87 ) -> Result<Vec<u8>> {
88 let mut request = DnsPacket::from_buffer(req_buffer).context(InvalidPacketSnafu)?;
91
92 let mut packet = DnsPacket::new();
94 packet.header.id = request.header.id;
95 packet.header.recursion_desired = false;
96 packet.header.recursion_available = false;
97 packet.header.response = true;
98
99 if let Some(question) = request.questions.pop() {
101 if question.qtype == QueryType::TXT && question.name.ends_with(&cfg.root_domain) {
104 let guard = challenge.lock();
105 let chall = &*guard.clone();
106 if !chall.is_empty() {
107 tracing::info!("query is an ACME challenge");
108 packet.questions.push(question);
110 let challenge_bytes = chall.as_bytes().to_vec();
111 packet.answers.push(DnsRecord::TXT {
112 domain_bytes: vec![192, 12],
113 ttl: 500,
114 data_len: challenge_bytes.len() as u16,
115 text: vec![challenge_bytes],
116 });
117 packet.header.authoritative_answer = true;
118 } else {
119 tracing::warn!("got ACME challenge but no challenge is set");
120 }
121 } else {
122 match resolver::lookup(&question.name, question.qtype, cfg).await {
123 Ok(Some(result)) => {
124 packet.questions.push(question);
125 packet.header.rescode = result.header.rescode;
126
127 for rec in result.answers {
128 tracing::debug!("answer: {:?}", rec);
129 packet.answers.push(rec);
130 }
131 for rec in result.authorities {
132 tracing::debug!("authority: {:?}", rec);
133 packet.authorities.push(rec);
134 }
135 for rec in result.resources {
136 tracing::debug!("resource: {:?}", rec);
137 packet.resources.push(rec);
138 }
139 }
140 Ok(None) => {
141 tracing::debug!("ignoring packet");
142 }
143 Err(e) => {
144 tracing::error!("servfail: {}", e);
145 packet.header.rescode = ResultCode::ServFail;
146 }
147 }
148 }
149 }
150 else {
154 tracing::warn!("FORMERR");
155 packet.header.rescode = ResultCode::FormErr;
156 }
157
158 let mut res_buffer = BytePacketBuffer::new();
160 packet.write(&mut res_buffer).context(InvalidPacketSnafu)?;
161
162 let len = res_buffer.pos();
163 let data = res_buffer.get_range(0, len).context(InvalidBufferSnafu)?;
164
165 tracing::trace!(
166 "sending raw packet of length {} as response: {:?}",
167 len,
168 data
169 );
170
171 Ok(data.to_vec())
173 }
174
175 async fn run(cfg: Config, mut stop_rx: mpsc::Receiver<()>, challenge: Arc<Mutex<String>>) {
176 let socket = match UdpSocket::bind(cfg.listen).await {
178 Ok(s) => Arc::from(s),
179 Err(e) => {
180 tracing::error!("cannot bind to socket: {}", e);
181 return;
182 }
183 };
184
185 let (req_tx, mut req_rx) = mpsc::unbounded_channel();
186 let (resp_tx, mut resp_rx): (ResponseSender, ResponseReceiver) = mpsc::unbounded_channel();
187
188 let (recv_stop_tx, mut recv_stop_rx) = mpsc::channel(1);
189 let (send_stop_tx, mut send_stop_rx) = mpsc::channel(1);
190
191 let socket_copy = socket.clone();
193 let recv_task_handle = tokio::task::spawn(async move {
194 loop {
195 let mut req_buffer = BytePacketBuffer::new();
196
197 let recv_future = socket_copy.recv_from(&mut req_buffer.buf);
198 let abort_future = recv_stop_rx.recv();
199
200 let should_abort = tokio::select! {
201 _ = abort_future => {
202 true
203 }
204 packet_result = recv_future => {
205 match packet_result {
206 Ok((_, addr)) => {
207 if let Err(e) = req_tx.send((addr, req_buffer)) {
208 tracing::warn!("failed to send request: {}", e);
209 }
210 }
211 Err(e) => {
212 tracing::warn!("packet recv error: {}", e);
213 }
214 };
215 false
216 }
217 };
218
219 if should_abort {
220 tracing::info!("quitting receive task");
221 break;
222 }
223 }
224 });
225
226 let socket_copy = socket.clone();
228 let send_task_handle = tokio::task::spawn(async move {
229 loop {
230 let recv_future = resp_rx.recv();
231 let abort_future = send_stop_rx.recv();
232
233 let should_abort = tokio::select! {
234 _ = abort_future => {
235 true
236 }
237 opt_response = recv_future => {
238 match opt_response {
239 Some((socket_addr, resp_data)) => {
240 if let Err(e) = socket_copy.send_to(resp_data.as_ref(), &socket_addr).await {
241 tracing::warn!("error sending on socket: {}", e);
242 }
243 false
244 }
245 None => {
246 true
247 }
248
249 }
250 }
251 };
252
253 if should_abort {
254 tracing::info!("quitting send task");
255 break;
256 }
257 }
258 });
259
260 let concurrent_query_sem = Arc::from(Semaphore::new(cfg.nb_of_concurrent_requests));
262 loop {
263 let abort_future = stop_rx.recv();
264 let req_future = req_rx.recv();
265
266 let should_abort = tokio::select! {
267 _ = abort_future => {
268 true
269 }
270 opt_request = req_future => {
271 match opt_request {
272 Some((socket_addr, mut req_buffer)) => {
273 let wait_start = Instant::now();
274 let concurrent_query_permit = concurrent_query_sem.clone().acquire_owned().await;
275 let cloned_cfg = cfg.clone();
276 let cloned_challenge = challenge.clone();
277 let cloned_tx = resp_tx.clone();
278
279 let wait_duration = Instant::now().duration_since(wait_start);
280 tracing::debug!("started processing packet from {} (waited {}ms)", socket_addr.ip(), wait_duration.as_millis());
281 let _permit_handle = concurrent_query_permit; match Server::handle_query(&cloned_cfg, &mut req_buffer, cloned_challenge).await {
283 Ok(data) => {
284 if let Err(e) = cloned_tx.send((socket_addr, data)) {
285 tracing::error!("failed to send reply to writer thread: {}", e);
286 }
287 }
288 Err(e) => {
289 tracing::error!("uncaught error: {}", e);
290 }
291 }
292
293 false
294 }
295 None => {
296 true
297 }
298 }
299 }
300 };
301
302 if should_abort {
303 tracing::info!("quitting main task");
304 break;
305 }
306 }
307
308 if let Err(e) = send_stop_tx.send(()).await {
309 tracing::error!("failed to stop writer task: {}", e);
310 }
311
312 if let Err(e) = recv_stop_tx.send(()).await {
313 tracing::error!("failed to stop reader task: {}", e);
314 }
315
316 if let Err(e) = tokio::try_join!(recv_task_handle, send_task_handle) {
317 tracing::error!("failed to join tasks: {}", e);
318 }
319 }
320
321 pub async fn stop(self) -> Result<()> {
322 tracing::info!("requesting to quit");
323 self.handle.tx_stop.send(()).await.unwrap();
324 self.handle
325 .join_handle
326 .await
327 .map_err(|_e| ServerError::JoinError)?;
328 tracing::info!("exited");
329 Ok(())
330 }
331}