antidns/
server.rs

1use std::io;
2use std::time::Instant;
3use std::{net::SocketAddr, sync::Arc};
4
5use mpsc::{UnboundedReceiver, UnboundedSender};
6
7use parking_lot::Mutex;
8
9use resolver::ResolveError;
10
11use snafu::{ResultExt, Snafu};
12
13use tokio::net::UdpSocket;
14use tokio::sync::mpsc;
15use tokio::sync::Semaphore;
16use tokio::task::JoinHandle;
17
18use crate::{
19    packet::PacketError, packet_buffer::BufferError, query_type::QueryType, record::DnsRecord,
20    resolver, BytePacketBuffer, Config, DnsPacket, ResultCode,
21};
22
23#[derive(Debug, Snafu)]
24pub enum ServerError {
25    InvalidBuffer { source: BufferError },
26    InvalidPacket { source: PacketError },
27
28    SocketRecvError { source: io::Error },
29    SocketSendError { source: io::Error },
30
31    ResolutionError { source: ResolveError },
32
33    JoinError,
34}
35
36type Result<T> = std::result::Result<T, ServerError>;
37
38type ResponseSender = UnboundedSender<(SocketAddr, Vec<u8>)>;
39type ResponseReceiver = UnboundedReceiver<(SocketAddr, Vec<u8>)>;
40
41struct ServerProcess {
42    pub join_handle: JoinHandle<()>,
43    pub tx_stop: mpsc::Sender<()>,
44}
45
46pub struct Server {
47    txt_challenge: Arc<Mutex<String>>,
48    handle: ServerProcess,
49}
50
51impl Server {
52    pub fn start(cfg: Config) -> Server {
53        let (tx_stop, rx) = mpsc::channel(1);
54
55        tracing::info!("starting DNS layer");
56
57        let txt_challenge = Arc::from(Mutex::from(String::default()));
58        let join_handle: JoinHandle<()> = {
59            let challenge_cloned = txt_challenge.clone();
60            tokio::task::spawn(async move {
61                Server::run(cfg, rx, challenge_cloned).await;
62            })
63        };
64
65        tracing::info!("DNS layer started");
66
67        Server {
68            handle: ServerProcess {
69                join_handle,
70                tx_stop,
71            },
72            txt_challenge,
73        }
74    }
75
76    pub async fn set_dns_challenge(&self, challenge: &str) -> Result<()> {
77        let mut guard = self.txt_challenge.lock();
78        *guard = challenge.to_string();
79        tracing::info!("set acme challenge: {}", &*guard);
80        Ok(())
81    }
82
83    async fn handle_query(
84        cfg: &Config,
85        req_buffer: &mut BytePacketBuffer,
86        challenge: Arc<Mutex<String>>,
87    ) -> Result<Vec<u8>> {
88        // Next, `DnsPacket::from_buffer` is used to parse the raw bytes into
89        // a `DnsPacket`.
90        let mut request = DnsPacket::from_buffer(req_buffer).context(InvalidPacketSnafu)?;
91
92        // Create and initialize the response packet
93        let mut packet = DnsPacket::new();
94        packet.header.id = request.header.id;
95        packet.header.recursion_desired = false;
96        packet.header.recursion_available = false;
97        packet.header.response = true;
98
99        // In the normal case, exactly one question is present
100        if let Some(question) = request.questions.pop() {
101            // Handle the special case of a TXT query on our handled domain.
102
103            if question.qtype == QueryType::TXT && question.name.ends_with(&cfg.root_domain) {
104                let guard = challenge.lock();
105                let chall = &*guard.clone();
106                if !chall.is_empty() {
107                    tracing::info!("query is an ACME challenge");
108                    // Resolve the challenge without going to the resolver.
109                    packet.questions.push(question);
110                    let challenge_bytes = chall.as_bytes().to_vec();
111                    packet.answers.push(DnsRecord::TXT {
112                        domain_bytes: vec![192, 12],
113                        ttl: 500,
114                        data_len: challenge_bytes.len() as u16,
115                        text: vec![challenge_bytes],
116                    });
117                    packet.header.authoritative_answer = true;
118                } else {
119                    tracing::warn!("got ACME challenge but no challenge is set");
120                }
121            } else {
122                match resolver::lookup(&question.name, question.qtype, cfg).await {
123                    Ok(Some(result)) => {
124                        packet.questions.push(question);
125                        packet.header.rescode = result.header.rescode;
126
127                        for rec in result.answers {
128                            tracing::debug!("answer: {:?}", rec);
129                            packet.answers.push(rec);
130                        }
131                        for rec in result.authorities {
132                            tracing::debug!("authority: {:?}", rec);
133                            packet.authorities.push(rec);
134                        }
135                        for rec in result.resources {
136                            tracing::debug!("resource: {:?}", rec);
137                            packet.resources.push(rec);
138                        }
139                    }
140                    Ok(None) => {
141                        tracing::debug!("ignoring packet");
142                    }
143                    Err(e) => {
144                        tracing::error!("servfail: {}", e);
145                        packet.header.rescode = ResultCode::ServFail;
146                    }
147                }
148            }
149        }
150        // Being mindful of how unreliable input data from arbitrary senders can be, we
151        // need make sure that a question is actually present. If not, we return `FORMERR`
152        // to indicate that the sender made something wrong.
153        else {
154            tracing::warn!("FORMERR");
155            packet.header.rescode = ResultCode::FormErr;
156        }
157
158        // The only thing remaining is to encode our response and send it off!
159        let mut res_buffer = BytePacketBuffer::new();
160        packet.write(&mut res_buffer).context(InvalidPacketSnafu)?;
161
162        let len = res_buffer.pos();
163        let data = res_buffer.get_range(0, len).context(InvalidBufferSnafu)?;
164
165        tracing::trace!(
166            "sending raw packet of length {} as response: {:?}",
167            len,
168            data
169        );
170
171        // TODO: Instead, take the response buffer as argument as well.
172        Ok(data.to_vec())
173    }
174
175    async fn run(cfg: Config, mut stop_rx: mpsc::Receiver<()>, challenge: Arc<Mutex<String>>) {
176        // Bind to the UDP socket.
177        let socket = match UdpSocket::bind(cfg.listen).await {
178            Ok(s) => Arc::from(s),
179            Err(e) => {
180                tracing::error!("cannot bind to socket: {}", e);
181                return;
182            }
183        };
184
185        let (req_tx, mut req_rx) = mpsc::unbounded_channel();
186        let (resp_tx, mut resp_rx): (ResponseSender, ResponseReceiver) = mpsc::unbounded_channel();
187
188        let (recv_stop_tx, mut recv_stop_rx) = mpsc::channel(1);
189        let (send_stop_tx, mut send_stop_rx) = mpsc::channel(1);
190
191        // Socket read routine.
192        let socket_copy = socket.clone();
193        let recv_task_handle = tokio::task::spawn(async move {
194            loop {
195                let mut req_buffer = BytePacketBuffer::new();
196
197                let recv_future = socket_copy.recv_from(&mut req_buffer.buf);
198                let abort_future = recv_stop_rx.recv();
199
200                let should_abort = tokio::select! {
201                    _ = abort_future => {
202                        true
203                    }
204                    packet_result = recv_future => {
205                        match packet_result {
206                            Ok((_, addr)) => {
207                                if let Err(e) = req_tx.send((addr, req_buffer)) {
208                                    tracing::warn!("failed to send request: {}", e);
209                                }
210                            }
211                            Err(e) => {
212                                tracing::warn!("packet recv error: {}", e);
213                            }
214                        };
215                        false
216                    }
217                };
218
219                if should_abort {
220                    tracing::info!("quitting receive task");
221                    break;
222                }
223            }
224        });
225
226        // Socket write routine.
227        let socket_copy = socket.clone();
228        let send_task_handle = tokio::task::spawn(async move {
229            loop {
230                let recv_future = resp_rx.recv();
231                let abort_future = send_stop_rx.recv();
232
233                let should_abort = tokio::select! {
234                    _ = abort_future => {
235                        true
236                    }
237                    opt_response = recv_future => {
238                        match opt_response {
239                            Some((socket_addr, resp_data)) => {
240                                if let Err(e) = socket_copy.send_to(resp_data.as_ref(), &socket_addr).await {
241                                    tracing::warn!("error sending on socket: {}", e);
242                                }
243                                false
244                            }
245                            None => {
246                                true
247                            }
248
249                        }
250                    }
251                };
252
253                if should_abort {
254                    tracing::info!("quitting send task");
255                    break;
256                }
257            }
258        });
259
260        // DNS server loop (main task).
261        let concurrent_query_sem = Arc::from(Semaphore::new(cfg.nb_of_concurrent_requests));
262        loop {
263            let abort_future = stop_rx.recv();
264            let req_future = req_rx.recv();
265
266            let should_abort = tokio::select! {
267                _ = abort_future => {
268                    true
269                }
270                opt_request = req_future => {
271                    match opt_request {
272                        Some((socket_addr, mut req_buffer)) => {
273                            let wait_start = Instant::now();
274                            let concurrent_query_permit = concurrent_query_sem.clone().acquire_owned().await;
275                            let cloned_cfg = cfg.clone();
276                            let cloned_challenge = challenge.clone();
277                            let cloned_tx = resp_tx.clone();
278
279                            let wait_duration = Instant::now().duration_since(wait_start);
280                            tracing::debug!("started processing packet from {} (waited {}ms)", socket_addr.ip(), wait_duration.as_millis());
281                            let _permit_handle = concurrent_query_permit; // 0% useful, except to keep the permit alive until the end of the tokio task.
282                            match Server::handle_query(&cloned_cfg, &mut req_buffer, cloned_challenge).await {
283                                Ok(data) => {
284                                    if let Err(e) = cloned_tx.send((socket_addr, data)) {
285                                        tracing::error!("failed to send reply to writer thread: {}", e);
286                                    }
287                                }
288                                Err(e) => {
289                                    tracing::error!("uncaught error: {}", e);
290                                }
291                            }
292
293                            false
294                        }
295                        None => {
296                            true
297                        }
298                    }
299                }
300            };
301
302            if should_abort {
303                tracing::info!("quitting main task");
304                break;
305            }
306        }
307
308        if let Err(e) = send_stop_tx.send(()).await {
309            tracing::error!("failed to stop writer task: {}", e);
310        }
311
312        if let Err(e) = recv_stop_tx.send(()).await {
313            tracing::error!("failed to stop reader task: {}", e);
314        }
315
316        if let Err(e) = tokio::try_join!(recv_task_handle, send_task_handle) {
317            tracing::error!("failed to join tasks: {}", e);
318        }
319    }
320
321    pub async fn stop(self) -> Result<()> {
322        tracing::info!("requesting to quit");
323        self.handle.tx_stop.send(()).await.unwrap();
324        self.handle
325            .join_handle
326            .await
327            .map_err(|_e| ServerError::JoinError)?;
328        tracing::info!("exited");
329        Ok(())
330    }
331}