hickory_server/server/
server_future.rs

1// Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7use std::{
8    io,
9    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
10    sync::Arc,
11    time::Duration,
12};
13
14use futures_util::{FutureExt, StreamExt};
15use hickory_proto::{op::MessageType, rr::Record};
16#[cfg(feature = "dns-over-rustls")]
17use rustls::{Certificate, PrivateKey, ServerConfig};
18use tokio::{net, task::JoinSet};
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, info, warn};
21
22#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
23use crate::proto::openssl::tls_server::*;
24use crate::{
25    authority::{MessageRequest, MessageResponseBuilder},
26    proto::{
27        error::ProtoError,
28        iocompat::AsyncIoTokioAsStd,
29        op::{Edns, Header, LowerQuery, Query, ResponseCode},
30        serialize::binary::{BinDecodable, BinDecoder},
31        tcp::TcpStream,
32        udp::UdpStream,
33        xfer::SerialMessage,
34        BufDnsStreamHandle,
35    },
36    server::{Protocol, Request, RequestHandler, ResponseHandle, ResponseHandler, TimeoutStream},
37};
38
39// TODO, would be nice to have a Slab for buffers here...
40/// A Futures based implementation of a DNS server
41pub struct ServerFuture<T: RequestHandler> {
42    handler: Arc<T>,
43    join_set: JoinSet<Result<(), ProtoError>>,
44    shutdown_token: CancellationToken,
45}
46
47impl<T: RequestHandler> ServerFuture<T> {
48    /// Creates a new ServerFuture with the specified Handler.
49    pub fn new(handler: T) -> Self {
50        Self {
51            handler: Arc::new(handler),
52            join_set: JoinSet::new(),
53            shutdown_token: CancellationToken::new(),
54        }
55    }
56
57    /// Register a UDP socket. Should be bound before calling this function.
58    pub fn register_socket(&mut self, socket: net::UdpSocket) {
59        debug!("registering udp: {:?}", socket);
60
61        // create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where.
62        //   the address used is acquired from the inbound queries
63        let (mut stream, stream_handle) =
64            UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
65        let shutdown = self.shutdown_token.clone();
66        let handler = self.handler.clone();
67
68        // this spawns a ForEach future which handles all the requests into a Handler.
69        self.join_set.spawn({
70            async move {
71                let mut inner_join_set = JoinSet::new();
72                loop {
73                    let message = tokio::select! {
74                        message = stream.next() => match message {
75                            None => break,
76                            Some(message) => message,
77                        },
78                        _ = shutdown.cancelled() => break,
79                    };
80
81                    let message = match message {
82                        Err(e) => {
83                            warn!("error receiving message on udp_socket: {}", e);
84                            if is_unrecoverable_socket_error(&e) {
85                                break;
86                            }
87                            continue;
88                        }
89                        Ok(message) => message,
90                    };
91
92                    let src_addr = message.addr();
93                    debug!("received udp request from: {}", src_addr);
94
95                    // verify that the src address is safe for responses
96                    if let Err(e) = sanitize_src_address(src_addr) {
97                        warn!(
98                            "address can not be responded to {src_addr}: {e}",
99                            src_addr = src_addr,
100                            e = e
101                        );
102                        continue;
103                    }
104
105                    let handler = handler.clone();
106                    let stream_handle = stream_handle.with_remote_addr(src_addr);
107
108                    inner_join_set.spawn(async move {
109                        handle_raw_request(message, Protocol::Udp, handler, stream_handle).await;
110                    });
111
112                    reap_tasks(&mut inner_join_set);
113                }
114
115                if shutdown.is_cancelled() {
116                    Ok(())
117                } else {
118                    // TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
119                    Err(ProtoError::from("unexpected close of UDP socket"))
120                }
121            }
122        });
123    }
124
125    /// Register a UDP socket. Should be bound before calling this function.
126    pub fn register_socket_std(&mut self, socket: std::net::UdpSocket) -> io::Result<()> {
127        self.register_socket(net::UdpSocket::from_std(socket)?);
128        Ok(())
129    }
130
131    /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
132    ///  IPv4 address.
133    ///
134    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
135    ///  to not make this too low depending on use cases.
136    ///
137    /// # Arguments
138    /// * `listener` - a bound TCP socket
139    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
140    ///               requests within this time period will be closed. In the future it should be
141    ///               possible to create long-lived queries, but these should be from trusted sources
142    ///               only, this would require some type of whitelisting.
143    pub fn register_listener(&mut self, listener: net::TcpListener, timeout: Duration) {
144        debug!("register tcp: {:?}", listener);
145
146        let handler = self.handler.clone();
147
148        // for each incoming request...
149        let shutdown = self.shutdown_token.clone();
150        self.join_set.spawn(async move {
151            let mut inner_join_set = JoinSet::new();
152            loop {
153                let (tcp_stream, src_addr) = tokio::select! {
154                    tcp_stream = listener.accept() => match tcp_stream {
155                        Ok((t, s)) => (t, s),
156                        Err(e) => {
157                            debug!("error receiving TCP tcp_stream error: {}", e);
158                            if is_unrecoverable_socket_error(&e) {
159                                break;
160                            }
161                            continue;
162                        },
163                    },
164                    _ = shutdown.cancelled() => {
165                        // A graceful shutdown was initiated. Break out of the loop.
166                        break;
167                    },
168                };
169
170                // verify that the src address is safe for responses
171                if let Err(e) = sanitize_src_address(src_addr) {
172                    warn!(
173                        "address can not be responded to {src_addr}: {e}",
174                        src_addr = src_addr,
175                        e = e
176                    );
177                    continue;
178                }
179
180                let handler = handler.clone();
181
182                // and spawn to the io_loop
183                inner_join_set.spawn(async move {
184                    debug!("accepted request from: {}", src_addr);
185                    // take the created stream...
186                    let (buf_stream, stream_handle) =
187                        TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
188                    let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
189
190                    while let Some(message) = timeout_stream.next().await {
191                        let message = match message {
192                            Ok(message) => message,
193                            Err(e) => {
194                                debug!(
195                                    "error in TCP request_stream src: {} error: {}",
196                                    src_addr, e
197                                );
198                                // we're going to bail on this connection...
199                                return;
200                            }
201                        };
202
203                        // we don't spawn here to limit clients from getting too many resources
204                        handle_raw_request(
205                            message,
206                            Protocol::Tcp,
207                            handler.clone(),
208                            stream_handle.clone(),
209                        )
210                        .await;
211                    }
212                });
213
214                reap_tasks(&mut inner_join_set);
215            }
216
217            if shutdown.is_cancelled() {
218                Ok(())
219            } else {
220                Err(ProtoError::from("unexpected close of socket"))
221            }
222        });
223    }
224
225    /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
226    ///  IPv4 address.
227    ///
228    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
229    ///  to not make this too low depending on use cases.
230    ///
231    /// # Arguments
232    /// * `listener` - a bound TCP socket
233    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
234    ///               requests within this time period will be closed. In the future it should be
235    ///               possible to create long-lived queries, but these should be from trusted sources
236    ///               only, this would require some type of whitelisting.
237    pub fn register_listener_std(
238        &mut self,
239        listener: std::net::TcpListener,
240        timeout: Duration,
241    ) -> io::Result<()> {
242        self.register_listener(net::TcpListener::from_std(listener)?, timeout);
243        Ok(())
244    }
245
246    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
247    /// IPv6 or an IPv4 address.
248    ///
249    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
250    ///  to not make this too low depending on use cases.
251    ///
252    /// # Arguments
253    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
254    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
255    ///               requests within this time period will be closed. In the future it should be
256    ///               possible to create long-lived queries, but these should be from trusted sources
257    ///               only, this would require some type of whitelisting.
258    /// * `pkcs12` - certificate used to announce to clients
259    #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
260    #[cfg_attr(
261        docsrs,
262        doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
263    )]
264    pub fn register_tls_listener(
265        &mut self,
266        listener: net::TcpListener,
267        timeout: Duration,
268        certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
269    ) -> io::Result<()> {
270        use crate::proto::openssl::{tls_server, TlsStream};
271        use openssl::ssl::Ssl;
272        use std::pin::Pin;
273        use tokio_openssl::SslStream as TokioSslStream;
274
275        let ((cert, chain), key) = certificate_and_key;
276
277        let handler = self.handler.clone();
278        debug!("registered tcp: {:?}", listener);
279
280        let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?);
281
282        // for each incoming request...
283        let shutdown = self.shutdown_watch.clone();
284        self.join_set.spawn(async move {
285            let mut inner_join_set = JoinSet::new();
286            loop {
287                let (tcp_stream, src_addr) = tokio::select! {
288                    tcp_stream = listener.accept() => match tcp_stream {
289                        Ok((t, s)) => (t, s),
290                        Err(e) => {
291                            debug!("error receiving TLS tcp_stream error: {}", e);
292                            if is_unrecoverable_socket_error(&e) {
293                                break;
294                            }
295                            continue;
296                        },
297                    },
298                    _ = shutdown.clone().signaled() => {
299                        // A graceful shutdown was initiated. Break out of the loop.
300                        break;
301                    },
302                };
303
304                // verify that the src address is safe for responses
305                if let Err(e) = sanitize_src_address(src_addr) {
306                    warn!(
307                        "address can not be responded to {src_addr}: {e}",
308                        src_addr = src_addr,
309                        e = e
310                    );
311                    continue;
312                }
313
314                let handler = handler.clone();
315                let tls_acceptor = tls_acceptor.clone();
316
317                // kick out to a different task immediately, let them do the TLS handshake
318                inner_join_set.spawn(async move {
319                    debug!("starting TLS request from: {}", src_addr);
320
321                    // perform the TLS
322                    let mut tls_stream = match Ssl::new(tls_acceptor.context())
323                        .and_then(|ssl| TokioSslStream::new(ssl, tcp_stream))
324                    {
325                        Ok(tls_stream) => tls_stream,
326                        Err(e) => {
327                            debug!("tls handshake src: {} error: {}", src_addr, e);
328                            return ();
329                        }
330                    };
331                    match Pin::new(&mut tls_stream).accept().await {
332                        Ok(()) => {}
333                        Err(e) => {
334                            debug!("tls handshake src: {} error: {}", src_addr, e);
335                            return ();
336                        }
337                    };
338                    debug!("accepted TLS request from: {}", src_addr);
339                    let (buf_stream, stream_handle) =
340                        TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr);
341                    let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
342                    while let Some(message) = timeout_stream.next().await {
343                        let message = match message {
344                            Ok(message) => message,
345                            Err(e) => {
346                                debug!(
347                                    "error in TLS request_stream src: {:?} error: {}",
348                                    src_addr, e
349                                );
350
351                                // kill this connection
352                                return ();
353                            }
354                        };
355
356                        self::handle_raw_request(
357                            message,
358                            Protocol::Tls,
359                            handler.clone(),
360                            stream_handle.clone(),
361                        )
362                        .await;
363                    }
364                });
365
366                reap_tasks(&mut inner_join_set);
367            }
368
369            if shutdown.is_cancelled() {
370                Ok(())
371            } else {
372                Err(ProtoError::from("unexpected close of socket"))
373            }
374        });
375
376        Ok(())
377    }
378
379    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
380    /// IPv6 or an IPv4 address.
381    ///
382    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
383    ///  to not make this too low depending on use cases.
384    ///
385    /// # Arguments
386    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
387    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
388    ///               requests within this time period will be closed. In the future it should be
389    ///               possible to create long-lived queries, but these should be from trusted sources
390    ///               only, this would require some type of whitelisting.
391    /// * `pkcs12` - certificate used to announce to clients
392    #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
393    #[cfg_attr(
394        docsrs,
395        doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
396    )]
397    pub fn register_tls_listener_std(
398        &mut self,
399        listener: std::net::TcpListener,
400        timeout: Duration,
401        certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
402    ) -> io::Result<()> {
403        self.register_tls_listener(
404            net::TcpListener::from_std(listener)?,
405            timeout,
406            certificate_and_key,
407        )
408    }
409
410    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
411    /// IPv6 or an IPv4 address.
412    ///
413    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
414    ///  to not make this too low depending on use cases.
415    ///
416    /// # Arguments
417    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
418    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
419    ///               requests within this time period will be closed. In the future it should be
420    ///               possible to create long-lived queries, but these should be from trusted sources
421    ///               only, this would require some type of whitelisting.
422    /// * `tls_config` - rustls server config
423    #[cfg(feature = "dns-over-rustls")]
424    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-rustls")))]
425    pub fn register_tls_listener_with_tls_config(
426        &mut self,
427        listener: net::TcpListener,
428        timeout: Duration,
429        tls_config: Arc<ServerConfig>,
430    ) -> io::Result<()> {
431        use crate::proto::rustls::tls_from_stream;
432        use tokio_rustls::TlsAcceptor;
433
434        let handler = self.handler.clone();
435
436        debug!("registered tcp: {:?}", listener);
437
438        let tls_acceptor = TlsAcceptor::from(tls_config);
439
440        // for each incoming request...
441        let shutdown = self.shutdown_token.clone();
442        self.join_set.spawn(async move {
443            let mut inner_join_set = JoinSet::new();
444            loop {
445                let (tcp_stream, src_addr) = tokio::select! {
446                    tcp_stream = listener.accept() => match tcp_stream {
447                        Ok((t, s)) => (t, s),
448                        Err(e) => {
449                            debug!("error receiving TLS tcp_stream error: {}", e);
450                            if is_unrecoverable_socket_error(&e) {
451                                break;
452                            }
453                            continue;
454                        },
455                    },
456                    _ = shutdown.cancelled() => {
457                        // A graceful shutdown was initiated. Break out of the loop.
458                        break;
459                    },
460                };
461
462                // verify that the src address is safe for responses
463                if let Err(e) = sanitize_src_address(src_addr) {
464                    warn!(
465                        "address can not be responded to {src_addr}: {e}",
466                        src_addr = src_addr,
467                        e = e
468                    );
469                    continue;
470                }
471
472                let handler = handler.clone();
473                let tls_acceptor = tls_acceptor.clone();
474
475                // kick out to a different task immediately, let them do the TLS handshake
476                inner_join_set.spawn(async move {
477                    debug!("starting TLS request from: {}", src_addr);
478
479                    // perform the TLS
480                    let tls_stream = tls_acceptor.accept(tcp_stream).await;
481
482                    let tls_stream = match tls_stream {
483                        Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
484                        Err(e) => {
485                            debug!("tls handshake src: {} error: {}", src_addr, e);
486                            return;
487                        }
488                    };
489                    debug!("accepted TLS request from: {}", src_addr);
490                    let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
491                    let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
492                    while let Some(message) = timeout_stream.next().await {
493                        let message = match message {
494                            Ok(message) => message,
495                            Err(e) => {
496                                debug!(
497                                    "error in TLS request_stream src: {:?} error: {}",
498                                    src_addr, e
499                                );
500
501                                // kill this connection
502                                return;
503                            }
504                        };
505
506                        handle_raw_request(
507                            message,
508                            Protocol::Tls,
509                            handler.clone(),
510                            stream_handle.clone(),
511                        )
512                        .await;
513                    }
514                });
515
516                reap_tasks(&mut inner_join_set);
517            }
518
519            if shutdown.is_cancelled() {
520                Ok(())
521            } else {
522                Err(ProtoError::from("unexpected close of socket"))
523            }
524        });
525
526        Ok(())
527    }
528
529    /// Register a TlsListener to the Server by providing a pkcs12 certificate and key. The TlsListener
530    /// should already be bound to either an IPv6 or an IPv4 address.
531    ///
532    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
533    ///  to not make this too low depending on use cases.
534    ///
535    /// # Arguments
536    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
537    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
538    ///               requests within this time period will be closed. In the future it should be
539    ///               possible to create long-lived queries, but these should be from trusted sources
540    ///               only, this would require some type of whitelisting.
541    /// * `pkcs12` - certificate used to announce to clients
542    #[cfg(feature = "dns-over-rustls")]
543    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-rustls")))]
544    pub fn register_tls_listener(
545        &mut self,
546        listener: net::TcpListener,
547        timeout: Duration,
548        certificate_and_key: (Vec<Certificate>, PrivateKey),
549    ) -> io::Result<()> {
550        use crate::proto::rustls::tls_server;
551
552        let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
553            .map_err(|e| {
554            io::Error::new(
555                io::ErrorKind::Other,
556                format!("error creating TLS acceptor: {e}"),
557            )
558        })?;
559
560        Self::register_tls_listener_with_tls_config(self, listener, timeout, Arc::new(tls_acceptor))
561    }
562
563    /// Register a TcpListener for HTTPS (h2) to the Server for supporting DoH (dns-over-https). The TcpListener should already be bound to either an
564    /// IPv6 or an IPv4 address.
565    ///
566    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
567    ///  to not make this too low depending on use cases.
568    ///
569    /// # Arguments
570    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
571    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
572    ///               requests within this time period will be closed. In the future it should be
573    ///               possible to create long-lived queries, but these should be from trusted sources
574    ///               only, this would require some type of whitelisting.
575    /// * `certificate_and_key` - certificate and key used to announce to clients
576    #[cfg(feature = "dns-over-https-rustls")]
577    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-https-rustls")))]
578    pub fn register_https_listener(
579        &mut self,
580        listener: net::TcpListener,
581        // TODO: need to set a timeout between requests.
582        _timeout: Duration,
583        certificate_and_key: (Vec<Certificate>, PrivateKey),
584        dns_hostname: Option<String>,
585    ) -> io::Result<()> {
586        use tokio_rustls::TlsAcceptor;
587
588        use crate::proto::rustls::tls_server;
589        use crate::server::h2_handler::h2_handler;
590
591        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
592
593        let handler = self.handler.clone();
594        debug!("registered https: {listener:?}");
595
596        let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
597            .map_err(|e| {
598            io::Error::new(
599                io::ErrorKind::Other,
600                format!("error creating TLS acceptor: {e}"),
601            )
602        })?;
603        let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor));
604
605        // for each incoming request...
606        let shutdown = self.shutdown_token.clone();
607        self.join_set.spawn(async move {
608            let mut inner_join_set = JoinSet::new();
609            loop {
610                let shutdown = shutdown.clone();
611                let (tcp_stream, src_addr) = tokio::select! {
612                    tcp_stream = listener.accept() => match tcp_stream {
613                        Ok((t, s)) => (t, s),
614                        Err(e) => {
615                            debug!("error receiving HTTPS tcp_stream error: {}", e);
616                            if is_unrecoverable_socket_error(&e) {
617                                break;
618                            }
619                            continue;
620                        },
621                    },
622                    _ = shutdown.cancelled() => {
623                        // A graceful shutdown was initiated. Break out of the loop.
624                        break;
625                    },
626                };
627
628                // verify that the src address is safe for responses
629                if let Err(e) = sanitize_src_address(src_addr) {
630                    warn!("address can not be responded to {src_addr}: {e}");
631                    continue;
632                }
633
634                let handler = handler.clone();
635                let tls_acceptor = tls_acceptor.clone();
636                let dns_hostname = dns_hostname.clone();
637
638                inner_join_set.spawn(async move {
639                    debug!("starting HTTPS request from: {src_addr}");
640
641                    // TODO: need to consider timeout of total connect...
642                    // take the created stream...
643                    let tls_stream = tls_acceptor.accept(tcp_stream).await;
644
645                    let tls_stream = match tls_stream {
646                        Ok(tls_stream) => tls_stream,
647                        Err(e) => {
648                            debug!("https handshake src: {src_addr} error: {e}");
649                            return;
650                        }
651                    };
652                    debug!("accepted HTTPS request from: {src_addr}");
653
654                    h2_handler(
655                        handler,
656                        tls_stream,
657                        src_addr,
658                        dns_hostname,
659                        shutdown.clone(),
660                    )
661                    .await;
662                });
663
664                reap_tasks(&mut inner_join_set);
665            }
666
667            if shutdown.is_cancelled() {
668                Ok(())
669            } else {
670                Err(ProtoError::from("unexpected close of socket"))
671            }
672        });
673
674        Ok(())
675    }
676
677    /// Register a UdpSocket to the Server for supporting DoQ (dns-over-quic). The UdpSocket should already be bound to either an
678    /// IPv6 or an IPv4 address.
679    ///
680    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
681    ///  to not make this too low depending on use cases.
682    ///
683    /// # Arguments
684    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
685    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
686    ///               requests within this time period will be closed. In the future it should be
687    ///               possible to create long-lived queries, but these should be from trusted sources
688    ///               only, this would require some type of whitelisting.
689    /// * `pkcs12` - certificate used to announce to clients
690    #[cfg(feature = "dns-over-quic")]
691    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-quic")))]
692    pub fn register_quic_listener(
693        &mut self,
694        socket: net::UdpSocket,
695        // TODO: need to set a timeout between requests.
696        _timeout: Duration,
697        certificate_and_key: (Vec<Certificate>, PrivateKey),
698        dns_hostname: Option<String>,
699    ) -> io::Result<()> {
700        use crate::proto::quic::QuicServer;
701        use crate::server::quic_handler::quic_handler;
702
703        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
704
705        let handler = self.handler.clone();
706
707        debug!("registered quic: {:?}", socket);
708        let mut server =
709            QuicServer::with_socket(socket, certificate_and_key.0, certificate_and_key.1)?;
710
711        // for each incoming request...
712        let shutdown = self.shutdown_token.clone();
713        self.join_set.spawn(async move {
714            let mut inner_join_set = JoinSet::new();
715            loop {
716                let shutdown = shutdown.clone();
717                let (streams, src_addr) = tokio::select! {
718                    result = server.next() => match result {
719                        Ok(Some(c)) => c,
720                        Ok(None) => continue,
721                        Err(e) => {
722                            debug!("error receiving quic connection: {e}");
723                            continue;
724                        }
725                    },
726                    _ = shutdown.cancelled() => {
727                        // A graceful shutdown was initiated. Break out of the loop.
728                        break;
729                    },
730                };
731
732                // verify that the src address is safe for responses
733                // TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing
734                if let Err(e) = sanitize_src_address(src_addr) {
735                    warn!(
736                        "address can not be responded to {src_addr}: {e}",
737                        src_addr = src_addr,
738                        e = e
739                    );
740                    continue;
741                }
742
743                let handler = handler.clone();
744                let dns_hostname = dns_hostname.clone();
745
746                inner_join_set.spawn(async move {
747                    debug!("starting quic stream request from: {src_addr}");
748
749                    // TODO: need to consider timeout of total connect...
750                    let result =
751                        quic_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
752                            .await;
753
754                    if let Err(e) = result {
755                        warn!("quic stream processing failed from {src_addr}: {e}")
756                    }
757                });
758
759                reap_tasks(&mut inner_join_set);
760            }
761
762            Ok(())
763        });
764
765        Ok(())
766    }
767
768    /// Register a UdpSocket to the Server for supporting DoH3 (dns-over-h3). The UdpSocket should already be bound to either an
769    /// IPv6 or an IPv4 address.
770    ///
771    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
772    ///  to not make this too low depending on use cases.
773    ///
774    /// # Arguments
775    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
776    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
777    ///               requests within this time period will be closed. In the future it should be
778    ///               possible to create long-lived queries, but these should be from trusted sources
779    ///               only, this would require some type of whitelisting.
780    /// * `pkcs12` - certificate used to announce to clients
781    #[cfg(feature = "dns-over-h3")]
782    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-h3")))]
783    pub fn register_h3_listener(
784        &mut self,
785        socket: net::UdpSocket,
786        // TODO: need to set a timeout between requests.
787        _timeout: Duration,
788        certificate_and_key: (Vec<Certificate>, PrivateKey),
789        dns_hostname: Option<String>,
790    ) -> io::Result<()> {
791        use crate::proto::h3::h3_server::H3Server;
792        use crate::server::h3_handler::h3_handler;
793
794        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
795
796        let handler = self.handler.clone();
797
798        debug!("registered h3: {:?}", socket);
799        let mut server =
800            H3Server::with_socket(socket, certificate_and_key.0, certificate_and_key.1)?;
801
802        // for each incoming request...
803        let shutdown = self.shutdown_token.clone();
804        self.join_set.spawn(async move {
805            let mut inner_join_set = JoinSet::new();
806            loop {
807                let shutdown = shutdown.clone();
808                let (streams, src_addr) = tokio::select! {
809                    result = server.accept() => match result {
810                        Ok(Some(c)) => c,
811                        Ok(None) => continue,
812                        Err(e) => {
813                            debug!("error receiving h3 connection: {e}");
814                            continue;
815                        }
816                    },
817                    _ = shutdown.cancelled() => {
818                        // A graceful shutdown was initiated. Break out of the loop.
819                        break;
820                    },
821                };
822
823                // verify that the src address is safe for responses
824                // TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing
825                if let Err(e) = sanitize_src_address(src_addr) {
826                    warn!(
827                        "address can not be responded to {src_addr}: {e}",
828                        src_addr = src_addr,
829                        e = e
830                    );
831                    continue;
832                }
833
834                let handler = handler.clone();
835                let dns_hostname = dns_hostname.clone();
836
837                inner_join_set.spawn(async move {
838                    debug!("starting h3 stream request from: {src_addr}");
839
840                    // TODO: need to consider timeout of total connect...
841                    let result =
842                        h3_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
843                            .await;
844
845                    if let Err(e) = result {
846                        warn!("h3 stream processing failed from {src_addr}: {e}")
847                    }
848                });
849
850                reap_tasks(&mut inner_join_set);
851            }
852
853            Ok(())
854        });
855
856        Ok(())
857    }
858
859    /// Triggers a graceful shutdown the server. All background tasks will stop accepting
860    /// new connections and the returned future will complete once all tasks have terminated.
861    pub async fn shutdown_gracefully(&mut self) -> Result<(), ProtoError> {
862        self.shutdown_token.cancel();
863
864        // Wait for the server to complete.
865        block_until_done(&mut self.join_set).await
866    }
867
868    /// This will run until all background tasks complete. If one or more tasks return an error,
869    /// one will be chosen as the returned error for this future.
870    pub async fn block_until_done(&mut self) -> Result<(), ProtoError> {
871        block_until_done(&mut self.join_set).await
872    }
873}
874
875async fn block_until_done(
876    join_set: &mut JoinSet<Result<(), ProtoError>>,
877) -> Result<(), ProtoError> {
878    if join_set.is_empty() {
879        warn!("block_until_done called with no pending tasks");
880        return Ok(());
881    }
882
883    // Now wait for all of the tasks to complete.
884    let mut out = Ok(());
885    while let Some(join_result) = join_set.join_next().await {
886        match join_result {
887            Ok(result) => {
888                match result {
889                    Ok(_) => (),
890                    Err(e) => {
891                        // Save the last error.
892                        out = Err(e);
893                    }
894                }
895            }
896            Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
897        }
898    }
899    out
900}
901
902/// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
903fn reap_tasks(join_set: &mut JoinSet<()>) {
904    while FutureExt::now_or_never(join_set.join_next())
905        .flatten()
906        .is_some()
907    {}
908}
909
910pub(crate) async fn handle_raw_request<T: RequestHandler>(
911    message: SerialMessage,
912    protocol: Protocol,
913    request_handler: Arc<T>,
914    response_handler: BufDnsStreamHandle,
915) {
916    let src_addr = message.addr();
917    let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
918
919    handle_request(
920        message.bytes(),
921        src_addr,
922        protocol,
923        request_handler,
924        response_handler,
925    )
926    .await;
927}
928
929#[derive(Clone)]
930struct ReportingResponseHandler<R: ResponseHandler> {
931    request_header: Header,
932    query: LowerQuery,
933    protocol: Protocol,
934    src_addr: SocketAddr,
935    handler: R,
936}
937
938#[async_trait::async_trait]
939#[allow(clippy::uninlined_format_args)]
940impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
941    async fn send_response<'a>(
942        &mut self,
943        response: crate::authority::MessageResponse<
944            '_,
945            'a,
946            impl Iterator<Item = &'a Record> + Send + 'a,
947            impl Iterator<Item = &'a Record> + Send + 'a,
948            impl Iterator<Item = &'a Record> + Send + 'a,
949            impl Iterator<Item = &'a Record> + Send + 'a,
950        >,
951    ) -> io::Result<super::ResponseInfo> {
952        let response_info = self.handler.send_response(response).await?;
953
954        let id = self.request_header.id();
955        let rid = response_info.id();
956        if id != rid {
957            warn!("request id:{id} does not match response id:{rid}");
958            debug_assert_eq!(id, rid, "request id and response id should match");
959        }
960
961        let rflags = response_info.flags();
962        let answer_count = response_info.answer_count();
963        let authority_count = response_info.name_server_count();
964        let additional_count = response_info.additional_count();
965        let response_code = response_info.response_code();
966
967        info!("request:{id} src:{proto}://{addr}#{port} {op}:{query}:{qtype}:{class} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
968            id = rid,
969            proto = self.protocol,
970            addr = self.src_addr.ip(),
971            port = self.src_addr.port(),
972            op = self.request_header.op_code(),
973            query = self.query.name(),
974            qtype = self.query.query_type(),
975            class = self.query.query_class(),
976            qflags = self.request_header.flags(),
977            code = response_code,
978            answers = answer_count,
979            authorities = authority_count,
980            additionals = additional_count,
981            rflags = rflags
982        );
983
984        Ok(response_info)
985    }
986}
987
988pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
989    // TODO: allow Message here...
990    message_bytes: &[u8],
991    src_addr: SocketAddr,
992    protocol: Protocol,
993    request_handler: Arc<T>,
994    response_handler: R,
995) {
996    let mut decoder = BinDecoder::new(message_bytes);
997
998    // method to handle the request
999    let inner_handle_request = |message: MessageRequest, response_handler: R| async move {
1000        if message.message_type() == MessageType::Response {
1001            // Don't process response messages to avoid DoS attacks from reflection.
1002            return;
1003        }
1004
1005        let id = message.id();
1006        let qflags = message.header().flags();
1007        let qop_code = message.op_code();
1008        let message_type = message.message_type();
1009        let is_dnssec = message.edns().is_some_and(Edns::dnssec_ok);
1010
1011        let request = Request::new(message, src_addr, protocol);
1012
1013        let info = request.request_info();
1014        let query = info.query.clone();
1015        let query_name = info.query.name();
1016        let query_type = info.query.query_type();
1017        let query_class = info.query.query_class();
1018
1019        debug!(
1020            "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op}:{query}:{qtype}:{class} qflags:{qflags}",
1021            id = id,
1022            proto = protocol,
1023            addr = src_addr.ip(),
1024            port = src_addr.port(),
1025            message_type= message_type,
1026            is_dnssec = is_dnssec,
1027            op = qop_code,
1028            query = query_name,
1029            qtype = query_type,
1030            class = query_class,
1031            qflags = qflags,
1032        );
1033
1034        // The reporter will handle making sure to log the result of the request
1035        let reporter = ReportingResponseHandler {
1036            request_header: *request.header(),
1037            query,
1038            protocol,
1039            src_addr,
1040            handler: response_handler,
1041        };
1042
1043        request_handler.handle_request(&request, reporter).await;
1044    };
1045
1046    // Attempt to decode the message
1047    match MessageRequest::read(&mut decoder) {
1048        Ok(message) => {
1049            inner_handle_request(message, response_handler).await;
1050        }
1051        Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
1052            // We failed to parse the request due to some issue in the message, but the header is available, so we can respond
1053            let (header, error) = kind
1054                .into_form_error()
1055                .expect("as form_error already confirmed this is a FormError");
1056            let query = LowerQuery::query(Query::default());
1057
1058            // debug for more info on why the message parsing failed
1059            debug!(
1060                "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:FormError:{error}",
1061                id = header.id(),
1062                proto = protocol,
1063                addr = src_addr.ip(),
1064                port = src_addr.port(),
1065                message_type= header.message_type(),
1066                op = header.op_code(),
1067                error = error,
1068            );
1069
1070            // The reporter will handle making sure to log the result of the request
1071            let mut reporter = ReportingResponseHandler {
1072                request_header: header,
1073                query,
1074                protocol,
1075                src_addr,
1076                handler: response_handler,
1077            };
1078
1079            let response = MessageResponseBuilder::new(None);
1080            let result = reporter
1081                .send_response(response.error_msg(&header, ResponseCode::FormErr))
1082                .await;
1083
1084            if let Err(e) = result {
1085                warn!("failed to return FormError to client: {}", e);
1086            }
1087        }
1088        Err(e) => warn!("failed to read message: {}", e),
1089    }
1090}
1091
1092/// Checks if the IP address is safe for returning messages
1093///
1094/// Examples of unsafe addresses are any with a port of `0`
1095///
1096/// # Returns
1097///
1098/// Error if the address should not be used for returned requests
1099fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
1100    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
1101    if src.port() == 0 {
1102        return Err(format!("cannot respond to src on port 0: {src}"));
1103    }
1104
1105    fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
1106        if src.is_unspecified() {
1107            return Err(format!("cannot respond to unspecified v4 addr: {src}"));
1108        }
1109
1110        if src.is_broadcast() {
1111            return Err(format!("cannot respond to broadcast v4 addr: {src}"));
1112        }
1113
1114        // TODO: add check for is_reserved when that stabilizes
1115
1116        Ok(())
1117    }
1118
1119    fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
1120        if src.is_unspecified() {
1121            return Err(format!("cannot respond to unspecified v6 addr: {src}"));
1122        }
1123
1124        Ok(())
1125    }
1126
1127    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
1128    match src.ip() {
1129        IpAddr::V4(v4) => verify_v4(v4),
1130        IpAddr::V6(v6) => verify_v6(v6),
1131    }
1132}
1133
1134fn is_unrecoverable_socket_error(err: &io::Error) -> bool {
1135    matches!(
1136        err.kind(),
1137        io::ErrorKind::NotConnected | io::ErrorKind::ConnectionAborted
1138    )
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143    use super::*;
1144    use crate::authority::Catalog;
1145    use futures_util::future;
1146    #[cfg(feature = "dns-over-rustls")]
1147    use rustls::{Certificate, PrivateKey};
1148    use std::net::SocketAddr;
1149    use tokio::net::{TcpListener, UdpSocket};
1150    use tokio::time::timeout;
1151
1152    #[tokio::test]
1153    async fn abort() {
1154        let endpoints = Endpoints::new().await;
1155
1156        let endpoints2 = endpoints.clone();
1157        let (abortable, abort_handle) = future::abortable(async move {
1158            let mut server_future = ServerFuture::new(Catalog::new());
1159            endpoints2.register(&mut server_future).await;
1160            server_future.block_until_done().await
1161        });
1162
1163        abort_handle.abort();
1164        abortable.await.expect_err("expected abort");
1165
1166        endpoints.rebind_all().await;
1167    }
1168
1169    #[tokio::test]
1170    async fn graceful_shutdown() {
1171        let mut server_future = ServerFuture::new(Catalog::new());
1172        let endpoints = Endpoints::new().await;
1173        endpoints.register(&mut server_future).await;
1174
1175        timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1176            .await
1177            .expect("timed out waiting for the server to complete")
1178            .expect("error while awaiting tasks");
1179
1180        endpoints.rebind_all().await;
1181    }
1182
1183    #[test]
1184    fn test_sanitize_src_addr() {
1185        // ipv4 tests
1186        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4096))).is_ok());
1187        assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1188
1189        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1190        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1191        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4096))).is_err());
1192        assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4096))).is_err());
1193
1194        // ipv6 tests
1195        assert!(
1196            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4096))).is_ok()
1197        );
1198        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4096))).is_ok());
1199
1200        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4096))).is_err());
1201        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1202        assert!(
1203            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1204        );
1205    }
1206
1207    #[derive(Clone)]
1208    struct Endpoints {
1209        udp_addr: SocketAddr,
1210        udp_std_addr: SocketAddr,
1211        tcp_addr: SocketAddr,
1212        tcp_std_addr: SocketAddr,
1213        #[cfg(feature = "dns-over-rustls")]
1214        rustls_addr: SocketAddr,
1215        #[cfg(feature = "dns-over-https-rustls")]
1216        https_rustls_addr: SocketAddr,
1217        #[cfg(feature = "dns-over-quic")]
1218        quic_addr: SocketAddr,
1219        #[cfg(feature = "dns-over-h3")]
1220        h3_addr: SocketAddr,
1221    }
1222
1223    impl Endpoints {
1224        async fn new() -> Self {
1225            let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1226            let udp_std = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1227            let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1228            let tcp_std = TcpListener::bind("127.0.0.1:0").await.unwrap();
1229            #[cfg(feature = "dns-over-rustls")]
1230            let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1231            #[cfg(feature = "dns-over-https-rustls")]
1232            let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1233            #[cfg(feature = "dns-over-quic")]
1234            let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1235            #[cfg(feature = "dns-over-h3")]
1236            let h3 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1237
1238            Self {
1239                udp_addr: udp.local_addr().unwrap(),
1240                udp_std_addr: udp_std.local_addr().unwrap(),
1241                tcp_addr: tcp.local_addr().unwrap(),
1242                tcp_std_addr: tcp_std.local_addr().unwrap(),
1243                #[cfg(feature = "dns-over-rustls")]
1244                rustls_addr: rustls.local_addr().unwrap(),
1245                #[cfg(feature = "dns-over-https-rustls")]
1246                https_rustls_addr: https_rustls.local_addr().unwrap(),
1247                #[cfg(feature = "dns-over-quic")]
1248                quic_addr: quic.local_addr().unwrap(),
1249                #[cfg(feature = "dns-over-h3")]
1250                h3_addr: h3.local_addr().unwrap(),
1251            }
1252        }
1253
1254        async fn register<T: RequestHandler>(&self, server: &mut ServerFuture<T>) {
1255            server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1256            server
1257                .register_socket_std(std::net::UdpSocket::bind(self.udp_std_addr).unwrap())
1258                .unwrap();
1259            server.register_listener(
1260                TcpListener::bind(self.tcp_addr).await.unwrap(),
1261                Duration::from_secs(1),
1262            );
1263            server
1264                .register_listener_std(
1265                    std::net::TcpListener::bind(self.tcp_std_addr).unwrap(),
1266                    Duration::from_secs(1),
1267                )
1268                .unwrap();
1269
1270            #[cfg(feature = "dns-over-rustls")]
1271            {
1272                let cert_key = rustls_cert_key();
1273                server
1274                    .register_tls_listener(
1275                        TcpListener::bind(self.rustls_addr).await.unwrap(),
1276                        Duration::from_secs(30),
1277                        cert_key,
1278                    )
1279                    .unwrap();
1280            }
1281
1282            #[cfg(feature = "dns-over-https-rustls")]
1283            {
1284                let cert_key = rustls_cert_key();
1285                server
1286                    .register_https_listener(
1287                        TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1288                        Duration::from_secs(1),
1289                        cert_key,
1290                        None,
1291                    )
1292                    .unwrap();
1293            }
1294
1295            #[cfg(feature = "dns-over-quic")]
1296            {
1297                let cert_key = rustls_cert_key();
1298                server
1299                    .register_quic_listener(
1300                        UdpSocket::bind(self.quic_addr).await.unwrap(),
1301                        Duration::from_secs(1),
1302                        cert_key,
1303                        None,
1304                    )
1305                    .unwrap();
1306            }
1307
1308            #[cfg(feature = "dns-over-h3")]
1309            {
1310                let cert_key = rustls_cert_key();
1311                server
1312                    .register_h3_listener(
1313                        UdpSocket::bind(self.h3_addr).await.unwrap(),
1314                        Duration::from_secs(1),
1315                        cert_key,
1316                        None,
1317                    )
1318                    .unwrap();
1319            }
1320        }
1321
1322        async fn rebind_all(&self) {
1323            UdpSocket::bind(self.udp_addr).await.unwrap();
1324            UdpSocket::bind(self.udp_std_addr).await.unwrap();
1325            TcpListener::bind(self.tcp_addr).await.unwrap();
1326            TcpListener::bind(self.tcp_std_addr).await.unwrap();
1327            #[cfg(feature = "dns-over-rustls")]
1328            TcpListener::bind(self.rustls_addr).await.unwrap();
1329            #[cfg(feature = "dns-over-https-rustls")]
1330            TcpListener::bind(self.https_rustls_addr).await.unwrap();
1331            #[cfg(feature = "dns-over-quic")]
1332            UdpSocket::bind(self.quic_addr).await.unwrap();
1333            #[cfg(feature = "dns-over-h3")]
1334            UdpSocket::bind(self.h3_addr).await.unwrap();
1335        }
1336    }
1337
1338    #[cfg(feature = "dns-over-rustls")]
1339    fn rustls_cert_key() -> (Vec<Certificate>, PrivateKey) {
1340        use hickory_proto::rustls::tls_server;
1341        use std::env;
1342        use std::path::Path;
1343
1344        let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1345
1346        let cert = tls_server::read_cert(Path::new(&format!(
1347            "{}/tests/test-data/cert.pem",
1348            server_path
1349        )))
1350        .map_err(|e| format!("error reading cert: {e}"))
1351        .unwrap();
1352        let key = tls_server::read_key_from_pem(Path::new(&format!(
1353            "{}/tests/test-data/cert.key",
1354            server_path
1355        )))
1356        .unwrap();
1357
1358        (cert, key)
1359    }
1360
1361    #[test]
1362    fn task_reap_on_empty_joinset() {
1363        let mut joinset = JoinSet::new();
1364
1365        // this should return immediately
1366        reap_tasks(&mut joinset);
1367    }
1368
1369    #[test]
1370    fn task_reap_on_nonempty_joinset() {
1371        let runtime = tokio::runtime::Runtime::new().unwrap();
1372        runtime.block_on(async {
1373            let mut joinset = JoinSet::new();
1374            let t = joinset.spawn(tokio::time::sleep(Duration::from_secs(2)));
1375
1376            // this should return immediately since no task is ready
1377            reap_tasks(&mut joinset);
1378            t.abort();
1379
1380            // this should also return immediately since the task has been aborted
1381            reap_tasks(&mut joinset);
1382        });
1383    }
1384}