hickory_server/server/
mod.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! `Server` component for hosting a domain name servers operations.
9
10use std::{
11    io,
12    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
13    sync::Arc,
14    time::Duration,
15};
16
17use futures_util::{FutureExt, StreamExt};
18use hickory_proto::ProtoErrorKind;
19use ipnet::IpNet;
20#[cfg(feature = "__tls")]
21use rustls::{ServerConfig, server::ResolvesServerCert};
22#[cfg(feature = "__tls")]
23use tokio::time::timeout;
24use tokio::{net, task::JoinSet};
25use tokio_util::sync::CancellationToken;
26use tracing::{debug, info, warn};
27
28#[cfg(feature = "__tls")]
29use crate::proto::rustls::default_provider;
30use crate::{
31    access::AccessControl,
32    authority::{MessageRequest, MessageResponseBuilder, Queries},
33    proto::{
34        BufDnsStreamHandle, ProtoError,
35        op::{Header, LowerQuery, MessageType, ResponseCode},
36        rr::Record,
37        runtime::{TokioRuntimeProvider, iocompat::AsyncIoTokioAsStd},
38        serialize::binary::{BinDecodable, BinDecoder},
39        tcp::TcpStream,
40        udp::UdpStream,
41        xfer::{Protocol, SerialMessage},
42    },
43};
44
45#[cfg(feature = "__https")]
46mod h2_handler;
47#[cfg(feature = "__h3")]
48mod h3_handler;
49#[cfg(feature = "__quic")]
50mod quic_handler;
51mod request_handler;
52pub use request_handler::{Request, RequestHandler, RequestInfo, ResponseInfo};
53mod response_handler;
54pub use response_handler::{ResponseHandle, ResponseHandler};
55#[cfg(feature = "metrics")]
56mod metrics;
57#[cfg(feature = "metrics")]
58use metrics::ResponseHandlerMetrics;
59mod timeout_stream;
60pub use timeout_stream::TimeoutStream;
61
62// TODO, would be nice to have a Slab for buffers here...
63/// A Futures based implementation of a DNS server
64pub struct ServerFuture<T: RequestHandler> {
65    handler: Arc<T>,
66    join_set: JoinSet<Result<(), ProtoError>>,
67    shutdown_token: CancellationToken,
68    access: Arc<AccessControl>,
69}
70
71impl<T: RequestHandler> ServerFuture<T> {
72    /// Creates a new ServerFuture with the specified Handler.
73    pub fn new(handler: T) -> Self {
74        Self::with_access(handler, &[], &[])
75    }
76
77    /// Creates a new ServerFuture with the specified Handler and Access
78    pub fn with_access(handler: T, denied_networks: &[IpNet], allowed_networks: &[IpNet]) -> Self {
79        let mut access = AccessControl::default();
80        access.insert_deny(denied_networks);
81        access.insert_allow(allowed_networks);
82
83        Self {
84            handler: Arc::new(handler),
85            join_set: JoinSet::new(),
86            shutdown_token: CancellationToken::new(),
87            access: Arc::new(access),
88        }
89    }
90
91    /// Register a UDP socket. Should be bound before calling this function.
92    pub fn register_socket(&mut self, socket: net::UdpSocket) {
93        debug!("registering udp: {:?}", socket);
94
95        // create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where.
96        //   the address used is acquired from the inbound queries
97        let (mut stream, stream_handle) =
98            UdpStream::<TokioRuntimeProvider>::with_bound(socket, ([127, 255, 255, 254], 0).into());
99        let shutdown = self.shutdown_token.clone();
100        let handler = self.handler.clone();
101        let access = self.access.clone();
102
103        // this spawns a ForEach future which handles all the requests into a Handler.
104        self.join_set.spawn({
105            async move {
106                let mut inner_join_set = JoinSet::new();
107                loop {
108                    let message = tokio::select! {
109                        message = stream.next() => match message {
110                            None => break,
111                            Some(message) => message,
112                        },
113                        _ = shutdown.cancelled() => break,
114                    };
115
116                    let message = match message {
117                        Err(e) => {
118                            warn!("error receiving message on udp_socket: {}", e);
119                            if is_unrecoverable_socket_error(&e) {
120                                break;
121                            }
122                            continue;
123                        }
124                        Ok(message) => message,
125                    };
126
127                    let src_addr = message.addr();
128                    debug!("received udp request from: {}", src_addr);
129
130                    // verify that the src address is safe for responses
131                    if let Err(e) = sanitize_src_address(src_addr) {
132                        warn!(
133                            "address can not be responded to {src_addr}: {e}",
134                            src_addr = src_addr,
135                            e = e
136                        );
137                        continue;
138                    }
139
140                    let handler = handler.clone();
141                    let access = access.clone();
142                    let stream_handle = stream_handle.with_remote_addr(src_addr);
143
144                    inner_join_set.spawn(async move {
145                        handle_raw_request(message, Protocol::Udp, access, handler, stream_handle)
146                            .await;
147                    });
148
149                    reap_tasks(&mut inner_join_set);
150                }
151
152                if shutdown.is_cancelled() {
153                    Ok(())
154                } else {
155                    // TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
156                    Err(ProtoError::from("unexpected close of UDP socket"))
157                }
158            }
159        });
160    }
161
162    /// Register a UDP socket. Should be bound before calling this function.
163    pub fn register_socket_std(&mut self, socket: std::net::UdpSocket) -> io::Result<()> {
164        socket.set_nonblocking(true)?;
165        self.register_socket(net::UdpSocket::from_std(socket)?);
166        Ok(())
167    }
168
169    /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
170    ///  IPv4 address.
171    ///
172    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
173    ///  to not make this too low depending on use cases.
174    ///
175    /// # Arguments
176    /// * `listener` - a bound TCP socket
177    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
178    ///   requests within this time period will be closed. In the future it should be
179    ///   possible to create long-lived queries, but these should be from trusted sources
180    ///   only, this would require some type of whitelisting.
181    pub fn register_listener(&mut self, listener: net::TcpListener, timeout: Duration) {
182        debug!("register tcp: {:?}", listener);
183
184        let handler = self.handler.clone();
185        let access = self.access.clone();
186
187        // for each incoming request...
188        let shutdown = self.shutdown_token.clone();
189        self.join_set.spawn(async move {
190            let mut inner_join_set = JoinSet::new();
191            loop {
192                let (tcp_stream, src_addr) = tokio::select! {
193                    tcp_stream = listener.accept() => match tcp_stream {
194                        Ok((t, s)) => (t, s),
195                        Err(e) => {
196                            debug!("error receiving TCP tcp_stream error: {}", e);
197                            if is_unrecoverable_socket_error(&e) {
198                                break;
199                            }
200                            continue;
201                        },
202                    },
203                    _ = shutdown.cancelled() => {
204                        // A graceful shutdown was initiated. Break out of the loop.
205                        break;
206                    },
207                };
208
209                // verify that the src address is safe for responses
210                if let Err(e) = sanitize_src_address(src_addr) {
211                    warn!(
212                        "address can not be responded to {src_addr}: {e}",
213                        src_addr = src_addr,
214                        e = e
215                    );
216                    continue;
217                }
218
219                let handler = handler.clone();
220                let access = access.clone();
221
222                // and spawn to the io_loop
223                inner_join_set.spawn(async move {
224                    debug!("accepted request from: {}", src_addr);
225                    // take the created stream...
226                    let (buf_stream, stream_handle) =
227                        TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
228                    let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
229
230                    while let Some(message) = timeout_stream.next().await {
231                        let message = match message {
232                            Ok(message) => message,
233                            Err(e) => {
234                                debug!(
235                                    "error in TCP request_stream src: {} error: {}",
236                                    src_addr, e
237                                );
238                                // we're going to bail on this connection...
239                                return;
240                            }
241                        };
242
243                        // we don't spawn here to limit clients from getting too many resources
244                        handle_raw_request(
245                            message,
246                            Protocol::Tcp,
247                            access.clone(),
248                            handler.clone(),
249                            stream_handle.clone(),
250                        )
251                        .await;
252                    }
253                });
254
255                reap_tasks(&mut inner_join_set);
256            }
257
258            if shutdown.is_cancelled() {
259                Ok(())
260            } else {
261                Err(ProtoError::from("unexpected close of socket"))
262            }
263        });
264    }
265
266    /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
267    ///  IPv4 address.
268    ///
269    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
270    ///  to not make this too low depending on use cases.
271    ///
272    /// # Arguments
273    /// * `listener` - a bound TCP socket
274    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
275    ///   requests within this time period will be closed. In the future it should be
276    ///   possible to create long-lived queries, but these should be from trusted sources
277    ///   only, this would require some type of whitelisting.
278    pub fn register_listener_std(
279        &mut self,
280        listener: std::net::TcpListener,
281        timeout: Duration,
282    ) -> io::Result<()> {
283        listener.set_nonblocking(true)?;
284        self.register_listener(net::TcpListener::from_std(listener)?, timeout);
285        Ok(())
286    }
287
288    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
289    /// IPv6 or an IPv4 address.
290    ///
291    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
292    ///  to not make this too low depending on use cases.
293    ///
294    /// # Arguments
295    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
296    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
297    ///   requests within this time period will be closed. In the future it should be
298    ///   possible to create long-lived queries, but these should be from trusted sources
299    ///   only, this would require some type of whitelisting.
300    /// * `tls_config` - rustls server config
301    #[cfg(feature = "__tls")]
302    pub fn register_tls_listener_with_tls_config(
303        &mut self,
304        listener: net::TcpListener,
305        handshake_timeout: Duration,
306        tls_config: Arc<ServerConfig>,
307    ) -> io::Result<()> {
308        use crate::proto::rustls::tls_from_stream;
309        use tokio_rustls::TlsAcceptor;
310
311        let handler = self.handler.clone();
312        let access = self.access.clone();
313
314        debug!("registered tcp: {:?}", listener);
315
316        let tls_acceptor = TlsAcceptor::from(tls_config);
317
318        // for each incoming request...
319        let shutdown = self.shutdown_token.clone();
320        self.join_set.spawn(async move {
321            let mut inner_join_set = JoinSet::new();
322            loop {
323                let (tcp_stream, src_addr) = tokio::select! {
324                    tcp_stream = listener.accept() => match tcp_stream {
325                        Ok((t, s)) => (t, s),
326                        Err(e) => {
327                            debug!("error receiving TLS tcp_stream error: {}", e);
328                            if is_unrecoverable_socket_error(&e) {
329                                break;
330                            }
331                            continue;
332                        },
333                    },
334                    _ = shutdown.cancelled() => {
335                        // A graceful shutdown was initiated. Break out of the loop.
336                        break;
337                    },
338                };
339
340                // verify that the src address is safe for responses
341                if let Err(e) = sanitize_src_address(src_addr) {
342                    warn!(
343                        "address can not be responded to {src_addr}: {e}",
344                        src_addr = src_addr,
345                        e = e
346                    );
347                    continue;
348                }
349
350                let handler = handler.clone();
351                let access = access.clone();
352                let tls_acceptor = tls_acceptor.clone();
353
354                // kick out to a different task immediately, let them do the TLS handshake
355                inner_join_set.spawn(async move {
356                    debug!("starting TLS request from: {}", src_addr);
357
358                    // perform the TLS
359                    let Ok(tls_stream) =
360                        timeout(handshake_timeout, tls_acceptor.accept(tcp_stream)).await
361                    else {
362                        warn!("tls timeout expired during handshake");
363                        return;
364                    };
365
366                    let tls_stream = match tls_stream {
367                        Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
368                        Err(e) => {
369                            debug!("tls handshake src: {} error: {}", src_addr, e);
370                            return;
371                        }
372                    };
373                    debug!("accepted TLS request from: {}", src_addr);
374                    let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
375                    let mut timeout_stream = TimeoutStream::new(buf_stream, handshake_timeout);
376                    while let Some(message) = timeout_stream.next().await {
377                        let message = match message {
378                            Ok(message) => message,
379                            Err(e) => {
380                                debug!(
381                                    "error in TLS request_stream src: {:?} error: {}",
382                                    src_addr, e
383                                );
384
385                                // kill this connection
386                                return;
387                            }
388                        };
389
390                        handle_raw_request(
391                            message,
392                            Protocol::Tls,
393                            access.clone(),
394                            handler.clone(),
395                            stream_handle.clone(),
396                        )
397                        .await;
398                    }
399                });
400
401                reap_tasks(&mut inner_join_set);
402            }
403
404            if shutdown.is_cancelled() {
405                Ok(())
406            } else {
407                Err(ProtoError::from("unexpected close of socket"))
408            }
409        });
410
411        Ok(())
412    }
413
414    /// Register a TlsListener to the Server by providing a pkcs12 certificate and key. The TlsListener
415    /// should already be bound to either an IPv6 or an IPv4 address.
416    ///
417    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
418    ///  to not make this too low depending on use cases.
419    ///
420    /// # Arguments
421    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
422    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
423    ///   requests within this time period will be closed. In the future it should be
424    ///   possible to create long-lived queries, but these should be from trusted sources
425    ///   only, this would require some type of whitelisting.
426    /// * `server_cert_resolver` - resolver for the certificate and key used to announce to clients
427    #[cfg(feature = "__tls")]
428    pub fn register_tls_listener(
429        &mut self,
430        listener: net::TcpListener,
431        timeout: Duration,
432        server_cert_resolver: Arc<dyn ResolvesServerCert>,
433    ) -> io::Result<()> {
434        let config = tls_server_config(b"dot", server_cert_resolver)?;
435        Self::register_tls_listener_with_tls_config(self, listener, timeout, Arc::new(config))
436    }
437
438    /// Register a TcpListener for HTTPS (h2) to the Server for supporting DoH (DNS-over-HTTPS). The TcpListener should already be bound to either an
439    /// IPv6 or an IPv4 address.
440    ///
441    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
442    ///  to not make this too low depending on use cases.
443    ///
444    /// # Arguments
445    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
446    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
447    ///   requests within this time period will be closed. In the future it should be
448    ///   possible to create long-lived queries, but these should be from trusted sources
449    ///   only, this would require some type of whitelisting.
450    /// * `server_cert_resolver` - resolver for the certificate and key used to announce to clients
451    #[cfg(feature = "__https")]
452    pub fn register_https_listener(
453        &mut self,
454        listener: net::TcpListener,
455        // TODO: need to set a timeout between requests.
456        handshake_timeout: Duration,
457        server_cert_resolver: Arc<dyn ResolvesServerCert>,
458        dns_hostname: Option<String>,
459        http_endpoint: String,
460    ) -> io::Result<()> {
461        use crate::server::h2_handler::h2_handler;
462        use tokio_rustls::TlsAcceptor;
463
464        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
465        let http_endpoint: Arc<str> = Arc::from(http_endpoint);
466
467        let handler = self.handler.clone();
468        let access = self.access.clone();
469        debug!("registered https: {listener:?}");
470
471        let tls_acceptor =
472            TlsAcceptor::from(Arc::new(tls_server_config(b"h2", server_cert_resolver)?));
473
474        // for each incoming request...
475        let shutdown = self.shutdown_token.clone();
476        self.join_set.spawn(async move {
477            let mut inner_join_set = JoinSet::new();
478            loop {
479                let shutdown = shutdown.clone();
480                let (tcp_stream, src_addr) = tokio::select! {
481                    tcp_stream = listener.accept() => match tcp_stream {
482                        Ok((t, s)) => (t, s),
483                        Err(e) => {
484                            debug!("error receiving HTTPS tcp_stream error: {}", e);
485                            if is_unrecoverable_socket_error(&e) {
486                                break;
487                            }
488                            continue;
489                        },
490                    },
491                    _ = shutdown.cancelled() => {
492                        // A graceful shutdown was initiated. Break out of the loop.
493                        break;
494                    },
495                };
496
497                // verify that the src address is safe for responses
498                if let Err(e) = sanitize_src_address(src_addr) {
499                    warn!("address can not be responded to {src_addr}: {e}");
500                    continue;
501                }
502
503                let handler = handler.clone();
504                let access = access.clone();
505                let tls_acceptor = tls_acceptor.clone();
506                let dns_hostname = dns_hostname.clone();
507                let http_endpoint = http_endpoint.clone();
508
509                inner_join_set.spawn(async move {
510                    debug!("starting HTTPS request from: {src_addr}");
511
512                    // TODO: need to consider timeout of total connect...
513                    // take the created stream...
514                    let Ok(tls_stream) =
515                        timeout(handshake_timeout, tls_acceptor.accept(tcp_stream)).await
516                    else {
517                        warn!("https timeout expired during handshake");
518                        return;
519                    };
520
521                    let tls_stream = match tls_stream {
522                        Ok(tls_stream) => tls_stream,
523                        Err(e) => {
524                            debug!("https handshake src: {src_addr} error: {e}");
525                            return;
526                        }
527                    };
528                    debug!("accepted HTTPS request from: {src_addr}");
529
530                    h2_handler(
531                        access,
532                        handler,
533                        tls_stream,
534                        src_addr,
535                        dns_hostname,
536                        http_endpoint,
537                        shutdown.clone(),
538                    )
539                    .await;
540                });
541
542                reap_tasks(&mut inner_join_set);
543            }
544
545            if shutdown.is_cancelled() {
546                Ok(())
547            } else {
548                Err(ProtoError::from("unexpected close of socket"))
549            }
550        });
551
552        Ok(())
553    }
554
555    /// Register a UdpSocket to the Server for supporting DoQ (DNS-over-QUIC). The UdpSocket should already be bound to either an
556    /// IPv6 or an IPv4 address.
557    ///
558    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
559    ///  to not make this too low depending on use cases.
560    ///
561    /// # Arguments
562    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
563    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
564    ///   requests within this time period will be closed. In the future it should be
565    ///   possible to create long-lived queries, but these should be from trusted sources
566    ///   only, this would require some type of whitelisting.
567    /// * `server_cert_resolver` - resolver for certificate and key used to announce to clients
568    #[cfg(feature = "__quic")]
569    pub fn register_quic_listener(
570        &mut self,
571        socket: net::UdpSocket,
572        // TODO: need to set a timeout between requests.
573        _timeout: Duration,
574        server_cert_resolver: Arc<dyn ResolvesServerCert>,
575        dns_hostname: Option<String>,
576    ) -> io::Result<()> {
577        use crate::proto::quic::QuicServer;
578        use crate::server::quic_handler::quic_handler;
579
580        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
581
582        let handler = self.handler.clone();
583        let access = self.access.clone();
584
585        debug!("registered quic: {:?}", socket);
586        let mut server = QuicServer::with_socket(socket, server_cert_resolver)?;
587
588        // for each incoming request...
589        let shutdown = self.shutdown_token.clone();
590        self.join_set.spawn(async move {
591            let mut inner_join_set = JoinSet::new();
592            loop {
593                let shutdown = shutdown.clone();
594                let (streams, src_addr) = tokio::select! {
595                    result = server.next() => match result {
596                        Ok(Some(c)) => c,
597                        Ok(None) => continue,
598                        Err(e) => {
599                            debug!("error receiving quic connection: {e}");
600                            continue;
601                        }
602                    },
603                    _ = shutdown.cancelled() => {
604                        // A graceful shutdown was initiated. Break out of the loop.
605                        break;
606                    },
607                };
608
609                // verify that the src address is safe for responses
610                // TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing
611                if let Err(e) = sanitize_src_address(src_addr) {
612                    warn!(
613                        "address can not be responded to {src_addr}: {e}",
614                        src_addr = src_addr,
615                        e = e
616                    );
617                    continue;
618                }
619
620                let handler = handler.clone();
621                let access = access.clone();
622                let dns_hostname = dns_hostname.clone();
623
624                inner_join_set.spawn(async move {
625                    debug!("starting quic stream request from: {src_addr}");
626
627                    // TODO: need to consider timeout of total connect...
628                    let result = quic_handler(
629                        access,
630                        handler,
631                        streams,
632                        src_addr,
633                        dns_hostname,
634                        shutdown.clone(),
635                    )
636                    .await;
637
638                    if let Err(e) = result {
639                        warn!("quic stream processing failed from {src_addr}: {e}")
640                    }
641                });
642
643                reap_tasks(&mut inner_join_set);
644            }
645
646            Ok(())
647        });
648
649        Ok(())
650    }
651
652    /// Register a UdpSocket to the Server for supporting DoH3 (DNS-over-HTTP/3). The UdpSocket should already be bound to either an
653    /// IPv6 or an IPv4 address.
654    ///
655    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
656    ///  to not make this too low depending on use cases.
657    ///
658    /// # Arguments
659    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
660    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
661    ///   requests within this time period will be closed. In the future it should be
662    ///   possible to create long-lived queries, but these should be from trusted sources
663    ///   only, this would require some type of whitelisting.
664    /// * `server_cert_resolver` - resolver for certificate and key used to announce to clients
665    #[cfg(feature = "__h3")]
666    pub fn register_h3_listener(
667        &mut self,
668        socket: net::UdpSocket,
669        // TODO: need to set a timeout between requests.
670        _timeout: Duration,
671        server_cert_resolver: Arc<dyn ResolvesServerCert>,
672        dns_hostname: Option<String>,
673    ) -> io::Result<()> {
674        use crate::proto::h3::h3_server::H3Server;
675        use crate::server::h3_handler::h3_handler;
676
677        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
678
679        let handler = self.handler.clone();
680        let access = self.access.clone();
681
682        debug!("registered h3: {:?}", socket);
683        let mut server = H3Server::with_socket(socket, server_cert_resolver)?;
684
685        // for each incoming request...
686        let shutdown = self.shutdown_token.clone();
687        self.join_set.spawn(async move {
688            let mut inner_join_set = JoinSet::new();
689            loop {
690                let shutdown = shutdown.clone();
691                let (streams, src_addr) = tokio::select! {
692                    result = server.accept() => match result {
693                        Ok(Some(c)) => c,
694                        Ok(None) => continue,
695                        Err(e) => {
696                            debug!("error receiving h3 connection: {e}");
697                            continue;
698                        }
699                    },
700                    _ = shutdown.cancelled() => {
701                        // A graceful shutdown was initiated. Break out of the loop.
702                        break;
703                    },
704                };
705
706                // verify that the src address is safe for responses
707                // TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing
708                if let Err(e) = sanitize_src_address(src_addr) {
709                    warn!(
710                        "address can not be responded to {src_addr}: {e}",
711                        src_addr = src_addr,
712                        e = e
713                    );
714                    continue;
715                }
716
717                let handler = handler.clone();
718                let access = access.clone();
719                let dns_hostname = dns_hostname.clone();
720
721                inner_join_set.spawn(async move {
722                    debug!("starting h3 stream request from: {src_addr}");
723
724                    // TODO: need to consider timeout of total connect...
725                    let result = h3_handler(
726                        access,
727                        handler,
728                        streams,
729                        src_addr,
730                        dns_hostname,
731                        shutdown.clone(),
732                    )
733                    .await;
734
735                    if let Err(e) = result {
736                        warn!("h3 stream processing failed from {src_addr}: {e}")
737                    }
738                });
739
740                reap_tasks(&mut inner_join_set);
741            }
742
743            Ok(())
744        });
745
746        Ok(())
747    }
748
749    /// Triggers a graceful shutdown the server. All background tasks will stop accepting
750    /// new connections and the returned future will complete once all tasks have terminated.
751    pub async fn shutdown_gracefully(&mut self) -> Result<(), ProtoError> {
752        self.shutdown_token.cancel();
753
754        // Wait for the server to complete.
755        block_until_done(&mut self.join_set).await
756    }
757
758    /// Returns a reference to the [`CancellationToken`] used to gracefully shut down the server.
759    ///
760    /// Once cancellation is requested, all background tasks will stop accepting new connections,
761    /// and `block_until_done()` will complete once all tasks have terminated.
762    pub fn shutdown_token(&self) -> &CancellationToken {
763        &self.shutdown_token
764    }
765
766    /// This will run until all background tasks complete. If one or more tasks return an error,
767    /// one will be chosen as the returned error for this future.
768    pub async fn block_until_done(&mut self) -> Result<(), ProtoError> {
769        block_until_done(&mut self.join_set).await
770    }
771}
772
773async fn block_until_done(
774    join_set: &mut JoinSet<Result<(), ProtoError>>,
775) -> Result<(), ProtoError> {
776    if join_set.is_empty() {
777        warn!("block_until_done called with no pending tasks");
778        return Ok(());
779    }
780
781    // Now wait for all of the tasks to complete.
782    let mut out = Ok(());
783    while let Some(join_result) = join_set.join_next().await {
784        match join_result {
785            Ok(result) => {
786                match result {
787                    Ok(_) => (),
788                    Err(e) => {
789                        // Save the last error.
790                        out = Err(e);
791                    }
792                }
793            }
794            Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
795        }
796    }
797    out
798}
799
800/// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
801fn reap_tasks(join_set: &mut JoinSet<()>) {
802    while FutureExt::now_or_never(join_set.join_next())
803        .flatten()
804        .is_some()
805    {}
806}
807
808pub(crate) async fn handle_raw_request<T: RequestHandler>(
809    message: SerialMessage,
810    protocol: Protocol,
811    access: Arc<AccessControl>,
812    request_handler: Arc<T>,
813    response_handler: BufDnsStreamHandle,
814) {
815    let src_addr = message.addr();
816    let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
817
818    handle_request(
819        message.bytes(),
820        src_addr,
821        protocol,
822        access,
823        request_handler,
824        response_handler,
825    )
826    .await;
827}
828
829#[cfg(feature = "__tls")]
830fn tls_server_config(
831    protocol: &[u8],
832    server_cert_resolver: Arc<dyn ResolvesServerCert>,
833) -> io::Result<ServerConfig> {
834    let mut config = ServerConfig::builder_with_provider(Arc::new(default_provider()))
835        .with_safe_default_protocol_versions()
836        .map_err(|e| {
837            io::Error::new(
838                io::ErrorKind::Other,
839                format!("error creating TLS acceptor: {e}"),
840            )
841        })?
842        .with_no_client_auth()
843        .with_cert_resolver(server_cert_resolver);
844
845    config.alpn_protocols = vec![protocol.to_vec()];
846    Ok(config)
847}
848
849#[derive(Clone)]
850struct ReportingResponseHandler<R: ResponseHandler> {
851    request_header: Header,
852    queries: Vec<LowerQuery>,
853    protocol: Protocol,
854    src_addr: SocketAddr,
855    handler: R,
856    #[cfg(feature = "metrics")]
857    metrics: ResponseHandlerMetrics,
858}
859
860#[async_trait::async_trait]
861#[allow(clippy::uninlined_format_args)]
862impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
863    async fn send_response<'a>(
864        &mut self,
865        response: crate::authority::MessageResponse<
866            '_,
867            'a,
868            impl Iterator<Item = &'a Record> + Send + 'a,
869            impl Iterator<Item = &'a Record> + Send + 'a,
870            impl Iterator<Item = &'a Record> + Send + 'a,
871            impl Iterator<Item = &'a Record> + Send + 'a,
872        >,
873    ) -> io::Result<ResponseInfo> {
874        let response_info = self.handler.send_response(response).await?;
875
876        let id = self.request_header.id();
877        let rid = response_info.id();
878        if id != rid {
879            warn!("request id:{id} does not match response id:{rid}");
880            debug_assert_eq!(id, rid, "request id and response id should match");
881        }
882
883        let rflags = response_info.flags();
884        let answer_count = response_info.answer_count();
885        let authority_count = response_info.name_server_count();
886        let additional_count = response_info.additional_count();
887        let response_code = response_info.response_code();
888
889        info!(
890            "request:{id} src:{proto}://{addr}#{port} {op} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
891            id = rid,
892            proto = self.protocol,
893            addr = self.src_addr.ip(),
894            port = self.src_addr.port(),
895            op = self.request_header.op_code(),
896            qflags = self.request_header.flags(),
897            code = response_code,
898            answers = answer_count,
899            authorities = authority_count,
900            additionals = additional_count,
901            rflags = rflags
902        );
903        for query in self.queries.iter() {
904            info!(
905                "query:{query}:{qtype}:{class}",
906                query = query.name(),
907                qtype = query.query_type(),
908                class = query.query_class()
909            );
910        }
911
912        #[cfg(feature = "metrics")]
913        self.metrics.update(self, &response_info);
914
915        Ok(response_info)
916    }
917}
918
919#[cfg(feature = "metrics")]
920impl ResponseHandlerMetrics {
921    fn update(
922        &self,
923        response_handler: &ReportingResponseHandler<impl ResponseHandler>,
924        response_info: &ResponseInfo,
925    ) {
926        self.proto.increment(&response_handler.protocol);
927        self.operation
928            .increment(&response_handler.request_header.op_code());
929        self.request_flags
930            .increment(&response_handler.request_header);
931
932        self.response_code.increment(&response_info.response_code());
933        self.response_flags.increment(response_info);
934    }
935}
936
937pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
938    message_bytes: &[u8],
939    src_addr: SocketAddr,
940    protocol: Protocol,
941    access: Arc<AccessControl>,
942    request_handler: Arc<T>,
943    response_handler: R,
944) {
945    let mut decoder = BinDecoder::new(message_bytes);
946
947    // method to handle the request
948    let inner_handle_request = |message: MessageRequest, response_handler: R| async move {
949        if message.message_type() == MessageType::Response {
950            // Don't process response messages to avoid DoS attacks from reflection.
951            return;
952        }
953
954        let id = message.id();
955        let qflags = message.header().flags();
956        let qop_code = message.op_code();
957        let message_type = message.message_type();
958        let is_dnssec = message.edns().is_some_and(|edns| edns.flags().dnssec_ok);
959
960        let request = Request::new(message, src_addr, protocol);
961
962        debug!(
963            "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op} qflags:{qflags}",
964            id = id,
965            proto = protocol,
966            addr = src_addr.ip(),
967            port = src_addr.port(),
968            message_type = message_type,
969            is_dnssec = is_dnssec,
970            op = qop_code,
971            qflags = qflags
972        );
973        for query in request.queries().iter() {
974            debug!(
975                "query:{query}:{qtype}:{class}",
976                query = query.name(),
977                qtype = query.query_type(),
978                class = query.query_class()
979            );
980        }
981
982        // The reporter will handle making sure to log the result of the request
983        let queries = request.queries().to_vec();
984        let reporter = ReportingResponseHandler {
985            request_header: *request.header(),
986            queries,
987            protocol,
988            src_addr,
989            handler: response_handler,
990            #[cfg(feature = "metrics")]
991            metrics: ResponseHandlerMetrics::default(),
992        };
993
994        request_handler.handle_request(&request, reporter).await;
995    };
996
997    // method to return an error to the client
998    let error_response_handler = |protocol: Protocol,
999                                  src_addr: SocketAddr,
1000                                  header: Header,
1001                                  queries: Queries,
1002                                  response_code: ResponseCode,
1003                                  error: Box<ProtoError>,
1004                                  response_handler: R| async move {
1005        // debug for more info on why the message parsing failed
1006        debug!(
1007            "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:{response_code}:{error}",
1008            id = header.id(),
1009            proto = protocol,
1010            addr = src_addr.ip(),
1011            port = src_addr.port(),
1012            message_type = header.message_type(),
1013            op = header.op_code(),
1014            response_code = response_code,
1015            error = error,
1016        );
1017
1018        // The reporter will handle making sure to log the result of the request
1019        let mut reporter = ReportingResponseHandler {
1020            request_header: header,
1021            queries: queries.queries().to_vec(),
1022            protocol,
1023            src_addr,
1024            handler: response_handler,
1025            #[cfg(feature = "metrics")]
1026            metrics: ResponseHandlerMetrics::default(),
1027        };
1028
1029        let response = MessageResponseBuilder::new(&queries);
1030        let result = reporter
1031            .send_response(response.error_msg(&header, response_code))
1032            .await;
1033
1034        if let Err(e) = result {
1035            warn!("failed to return FormError to client: {}", e);
1036        }
1037    };
1038
1039    if !access.allow(src_addr.ip()) {
1040        info!(
1041            "request:Refused src:{proto}://{addr}#{port}",
1042            proto = protocol,
1043            addr = src_addr.ip(),
1044            port = src_addr.port(),
1045        );
1046
1047        let Ok(header) = Header::read(&mut decoder) else {
1048            // This will only fail if the message is less than twelve bytes long. Such messages are
1049            // definitely not valid DNS queries, so it should be fine to return without sending a
1050            // response.
1051            return;
1052        };
1053        let queries = match Queries::read(&mut decoder, header.query_count() as usize) {
1054            Ok(queries) => queries,
1055            Err(_) => Queries::empty(),
1056        };
1057        error_response_handler(
1058            protocol,
1059            src_addr,
1060            header,
1061            queries,
1062            ResponseCode::Refused,
1063            Box::new(ProtoErrorKind::RequestRefused.into()),
1064            response_handler,
1065        )
1066        .await;
1067
1068        return;
1069    }
1070
1071    // Attempt to decode the message
1072    match MessageRequest::read(&mut decoder) {
1073        Ok(message) => {
1074            inner_handle_request(message, response_handler).await;
1075        }
1076        Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
1077            // We failed to parse the request due to some issue in the message, but the header is available, so we can respond
1078            let (header, error) = kind
1079                .into_form_error()
1080                .expect("as form_error already confirmed this is a FormError");
1081            let queries = Queries::empty();
1082
1083            error_response_handler(
1084                protocol,
1085                src_addr,
1086                header,
1087                queries,
1088                ResponseCode::FormErr,
1089                error,
1090                response_handler,
1091            )
1092            .await;
1093        }
1094        Err(error) => info!(
1095            "request:Failed src:{proto}://{addr}#{port} error:{error}",
1096            proto = protocol,
1097            addr = src_addr.ip(),
1098            port = src_addr.port(),
1099        ),
1100    }
1101}
1102
1103/// Checks if the IP address is safe for returning messages
1104///
1105/// Examples of unsafe addresses are any with a port of `0`
1106///
1107/// # Returns
1108///
1109/// Error if the address should not be used for returned requests
1110fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
1111    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
1112    if src.port() == 0 {
1113        return Err(format!("cannot respond to src on port 0: {src}"));
1114    }
1115
1116    fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
1117        if src.is_unspecified() {
1118            return Err(format!("cannot respond to unspecified v4 addr: {src}"));
1119        }
1120
1121        if src.is_broadcast() {
1122            return Err(format!("cannot respond to broadcast v4 addr: {src}"));
1123        }
1124
1125        // TODO: add check for is_reserved when that stabilizes
1126
1127        Ok(())
1128    }
1129
1130    fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
1131        if src.is_unspecified() {
1132            return Err(format!("cannot respond to unspecified v6 addr: {src}"));
1133        }
1134
1135        Ok(())
1136    }
1137
1138    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
1139    match src.ip() {
1140        IpAddr::V4(v4) => verify_v4(v4),
1141        IpAddr::V6(v6) => verify_v6(v6),
1142    }
1143}
1144
1145fn is_unrecoverable_socket_error(err: &io::Error) -> bool {
1146    matches!(
1147        err.kind(),
1148        io::ErrorKind::NotConnected | io::ErrorKind::ConnectionAborted
1149    )
1150}
1151
1152#[cfg(test)]
1153mod tests {
1154    use super::*;
1155    use crate::authority::Catalog;
1156    use futures_util::future;
1157    #[cfg(feature = "__tls")]
1158    use rustls::{
1159        pki_types::{CertificateDer, PrivateKeyDer},
1160        sign::{CertifiedKey, SingleCertAndKey},
1161    };
1162    use std::net::SocketAddr;
1163    use test_support::subscribe;
1164    use tokio::net::{TcpListener, UdpSocket};
1165    use tokio::time::timeout;
1166
1167    #[tokio::test]
1168    async fn abort() {
1169        subscribe();
1170
1171        let endpoints = Endpoints::new().await;
1172
1173        let endpoints2 = endpoints.clone();
1174        let (abortable, abort_handle) = future::abortable(async move {
1175            let mut server_future = ServerFuture::new(Catalog::new());
1176            endpoints2.register(&mut server_future).await;
1177            server_future.block_until_done().await
1178        });
1179
1180        abort_handle.abort();
1181        abortable.await.expect_err("expected abort");
1182
1183        endpoints.rebind_all().await;
1184    }
1185
1186    #[tokio::test]
1187    async fn graceful_shutdown() {
1188        subscribe();
1189        let mut server_future = ServerFuture::new(Catalog::new());
1190        let endpoints = Endpoints::new().await;
1191        endpoints.register(&mut server_future).await;
1192
1193        timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1194            .await
1195            .expect("timed out waiting for the server to complete")
1196            .expect("error while awaiting tasks");
1197
1198        endpoints.rebind_all().await;
1199    }
1200
1201    #[test]
1202    fn test_sanitize_src_addr() {
1203        // ipv4 tests
1204        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4_096))).is_ok());
1205        assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1206
1207        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1208        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1209        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4_096))).is_err());
1210        assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4_096))).is_err());
1211
1212        // ipv6 tests
1213        assert!(
1214            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4_096))).is_ok()
1215        );
1216        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4_096))).is_ok());
1217
1218        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4_096))).is_err());
1219        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1220        assert!(
1221            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1222        );
1223    }
1224
1225    #[derive(Clone)]
1226    struct Endpoints {
1227        udp_addr: SocketAddr,
1228        udp_std_addr: SocketAddr,
1229        tcp_addr: SocketAddr,
1230        tcp_std_addr: SocketAddr,
1231        #[cfg(feature = "__tls")]
1232        rustls_addr: SocketAddr,
1233        #[cfg(feature = "__https")]
1234        https_rustls_addr: SocketAddr,
1235        #[cfg(feature = "__quic")]
1236        quic_addr: SocketAddr,
1237        #[cfg(feature = "__h3")]
1238        h3_addr: SocketAddr,
1239    }
1240
1241    impl Endpoints {
1242        async fn new() -> Self {
1243            let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1244            let udp_std = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1245            let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1246            let tcp_std = TcpListener::bind("127.0.0.1:0").await.unwrap();
1247            #[cfg(feature = "__tls")]
1248            let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1249            #[cfg(feature = "__https")]
1250            let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1251            #[cfg(feature = "__quic")]
1252            let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1253            #[cfg(feature = "__h3")]
1254            let h3 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1255
1256            Self {
1257                udp_addr: udp.local_addr().unwrap(),
1258                udp_std_addr: udp_std.local_addr().unwrap(),
1259                tcp_addr: tcp.local_addr().unwrap(),
1260                tcp_std_addr: tcp_std.local_addr().unwrap(),
1261                #[cfg(feature = "__tls")]
1262                rustls_addr: rustls.local_addr().unwrap(),
1263                #[cfg(feature = "__https")]
1264                https_rustls_addr: https_rustls.local_addr().unwrap(),
1265                #[cfg(feature = "__quic")]
1266                quic_addr: quic.local_addr().unwrap(),
1267                #[cfg(feature = "__h3")]
1268                h3_addr: h3.local_addr().unwrap(),
1269            }
1270        }
1271
1272        async fn register<T: RequestHandler>(&self, server: &mut ServerFuture<T>) {
1273            server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1274            server
1275                .register_socket_std(std::net::UdpSocket::bind(self.udp_std_addr).unwrap())
1276                .unwrap();
1277            server.register_listener(
1278                TcpListener::bind(self.tcp_addr).await.unwrap(),
1279                Duration::from_secs(1),
1280            );
1281            server
1282                .register_listener_std(
1283                    std::net::TcpListener::bind(self.tcp_std_addr).unwrap(),
1284                    Duration::from_secs(1),
1285                )
1286                .unwrap();
1287
1288            #[cfg(feature = "__tls")]
1289            {
1290                let cert_key = rustls_cert_key();
1291                server
1292                    .register_tls_listener(
1293                        TcpListener::bind(self.rustls_addr).await.unwrap(),
1294                        Duration::from_secs(30),
1295                        cert_key,
1296                    )
1297                    .unwrap();
1298            }
1299
1300            #[cfg(feature = "__https")]
1301            {
1302                let cert_key = rustls_cert_key();
1303                server
1304                    .register_https_listener(
1305                        TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1306                        Duration::from_secs(1),
1307                        cert_key,
1308                        None,
1309                        "/dns-query".into(),
1310                    )
1311                    .unwrap();
1312            }
1313
1314            #[cfg(feature = "__quic")]
1315            {
1316                let cert_key = rustls_cert_key();
1317                server
1318                    .register_quic_listener(
1319                        UdpSocket::bind(self.quic_addr).await.unwrap(),
1320                        Duration::from_secs(1),
1321                        cert_key,
1322                        None,
1323                    )
1324                    .unwrap();
1325            }
1326
1327            #[cfg(feature = "__h3")]
1328            {
1329                let cert_key = rustls_cert_key();
1330                server
1331                    .register_h3_listener(
1332                        UdpSocket::bind(self.h3_addr).await.unwrap(),
1333                        Duration::from_secs(1),
1334                        cert_key,
1335                        None,
1336                    )
1337                    .unwrap();
1338            }
1339        }
1340
1341        async fn rebind_all(&self) {
1342            UdpSocket::bind(self.udp_addr).await.unwrap();
1343            UdpSocket::bind(self.udp_std_addr).await.unwrap();
1344            TcpListener::bind(self.tcp_addr).await.unwrap();
1345            TcpListener::bind(self.tcp_std_addr).await.unwrap();
1346            #[cfg(feature = "__tls")]
1347            TcpListener::bind(self.rustls_addr).await.unwrap();
1348            #[cfg(feature = "__https")]
1349            TcpListener::bind(self.https_rustls_addr).await.unwrap();
1350            #[cfg(feature = "__quic")]
1351            UdpSocket::bind(self.quic_addr).await.unwrap();
1352            #[cfg(feature = "__h3")]
1353            UdpSocket::bind(self.h3_addr).await.unwrap();
1354        }
1355    }
1356
1357    #[cfg(feature = "__tls")]
1358    fn rustls_cert_key() -> Arc<dyn ResolvesServerCert> {
1359        use rustls::pki_types::pem::PemObject;
1360        use std::env;
1361
1362        let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1363        let cert_chain =
1364            CertificateDer::pem_file_iter(format!("{}/tests/test-data/cert.pem", server_path))
1365                .unwrap()
1366                .collect::<Result<Vec<_>, _>>()
1367                .unwrap();
1368
1369        let key = PrivateKeyDer::from_pem_file(format!("{server_path}/tests/test-data/cert.key"))
1370            .unwrap();
1371
1372        let certified_key = CertifiedKey::from_der(cert_chain, key, &default_provider()).unwrap();
1373        Arc::new(SingleCertAndKey::from(certified_key))
1374    }
1375
1376    #[test]
1377    fn task_reap_on_empty_joinset() {
1378        let mut joinset = JoinSet::new();
1379
1380        // this should return immediately
1381        reap_tasks(&mut joinset);
1382    }
1383
1384    #[tokio::test]
1385    async fn task_reap_on_nonempty_joinset() {
1386        let mut joinset = JoinSet::new();
1387        let t = joinset.spawn(tokio::time::sleep(Duration::from_secs(2)));
1388
1389        // this should return immediately since no task is ready
1390        reap_tasks(&mut joinset);
1391        t.abort();
1392
1393        // this should also return immediately since the task has been aborted
1394        reap_tasks(&mut joinset);
1395    }
1396}