Skip to main content

hickory_server/server/
mod.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! `Server` component for hosting a domain name servers operations.
9
10use std::{
11    fmt, io,
12    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
13    sync::Arc,
14    time::Duration,
15};
16
17use bytes::Bytes;
18use futures_util::StreamExt;
19use ipnet::IpNet;
20#[cfg(feature = "__tls")]
21use rustls::{ServerConfig, server::ResolvesServerCert};
22#[cfg(feature = "__tls")]
23use tokio::time::timeout;
24use tokio::{net, task::JoinSet};
25#[cfg(feature = "__tls")]
26use tokio_rustls::TlsAcceptor;
27use tokio_util::sync::CancellationToken;
28use tracing::{debug, info, warn};
29
30#[cfg(feature = "metrics")]
31use crate::metrics::ResponseHandlerMetrics;
32#[cfg(feature = "__h3")]
33use crate::net::h3::h3_server::H3Server;
34#[cfg(feature = "__quic")]
35use crate::net::quic::QuicServer;
36#[cfg(feature = "__tls")]
37use crate::net::tls::{default_provider, tls_from_stream};
38use crate::{
39    access::AccessControl,
40    net::{
41        BufDnsStreamHandle, NetError,
42        runtime::{TokioRuntimeProvider, TokioTime, iocompat::AsyncIoTokioAsStd},
43        tcp::TcpStream,
44        udp::UdpStream,
45        xfer::Protocol,
46    },
47    proto::{
48        op::{Header, LowerQuery, MessageType, Metadata, ResponseCode, SerialMessage},
49        rr::Record,
50        serialize::binary::{BinDecodable, BinDecoder},
51    },
52    zone_handler::{MessageRequest, MessageResponseBuilder, Queries},
53};
54
55#[cfg(feature = "__https")]
56mod h2_handler;
57#[cfg(feature = "__h3")]
58mod h3_handler;
59#[cfg(feature = "__quic")]
60mod quic_handler;
61mod request_handler;
62pub use request_handler::{Request, RequestHandler, RequestInfo, ResponseInfo};
63mod response_handler;
64pub use response_handler::{ResponseHandle, ResponseHandler};
65mod timeout_stream;
66pub use timeout_stream::TimeoutStream;
67
68// TODO, would be nice to have a Slab for buffers here...
69/// A Futures based implementation of a DNS server
70pub struct Server<T: RequestHandler> {
71    context: Arc<ServerContext<T>>,
72    join_set: JoinSet<Result<(), NetError>>,
73}
74
75impl<T: RequestHandler> Server<T> {
76    /// Creates a new ServerFuture with the specified Handler.
77    pub fn new(handler: T) -> Self {
78        Self::with_access(handler, [], [])
79    }
80
81    /// Creates a new ServerFuture with the specified Handler and denied/allowed networks
82    pub fn with_access(
83        handler: T,
84        denied_networks: impl IntoIterator<Item = IpNet>,
85        allowed_networks: impl IntoIterator<Item = IpNet>,
86    ) -> Self {
87        let mut access = AccessControl::default();
88        access.insert_deny(denied_networks);
89        access.insert_allow(allowed_networks);
90
91        Self {
92            context: Arc::new(ServerContext {
93                handler,
94                access,
95                shutdown: CancellationToken::new(),
96            }),
97            join_set: JoinSet::new(),
98        }
99    }
100
101    /// Register a UDP socket. Should be bound before calling this function.
102    pub fn register_socket(&mut self, socket: net::UdpSocket) {
103        self.join_set
104            .spawn(handle_udp(socket, self.context.clone()));
105    }
106
107    /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
108    ///  IPv4 address.
109    ///
110    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
111    ///  to not make this too low depending on use cases.
112    ///
113    /// # Arguments
114    /// * `listener` - a bound TCP socket
115    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
116    ///   requests within this time period will be closed. In the future it should be
117    ///   possible to create long-lived queries, but these should be from trusted sources
118    ///   only, this would require some type of whitelisting.
119    /// * `response_buffer_size` - size of the buffer for outgoing responses per connection
120    pub fn register_listener(
121        &mut self,
122        listener: net::TcpListener,
123        timeout: Duration,
124        response_buffer_size: usize,
125    ) {
126        self.join_set.spawn(handle_tcp(
127            listener,
128            timeout,
129            response_buffer_size,
130            self.context.clone(),
131        ));
132    }
133
134    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
135    /// IPv6 or an IPv4 address.
136    ///
137    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
138    ///  to not make this too low depending on use cases.
139    ///
140    /// The TLS `ServerConfig` should be configured with TLS 1.3 support and the DoT ALPN protocol
141    /// enabled.
142    ///
143    /// # Arguments
144    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
145    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
146    ///   requests within this time period will be closed. In the future it should be
147    ///   possible to create long-lived queries, but these should be from trusted sources
148    ///   only, this would require some type of whitelisting.
149    /// * `tls_config` - rustls server config
150    #[cfg(feature = "__tls")]
151    pub fn register_tls_listener_with_tls_config(
152        &mut self,
153        listener: net::TcpListener,
154        handshake_timeout: Duration,
155        tls_config: Arc<ServerConfig>,
156    ) -> io::Result<()> {
157        self.join_set.spawn(handle_tls(
158            listener,
159            tls_config,
160            handshake_timeout,
161            self.context.clone(),
162        ));
163        Ok(())
164    }
165
166    /// Register a TlsListener to the Server by providing a rustls `ResolvesServerCert`. The
167    /// TlsListener should already be bound to either an IPv6 or an IPv4 address.
168    ///
169    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
170    ///  to not make this too low depending on use cases.
171    ///
172    /// # Arguments
173    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
174    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
175    ///   requests within this time period will be closed. In the future it should be
176    ///   possible to create long-lived queries, but these should be from trusted sources
177    ///   only, this would require some type of whitelisting.
178    /// * `server_cert_resolver` - resolver for the certificate and key used to announce to clients
179    #[cfg(feature = "__tls")]
180    pub fn register_tls_listener(
181        &mut self,
182        listener: net::TcpListener,
183        timeout: Duration,
184        server_cert_resolver: Arc<dyn ResolvesServerCert>,
185    ) -> io::Result<()> {
186        Self::register_tls_listener_with_tls_config(
187            self,
188            listener,
189            timeout,
190            Arc::new(default_tls_server_config(b"dot", server_cert_resolver)?),
191        )
192    }
193
194    /// Register a TcpListener for HTTPS (h2) to the Server for supporting DoH (DNS-over-HTTPS). The TcpListener should already be bound to either an
195    /// IPv6 or an IPv4 address.
196    ///
197    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
198    ///  to not make this too low depending on use cases.
199    ///
200    /// # Arguments
201    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
202    /// * `handshake_timeout` - timeout duration of incoming requests, any connection that does not send
203    ///   requests within this time period will be closed. In the future it should be
204    ///   possible to create long-lived queries, but these should be from trusted sources
205    ///   only, this would require some type of whitelisting.
206    /// * `server_cert_resolver` - resolver for the certificate and key used to announce to clients
207    /// * `dns_hostname` - the DNS hostname of the H2 server.
208    /// * `http_endpoint` - the HTTP endpoint of the H2 server.
209    #[cfg(feature = "__https")]
210    pub fn register_https_listener(
211        &mut self,
212        listener: net::TcpListener,
213        // TODO: need to set a timeout between requests.
214        handshake_timeout: Duration,
215        server_cert_resolver: Arc<dyn ResolvesServerCert>,
216        dns_hostname: Option<String>,
217        http_endpoint: String,
218    ) -> io::Result<()> {
219        self.join_set.spawn(h2_handler::handle_h2(
220            listener,
221            handshake_timeout,
222            server_cert_resolver,
223            dns_hostname,
224            http_endpoint,
225            self.context.clone(),
226        ));
227        Ok(())
228    }
229
230    /// Register a TcpListener for HTTPS (h2) for supporting DoH with the given TLS config.
231    ///
232    /// The TcpListener should already be bound to either an IPv6 or an IPv4 address.
233    ///
234    /// The TLS `ServerConfig` should be configured with TLS 1.3 support and the DoH ALPN protocol
235    /// enabled.
236    ///
237    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
238    ///  to not make this too low depending on use cases.
239    ///
240    /// # Arguments
241    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
242    /// * `handshake_timeout` - timeout duration of incoming requests, any connection that does not send
243    ///   requests within this time period will be closed. In the future it should be
244    ///   possible to create long-lived queries, but these should be from trusted sources
245    ///   only, this would require some type of whitelisting.
246    /// * `tls_config` - a customized `ServerConfig` to use for TLS.
247    /// * `dns_hostname` - the DNS hostname of the H2 server.
248    /// * `http_endpoint` - the HTTP endpoint of the H2 server.
249    #[cfg(feature = "__https")]
250    pub fn register_https_listener_with_tls_config(
251        &mut self,
252        listener: net::TcpListener,
253        // TODO: need to set a timeout between requests.
254        handshake_timeout: Duration,
255        tls_config: Arc<ServerConfig>,
256        dns_hostname: Option<String>,
257        http_endpoint: String,
258    ) -> io::Result<()> {
259        self.join_set.spawn(h2_handler::handle_h2_with_acceptor(
260            listener,
261            handshake_timeout,
262            TlsAcceptor::from(tls_config),
263            dns_hostname,
264            http_endpoint,
265            self.context.clone(),
266        ));
267        Ok(())
268    }
269
270    /// Register a UdpSocket to the Server for supporting DoQ (DNS-over-QUIC). The UdpSocket should already be bound to either an
271    /// IPv6 or an IPv4 address.
272    ///
273    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
274    ///  to not make this too low depending on use cases.
275    ///
276    /// # Arguments
277    /// * `socket` - a bound UDP socket
278    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
279    ///   requests within this time period will be closed. In the future it should be
280    ///   possible to create long-lived queries, but these should be from trusted sources
281    ///   only, this would require some type of whitelisting.
282    /// * `server_cert_resolver` - resolver for certificate and key used to announce to clients
283    /// * `dns_hostname` - the DNS hostname of the DoQ server.
284    #[cfg(feature = "__quic")]
285    pub fn register_quic_listener(
286        &mut self,
287        socket: net::UdpSocket,
288        // TODO: need to set a timeout between requests.
289        _timeout: Duration,
290        server_cert_resolver: Arc<dyn ResolvesServerCert>,
291    ) -> io::Result<()> {
292        let cx = self.context.clone();
293        self.join_set
294            .spawn(quic_handler::handle_quic(socket, server_cert_resolver, cx));
295        Ok(())
296    }
297
298    /// Register a UdpSocket for supporting DoQ (DNS-over-QUIC) with the provided TLS config.
299    ///
300    /// The UdpSocket should already be bound to either an IPv6 or an IPv4 address.
301    ///
302    /// The TLS `ServerConfig` should be configured with TLS 1.3 support and the DoQ ALPN protocol
303    /// enabled.
304    ///
305    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
306    ///  to not make this too low depending on use cases.
307    ///
308    /// # Arguments
309    /// * `socket` - a bound UDP socket
310    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
311    ///   requests within this time period will be closed. In the future it should be
312    ///   possible to create long-lived queries, but these should be from trusted sources
313    ///   only, this would require some type of whitelisting.
314    /// * `tls_config` - a customized ServerConfig to use for TLS.
315    /// * `dns_hostname` - the DNS hostname of the DoQ server.
316    #[cfg(feature = "__quic")]
317    pub fn register_quic_listener_and_tls_config(
318        &mut self,
319        socket: net::UdpSocket,
320        // TODO: need to set a timeout between requests.
321        _timeout: Duration,
322        tls_config: Arc<ServerConfig>,
323    ) -> Result<(), NetError> {
324        let cx = self.context.clone();
325
326        self.join_set.spawn(quic_handler::handle_quic_with_server(
327            QuicServer::with_socket_and_tls_config(socket, tls_config)?,
328            cx,
329        ));
330        Ok(())
331    }
332
333    /// Register a UdpSocket to the Server for supporting DoH3 (DNS-over-HTTP/3). The UdpSocket should already be bound to either an
334    /// IPv6 or an IPv4 address.
335    ///
336    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
337    ///  to not make this too low depending on use cases.
338    ///
339    /// # Arguments
340    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
341    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
342    ///   requests within this time period will be closed. In the future it should be
343    ///   possible to create long-lived queries, but these should be from trusted sources
344    ///   only, this would require some type of whitelisting.
345    /// * `server_cert_resolver` - resolver for certificate and key used to announce to clients
346    #[cfg(feature = "__h3")]
347    pub fn register_h3_listener(
348        &mut self,
349        socket: net::UdpSocket,
350        // TODO: need to set a timeout between requests.
351        _timeout: Duration,
352        server_cert_resolver: Arc<dyn ResolvesServerCert>,
353        dns_hostname: Option<String>,
354    ) -> io::Result<()> {
355        self.join_set.spawn(h3_handler::handle_h3(
356            socket,
357            server_cert_resolver,
358            dns_hostname,
359            self.context.clone(),
360        ));
361        Ok(())
362    }
363
364    /// Register a UdpSocket for supporting DoH3 (DNS-over-HTTP/3) with the specified TLS config.
365    ///
366    /// The UdpSocket should already be bound to either an IPv6 or an IPv4 address.
367    ///
368    /// The TLS `ServerConfig` should be configured with TLS 1.3 support and the DoH3 ALPN protocol
369    /// enabled.
370    ///
371    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
372    ///  to not make this too low depending on use cases.
373    ///
374    /// # Arguments
375    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
376    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
377    ///   requests within this time period will be closed. In the future it should be
378    ///   possible to create long-lived queries, but these should be from trusted sources
379    ///   only, this would require some type of whitelisting.
380    /// * `tls_config` - a customized ServerConfig to use for TLS.
381    #[cfg(feature = "__h3")]
382    pub fn register_h3_listener_with_tls_config(
383        &mut self,
384        socket: net::UdpSocket,
385        // TODO: need to set a timeout between requests.
386        _timeout: Duration,
387        tls_config: Arc<ServerConfig>,
388        dns_hostname: Option<String>,
389    ) -> Result<(), NetError> {
390        self.join_set.spawn(h3_handler::handle_h3_with_server(
391            H3Server::with_socket_and_tls_config(socket, tls_config)?,
392            dns_hostname,
393            self.context.clone(),
394        ));
395        Ok(())
396    }
397
398    /// Triggers a graceful shutdown the server. All background tasks will stop accepting
399    /// new connections and the returned future will complete once all tasks have terminated.
400    pub async fn shutdown_gracefully(&mut self) -> Result<(), NetError> {
401        self.context.shutdown.cancel();
402
403        // Wait for the server to complete.
404        self.block_until_done().await
405    }
406
407    /// Returns a reference to the [`CancellationToken`] used to gracefully shut down the server.
408    ///
409    /// Once cancellation is requested, all background tasks will stop accepting new connections,
410    /// and `block_until_done()` will complete once all tasks have terminated.
411    pub fn shutdown_token(&self) -> &CancellationToken {
412        &self.context.shutdown
413    }
414
415    /// This will run until all background tasks complete. If one or more tasks return an error,
416    /// one will be chosen as the returned error for this future.
417    pub async fn block_until_done(&mut self) -> Result<(), NetError> {
418        if self.join_set.is_empty() {
419            warn!("block_until_done called with no pending tasks");
420            return Ok(());
421        }
422
423        let mut out = Ok(());
424        while let Some(join_result) = self.join_set.join_next().await {
425            match join_result {
426                Ok(Ok(())) => continue,
427                Ok(Err(e)) => out = Err(e),
428                Err(e) => return Err(NetError::from(format!("internal error in spawn: {e}"))),
429            }
430        }
431
432        out
433    }
434}
435
436async fn handle_udp(
437    socket: net::UdpSocket,
438    cx: Arc<ServerContext<impl RequestHandler>>,
439) -> Result<(), NetError> {
440    debug!("registering udp: {:?}", socket);
441
442    // create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where.
443    //   the address used is acquired from the inbound queries
444    let (mut stream, stream_handle) =
445        UdpStream::<TokioRuntimeProvider>::with_bound(socket, ([127, 255, 255, 254], 0).into());
446
447    let mut inner_join_set = JoinSet::new();
448    loop {
449        let message = tokio::select! {
450            message = stream.next() => match message {
451                None => break,
452                Some(message) => message,
453            },
454            _ = cx.shutdown.cancelled() => break,
455        };
456
457        let message = match message {
458            Err(error) => {
459                warn!(%error, "error receiving message on udp_socket");
460                if is_unrecoverable_socket_error(&error) {
461                    break;
462                }
463                continue;
464            }
465            Ok(message) => message,
466        };
467
468        let src_addr = message.addr();
469        debug!("received udp request from: {}", src_addr);
470
471        // verify that the src address is safe for responses
472        if let Err(e) = sanitize_src_address(src_addr) {
473            warn!(
474                "address can not be responded to {src_addr}: {e}",
475                src_addr = src_addr,
476                e = e
477            );
478            continue;
479        }
480
481        let cx = cx.clone();
482        let stream_handle = stream_handle.with_remote_addr(src_addr);
483        inner_join_set.spawn(async move {
484            cx.handle_raw_request(message, Protocol::Udp, stream_handle)
485                .await;
486        });
487
488        reap_tasks(&mut inner_join_set);
489    }
490
491    if cx.shutdown.is_cancelled() {
492        Ok(())
493    } else {
494        // TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
495        Err(NetError::from("unexpected close of UDP socket"))
496    }
497}
498
499async fn handle_tcp(
500    listener: net::TcpListener,
501    timeout: Duration,
502    response_buffer_size: usize,
503    cx: Arc<ServerContext<impl RequestHandler>>,
504) -> Result<(), NetError> {
505    debug!("register tcp: {listener:?}");
506    let mut inner_join_set = JoinSet::new();
507    loop {
508        let (tcp_stream, src_addr) = tokio::select! {
509            tcp_stream = listener.accept() => match tcp_stream {
510                Ok((t, s)) => (t, s),
511                Err(error) => {
512                    debug!(%error, "error receiving TCP tcp_stream error");
513                    if is_unrecoverable_socket_error(&error) {
514                        break;
515                    }
516                    continue;
517                },
518            },
519            _ = cx.shutdown.cancelled() => {
520                // A graceful shutdown was initiated. Break out of the loop.
521                break;
522            },
523        };
524
525        // verify that the src address is safe for responses
526        if let Err(error) = sanitize_src_address(src_addr) {
527            warn!(
528                %src_addr, %error,
529                "address can not be responded to (TCP)",
530            );
531            continue;
532        }
533
534        // and spawn to the io_loop
535        let cx = cx.clone();
536        inner_join_set.spawn(async move {
537            debug!(%src_addr, "accepted TCP request");
538            // take the created stream...
539            let (buf_stream, stream_handle) = TcpStream::from_stream_with_buffer_size(
540                AsyncIoTokioAsStd(tcp_stream),
541                src_addr,
542                response_buffer_size,
543            );
544            let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
545
546            while let Some(message) = timeout_stream.next().await {
547                let message = match message {
548                    Ok(message) => message,
549                    Err(error) => {
550                        debug!(%src_addr, %error, "error in TCP request stream");
551                        // we're going to bail on this connection...
552                        return;
553                    }
554                };
555
556                // we don't spawn here to limit clients from getting too many resources
557                cx.handle_raw_request(message, Protocol::Tcp, stream_handle.clone())
558                    .await;
559            }
560        });
561
562        reap_tasks(&mut inner_join_set);
563    }
564
565    if cx.shutdown.is_cancelled() {
566        Ok(())
567    } else {
568        Err(NetError::from("unexpected close of socket"))
569    }
570}
571
572#[cfg(feature = "__tls")]
573async fn handle_tls(
574    listener: net::TcpListener,
575    tls_config: Arc<ServerConfig>,
576    handshake_timeout: Duration,
577    cx: Arc<ServerContext<impl RequestHandler>>,
578) -> Result<(), NetError> {
579    debug!(?listener, "registered tls");
580    let tls_acceptor = TlsAcceptor::from(tls_config);
581
582    let mut inner_join_set = JoinSet::new();
583    loop {
584        let (tcp_stream, src_addr) = tokio::select! {
585            tcp_stream = listener.accept() => match tcp_stream {
586                Ok((t, s)) => (t, s),
587                Err(error) => {
588                    debug!(%error, "error receiving TLS tcp_stream error");
589                    if is_unrecoverable_socket_error(&error) {
590                        break;
591                    }
592                    continue;
593                },
594            },
595            _ = cx.shutdown.cancelled() => {
596                // A graceful shutdown was initiated. Break out of the loop.
597                break;
598            },
599        };
600
601        // verify that the src address is safe for responses
602        if let Err(error) = sanitize_src_address(src_addr) {
603            warn!(
604                %src_addr, %error,
605                "address can not be responded to (TLS)",
606            );
607            continue;
608        }
609
610        let cx = cx.clone();
611        let tls_acceptor = tls_acceptor.clone();
612        // kick out to a different task immediately, let them do the TLS handshake
613        inner_join_set.spawn(async move {
614            debug!(%src_addr, "starting TLS request");
615
616            // perform the TLS
617            let Ok(tls_stream) = timeout(handshake_timeout, tls_acceptor.accept(tcp_stream)).await
618            else {
619                warn!("tls timeout expired during handshake");
620                return;
621            };
622
623            let tls_stream = match tls_stream {
624                Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
625                Err(error) => {
626                    debug!(%src_addr, %error, "tls handshake error");
627                    return;
628                }
629            };
630            debug!(%src_addr, "accepted TLS request");
631            let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
632            let mut timeout_stream = TimeoutStream::new(buf_stream, handshake_timeout);
633            while let Some(message) = timeout_stream.next().await {
634                let message = match message {
635                    Ok(message) => message,
636                    Err(error) => {
637                        debug!(
638                            %src_addr, %error,
639                            "error in TLS request stream",
640                        );
641
642                        // kill this connection
643                        return;
644                    }
645                };
646
647                cx.handle_raw_request(message, Protocol::Tls, stream_handle.clone())
648                    .await;
649            }
650        });
651
652        reap_tasks(&mut inner_join_set);
653    }
654
655    if cx.shutdown.is_cancelled() {
656        Ok(())
657    } else {
658        Err(NetError::from("unexpected close of socket"))
659    }
660}
661
662/// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
663fn reap_tasks(join_set: &mut JoinSet<()>) {
664    while join_set.try_join_next().is_some() {}
665}
666
667/// Construct a default `ServerConfig` for the given ALPN protocol and server cert resolver.
668#[cfg(feature = "__tls")]
669pub fn default_tls_server_config(
670    protocol: &[u8],
671    server_cert_resolver: Arc<dyn ResolvesServerCert>,
672) -> io::Result<ServerConfig> {
673    let mut config = ServerConfig::builder_with_provider(Arc::new(default_provider()))
674        .with_safe_default_protocol_versions()
675        .map_err(|e| io::Error::other(format!("error creating TLS acceptor: {e}")))?
676        .with_no_client_auth()
677        .with_cert_resolver(server_cert_resolver);
678
679    config.alpn_protocols = vec![protocol.to_vec()];
680
681    Ok(config)
682}
683
684#[derive(Clone)]
685pub(super) struct ReportingResponseHandler<R: ResponseHandler> {
686    pub(super) request_meta: Metadata,
687    queries: Vec<LowerQuery>,
688    pub(super) protocol: Protocol,
689    src_addr: SocketAddr,
690    handler: R,
691    #[cfg(feature = "metrics")]
692    metrics: ResponseHandlerMetrics,
693}
694
695#[async_trait::async_trait]
696impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
697    async fn send_response<'a>(
698        &mut self,
699        response: crate::zone_handler::MessageResponse<
700            '_,
701            'a,
702            impl Iterator<Item = &'a Record> + Send + 'a,
703            impl Iterator<Item = &'a Record> + Send + 'a,
704            impl Iterator<Item = &'a Record> + Send + 'a,
705            impl Iterator<Item = &'a Record> + Send + 'a,
706        >,
707    ) -> Result<ResponseInfo, NetError> {
708        let response_info = self.handler.send_response(response).await?;
709
710        let id = self.request_meta.id;
711        let rid = response_info.id;
712        if id != rid {
713            warn!("request id:{id} does not match response id:{rid}");
714            debug_assert_eq!(id, rid, "request id and response id should match");
715        }
716
717        let rflags = response_info.flags();
718        let answer_count = response_info.counts().answers;
719        let authority_count = response_info.counts().authorities;
720        let additional_count = response_info.counts().additionals;
721        let response_code = response_info.response_code;
722
723        info!(
724            "request:{id} src:{proto}://{addr}#{port} {op} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
725            id = rid,
726            proto = self.protocol,
727            addr = self.src_addr.ip(),
728            port = self.src_addr.port(),
729            op = self.request_meta.op_code,
730            qflags = self.request_meta.flags(),
731            code = response_code,
732            answers = answer_count,
733            authorities = authority_count,
734            additionals = additional_count,
735            rflags = rflags
736        );
737        for query in self.queries.iter() {
738            info!(
739                "query:{query}:{qtype}:{class}",
740                query = query.name(),
741                qtype = query.query_type(),
742                class = query.query_class()
743            );
744        }
745
746        #[cfg(feature = "metrics")]
747        self.metrics.update(self, &response_info);
748
749        Ok(response_info)
750    }
751}
752
753struct ServerContext<T> {
754    handler: T,
755    access: AccessControl,
756    shutdown: CancellationToken,
757}
758
759impl<T: RequestHandler> ServerContext<T> {
760    async fn handle_raw_request(
761        &self,
762        message: SerialMessage,
763        protocol: Protocol,
764        response_handler: BufDnsStreamHandle,
765    ) {
766        let (message, src_addr) = message.into_parts();
767        let response_handler = ResponseHandle::new(src_addr, response_handler, protocol);
768
769        self.handle_request(Bytes::from(message), src_addr, protocol, response_handler)
770            .await;
771    }
772
773    async fn handle_request(
774        &self,
775        message_bytes: Bytes,
776        src_addr: SocketAddr,
777        protocol: Protocol,
778        response_handler: impl ResponseHandler,
779    ) {
780        let mut decoder = BinDecoder::new(&message_bytes);
781        let Ok(header) = Header::read(&mut decoder) else {
782            // This will only fail if the message is less than twelve bytes long. Such messages are
783            // definitely not valid DNS queries, so it should be fine to return without sending a
784            // response.
785            return;
786        };
787
788        if !self.access.allow(src_addr.ip()) {
789            info!(
790                "request:Refused src:{proto}://{addr}#{port}",
791                proto = protocol,
792                addr = src_addr.ip(),
793                port = src_addr.port(),
794            );
795
796            let queries = match Queries::read(&mut decoder, header.counts.queries as usize) {
797                Ok(queries) => queries,
798                Err(_) => Queries::empty(),
799            };
800            error_response_handler(
801                protocol,
802                src_addr,
803                header,
804                queries,
805                ResponseCode::Refused,
806                "request refused",
807                response_handler,
808            )
809            .await;
810
811            return;
812        }
813
814        // Attempt to decode the message
815        let request = match MessageRequest::read(&mut decoder, header) {
816            Ok(message) => Request {
817                message,
818                raw: message_bytes,
819                src: src_addr,
820                protocol,
821            },
822            Err(error) => {
823                // We failed to parse the request due to some issue in the message, but the header is available, so we can respond
824                let queries = Queries::empty();
825
826                error_response_handler(
827                    protocol,
828                    src_addr,
829                    header,
830                    queries,
831                    ResponseCode::FormErr,
832                    error,
833                    response_handler,
834                )
835                .await;
836
837                return;
838            }
839        };
840
841        if request.message.metadata.message_type == MessageType::Response {
842            // Don't process response messages to avoid DoS attacks from reflection.
843            return;
844        }
845
846        let id = request.message.metadata.id;
847        let qflags = request.message.metadata.flags();
848        let qop_code = request.message.metadata.op_code;
849        let message_type = request.message.metadata.message_type;
850        let is_dnssec = request
851            .message
852            .edns
853            .as_ref()
854            .is_some_and(|edns| edns.flags().dnssec_ok);
855
856        debug!(
857            "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op} qflags:{qflags}",
858            id = id,
859            proto = request.protocol(),
860            addr = request.src().ip(),
861            port = request.src().port(),
862            message_type = message_type,
863            is_dnssec = is_dnssec,
864            op = qop_code,
865            qflags = qflags
866        );
867        for query in request.queries.queries().iter() {
868            debug!(
869                "query:{query}:{qtype}:{class}",
870                query = query.name(),
871                qtype = query.query_type(),
872                class = query.query_class()
873            );
874        }
875
876        // The reporter will handle making sure to log the result of the request
877        let queries = request.queries.queries().to_vec();
878        let reporter = ReportingResponseHandler {
879            request_meta: request.metadata,
880            queries,
881            protocol: request.protocol(),
882            src_addr: request.src(),
883            handler: response_handler,
884            #[cfg(feature = "metrics")]
885            metrics: ResponseHandlerMetrics::default(),
886        };
887
888        self.handler
889            .handle_request::<_, TokioTime>(&request, reporter)
890            .await;
891    }
892}
893
894// method to return an error to the client
895async fn error_response_handler(
896    protocol: Protocol,
897    src_addr: SocketAddr,
898    header: Header,
899    queries: Queries,
900    response_code: ResponseCode,
901    error: impl fmt::Display,
902    response_handler: impl ResponseHandler,
903) {
904    // debug for more info on why the message parsing failed
905    debug!(
906        "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:{response_code}:{error}",
907        id = header.id,
908        proto = protocol,
909        addr = src_addr.ip(),
910        port = src_addr.port(),
911        message_type = header.message_type,
912        op = header.op_code,
913        response_code = response_code,
914        error = error,
915    );
916
917    // The reporter will handle making sure to log the result of the request
918    let mut reporter = ReportingResponseHandler {
919        request_meta: header.metadata,
920        queries: queries.queries().to_vec(),
921        protocol,
922        src_addr,
923        handler: response_handler,
924        #[cfg(feature = "metrics")]
925        metrics: ResponseHandlerMetrics::default(),
926    };
927
928    let response = MessageResponseBuilder::new(&queries, None);
929    let result = reporter
930        .send_response(response.error_msg(&header, response_code))
931        .await;
932
933    if let Err(error) = result {
934        warn!(%error, "failed to return FormError to client");
935    }
936}
937
938/// Checks if the IP address is safe for returning messages
939///
940/// Examples of unsafe addresses are any with a port of `0`
941///
942/// # Returns
943///
944/// Error if the address should not be used for returned requests
945fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
946    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
947    if src.port() == 0 {
948        return Err(format!("cannot respond to src on port 0: {src}"));
949    }
950
951    fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
952        if src.is_unspecified() {
953            return Err(format!("cannot respond to unspecified v4 addr: {src}"));
954        }
955
956        if src.is_broadcast() {
957            return Err(format!("cannot respond to broadcast v4 addr: {src}"));
958        }
959
960        // TODO: add check for is_reserved when that stabilizes
961
962        Ok(())
963    }
964
965    fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
966        if src.is_unspecified() {
967            return Err(format!("cannot respond to unspecified v6 addr: {src}"));
968        }
969
970        Ok(())
971    }
972
973    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
974    match src.ip() {
975        IpAddr::V4(v4) => verify_v4(v4),
976        IpAddr::V6(v6) => verify_v6(v6),
977    }
978}
979
980fn is_unrecoverable_socket_error(err: &io::Error) -> bool {
981    matches!(
982        err.kind(),
983        io::ErrorKind::NotConnected | io::ErrorKind::ConnectionAborted
984    )
985}
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990    use crate::zone_handler::Catalog;
991    use futures_util::future;
992    #[cfg(feature = "__tls")]
993    use rustls::{
994        pki_types::{CertificateDer, PrivateKeyDer},
995        sign::{CertifiedKey, SingleCertAndKey},
996    };
997    use std::net::SocketAddr;
998    use test_support::subscribe;
999    use tokio::net::{TcpListener, UdpSocket};
1000    use tokio::time::timeout;
1001
1002    #[tokio::test]
1003    async fn abort() {
1004        subscribe();
1005
1006        let endpoints = Endpoints::new().await;
1007
1008        let endpoints2 = endpoints.clone();
1009        let (abortable, abort_handle) = future::abortable(async move {
1010            let mut server_future = Server::new(Catalog::new());
1011            endpoints2.register(&mut server_future).await;
1012            server_future.block_until_done().await
1013        });
1014
1015        abort_handle.abort();
1016        abortable.await.expect_err("expected abort");
1017
1018        endpoints.rebind_all().await;
1019    }
1020
1021    #[tokio::test]
1022    async fn graceful_shutdown() {
1023        subscribe();
1024        let mut server_future = Server::new(Catalog::new());
1025        let endpoints = Endpoints::new().await;
1026        endpoints.register(&mut server_future).await;
1027
1028        timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1029            .await
1030            .expect("timed out waiting for the server to complete")
1031            .expect("error while awaiting tasks");
1032
1033        endpoints.rebind_all().await;
1034    }
1035
1036    #[test]
1037    fn test_sanitize_src_addr() {
1038        // ipv4 tests
1039        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4_096))).is_ok());
1040        assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1041
1042        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1043        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1044        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4_096))).is_err());
1045        assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4_096))).is_err());
1046
1047        // ipv6 tests
1048        assert!(
1049            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4_096))).is_ok()
1050        );
1051        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4_096))).is_ok());
1052
1053        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4_096))).is_err());
1054        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1055        assert!(
1056            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1057        );
1058    }
1059
1060    #[derive(Clone)]
1061    struct Endpoints {
1062        udp_addr: SocketAddr,
1063        tcp_addr: SocketAddr,
1064        #[cfg(feature = "__tls")]
1065        rustls_addr: SocketAddr,
1066        #[cfg(feature = "__https")]
1067        https_rustls_addr: SocketAddr,
1068        #[cfg(feature = "__quic")]
1069        quic_addr: SocketAddr,
1070        #[cfg(feature = "__h3")]
1071        h3_addr: SocketAddr,
1072    }
1073
1074    impl Endpoints {
1075        async fn new() -> Self {
1076            let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1077            let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1078            #[cfg(feature = "__tls")]
1079            let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1080            #[cfg(feature = "__https")]
1081            let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1082            #[cfg(feature = "__quic")]
1083            let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1084            #[cfg(feature = "__h3")]
1085            let h3 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1086
1087            Self {
1088                udp_addr: udp.local_addr().unwrap(),
1089                tcp_addr: tcp.local_addr().unwrap(),
1090                #[cfg(feature = "__tls")]
1091                rustls_addr: rustls.local_addr().unwrap(),
1092                #[cfg(feature = "__https")]
1093                https_rustls_addr: https_rustls.local_addr().unwrap(),
1094                #[cfg(feature = "__quic")]
1095                quic_addr: quic.local_addr().unwrap(),
1096                #[cfg(feature = "__h3")]
1097                h3_addr: h3.local_addr().unwrap(),
1098            }
1099        }
1100
1101        async fn register<T: RequestHandler>(&self, server: &mut Server<T>) {
1102            server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1103            server.register_listener(
1104                TcpListener::bind(self.tcp_addr).await.unwrap(),
1105                Duration::from_secs(1),
1106                32,
1107            );
1108
1109            #[cfg(feature = "__tls")]
1110            {
1111                let cert_key = rustls_cert_key();
1112                server
1113                    .register_tls_listener(
1114                        TcpListener::bind(self.rustls_addr).await.unwrap(),
1115                        Duration::from_secs(30),
1116                        cert_key,
1117                    )
1118                    .unwrap();
1119            }
1120
1121            #[cfg(feature = "__https")]
1122            {
1123                let cert_key = rustls_cert_key();
1124                server
1125                    .register_https_listener(
1126                        TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1127                        Duration::from_secs(1),
1128                        cert_key,
1129                        None,
1130                        "/dns-query".into(),
1131                    )
1132                    .unwrap();
1133            }
1134
1135            #[cfg(feature = "__quic")]
1136            {
1137                let cert_key = rustls_cert_key();
1138                server
1139                    .register_quic_listener(
1140                        UdpSocket::bind(self.quic_addr).await.unwrap(),
1141                        Duration::from_secs(1),
1142                        cert_key,
1143                    )
1144                    .unwrap();
1145            }
1146
1147            #[cfg(feature = "__h3")]
1148            {
1149                let cert_key = rustls_cert_key();
1150                server
1151                    .register_h3_listener(
1152                        UdpSocket::bind(self.h3_addr).await.unwrap(),
1153                        Duration::from_secs(1),
1154                        cert_key,
1155                        None,
1156                    )
1157                    .unwrap();
1158            }
1159        }
1160
1161        async fn rebind_all(&self) {
1162            UdpSocket::bind(self.udp_addr).await.unwrap();
1163            TcpListener::bind(self.tcp_addr).await.unwrap();
1164            #[cfg(feature = "__tls")]
1165            TcpListener::bind(self.rustls_addr).await.unwrap();
1166            #[cfg(feature = "__https")]
1167            TcpListener::bind(self.https_rustls_addr).await.unwrap();
1168            #[cfg(feature = "__quic")]
1169            UdpSocket::bind(self.quic_addr).await.unwrap();
1170            #[cfg(feature = "__h3")]
1171            UdpSocket::bind(self.h3_addr).await.unwrap();
1172        }
1173    }
1174
1175    #[cfg(feature = "__tls")]
1176    fn rustls_cert_key() -> Arc<dyn ResolvesServerCert> {
1177        use rustls::pki_types::pem::PemObject;
1178        use std::env;
1179
1180        let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1181        let cert_chain =
1182            CertificateDer::pem_file_iter(format!("{server_path}/tests/test-data/cert.pem"))
1183                .unwrap()
1184                .collect::<Result<Vec<_>, _>>()
1185                .unwrap();
1186
1187        let key = PrivateKeyDer::from_pem_file(format!("{server_path}/tests/test-data/cert.key"))
1188            .unwrap();
1189
1190        let certified_key = CertifiedKey::from_der(cert_chain, key, &default_provider()).unwrap();
1191        Arc::new(SingleCertAndKey::from(certified_key))
1192    }
1193
1194    #[test]
1195    fn task_reap_on_empty_joinset() {
1196        let mut joinset = JoinSet::new();
1197
1198        // this should return immediately
1199        reap_tasks(&mut joinset);
1200    }
1201
1202    #[tokio::test]
1203    async fn task_reap_on_nonempty_joinset() {
1204        let mut joinset = JoinSet::new();
1205        let t = joinset.spawn(tokio::time::sleep(Duration::from_secs(2)));
1206
1207        // this should return immediately since no task is ready
1208        reap_tasks(&mut joinset);
1209        t.abort();
1210
1211        // this should also return immediately since the task has been aborted
1212        reap_tasks(&mut joinset);
1213    }
1214}