1use std::{
11 fmt, io,
12 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
13 sync::Arc,
14 time::Duration,
15};
16
17use bytes::Bytes;
18use futures_util::StreamExt;
19use ipnet::IpNet;
20#[cfg(feature = "__tls")]
21use rustls::{ServerConfig, server::ResolvesServerCert};
22#[cfg(feature = "__tls")]
23use tokio::time::timeout;
24use tokio::{net, task::JoinSet};
25#[cfg(feature = "__tls")]
26use tokio_rustls::TlsAcceptor;
27use tokio_util::sync::CancellationToken;
28use tracing::{debug, info, warn};
29
30#[cfg(feature = "metrics")]
31use crate::metrics::ResponseHandlerMetrics;
32#[cfg(feature = "__h3")]
33use crate::net::h3::h3_server::H3Server;
34#[cfg(feature = "__quic")]
35use crate::net::quic::QuicServer;
36#[cfg(feature = "__tls")]
37use crate::net::tls::{default_provider, tls_from_stream};
38use crate::{
39 access::AccessControl,
40 net::{
41 BufDnsStreamHandle, NetError,
42 runtime::{TokioRuntimeProvider, TokioTime, iocompat::AsyncIoTokioAsStd},
43 tcp::TcpStream,
44 udp::UdpStream,
45 xfer::Protocol,
46 },
47 proto::{
48 op::{Header, LowerQuery, MessageType, Metadata, ResponseCode, SerialMessage},
49 rr::Record,
50 serialize::binary::{BinDecodable, BinDecoder},
51 },
52 zone_handler::{MessageRequest, MessageResponseBuilder, Queries},
53};
54
55#[cfg(feature = "__https")]
56mod h2_handler;
57#[cfg(feature = "__h3")]
58mod h3_handler;
59#[cfg(feature = "__quic")]
60mod quic_handler;
61mod request_handler;
62pub use request_handler::{Request, RequestHandler, RequestInfo, ResponseInfo};
63mod response_handler;
64pub use response_handler::{ResponseHandle, ResponseHandler};
65mod timeout_stream;
66pub use timeout_stream::TimeoutStream;
67
68pub struct Server<T: RequestHandler> {
71 context: Arc<ServerContext<T>>,
72 join_set: JoinSet<Result<(), NetError>>,
73}
74
75impl<T: RequestHandler> Server<T> {
76 pub fn new(handler: T) -> Self {
78 Self::with_access(handler, [], [])
79 }
80
81 pub fn with_access(
83 handler: T,
84 denied_networks: impl IntoIterator<Item = IpNet>,
85 allowed_networks: impl IntoIterator<Item = IpNet>,
86 ) -> Self {
87 let mut access = AccessControl::default();
88 access.insert_deny(denied_networks);
89 access.insert_allow(allowed_networks);
90
91 Self {
92 context: Arc::new(ServerContext {
93 handler,
94 access,
95 shutdown: CancellationToken::new(),
96 }),
97 join_set: JoinSet::new(),
98 }
99 }
100
101 pub fn register_socket(&mut self, socket: net::UdpSocket) {
103 self.join_set
104 .spawn(handle_udp(socket, self.context.clone()));
105 }
106
107 pub fn register_listener(
121 &mut self,
122 listener: net::TcpListener,
123 timeout: Duration,
124 response_buffer_size: usize,
125 ) {
126 self.join_set.spawn(handle_tcp(
127 listener,
128 timeout,
129 response_buffer_size,
130 self.context.clone(),
131 ));
132 }
133
134 #[cfg(feature = "__tls")]
151 pub fn register_tls_listener_with_tls_config(
152 &mut self,
153 listener: net::TcpListener,
154 handshake_timeout: Duration,
155 tls_config: Arc<ServerConfig>,
156 ) -> io::Result<()> {
157 self.join_set.spawn(handle_tls(
158 listener,
159 tls_config,
160 handshake_timeout,
161 self.context.clone(),
162 ));
163 Ok(())
164 }
165
166 #[cfg(feature = "__tls")]
180 pub fn register_tls_listener(
181 &mut self,
182 listener: net::TcpListener,
183 timeout: Duration,
184 server_cert_resolver: Arc<dyn ResolvesServerCert>,
185 ) -> io::Result<()> {
186 Self::register_tls_listener_with_tls_config(
187 self,
188 listener,
189 timeout,
190 Arc::new(default_tls_server_config(b"dot", server_cert_resolver)?),
191 )
192 }
193
194 #[cfg(feature = "__https")]
210 pub fn register_https_listener(
211 &mut self,
212 listener: net::TcpListener,
213 handshake_timeout: Duration,
215 server_cert_resolver: Arc<dyn ResolvesServerCert>,
216 dns_hostname: Option<String>,
217 http_endpoint: String,
218 ) -> io::Result<()> {
219 self.join_set.spawn(h2_handler::handle_h2(
220 listener,
221 handshake_timeout,
222 server_cert_resolver,
223 dns_hostname,
224 http_endpoint,
225 self.context.clone(),
226 ));
227 Ok(())
228 }
229
230 #[cfg(feature = "__https")]
250 pub fn register_https_listener_with_tls_config(
251 &mut self,
252 listener: net::TcpListener,
253 handshake_timeout: Duration,
255 tls_config: Arc<ServerConfig>,
256 dns_hostname: Option<String>,
257 http_endpoint: String,
258 ) -> io::Result<()> {
259 self.join_set.spawn(h2_handler::handle_h2_with_acceptor(
260 listener,
261 handshake_timeout,
262 TlsAcceptor::from(tls_config),
263 dns_hostname,
264 http_endpoint,
265 self.context.clone(),
266 ));
267 Ok(())
268 }
269
270 #[cfg(feature = "__quic")]
285 pub fn register_quic_listener(
286 &mut self,
287 socket: net::UdpSocket,
288 _timeout: Duration,
290 server_cert_resolver: Arc<dyn ResolvesServerCert>,
291 ) -> io::Result<()> {
292 let cx = self.context.clone();
293 self.join_set
294 .spawn(quic_handler::handle_quic(socket, server_cert_resolver, cx));
295 Ok(())
296 }
297
298 #[cfg(feature = "__quic")]
317 pub fn register_quic_listener_and_tls_config(
318 &mut self,
319 socket: net::UdpSocket,
320 _timeout: Duration,
322 tls_config: Arc<ServerConfig>,
323 ) -> Result<(), NetError> {
324 let cx = self.context.clone();
325
326 self.join_set.spawn(quic_handler::handle_quic_with_server(
327 QuicServer::with_socket_and_tls_config(socket, tls_config)?,
328 cx,
329 ));
330 Ok(())
331 }
332
333 #[cfg(feature = "__h3")]
347 pub fn register_h3_listener(
348 &mut self,
349 socket: net::UdpSocket,
350 _timeout: Duration,
352 server_cert_resolver: Arc<dyn ResolvesServerCert>,
353 dns_hostname: Option<String>,
354 ) -> io::Result<()> {
355 self.join_set.spawn(h3_handler::handle_h3(
356 socket,
357 server_cert_resolver,
358 dns_hostname,
359 self.context.clone(),
360 ));
361 Ok(())
362 }
363
364 #[cfg(feature = "__h3")]
382 pub fn register_h3_listener_with_tls_config(
383 &mut self,
384 socket: net::UdpSocket,
385 _timeout: Duration,
387 tls_config: Arc<ServerConfig>,
388 dns_hostname: Option<String>,
389 ) -> Result<(), NetError> {
390 self.join_set.spawn(h3_handler::handle_h3_with_server(
391 H3Server::with_socket_and_tls_config(socket, tls_config)?,
392 dns_hostname,
393 self.context.clone(),
394 ));
395 Ok(())
396 }
397
398 pub async fn shutdown_gracefully(&mut self) -> Result<(), NetError> {
401 self.context.shutdown.cancel();
402
403 self.block_until_done().await
405 }
406
407 pub fn shutdown_token(&self) -> &CancellationToken {
412 &self.context.shutdown
413 }
414
415 pub async fn block_until_done(&mut self) -> Result<(), NetError> {
418 if self.join_set.is_empty() {
419 warn!("block_until_done called with no pending tasks");
420 return Ok(());
421 }
422
423 let mut out = Ok(());
424 while let Some(join_result) = self.join_set.join_next().await {
425 match join_result {
426 Ok(Ok(())) => continue,
427 Ok(Err(e)) => out = Err(e),
428 Err(e) => return Err(NetError::from(format!("internal error in spawn: {e}"))),
429 }
430 }
431
432 out
433 }
434}
435
436async fn handle_udp(
437 socket: net::UdpSocket,
438 cx: Arc<ServerContext<impl RequestHandler>>,
439) -> Result<(), NetError> {
440 debug!("registering udp: {:?}", socket);
441
442 let (mut stream, stream_handle) =
445 UdpStream::<TokioRuntimeProvider>::with_bound(socket, ([127, 255, 255, 254], 0).into());
446
447 let mut inner_join_set = JoinSet::new();
448 loop {
449 let message = tokio::select! {
450 message = stream.next() => match message {
451 None => break,
452 Some(message) => message,
453 },
454 _ = cx.shutdown.cancelled() => break,
455 };
456
457 let message = match message {
458 Err(error) => {
459 warn!(%error, "error receiving message on udp_socket");
460 if is_unrecoverable_socket_error(&error) {
461 break;
462 }
463 continue;
464 }
465 Ok(message) => message,
466 };
467
468 let src_addr = message.addr();
469 debug!("received udp request from: {}", src_addr);
470
471 if let Err(e) = sanitize_src_address(src_addr) {
473 warn!(
474 "address can not be responded to {src_addr}: {e}",
475 src_addr = src_addr,
476 e = e
477 );
478 continue;
479 }
480
481 let cx = cx.clone();
482 let stream_handle = stream_handle.with_remote_addr(src_addr);
483 inner_join_set.spawn(async move {
484 cx.handle_raw_request(message, Protocol::Udp, stream_handle)
485 .await;
486 });
487
488 reap_tasks(&mut inner_join_set);
489 }
490
491 if cx.shutdown.is_cancelled() {
492 Ok(())
493 } else {
494 Err(NetError::from("unexpected close of UDP socket"))
496 }
497}
498
499async fn handle_tcp(
500 listener: net::TcpListener,
501 timeout: Duration,
502 response_buffer_size: usize,
503 cx: Arc<ServerContext<impl RequestHandler>>,
504) -> Result<(), NetError> {
505 debug!("register tcp: {listener:?}");
506 let mut inner_join_set = JoinSet::new();
507 loop {
508 let (tcp_stream, src_addr) = tokio::select! {
509 tcp_stream = listener.accept() => match tcp_stream {
510 Ok((t, s)) => (t, s),
511 Err(error) => {
512 debug!(%error, "error receiving TCP tcp_stream error");
513 if is_unrecoverable_socket_error(&error) {
514 break;
515 }
516 continue;
517 },
518 },
519 _ = cx.shutdown.cancelled() => {
520 break;
522 },
523 };
524
525 if let Err(error) = sanitize_src_address(src_addr) {
527 warn!(
528 %src_addr, %error,
529 "address can not be responded to (TCP)",
530 );
531 continue;
532 }
533
534 let cx = cx.clone();
536 inner_join_set.spawn(async move {
537 debug!(%src_addr, "accepted TCP request");
538 let (buf_stream, stream_handle) = TcpStream::from_stream_with_buffer_size(
540 AsyncIoTokioAsStd(tcp_stream),
541 src_addr,
542 response_buffer_size,
543 );
544 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
545
546 while let Some(message) = timeout_stream.next().await {
547 let message = match message {
548 Ok(message) => message,
549 Err(error) => {
550 debug!(%src_addr, %error, "error in TCP request stream");
551 return;
553 }
554 };
555
556 cx.handle_raw_request(message, Protocol::Tcp, stream_handle.clone())
558 .await;
559 }
560 });
561
562 reap_tasks(&mut inner_join_set);
563 }
564
565 if cx.shutdown.is_cancelled() {
566 Ok(())
567 } else {
568 Err(NetError::from("unexpected close of socket"))
569 }
570}
571
572#[cfg(feature = "__tls")]
573async fn handle_tls(
574 listener: net::TcpListener,
575 tls_config: Arc<ServerConfig>,
576 handshake_timeout: Duration,
577 cx: Arc<ServerContext<impl RequestHandler>>,
578) -> Result<(), NetError> {
579 debug!(?listener, "registered tls");
580 let tls_acceptor = TlsAcceptor::from(tls_config);
581
582 let mut inner_join_set = JoinSet::new();
583 loop {
584 let (tcp_stream, src_addr) = tokio::select! {
585 tcp_stream = listener.accept() => match tcp_stream {
586 Ok((t, s)) => (t, s),
587 Err(error) => {
588 debug!(%error, "error receiving TLS tcp_stream error");
589 if is_unrecoverable_socket_error(&error) {
590 break;
591 }
592 continue;
593 },
594 },
595 _ = cx.shutdown.cancelled() => {
596 break;
598 },
599 };
600
601 if let Err(error) = sanitize_src_address(src_addr) {
603 warn!(
604 %src_addr, %error,
605 "address can not be responded to (TLS)",
606 );
607 continue;
608 }
609
610 let cx = cx.clone();
611 let tls_acceptor = tls_acceptor.clone();
612 inner_join_set.spawn(async move {
614 debug!(%src_addr, "starting TLS request");
615
616 let Ok(tls_stream) = timeout(handshake_timeout, tls_acceptor.accept(tcp_stream)).await
618 else {
619 warn!("tls timeout expired during handshake");
620 return;
621 };
622
623 let tls_stream = match tls_stream {
624 Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
625 Err(error) => {
626 debug!(%src_addr, %error, "tls handshake error");
627 return;
628 }
629 };
630 debug!(%src_addr, "accepted TLS request");
631 let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
632 let mut timeout_stream = TimeoutStream::new(buf_stream, handshake_timeout);
633 while let Some(message) = timeout_stream.next().await {
634 let message = match message {
635 Ok(message) => message,
636 Err(error) => {
637 debug!(
638 %src_addr, %error,
639 "error in TLS request stream",
640 );
641
642 return;
644 }
645 };
646
647 cx.handle_raw_request(message, Protocol::Tls, stream_handle.clone())
648 .await;
649 }
650 });
651
652 reap_tasks(&mut inner_join_set);
653 }
654
655 if cx.shutdown.is_cancelled() {
656 Ok(())
657 } else {
658 Err(NetError::from("unexpected close of socket"))
659 }
660}
661
662fn reap_tasks(join_set: &mut JoinSet<()>) {
664 while join_set.try_join_next().is_some() {}
665}
666
667#[cfg(feature = "__tls")]
669pub fn default_tls_server_config(
670 protocol: &[u8],
671 server_cert_resolver: Arc<dyn ResolvesServerCert>,
672) -> io::Result<ServerConfig> {
673 let mut config = ServerConfig::builder_with_provider(Arc::new(default_provider()))
674 .with_safe_default_protocol_versions()
675 .map_err(|e| io::Error::other(format!("error creating TLS acceptor: {e}")))?
676 .with_no_client_auth()
677 .with_cert_resolver(server_cert_resolver);
678
679 config.alpn_protocols = vec![protocol.to_vec()];
680
681 Ok(config)
682}
683
684#[derive(Clone)]
685pub(super) struct ReportingResponseHandler<R: ResponseHandler> {
686 pub(super) request_meta: Metadata,
687 queries: Vec<LowerQuery>,
688 pub(super) protocol: Protocol,
689 src_addr: SocketAddr,
690 handler: R,
691 #[cfg(feature = "metrics")]
692 metrics: ResponseHandlerMetrics,
693}
694
695#[async_trait::async_trait]
696impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
697 async fn send_response<'a>(
698 &mut self,
699 response: crate::zone_handler::MessageResponse<
700 '_,
701 'a,
702 impl Iterator<Item = &'a Record> + Send + 'a,
703 impl Iterator<Item = &'a Record> + Send + 'a,
704 impl Iterator<Item = &'a Record> + Send + 'a,
705 impl Iterator<Item = &'a Record> + Send + 'a,
706 >,
707 ) -> Result<ResponseInfo, NetError> {
708 let response_info = self.handler.send_response(response).await?;
709
710 let id = self.request_meta.id;
711 let rid = response_info.id;
712 if id != rid {
713 warn!("request id:{id} does not match response id:{rid}");
714 debug_assert_eq!(id, rid, "request id and response id should match");
715 }
716
717 let rflags = response_info.flags();
718 let answer_count = response_info.counts().answers;
719 let authority_count = response_info.counts().authorities;
720 let additional_count = response_info.counts().additionals;
721 let response_code = response_info.response_code;
722
723 info!(
724 "request:{id} src:{proto}://{addr}#{port} {op} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
725 id = rid,
726 proto = self.protocol,
727 addr = self.src_addr.ip(),
728 port = self.src_addr.port(),
729 op = self.request_meta.op_code,
730 qflags = self.request_meta.flags(),
731 code = response_code,
732 answers = answer_count,
733 authorities = authority_count,
734 additionals = additional_count,
735 rflags = rflags
736 );
737 for query in self.queries.iter() {
738 info!(
739 "query:{query}:{qtype}:{class}",
740 query = query.name(),
741 qtype = query.query_type(),
742 class = query.query_class()
743 );
744 }
745
746 #[cfg(feature = "metrics")]
747 self.metrics.update(self, &response_info);
748
749 Ok(response_info)
750 }
751}
752
753struct ServerContext<T> {
754 handler: T,
755 access: AccessControl,
756 shutdown: CancellationToken,
757}
758
759impl<T: RequestHandler> ServerContext<T> {
760 async fn handle_raw_request(
761 &self,
762 message: SerialMessage,
763 protocol: Protocol,
764 response_handler: BufDnsStreamHandle,
765 ) {
766 let (message, src_addr) = message.into_parts();
767 let response_handler = ResponseHandle::new(src_addr, response_handler, protocol);
768
769 self.handle_request(Bytes::from(message), src_addr, protocol, response_handler)
770 .await;
771 }
772
773 async fn handle_request(
774 &self,
775 message_bytes: Bytes,
776 src_addr: SocketAddr,
777 protocol: Protocol,
778 response_handler: impl ResponseHandler,
779 ) {
780 let mut decoder = BinDecoder::new(&message_bytes);
781 let Ok(header) = Header::read(&mut decoder) else {
782 return;
786 };
787
788 if !self.access.allow(src_addr.ip()) {
789 info!(
790 "request:Refused src:{proto}://{addr}#{port}",
791 proto = protocol,
792 addr = src_addr.ip(),
793 port = src_addr.port(),
794 );
795
796 let queries = match Queries::read(&mut decoder, header.counts.queries as usize) {
797 Ok(queries) => queries,
798 Err(_) => Queries::empty(),
799 };
800 error_response_handler(
801 protocol,
802 src_addr,
803 header,
804 queries,
805 ResponseCode::Refused,
806 "request refused",
807 response_handler,
808 )
809 .await;
810
811 return;
812 }
813
814 let request = match MessageRequest::read(&mut decoder, header) {
816 Ok(message) => Request {
817 message,
818 raw: message_bytes,
819 src: src_addr,
820 protocol,
821 },
822 Err(error) => {
823 let queries = Queries::empty();
825
826 error_response_handler(
827 protocol,
828 src_addr,
829 header,
830 queries,
831 ResponseCode::FormErr,
832 error,
833 response_handler,
834 )
835 .await;
836
837 return;
838 }
839 };
840
841 if request.message.metadata.message_type == MessageType::Response {
842 return;
844 }
845
846 let id = request.message.metadata.id;
847 let qflags = request.message.metadata.flags();
848 let qop_code = request.message.metadata.op_code;
849 let message_type = request.message.metadata.message_type;
850 let is_dnssec = request
851 .message
852 .edns
853 .as_ref()
854 .is_some_and(|edns| edns.flags().dnssec_ok);
855
856 debug!(
857 "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op} qflags:{qflags}",
858 id = id,
859 proto = request.protocol(),
860 addr = request.src().ip(),
861 port = request.src().port(),
862 message_type = message_type,
863 is_dnssec = is_dnssec,
864 op = qop_code,
865 qflags = qflags
866 );
867 for query in request.queries.queries().iter() {
868 debug!(
869 "query:{query}:{qtype}:{class}",
870 query = query.name(),
871 qtype = query.query_type(),
872 class = query.query_class()
873 );
874 }
875
876 let queries = request.queries.queries().to_vec();
878 let reporter = ReportingResponseHandler {
879 request_meta: request.metadata,
880 queries,
881 protocol: request.protocol(),
882 src_addr: request.src(),
883 handler: response_handler,
884 #[cfg(feature = "metrics")]
885 metrics: ResponseHandlerMetrics::default(),
886 };
887
888 self.handler
889 .handle_request::<_, TokioTime>(&request, reporter)
890 .await;
891 }
892}
893
894async fn error_response_handler(
896 protocol: Protocol,
897 src_addr: SocketAddr,
898 header: Header,
899 queries: Queries,
900 response_code: ResponseCode,
901 error: impl fmt::Display,
902 response_handler: impl ResponseHandler,
903) {
904 debug!(
906 "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:{response_code}:{error}",
907 id = header.id,
908 proto = protocol,
909 addr = src_addr.ip(),
910 port = src_addr.port(),
911 message_type = header.message_type,
912 op = header.op_code,
913 response_code = response_code,
914 error = error,
915 );
916
917 let mut reporter = ReportingResponseHandler {
919 request_meta: header.metadata,
920 queries: queries.queries().to_vec(),
921 protocol,
922 src_addr,
923 handler: response_handler,
924 #[cfg(feature = "metrics")]
925 metrics: ResponseHandlerMetrics::default(),
926 };
927
928 let response = MessageResponseBuilder::new(&queries, None);
929 let result = reporter
930 .send_response(response.error_msg(&header, response_code))
931 .await;
932
933 if let Err(error) = result {
934 warn!(%error, "failed to return FormError to client");
935 }
936}
937
938fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
946 if src.port() == 0 {
948 return Err(format!("cannot respond to src on port 0: {src}"));
949 }
950
951 fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
952 if src.is_unspecified() {
953 return Err(format!("cannot respond to unspecified v4 addr: {src}"));
954 }
955
956 if src.is_broadcast() {
957 return Err(format!("cannot respond to broadcast v4 addr: {src}"));
958 }
959
960 Ok(())
963 }
964
965 fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
966 if src.is_unspecified() {
967 return Err(format!("cannot respond to unspecified v6 addr: {src}"));
968 }
969
970 Ok(())
971 }
972
973 match src.ip() {
975 IpAddr::V4(v4) => verify_v4(v4),
976 IpAddr::V6(v6) => verify_v6(v6),
977 }
978}
979
980fn is_unrecoverable_socket_error(err: &io::Error) -> bool {
981 matches!(
982 err.kind(),
983 io::ErrorKind::NotConnected | io::ErrorKind::ConnectionAborted
984 )
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990 use crate::zone_handler::Catalog;
991 use futures_util::future;
992 #[cfg(feature = "__tls")]
993 use rustls::{
994 pki_types::{CertificateDer, PrivateKeyDer},
995 sign::{CertifiedKey, SingleCertAndKey},
996 };
997 use std::net::SocketAddr;
998 use test_support::subscribe;
999 use tokio::net::{TcpListener, UdpSocket};
1000 use tokio::time::timeout;
1001
1002 #[tokio::test]
1003 async fn abort() {
1004 subscribe();
1005
1006 let endpoints = Endpoints::new().await;
1007
1008 let endpoints2 = endpoints.clone();
1009 let (abortable, abort_handle) = future::abortable(async move {
1010 let mut server_future = Server::new(Catalog::new());
1011 endpoints2.register(&mut server_future).await;
1012 server_future.block_until_done().await
1013 });
1014
1015 abort_handle.abort();
1016 abortable.await.expect_err("expected abort");
1017
1018 endpoints.rebind_all().await;
1019 }
1020
1021 #[tokio::test]
1022 async fn graceful_shutdown() {
1023 subscribe();
1024 let mut server_future = Server::new(Catalog::new());
1025 let endpoints = Endpoints::new().await;
1026 endpoints.register(&mut server_future).await;
1027
1028 timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1029 .await
1030 .expect("timed out waiting for the server to complete")
1031 .expect("error while awaiting tasks");
1032
1033 endpoints.rebind_all().await;
1034 }
1035
1036 #[test]
1037 fn test_sanitize_src_addr() {
1038 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4_096))).is_ok());
1040 assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1041
1042 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1043 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1044 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4_096))).is_err());
1045 assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4_096))).is_err());
1046
1047 assert!(
1049 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4_096))).is_ok()
1050 );
1051 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4_096))).is_ok());
1052
1053 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4_096))).is_err());
1054 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1055 assert!(
1056 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1057 );
1058 }
1059
1060 #[derive(Clone)]
1061 struct Endpoints {
1062 udp_addr: SocketAddr,
1063 tcp_addr: SocketAddr,
1064 #[cfg(feature = "__tls")]
1065 rustls_addr: SocketAddr,
1066 #[cfg(feature = "__https")]
1067 https_rustls_addr: SocketAddr,
1068 #[cfg(feature = "__quic")]
1069 quic_addr: SocketAddr,
1070 #[cfg(feature = "__h3")]
1071 h3_addr: SocketAddr,
1072 }
1073
1074 impl Endpoints {
1075 async fn new() -> Self {
1076 let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1077 let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1078 #[cfg(feature = "__tls")]
1079 let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1080 #[cfg(feature = "__https")]
1081 let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1082 #[cfg(feature = "__quic")]
1083 let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1084 #[cfg(feature = "__h3")]
1085 let h3 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1086
1087 Self {
1088 udp_addr: udp.local_addr().unwrap(),
1089 tcp_addr: tcp.local_addr().unwrap(),
1090 #[cfg(feature = "__tls")]
1091 rustls_addr: rustls.local_addr().unwrap(),
1092 #[cfg(feature = "__https")]
1093 https_rustls_addr: https_rustls.local_addr().unwrap(),
1094 #[cfg(feature = "__quic")]
1095 quic_addr: quic.local_addr().unwrap(),
1096 #[cfg(feature = "__h3")]
1097 h3_addr: h3.local_addr().unwrap(),
1098 }
1099 }
1100
1101 async fn register<T: RequestHandler>(&self, server: &mut Server<T>) {
1102 server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1103 server.register_listener(
1104 TcpListener::bind(self.tcp_addr).await.unwrap(),
1105 Duration::from_secs(1),
1106 32,
1107 );
1108
1109 #[cfg(feature = "__tls")]
1110 {
1111 let cert_key = rustls_cert_key();
1112 server
1113 .register_tls_listener(
1114 TcpListener::bind(self.rustls_addr).await.unwrap(),
1115 Duration::from_secs(30),
1116 cert_key,
1117 )
1118 .unwrap();
1119 }
1120
1121 #[cfg(feature = "__https")]
1122 {
1123 let cert_key = rustls_cert_key();
1124 server
1125 .register_https_listener(
1126 TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1127 Duration::from_secs(1),
1128 cert_key,
1129 None,
1130 "/dns-query".into(),
1131 )
1132 .unwrap();
1133 }
1134
1135 #[cfg(feature = "__quic")]
1136 {
1137 let cert_key = rustls_cert_key();
1138 server
1139 .register_quic_listener(
1140 UdpSocket::bind(self.quic_addr).await.unwrap(),
1141 Duration::from_secs(1),
1142 cert_key,
1143 )
1144 .unwrap();
1145 }
1146
1147 #[cfg(feature = "__h3")]
1148 {
1149 let cert_key = rustls_cert_key();
1150 server
1151 .register_h3_listener(
1152 UdpSocket::bind(self.h3_addr).await.unwrap(),
1153 Duration::from_secs(1),
1154 cert_key,
1155 None,
1156 )
1157 .unwrap();
1158 }
1159 }
1160
1161 async fn rebind_all(&self) {
1162 UdpSocket::bind(self.udp_addr).await.unwrap();
1163 TcpListener::bind(self.tcp_addr).await.unwrap();
1164 #[cfg(feature = "__tls")]
1165 TcpListener::bind(self.rustls_addr).await.unwrap();
1166 #[cfg(feature = "__https")]
1167 TcpListener::bind(self.https_rustls_addr).await.unwrap();
1168 #[cfg(feature = "__quic")]
1169 UdpSocket::bind(self.quic_addr).await.unwrap();
1170 #[cfg(feature = "__h3")]
1171 UdpSocket::bind(self.h3_addr).await.unwrap();
1172 }
1173 }
1174
1175 #[cfg(feature = "__tls")]
1176 fn rustls_cert_key() -> Arc<dyn ResolvesServerCert> {
1177 use rustls::pki_types::pem::PemObject;
1178 use std::env;
1179
1180 let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1181 let cert_chain =
1182 CertificateDer::pem_file_iter(format!("{server_path}/tests/test-data/cert.pem"))
1183 .unwrap()
1184 .collect::<Result<Vec<_>, _>>()
1185 .unwrap();
1186
1187 let key = PrivateKeyDer::from_pem_file(format!("{server_path}/tests/test-data/cert.key"))
1188 .unwrap();
1189
1190 let certified_key = CertifiedKey::from_der(cert_chain, key, &default_provider()).unwrap();
1191 Arc::new(SingleCertAndKey::from(certified_key))
1192 }
1193
1194 #[test]
1195 fn task_reap_on_empty_joinset() {
1196 let mut joinset = JoinSet::new();
1197
1198 reap_tasks(&mut joinset);
1200 }
1201
1202 #[tokio::test]
1203 async fn task_reap_on_nonempty_joinset() {
1204 let mut joinset = JoinSet::new();
1205 let t = joinset.spawn(tokio::time::sleep(Duration::from_secs(2)));
1206
1207 reap_tasks(&mut joinset);
1209 t.abort();
1210
1211 reap_tasks(&mut joinset);
1213 }
1214}