1use std::{
8 io,
9 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
10 sync::Arc,
11 time::Duration,
12};
13
14use futures_util::{FutureExt, StreamExt};
15use hickory_proto::{op::MessageType, rr::Record};
16#[cfg(feature = "dns-over-rustls")]
17use rustls::{Certificate, PrivateKey, ServerConfig};
18use tokio::{net, task::JoinSet};
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, info, warn};
21
22#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
23use crate::proto::openssl::tls_server::*;
24use crate::{
25 authority::{MessageRequest, MessageResponseBuilder},
26 proto::{
27 error::ProtoError,
28 iocompat::AsyncIoTokioAsStd,
29 op::{Edns, Header, LowerQuery, Query, ResponseCode},
30 serialize::binary::{BinDecodable, BinDecoder},
31 tcp::TcpStream,
32 udp::UdpStream,
33 xfer::SerialMessage,
34 BufDnsStreamHandle,
35 },
36 server::{Protocol, Request, RequestHandler, ResponseHandle, ResponseHandler, TimeoutStream},
37};
38
39pub struct ServerFuture<T: RequestHandler> {
42 handler: Arc<T>,
43 join_set: JoinSet<Result<(), ProtoError>>,
44 shutdown_token: CancellationToken,
45}
46
47impl<T: RequestHandler> ServerFuture<T> {
48 pub fn new(handler: T) -> Self {
50 Self {
51 handler: Arc::new(handler),
52 join_set: JoinSet::new(),
53 shutdown_token: CancellationToken::new(),
54 }
55 }
56
57 pub fn register_socket(&mut self, socket: net::UdpSocket) {
59 debug!("registering udp: {:?}", socket);
60
61 let (mut stream, stream_handle) =
64 UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
65 let shutdown = self.shutdown_token.clone();
66 let handler = self.handler.clone();
67
68 self.join_set.spawn({
70 async move {
71 let mut inner_join_set = JoinSet::new();
72 loop {
73 let message = tokio::select! {
74 message = stream.next() => match message {
75 None => break,
76 Some(message) => message,
77 },
78 _ = shutdown.cancelled() => break,
79 };
80
81 let message = match message {
82 Err(e) => {
83 warn!("error receiving message on udp_socket: {}", e);
84 if is_unrecoverable_socket_error(&e) {
85 break;
86 }
87 continue;
88 }
89 Ok(message) => message,
90 };
91
92 let src_addr = message.addr();
93 debug!("received udp request from: {}", src_addr);
94
95 if let Err(e) = sanitize_src_address(src_addr) {
97 warn!(
98 "address can not be responded to {src_addr}: {e}",
99 src_addr = src_addr,
100 e = e
101 );
102 continue;
103 }
104
105 let handler = handler.clone();
106 let stream_handle = stream_handle.with_remote_addr(src_addr);
107
108 inner_join_set.spawn(async move {
109 handle_raw_request(message, Protocol::Udp, handler, stream_handle).await;
110 });
111
112 reap_tasks(&mut inner_join_set);
113 }
114
115 if shutdown.is_cancelled() {
116 Ok(())
117 } else {
118 Err(ProtoError::from("unexpected close of UDP socket"))
120 }
121 }
122 });
123 }
124
125 pub fn register_socket_std(&mut self, socket: std::net::UdpSocket) -> io::Result<()> {
127 self.register_socket(net::UdpSocket::from_std(socket)?);
128 Ok(())
129 }
130
131 pub fn register_listener(&mut self, listener: net::TcpListener, timeout: Duration) {
144 debug!("register tcp: {:?}", listener);
145
146 let handler = self.handler.clone();
147
148 let shutdown = self.shutdown_token.clone();
150 self.join_set.spawn(async move {
151 let mut inner_join_set = JoinSet::new();
152 loop {
153 let (tcp_stream, src_addr) = tokio::select! {
154 tcp_stream = listener.accept() => match tcp_stream {
155 Ok((t, s)) => (t, s),
156 Err(e) => {
157 debug!("error receiving TCP tcp_stream error: {}", e);
158 if is_unrecoverable_socket_error(&e) {
159 break;
160 }
161 continue;
162 },
163 },
164 _ = shutdown.cancelled() => {
165 break;
167 },
168 };
169
170 if let Err(e) = sanitize_src_address(src_addr) {
172 warn!(
173 "address can not be responded to {src_addr}: {e}",
174 src_addr = src_addr,
175 e = e
176 );
177 continue;
178 }
179
180 let handler = handler.clone();
181
182 inner_join_set.spawn(async move {
184 debug!("accepted request from: {}", src_addr);
185 let (buf_stream, stream_handle) =
187 TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
188 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
189
190 while let Some(message) = timeout_stream.next().await {
191 let message = match message {
192 Ok(message) => message,
193 Err(e) => {
194 debug!(
195 "error in TCP request_stream src: {} error: {}",
196 src_addr, e
197 );
198 return;
200 }
201 };
202
203 handle_raw_request(
205 message,
206 Protocol::Tcp,
207 handler.clone(),
208 stream_handle.clone(),
209 )
210 .await;
211 }
212 });
213
214 reap_tasks(&mut inner_join_set);
215 }
216
217 if shutdown.is_cancelled() {
218 Ok(())
219 } else {
220 Err(ProtoError::from("unexpected close of socket"))
221 }
222 });
223 }
224
225 pub fn register_listener_std(
238 &mut self,
239 listener: std::net::TcpListener,
240 timeout: Duration,
241 ) -> io::Result<()> {
242 self.register_listener(net::TcpListener::from_std(listener)?, timeout);
243 Ok(())
244 }
245
246 #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
260 #[cfg_attr(
261 docsrs,
262 doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
263 )]
264 pub fn register_tls_listener(
265 &mut self,
266 listener: net::TcpListener,
267 timeout: Duration,
268 certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
269 ) -> io::Result<()> {
270 use crate::proto::openssl::{tls_server, TlsStream};
271 use openssl::ssl::Ssl;
272 use std::pin::Pin;
273 use tokio_openssl::SslStream as TokioSslStream;
274
275 let ((cert, chain), key) = certificate_and_key;
276
277 let handler = self.handler.clone();
278 debug!("registered tcp: {:?}", listener);
279
280 let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?);
281
282 let shutdown = self.shutdown_watch.clone();
284 self.join_set.spawn(async move {
285 let mut inner_join_set = JoinSet::new();
286 loop {
287 let (tcp_stream, src_addr) = tokio::select! {
288 tcp_stream = listener.accept() => match tcp_stream {
289 Ok((t, s)) => (t, s),
290 Err(e) => {
291 debug!("error receiving TLS tcp_stream error: {}", e);
292 if is_unrecoverable_socket_error(&e) {
293 break;
294 }
295 continue;
296 },
297 },
298 _ = shutdown.clone().signaled() => {
299 break;
301 },
302 };
303
304 if let Err(e) = sanitize_src_address(src_addr) {
306 warn!(
307 "address can not be responded to {src_addr}: {e}",
308 src_addr = src_addr,
309 e = e
310 );
311 continue;
312 }
313
314 let handler = handler.clone();
315 let tls_acceptor = tls_acceptor.clone();
316
317 inner_join_set.spawn(async move {
319 debug!("starting TLS request from: {}", src_addr);
320
321 let mut tls_stream = match Ssl::new(tls_acceptor.context())
323 .and_then(|ssl| TokioSslStream::new(ssl, tcp_stream))
324 {
325 Ok(tls_stream) => tls_stream,
326 Err(e) => {
327 debug!("tls handshake src: {} error: {}", src_addr, e);
328 return ();
329 }
330 };
331 match Pin::new(&mut tls_stream).accept().await {
332 Ok(()) => {}
333 Err(e) => {
334 debug!("tls handshake src: {} error: {}", src_addr, e);
335 return ();
336 }
337 };
338 debug!("accepted TLS request from: {}", src_addr);
339 let (buf_stream, stream_handle) =
340 TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr);
341 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
342 while let Some(message) = timeout_stream.next().await {
343 let message = match message {
344 Ok(message) => message,
345 Err(e) => {
346 debug!(
347 "error in TLS request_stream src: {:?} error: {}",
348 src_addr, e
349 );
350
351 return ();
353 }
354 };
355
356 self::handle_raw_request(
357 message,
358 Protocol::Tls,
359 handler.clone(),
360 stream_handle.clone(),
361 )
362 .await;
363 }
364 });
365
366 reap_tasks(&mut inner_join_set);
367 }
368
369 if shutdown.is_cancelled() {
370 Ok(())
371 } else {
372 Err(ProtoError::from("unexpected close of socket"))
373 }
374 });
375
376 Ok(())
377 }
378
379 #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
393 #[cfg_attr(
394 docsrs,
395 doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
396 )]
397 pub fn register_tls_listener_std(
398 &mut self,
399 listener: std::net::TcpListener,
400 timeout: Duration,
401 certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
402 ) -> io::Result<()> {
403 self.register_tls_listener(
404 net::TcpListener::from_std(listener)?,
405 timeout,
406 certificate_and_key,
407 )
408 }
409
410 #[cfg(feature = "dns-over-rustls")]
424 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-rustls")))]
425 pub fn register_tls_listener_with_tls_config(
426 &mut self,
427 listener: net::TcpListener,
428 timeout: Duration,
429 tls_config: Arc<ServerConfig>,
430 ) -> io::Result<()> {
431 use crate::proto::rustls::tls_from_stream;
432 use tokio_rustls::TlsAcceptor;
433
434 let handler = self.handler.clone();
435
436 debug!("registered tcp: {:?}", listener);
437
438 let tls_acceptor = TlsAcceptor::from(tls_config);
439
440 let shutdown = self.shutdown_token.clone();
442 self.join_set.spawn(async move {
443 let mut inner_join_set = JoinSet::new();
444 loop {
445 let (tcp_stream, src_addr) = tokio::select! {
446 tcp_stream = listener.accept() => match tcp_stream {
447 Ok((t, s)) => (t, s),
448 Err(e) => {
449 debug!("error receiving TLS tcp_stream error: {}", e);
450 if is_unrecoverable_socket_error(&e) {
451 break;
452 }
453 continue;
454 },
455 },
456 _ = shutdown.cancelled() => {
457 break;
459 },
460 };
461
462 if let Err(e) = sanitize_src_address(src_addr) {
464 warn!(
465 "address can not be responded to {src_addr}: {e}",
466 src_addr = src_addr,
467 e = e
468 );
469 continue;
470 }
471
472 let handler = handler.clone();
473 let tls_acceptor = tls_acceptor.clone();
474
475 inner_join_set.spawn(async move {
477 debug!("starting TLS request from: {}", src_addr);
478
479 let tls_stream = tls_acceptor.accept(tcp_stream).await;
481
482 let tls_stream = match tls_stream {
483 Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
484 Err(e) => {
485 debug!("tls handshake src: {} error: {}", src_addr, e);
486 return;
487 }
488 };
489 debug!("accepted TLS request from: {}", src_addr);
490 let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
491 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
492 while let Some(message) = timeout_stream.next().await {
493 let message = match message {
494 Ok(message) => message,
495 Err(e) => {
496 debug!(
497 "error in TLS request_stream src: {:?} error: {}",
498 src_addr, e
499 );
500
501 return;
503 }
504 };
505
506 handle_raw_request(
507 message,
508 Protocol::Tls,
509 handler.clone(),
510 stream_handle.clone(),
511 )
512 .await;
513 }
514 });
515
516 reap_tasks(&mut inner_join_set);
517 }
518
519 if shutdown.is_cancelled() {
520 Ok(())
521 } else {
522 Err(ProtoError::from("unexpected close of socket"))
523 }
524 });
525
526 Ok(())
527 }
528
529 #[cfg(feature = "dns-over-rustls")]
543 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-rustls")))]
544 pub fn register_tls_listener(
545 &mut self,
546 listener: net::TcpListener,
547 timeout: Duration,
548 certificate_and_key: (Vec<Certificate>, PrivateKey),
549 ) -> io::Result<()> {
550 use crate::proto::rustls::tls_server;
551
552 let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
553 .map_err(|e| {
554 io::Error::new(
555 io::ErrorKind::Other,
556 format!("error creating TLS acceptor: {e}"),
557 )
558 })?;
559
560 Self::register_tls_listener_with_tls_config(self, listener, timeout, Arc::new(tls_acceptor))
561 }
562
563 #[cfg(feature = "dns-over-https-rustls")]
577 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-https-rustls")))]
578 pub fn register_https_listener(
579 &mut self,
580 listener: net::TcpListener,
581 _timeout: Duration,
583 certificate_and_key: (Vec<Certificate>, PrivateKey),
584 dns_hostname: Option<String>,
585 ) -> io::Result<()> {
586 use tokio_rustls::TlsAcceptor;
587
588 use crate::proto::rustls::tls_server;
589 use crate::server::h2_handler::h2_handler;
590
591 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
592
593 let handler = self.handler.clone();
594 debug!("registered https: {listener:?}");
595
596 let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
597 .map_err(|e| {
598 io::Error::new(
599 io::ErrorKind::Other,
600 format!("error creating TLS acceptor: {e}"),
601 )
602 })?;
603 let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor));
604
605 let shutdown = self.shutdown_token.clone();
607 self.join_set.spawn(async move {
608 let mut inner_join_set = JoinSet::new();
609 loop {
610 let shutdown = shutdown.clone();
611 let (tcp_stream, src_addr) = tokio::select! {
612 tcp_stream = listener.accept() => match tcp_stream {
613 Ok((t, s)) => (t, s),
614 Err(e) => {
615 debug!("error receiving HTTPS tcp_stream error: {}", e);
616 if is_unrecoverable_socket_error(&e) {
617 break;
618 }
619 continue;
620 },
621 },
622 _ = shutdown.cancelled() => {
623 break;
625 },
626 };
627
628 if let Err(e) = sanitize_src_address(src_addr) {
630 warn!("address can not be responded to {src_addr}: {e}");
631 continue;
632 }
633
634 let handler = handler.clone();
635 let tls_acceptor = tls_acceptor.clone();
636 let dns_hostname = dns_hostname.clone();
637
638 inner_join_set.spawn(async move {
639 debug!("starting HTTPS request from: {src_addr}");
640
641 let tls_stream = tls_acceptor.accept(tcp_stream).await;
644
645 let tls_stream = match tls_stream {
646 Ok(tls_stream) => tls_stream,
647 Err(e) => {
648 debug!("https handshake src: {src_addr} error: {e}");
649 return;
650 }
651 };
652 debug!("accepted HTTPS request from: {src_addr}");
653
654 h2_handler(
655 handler,
656 tls_stream,
657 src_addr,
658 dns_hostname,
659 shutdown.clone(),
660 )
661 .await;
662 });
663
664 reap_tasks(&mut inner_join_set);
665 }
666
667 if shutdown.is_cancelled() {
668 Ok(())
669 } else {
670 Err(ProtoError::from("unexpected close of socket"))
671 }
672 });
673
674 Ok(())
675 }
676
677 #[cfg(feature = "dns-over-quic")]
691 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-quic")))]
692 pub fn register_quic_listener(
693 &mut self,
694 socket: net::UdpSocket,
695 _timeout: Duration,
697 certificate_and_key: (Vec<Certificate>, PrivateKey),
698 dns_hostname: Option<String>,
699 ) -> io::Result<()> {
700 use crate::proto::quic::QuicServer;
701 use crate::server::quic_handler::quic_handler;
702
703 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
704
705 let handler = self.handler.clone();
706
707 debug!("registered quic: {:?}", socket);
708 let mut server =
709 QuicServer::with_socket(socket, certificate_and_key.0, certificate_and_key.1)?;
710
711 let shutdown = self.shutdown_token.clone();
713 self.join_set.spawn(async move {
714 let mut inner_join_set = JoinSet::new();
715 loop {
716 let shutdown = shutdown.clone();
717 let (streams, src_addr) = tokio::select! {
718 result = server.next() => match result {
719 Ok(Some(c)) => c,
720 Ok(None) => continue,
721 Err(e) => {
722 debug!("error receiving quic connection: {e}");
723 continue;
724 }
725 },
726 _ = shutdown.cancelled() => {
727 break;
729 },
730 };
731
732 if let Err(e) = sanitize_src_address(src_addr) {
735 warn!(
736 "address can not be responded to {src_addr}: {e}",
737 src_addr = src_addr,
738 e = e
739 );
740 continue;
741 }
742
743 let handler = handler.clone();
744 let dns_hostname = dns_hostname.clone();
745
746 inner_join_set.spawn(async move {
747 debug!("starting quic stream request from: {src_addr}");
748
749 let result =
751 quic_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
752 .await;
753
754 if let Err(e) = result {
755 warn!("quic stream processing failed from {src_addr}: {e}")
756 }
757 });
758
759 reap_tasks(&mut inner_join_set);
760 }
761
762 Ok(())
763 });
764
765 Ok(())
766 }
767
768 #[cfg(feature = "dns-over-h3")]
782 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-h3")))]
783 pub fn register_h3_listener(
784 &mut self,
785 socket: net::UdpSocket,
786 _timeout: Duration,
788 certificate_and_key: (Vec<Certificate>, PrivateKey),
789 dns_hostname: Option<String>,
790 ) -> io::Result<()> {
791 use crate::proto::h3::h3_server::H3Server;
792 use crate::server::h3_handler::h3_handler;
793
794 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
795
796 let handler = self.handler.clone();
797
798 debug!("registered h3: {:?}", socket);
799 let mut server =
800 H3Server::with_socket(socket, certificate_and_key.0, certificate_and_key.1)?;
801
802 let shutdown = self.shutdown_token.clone();
804 self.join_set.spawn(async move {
805 let mut inner_join_set = JoinSet::new();
806 loop {
807 let shutdown = shutdown.clone();
808 let (streams, src_addr) = tokio::select! {
809 result = server.accept() => match result {
810 Ok(Some(c)) => c,
811 Ok(None) => continue,
812 Err(e) => {
813 debug!("error receiving h3 connection: {e}");
814 continue;
815 }
816 },
817 _ = shutdown.cancelled() => {
818 break;
820 },
821 };
822
823 if let Err(e) = sanitize_src_address(src_addr) {
826 warn!(
827 "address can not be responded to {src_addr}: {e}",
828 src_addr = src_addr,
829 e = e
830 );
831 continue;
832 }
833
834 let handler = handler.clone();
835 let dns_hostname = dns_hostname.clone();
836
837 inner_join_set.spawn(async move {
838 debug!("starting h3 stream request from: {src_addr}");
839
840 let result =
842 h3_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
843 .await;
844
845 if let Err(e) = result {
846 warn!("h3 stream processing failed from {src_addr}: {e}")
847 }
848 });
849
850 reap_tasks(&mut inner_join_set);
851 }
852
853 Ok(())
854 });
855
856 Ok(())
857 }
858
859 pub async fn shutdown_gracefully(&mut self) -> Result<(), ProtoError> {
862 self.shutdown_token.cancel();
863
864 block_until_done(&mut self.join_set).await
866 }
867
868 pub async fn block_until_done(&mut self) -> Result<(), ProtoError> {
871 block_until_done(&mut self.join_set).await
872 }
873}
874
875async fn block_until_done(
876 join_set: &mut JoinSet<Result<(), ProtoError>>,
877) -> Result<(), ProtoError> {
878 if join_set.is_empty() {
879 warn!("block_until_done called with no pending tasks");
880 return Ok(());
881 }
882
883 let mut out = Ok(());
885 while let Some(join_result) = join_set.join_next().await {
886 match join_result {
887 Ok(result) => {
888 match result {
889 Ok(_) => (),
890 Err(e) => {
891 out = Err(e);
893 }
894 }
895 }
896 Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
897 }
898 }
899 out
900}
901
902fn reap_tasks(join_set: &mut JoinSet<()>) {
904 while FutureExt::now_or_never(join_set.join_next())
905 .flatten()
906 .is_some()
907 {}
908}
909
910pub(crate) async fn handle_raw_request<T: RequestHandler>(
911 message: SerialMessage,
912 protocol: Protocol,
913 request_handler: Arc<T>,
914 response_handler: BufDnsStreamHandle,
915) {
916 let src_addr = message.addr();
917 let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
918
919 handle_request(
920 message.bytes(),
921 src_addr,
922 protocol,
923 request_handler,
924 response_handler,
925 )
926 .await;
927}
928
929#[derive(Clone)]
930struct ReportingResponseHandler<R: ResponseHandler> {
931 request_header: Header,
932 query: LowerQuery,
933 protocol: Protocol,
934 src_addr: SocketAddr,
935 handler: R,
936}
937
938#[async_trait::async_trait]
939#[allow(clippy::uninlined_format_args)]
940impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
941 async fn send_response<'a>(
942 &mut self,
943 response: crate::authority::MessageResponse<
944 '_,
945 'a,
946 impl Iterator<Item = &'a Record> + Send + 'a,
947 impl Iterator<Item = &'a Record> + Send + 'a,
948 impl Iterator<Item = &'a Record> + Send + 'a,
949 impl Iterator<Item = &'a Record> + Send + 'a,
950 >,
951 ) -> io::Result<super::ResponseInfo> {
952 let response_info = self.handler.send_response(response).await?;
953
954 let id = self.request_header.id();
955 let rid = response_info.id();
956 if id != rid {
957 warn!("request id:{id} does not match response id:{rid}");
958 debug_assert_eq!(id, rid, "request id and response id should match");
959 }
960
961 let rflags = response_info.flags();
962 let answer_count = response_info.answer_count();
963 let authority_count = response_info.name_server_count();
964 let additional_count = response_info.additional_count();
965 let response_code = response_info.response_code();
966
967 info!("request:{id} src:{proto}://{addr}#{port} {op}:{query}:{qtype}:{class} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
968 id = rid,
969 proto = self.protocol,
970 addr = self.src_addr.ip(),
971 port = self.src_addr.port(),
972 op = self.request_header.op_code(),
973 query = self.query.name(),
974 qtype = self.query.query_type(),
975 class = self.query.query_class(),
976 qflags = self.request_header.flags(),
977 code = response_code,
978 answers = answer_count,
979 authorities = authority_count,
980 additionals = additional_count,
981 rflags = rflags
982 );
983
984 Ok(response_info)
985 }
986}
987
988pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
989 message_bytes: &[u8],
991 src_addr: SocketAddr,
992 protocol: Protocol,
993 request_handler: Arc<T>,
994 response_handler: R,
995) {
996 let mut decoder = BinDecoder::new(message_bytes);
997
998 let inner_handle_request = |message: MessageRequest, response_handler: R| async move {
1000 if message.message_type() == MessageType::Response {
1001 return;
1003 }
1004
1005 let id = message.id();
1006 let qflags = message.header().flags();
1007 let qop_code = message.op_code();
1008 let message_type = message.message_type();
1009 let is_dnssec = message.edns().is_some_and(Edns::dnssec_ok);
1010
1011 let request = Request::new(message, src_addr, protocol);
1012
1013 let info = request.request_info();
1014 let query = info.query.clone();
1015 let query_name = info.query.name();
1016 let query_type = info.query.query_type();
1017 let query_class = info.query.query_class();
1018
1019 debug!(
1020 "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op}:{query}:{qtype}:{class} qflags:{qflags}",
1021 id = id,
1022 proto = protocol,
1023 addr = src_addr.ip(),
1024 port = src_addr.port(),
1025 message_type= message_type,
1026 is_dnssec = is_dnssec,
1027 op = qop_code,
1028 query = query_name,
1029 qtype = query_type,
1030 class = query_class,
1031 qflags = qflags,
1032 );
1033
1034 let reporter = ReportingResponseHandler {
1036 request_header: *request.header(),
1037 query,
1038 protocol,
1039 src_addr,
1040 handler: response_handler,
1041 };
1042
1043 request_handler.handle_request(&request, reporter).await;
1044 };
1045
1046 match MessageRequest::read(&mut decoder) {
1048 Ok(message) => {
1049 inner_handle_request(message, response_handler).await;
1050 }
1051 Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
1052 let (header, error) = kind
1054 .into_form_error()
1055 .expect("as form_error already confirmed this is a FormError");
1056 let query = LowerQuery::query(Query::default());
1057
1058 debug!(
1060 "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:FormError:{error}",
1061 id = header.id(),
1062 proto = protocol,
1063 addr = src_addr.ip(),
1064 port = src_addr.port(),
1065 message_type= header.message_type(),
1066 op = header.op_code(),
1067 error = error,
1068 );
1069
1070 let mut reporter = ReportingResponseHandler {
1072 request_header: header,
1073 query,
1074 protocol,
1075 src_addr,
1076 handler: response_handler,
1077 };
1078
1079 let response = MessageResponseBuilder::new(None);
1080 let result = reporter
1081 .send_response(response.error_msg(&header, ResponseCode::FormErr))
1082 .await;
1083
1084 if let Err(e) = result {
1085 warn!("failed to return FormError to client: {}", e);
1086 }
1087 }
1088 Err(e) => warn!("failed to read message: {}", e),
1089 }
1090}
1091
1092fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
1100 if src.port() == 0 {
1102 return Err(format!("cannot respond to src on port 0: {src}"));
1103 }
1104
1105 fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
1106 if src.is_unspecified() {
1107 return Err(format!("cannot respond to unspecified v4 addr: {src}"));
1108 }
1109
1110 if src.is_broadcast() {
1111 return Err(format!("cannot respond to broadcast v4 addr: {src}"));
1112 }
1113
1114 Ok(())
1117 }
1118
1119 fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
1120 if src.is_unspecified() {
1121 return Err(format!("cannot respond to unspecified v6 addr: {src}"));
1122 }
1123
1124 Ok(())
1125 }
1126
1127 match src.ip() {
1129 IpAddr::V4(v4) => verify_v4(v4),
1130 IpAddr::V6(v6) => verify_v6(v6),
1131 }
1132}
1133
1134fn is_unrecoverable_socket_error(err: &io::Error) -> bool {
1135 matches!(
1136 err.kind(),
1137 io::ErrorKind::NotConnected | io::ErrorKind::ConnectionAborted
1138 )
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143 use super::*;
1144 use crate::authority::Catalog;
1145 use futures_util::future;
1146 #[cfg(feature = "dns-over-rustls")]
1147 use rustls::{Certificate, PrivateKey};
1148 use std::net::SocketAddr;
1149 use tokio::net::{TcpListener, UdpSocket};
1150 use tokio::time::timeout;
1151
1152 #[tokio::test]
1153 async fn abort() {
1154 let endpoints = Endpoints::new().await;
1155
1156 let endpoints2 = endpoints.clone();
1157 let (abortable, abort_handle) = future::abortable(async move {
1158 let mut server_future = ServerFuture::new(Catalog::new());
1159 endpoints2.register(&mut server_future).await;
1160 server_future.block_until_done().await
1161 });
1162
1163 abort_handle.abort();
1164 abortable.await.expect_err("expected abort");
1165
1166 endpoints.rebind_all().await;
1167 }
1168
1169 #[tokio::test]
1170 async fn graceful_shutdown() {
1171 let mut server_future = ServerFuture::new(Catalog::new());
1172 let endpoints = Endpoints::new().await;
1173 endpoints.register(&mut server_future).await;
1174
1175 timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1176 .await
1177 .expect("timed out waiting for the server to complete")
1178 .expect("error while awaiting tasks");
1179
1180 endpoints.rebind_all().await;
1181 }
1182
1183 #[test]
1184 fn test_sanitize_src_addr() {
1185 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4096))).is_ok());
1187 assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1188
1189 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1190 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1191 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4096))).is_err());
1192 assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4096))).is_err());
1193
1194 assert!(
1196 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4096))).is_ok()
1197 );
1198 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4096))).is_ok());
1199
1200 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4096))).is_err());
1201 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1202 assert!(
1203 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1204 );
1205 }
1206
1207 #[derive(Clone)]
1208 struct Endpoints {
1209 udp_addr: SocketAddr,
1210 udp_std_addr: SocketAddr,
1211 tcp_addr: SocketAddr,
1212 tcp_std_addr: SocketAddr,
1213 #[cfg(feature = "dns-over-rustls")]
1214 rustls_addr: SocketAddr,
1215 #[cfg(feature = "dns-over-https-rustls")]
1216 https_rustls_addr: SocketAddr,
1217 #[cfg(feature = "dns-over-quic")]
1218 quic_addr: SocketAddr,
1219 #[cfg(feature = "dns-over-h3")]
1220 h3_addr: SocketAddr,
1221 }
1222
1223 impl Endpoints {
1224 async fn new() -> Self {
1225 let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1226 let udp_std = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1227 let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1228 let tcp_std = TcpListener::bind("127.0.0.1:0").await.unwrap();
1229 #[cfg(feature = "dns-over-rustls")]
1230 let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1231 #[cfg(feature = "dns-over-https-rustls")]
1232 let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1233 #[cfg(feature = "dns-over-quic")]
1234 let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1235 #[cfg(feature = "dns-over-h3")]
1236 let h3 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1237
1238 Self {
1239 udp_addr: udp.local_addr().unwrap(),
1240 udp_std_addr: udp_std.local_addr().unwrap(),
1241 tcp_addr: tcp.local_addr().unwrap(),
1242 tcp_std_addr: tcp_std.local_addr().unwrap(),
1243 #[cfg(feature = "dns-over-rustls")]
1244 rustls_addr: rustls.local_addr().unwrap(),
1245 #[cfg(feature = "dns-over-https-rustls")]
1246 https_rustls_addr: https_rustls.local_addr().unwrap(),
1247 #[cfg(feature = "dns-over-quic")]
1248 quic_addr: quic.local_addr().unwrap(),
1249 #[cfg(feature = "dns-over-h3")]
1250 h3_addr: h3.local_addr().unwrap(),
1251 }
1252 }
1253
1254 async fn register<T: RequestHandler>(&self, server: &mut ServerFuture<T>) {
1255 server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1256 server
1257 .register_socket_std(std::net::UdpSocket::bind(self.udp_std_addr).unwrap())
1258 .unwrap();
1259 server.register_listener(
1260 TcpListener::bind(self.tcp_addr).await.unwrap(),
1261 Duration::from_secs(1),
1262 );
1263 server
1264 .register_listener_std(
1265 std::net::TcpListener::bind(self.tcp_std_addr).unwrap(),
1266 Duration::from_secs(1),
1267 )
1268 .unwrap();
1269
1270 #[cfg(feature = "dns-over-rustls")]
1271 {
1272 let cert_key = rustls_cert_key();
1273 server
1274 .register_tls_listener(
1275 TcpListener::bind(self.rustls_addr).await.unwrap(),
1276 Duration::from_secs(30),
1277 cert_key,
1278 )
1279 .unwrap();
1280 }
1281
1282 #[cfg(feature = "dns-over-https-rustls")]
1283 {
1284 let cert_key = rustls_cert_key();
1285 server
1286 .register_https_listener(
1287 TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1288 Duration::from_secs(1),
1289 cert_key,
1290 None,
1291 )
1292 .unwrap();
1293 }
1294
1295 #[cfg(feature = "dns-over-quic")]
1296 {
1297 let cert_key = rustls_cert_key();
1298 server
1299 .register_quic_listener(
1300 UdpSocket::bind(self.quic_addr).await.unwrap(),
1301 Duration::from_secs(1),
1302 cert_key,
1303 None,
1304 )
1305 .unwrap();
1306 }
1307
1308 #[cfg(feature = "dns-over-h3")]
1309 {
1310 let cert_key = rustls_cert_key();
1311 server
1312 .register_h3_listener(
1313 UdpSocket::bind(self.h3_addr).await.unwrap(),
1314 Duration::from_secs(1),
1315 cert_key,
1316 None,
1317 )
1318 .unwrap();
1319 }
1320 }
1321
1322 async fn rebind_all(&self) {
1323 UdpSocket::bind(self.udp_addr).await.unwrap();
1324 UdpSocket::bind(self.udp_std_addr).await.unwrap();
1325 TcpListener::bind(self.tcp_addr).await.unwrap();
1326 TcpListener::bind(self.tcp_std_addr).await.unwrap();
1327 #[cfg(feature = "dns-over-rustls")]
1328 TcpListener::bind(self.rustls_addr).await.unwrap();
1329 #[cfg(feature = "dns-over-https-rustls")]
1330 TcpListener::bind(self.https_rustls_addr).await.unwrap();
1331 #[cfg(feature = "dns-over-quic")]
1332 UdpSocket::bind(self.quic_addr).await.unwrap();
1333 #[cfg(feature = "dns-over-h3")]
1334 UdpSocket::bind(self.h3_addr).await.unwrap();
1335 }
1336 }
1337
1338 #[cfg(feature = "dns-over-rustls")]
1339 fn rustls_cert_key() -> (Vec<Certificate>, PrivateKey) {
1340 use hickory_proto::rustls::tls_server;
1341 use std::env;
1342 use std::path::Path;
1343
1344 let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1345
1346 let cert = tls_server::read_cert(Path::new(&format!(
1347 "{}/tests/test-data/cert.pem",
1348 server_path
1349 )))
1350 .map_err(|e| format!("error reading cert: {e}"))
1351 .unwrap();
1352 let key = tls_server::read_key_from_pem(Path::new(&format!(
1353 "{}/tests/test-data/cert.key",
1354 server_path
1355 )))
1356 .unwrap();
1357
1358 (cert, key)
1359 }
1360
1361 #[test]
1362 fn task_reap_on_empty_joinset() {
1363 let mut joinset = JoinSet::new();
1364
1365 reap_tasks(&mut joinset);
1367 }
1368
1369 #[test]
1370 fn task_reap_on_nonempty_joinset() {
1371 let runtime = tokio::runtime::Runtime::new().unwrap();
1372 runtime.block_on(async {
1373 let mut joinset = JoinSet::new();
1374 let t = joinset.spawn(tokio::time::sleep(Duration::from_secs(2)));
1375
1376 reap_tasks(&mut joinset);
1378 t.abort();
1379
1380 reap_tasks(&mut joinset);
1382 });
1383 }
1384}