1use std::{
11 io,
12 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
13 sync::Arc,
14 time::Duration,
15};
16
17use futures_util::{FutureExt, StreamExt};
18use hickory_proto::ProtoErrorKind;
19use ipnet::IpNet;
20#[cfg(feature = "__tls")]
21use rustls::{ServerConfig, server::ResolvesServerCert};
22#[cfg(feature = "__tls")]
23use tokio::time::timeout;
24use tokio::{net, task::JoinSet};
25use tokio_util::sync::CancellationToken;
26use tracing::{debug, info, warn};
27
28#[cfg(feature = "__tls")]
29use crate::proto::rustls::default_provider;
30use crate::{
31 access::AccessControl,
32 authority::{MessageRequest, MessageResponseBuilder, Queries},
33 proto::{
34 BufDnsStreamHandle, ProtoError,
35 op::{Header, LowerQuery, MessageType, ResponseCode},
36 rr::Record,
37 runtime::{TokioRuntimeProvider, iocompat::AsyncIoTokioAsStd},
38 serialize::binary::{BinDecodable, BinDecoder},
39 tcp::TcpStream,
40 udp::UdpStream,
41 xfer::{Protocol, SerialMessage},
42 },
43};
44
45#[cfg(feature = "__https")]
46mod h2_handler;
47#[cfg(feature = "__h3")]
48mod h3_handler;
49#[cfg(feature = "__quic")]
50mod quic_handler;
51mod request_handler;
52pub use request_handler::{Request, RequestHandler, RequestInfo, ResponseInfo};
53mod response_handler;
54pub use response_handler::{ResponseHandle, ResponseHandler};
55#[cfg(feature = "metrics")]
56mod metrics;
57#[cfg(feature = "metrics")]
58use metrics::ResponseHandlerMetrics;
59mod timeout_stream;
60pub use timeout_stream::TimeoutStream;
61
62pub struct ServerFuture<T: RequestHandler> {
65 handler: Arc<T>,
66 join_set: JoinSet<Result<(), ProtoError>>,
67 shutdown_token: CancellationToken,
68 access: Arc<AccessControl>,
69}
70
71impl<T: RequestHandler> ServerFuture<T> {
72 pub fn new(handler: T) -> Self {
74 Self::with_access(handler, &[], &[])
75 }
76
77 pub fn with_access(handler: T, denied_networks: &[IpNet], allowed_networks: &[IpNet]) -> Self {
79 let mut access = AccessControl::default();
80 access.insert_deny(denied_networks);
81 access.insert_allow(allowed_networks);
82
83 Self {
84 handler: Arc::new(handler),
85 join_set: JoinSet::new(),
86 shutdown_token: CancellationToken::new(),
87 access: Arc::new(access),
88 }
89 }
90
91 pub fn register_socket(&mut self, socket: net::UdpSocket) {
93 debug!("registering udp: {:?}", socket);
94
95 let (mut stream, stream_handle) =
98 UdpStream::<TokioRuntimeProvider>::with_bound(socket, ([127, 255, 255, 254], 0).into());
99 let shutdown = self.shutdown_token.clone();
100 let handler = self.handler.clone();
101 let access = self.access.clone();
102
103 self.join_set.spawn({
105 async move {
106 let mut inner_join_set = JoinSet::new();
107 loop {
108 let message = tokio::select! {
109 message = stream.next() => match message {
110 None => break,
111 Some(message) => message,
112 },
113 _ = shutdown.cancelled() => break,
114 };
115
116 let message = match message {
117 Err(e) => {
118 warn!("error receiving message on udp_socket: {}", e);
119 if is_unrecoverable_socket_error(&e) {
120 break;
121 }
122 continue;
123 }
124 Ok(message) => message,
125 };
126
127 let src_addr = message.addr();
128 debug!("received udp request from: {}", src_addr);
129
130 if let Err(e) = sanitize_src_address(src_addr) {
132 warn!(
133 "address can not be responded to {src_addr}: {e}",
134 src_addr = src_addr,
135 e = e
136 );
137 continue;
138 }
139
140 let handler = handler.clone();
141 let access = access.clone();
142 let stream_handle = stream_handle.with_remote_addr(src_addr);
143
144 inner_join_set.spawn(async move {
145 handle_raw_request(message, Protocol::Udp, access, handler, stream_handle)
146 .await;
147 });
148
149 reap_tasks(&mut inner_join_set);
150 }
151
152 if shutdown.is_cancelled() {
153 Ok(())
154 } else {
155 Err(ProtoError::from("unexpected close of UDP socket"))
157 }
158 }
159 });
160 }
161
162 pub fn register_socket_std(&mut self, socket: std::net::UdpSocket) -> io::Result<()> {
164 socket.set_nonblocking(true)?;
165 self.register_socket(net::UdpSocket::from_std(socket)?);
166 Ok(())
167 }
168
169 pub fn register_listener(&mut self, listener: net::TcpListener, timeout: Duration) {
182 debug!("register tcp: {:?}", listener);
183
184 let handler = self.handler.clone();
185 let access = self.access.clone();
186
187 let shutdown = self.shutdown_token.clone();
189 self.join_set.spawn(async move {
190 let mut inner_join_set = JoinSet::new();
191 loop {
192 let (tcp_stream, src_addr) = tokio::select! {
193 tcp_stream = listener.accept() => match tcp_stream {
194 Ok((t, s)) => (t, s),
195 Err(e) => {
196 debug!("error receiving TCP tcp_stream error: {}", e);
197 if is_unrecoverable_socket_error(&e) {
198 break;
199 }
200 continue;
201 },
202 },
203 _ = shutdown.cancelled() => {
204 break;
206 },
207 };
208
209 if let Err(e) = sanitize_src_address(src_addr) {
211 warn!(
212 "address can not be responded to {src_addr}: {e}",
213 src_addr = src_addr,
214 e = e
215 );
216 continue;
217 }
218
219 let handler = handler.clone();
220 let access = access.clone();
221
222 inner_join_set.spawn(async move {
224 debug!("accepted request from: {}", src_addr);
225 let (buf_stream, stream_handle) =
227 TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
228 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
229
230 while let Some(message) = timeout_stream.next().await {
231 let message = match message {
232 Ok(message) => message,
233 Err(e) => {
234 debug!(
235 "error in TCP request_stream src: {} error: {}",
236 src_addr, e
237 );
238 return;
240 }
241 };
242
243 handle_raw_request(
245 message,
246 Protocol::Tcp,
247 access.clone(),
248 handler.clone(),
249 stream_handle.clone(),
250 )
251 .await;
252 }
253 });
254
255 reap_tasks(&mut inner_join_set);
256 }
257
258 if shutdown.is_cancelled() {
259 Ok(())
260 } else {
261 Err(ProtoError::from("unexpected close of socket"))
262 }
263 });
264 }
265
266 pub fn register_listener_std(
279 &mut self,
280 listener: std::net::TcpListener,
281 timeout: Duration,
282 ) -> io::Result<()> {
283 listener.set_nonblocking(true)?;
284 self.register_listener(net::TcpListener::from_std(listener)?, timeout);
285 Ok(())
286 }
287
288 #[cfg(feature = "__tls")]
302 pub fn register_tls_listener_with_tls_config(
303 &mut self,
304 listener: net::TcpListener,
305 handshake_timeout: Duration,
306 tls_config: Arc<ServerConfig>,
307 ) -> io::Result<()> {
308 use crate::proto::rustls::tls_from_stream;
309 use tokio_rustls::TlsAcceptor;
310
311 let handler = self.handler.clone();
312 let access = self.access.clone();
313
314 debug!("registered tcp: {:?}", listener);
315
316 let tls_acceptor = TlsAcceptor::from(tls_config);
317
318 let shutdown = self.shutdown_token.clone();
320 self.join_set.spawn(async move {
321 let mut inner_join_set = JoinSet::new();
322 loop {
323 let (tcp_stream, src_addr) = tokio::select! {
324 tcp_stream = listener.accept() => match tcp_stream {
325 Ok((t, s)) => (t, s),
326 Err(e) => {
327 debug!("error receiving TLS tcp_stream error: {}", e);
328 if is_unrecoverable_socket_error(&e) {
329 break;
330 }
331 continue;
332 },
333 },
334 _ = shutdown.cancelled() => {
335 break;
337 },
338 };
339
340 if let Err(e) = sanitize_src_address(src_addr) {
342 warn!(
343 "address can not be responded to {src_addr}: {e}",
344 src_addr = src_addr,
345 e = e
346 );
347 continue;
348 }
349
350 let handler = handler.clone();
351 let access = access.clone();
352 let tls_acceptor = tls_acceptor.clone();
353
354 inner_join_set.spawn(async move {
356 debug!("starting TLS request from: {}", src_addr);
357
358 let Ok(tls_stream) =
360 timeout(handshake_timeout, tls_acceptor.accept(tcp_stream)).await
361 else {
362 warn!("tls timeout expired during handshake");
363 return;
364 };
365
366 let tls_stream = match tls_stream {
367 Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
368 Err(e) => {
369 debug!("tls handshake src: {} error: {}", src_addr, e);
370 return;
371 }
372 };
373 debug!("accepted TLS request from: {}", src_addr);
374 let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
375 let mut timeout_stream = TimeoutStream::new(buf_stream, handshake_timeout);
376 while let Some(message) = timeout_stream.next().await {
377 let message = match message {
378 Ok(message) => message,
379 Err(e) => {
380 debug!(
381 "error in TLS request_stream src: {:?} error: {}",
382 src_addr, e
383 );
384
385 return;
387 }
388 };
389
390 handle_raw_request(
391 message,
392 Protocol::Tls,
393 access.clone(),
394 handler.clone(),
395 stream_handle.clone(),
396 )
397 .await;
398 }
399 });
400
401 reap_tasks(&mut inner_join_set);
402 }
403
404 if shutdown.is_cancelled() {
405 Ok(())
406 } else {
407 Err(ProtoError::from("unexpected close of socket"))
408 }
409 });
410
411 Ok(())
412 }
413
414 #[cfg(feature = "__tls")]
428 pub fn register_tls_listener(
429 &mut self,
430 listener: net::TcpListener,
431 timeout: Duration,
432 server_cert_resolver: Arc<dyn ResolvesServerCert>,
433 ) -> io::Result<()> {
434 let config = tls_server_config(b"dot", server_cert_resolver)?;
435 Self::register_tls_listener_with_tls_config(self, listener, timeout, Arc::new(config))
436 }
437
438 #[cfg(feature = "__https")]
452 pub fn register_https_listener(
453 &mut self,
454 listener: net::TcpListener,
455 handshake_timeout: Duration,
457 server_cert_resolver: Arc<dyn ResolvesServerCert>,
458 dns_hostname: Option<String>,
459 http_endpoint: String,
460 ) -> io::Result<()> {
461 use crate::server::h2_handler::h2_handler;
462 use tokio_rustls::TlsAcceptor;
463
464 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
465 let http_endpoint: Arc<str> = Arc::from(http_endpoint);
466
467 let handler = self.handler.clone();
468 let access = self.access.clone();
469 debug!("registered https: {listener:?}");
470
471 let tls_acceptor =
472 TlsAcceptor::from(Arc::new(tls_server_config(b"h2", server_cert_resolver)?));
473
474 let shutdown = self.shutdown_token.clone();
476 self.join_set.spawn(async move {
477 let mut inner_join_set = JoinSet::new();
478 loop {
479 let shutdown = shutdown.clone();
480 let (tcp_stream, src_addr) = tokio::select! {
481 tcp_stream = listener.accept() => match tcp_stream {
482 Ok((t, s)) => (t, s),
483 Err(e) => {
484 debug!("error receiving HTTPS tcp_stream error: {}", e);
485 if is_unrecoverable_socket_error(&e) {
486 break;
487 }
488 continue;
489 },
490 },
491 _ = shutdown.cancelled() => {
492 break;
494 },
495 };
496
497 if let Err(e) = sanitize_src_address(src_addr) {
499 warn!("address can not be responded to {src_addr}: {e}");
500 continue;
501 }
502
503 let handler = handler.clone();
504 let access = access.clone();
505 let tls_acceptor = tls_acceptor.clone();
506 let dns_hostname = dns_hostname.clone();
507 let http_endpoint = http_endpoint.clone();
508
509 inner_join_set.spawn(async move {
510 debug!("starting HTTPS request from: {src_addr}");
511
512 let Ok(tls_stream) =
515 timeout(handshake_timeout, tls_acceptor.accept(tcp_stream)).await
516 else {
517 warn!("https timeout expired during handshake");
518 return;
519 };
520
521 let tls_stream = match tls_stream {
522 Ok(tls_stream) => tls_stream,
523 Err(e) => {
524 debug!("https handshake src: {src_addr} error: {e}");
525 return;
526 }
527 };
528 debug!("accepted HTTPS request from: {src_addr}");
529
530 h2_handler(
531 access,
532 handler,
533 tls_stream,
534 src_addr,
535 dns_hostname,
536 http_endpoint,
537 shutdown.clone(),
538 )
539 .await;
540 });
541
542 reap_tasks(&mut inner_join_set);
543 }
544
545 if shutdown.is_cancelled() {
546 Ok(())
547 } else {
548 Err(ProtoError::from("unexpected close of socket"))
549 }
550 });
551
552 Ok(())
553 }
554
555 #[cfg(feature = "__quic")]
569 pub fn register_quic_listener(
570 &mut self,
571 socket: net::UdpSocket,
572 _timeout: Duration,
574 server_cert_resolver: Arc<dyn ResolvesServerCert>,
575 dns_hostname: Option<String>,
576 ) -> io::Result<()> {
577 use crate::proto::quic::QuicServer;
578 use crate::server::quic_handler::quic_handler;
579
580 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
581
582 let handler = self.handler.clone();
583 let access = self.access.clone();
584
585 debug!("registered quic: {:?}", socket);
586 let mut server = QuicServer::with_socket(socket, server_cert_resolver)?;
587
588 let shutdown = self.shutdown_token.clone();
590 self.join_set.spawn(async move {
591 let mut inner_join_set = JoinSet::new();
592 loop {
593 let shutdown = shutdown.clone();
594 let (streams, src_addr) = tokio::select! {
595 result = server.next() => match result {
596 Ok(Some(c)) => c,
597 Ok(None) => continue,
598 Err(e) => {
599 debug!("error receiving quic connection: {e}");
600 continue;
601 }
602 },
603 _ = shutdown.cancelled() => {
604 break;
606 },
607 };
608
609 if let Err(e) = sanitize_src_address(src_addr) {
612 warn!(
613 "address can not be responded to {src_addr}: {e}",
614 src_addr = src_addr,
615 e = e
616 );
617 continue;
618 }
619
620 let handler = handler.clone();
621 let access = access.clone();
622 let dns_hostname = dns_hostname.clone();
623
624 inner_join_set.spawn(async move {
625 debug!("starting quic stream request from: {src_addr}");
626
627 let result = quic_handler(
629 access,
630 handler,
631 streams,
632 src_addr,
633 dns_hostname,
634 shutdown.clone(),
635 )
636 .await;
637
638 if let Err(e) = result {
639 warn!("quic stream processing failed from {src_addr}: {e}")
640 }
641 });
642
643 reap_tasks(&mut inner_join_set);
644 }
645
646 Ok(())
647 });
648
649 Ok(())
650 }
651
652 #[cfg(feature = "__h3")]
666 pub fn register_h3_listener(
667 &mut self,
668 socket: net::UdpSocket,
669 _timeout: Duration,
671 server_cert_resolver: Arc<dyn ResolvesServerCert>,
672 dns_hostname: Option<String>,
673 ) -> io::Result<()> {
674 use crate::proto::h3::h3_server::H3Server;
675 use crate::server::h3_handler::h3_handler;
676
677 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
678
679 let handler = self.handler.clone();
680 let access = self.access.clone();
681
682 debug!("registered h3: {:?}", socket);
683 let mut server = H3Server::with_socket(socket, server_cert_resolver)?;
684
685 let shutdown = self.shutdown_token.clone();
687 self.join_set.spawn(async move {
688 let mut inner_join_set = JoinSet::new();
689 loop {
690 let shutdown = shutdown.clone();
691 let (streams, src_addr) = tokio::select! {
692 result = server.accept() => match result {
693 Ok(Some(c)) => c,
694 Ok(None) => continue,
695 Err(e) => {
696 debug!("error receiving h3 connection: {e}");
697 continue;
698 }
699 },
700 _ = shutdown.cancelled() => {
701 break;
703 },
704 };
705
706 if let Err(e) = sanitize_src_address(src_addr) {
709 warn!(
710 "address can not be responded to {src_addr}: {e}",
711 src_addr = src_addr,
712 e = e
713 );
714 continue;
715 }
716
717 let handler = handler.clone();
718 let access = access.clone();
719 let dns_hostname = dns_hostname.clone();
720
721 inner_join_set.spawn(async move {
722 debug!("starting h3 stream request from: {src_addr}");
723
724 let result = h3_handler(
726 access,
727 handler,
728 streams,
729 src_addr,
730 dns_hostname,
731 shutdown.clone(),
732 )
733 .await;
734
735 if let Err(e) = result {
736 warn!("h3 stream processing failed from {src_addr}: {e}")
737 }
738 });
739
740 reap_tasks(&mut inner_join_set);
741 }
742
743 Ok(())
744 });
745
746 Ok(())
747 }
748
749 pub async fn shutdown_gracefully(&mut self) -> Result<(), ProtoError> {
752 self.shutdown_token.cancel();
753
754 block_until_done(&mut self.join_set).await
756 }
757
758 pub fn shutdown_token(&self) -> &CancellationToken {
763 &self.shutdown_token
764 }
765
766 pub async fn block_until_done(&mut self) -> Result<(), ProtoError> {
769 block_until_done(&mut self.join_set).await
770 }
771}
772
773async fn block_until_done(
774 join_set: &mut JoinSet<Result<(), ProtoError>>,
775) -> Result<(), ProtoError> {
776 if join_set.is_empty() {
777 warn!("block_until_done called with no pending tasks");
778 return Ok(());
779 }
780
781 let mut out = Ok(());
783 while let Some(join_result) = join_set.join_next().await {
784 match join_result {
785 Ok(result) => {
786 match result {
787 Ok(_) => (),
788 Err(e) => {
789 out = Err(e);
791 }
792 }
793 }
794 Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
795 }
796 }
797 out
798}
799
800fn reap_tasks(join_set: &mut JoinSet<()>) {
802 while FutureExt::now_or_never(join_set.join_next())
803 .flatten()
804 .is_some()
805 {}
806}
807
808pub(crate) async fn handle_raw_request<T: RequestHandler>(
809 message: SerialMessage,
810 protocol: Protocol,
811 access: Arc<AccessControl>,
812 request_handler: Arc<T>,
813 response_handler: BufDnsStreamHandle,
814) {
815 let src_addr = message.addr();
816 let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
817
818 handle_request(
819 message.bytes(),
820 src_addr,
821 protocol,
822 access,
823 request_handler,
824 response_handler,
825 )
826 .await;
827}
828
829#[cfg(feature = "__tls")]
830fn tls_server_config(
831 protocol: &[u8],
832 server_cert_resolver: Arc<dyn ResolvesServerCert>,
833) -> io::Result<ServerConfig> {
834 let mut config = ServerConfig::builder_with_provider(Arc::new(default_provider()))
835 .with_safe_default_protocol_versions()
836 .map_err(|e| {
837 io::Error::new(
838 io::ErrorKind::Other,
839 format!("error creating TLS acceptor: {e}"),
840 )
841 })?
842 .with_no_client_auth()
843 .with_cert_resolver(server_cert_resolver);
844
845 config.alpn_protocols = vec![protocol.to_vec()];
846 Ok(config)
847}
848
849#[derive(Clone)]
850struct ReportingResponseHandler<R: ResponseHandler> {
851 request_header: Header,
852 queries: Vec<LowerQuery>,
853 protocol: Protocol,
854 src_addr: SocketAddr,
855 handler: R,
856 #[cfg(feature = "metrics")]
857 metrics: ResponseHandlerMetrics,
858}
859
860#[async_trait::async_trait]
861#[allow(clippy::uninlined_format_args)]
862impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
863 async fn send_response<'a>(
864 &mut self,
865 response: crate::authority::MessageResponse<
866 '_,
867 'a,
868 impl Iterator<Item = &'a Record> + Send + 'a,
869 impl Iterator<Item = &'a Record> + Send + 'a,
870 impl Iterator<Item = &'a Record> + Send + 'a,
871 impl Iterator<Item = &'a Record> + Send + 'a,
872 >,
873 ) -> io::Result<ResponseInfo> {
874 let response_info = self.handler.send_response(response).await?;
875
876 let id = self.request_header.id();
877 let rid = response_info.id();
878 if id != rid {
879 warn!("request id:{id} does not match response id:{rid}");
880 debug_assert_eq!(id, rid, "request id and response id should match");
881 }
882
883 let rflags = response_info.flags();
884 let answer_count = response_info.answer_count();
885 let authority_count = response_info.name_server_count();
886 let additional_count = response_info.additional_count();
887 let response_code = response_info.response_code();
888
889 info!(
890 "request:{id} src:{proto}://{addr}#{port} {op} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
891 id = rid,
892 proto = self.protocol,
893 addr = self.src_addr.ip(),
894 port = self.src_addr.port(),
895 op = self.request_header.op_code(),
896 qflags = self.request_header.flags(),
897 code = response_code,
898 answers = answer_count,
899 authorities = authority_count,
900 additionals = additional_count,
901 rflags = rflags
902 );
903 for query in self.queries.iter() {
904 info!(
905 "query:{query}:{qtype}:{class}",
906 query = query.name(),
907 qtype = query.query_type(),
908 class = query.query_class()
909 );
910 }
911
912 #[cfg(feature = "metrics")]
913 self.metrics.update(self, &response_info);
914
915 Ok(response_info)
916 }
917}
918
919#[cfg(feature = "metrics")]
920impl ResponseHandlerMetrics {
921 fn update(
922 &self,
923 response_handler: &ReportingResponseHandler<impl ResponseHandler>,
924 response_info: &ResponseInfo,
925 ) {
926 self.proto.increment(&response_handler.protocol);
927 self.operation
928 .increment(&response_handler.request_header.op_code());
929 self.request_flags
930 .increment(&response_handler.request_header);
931
932 self.response_code.increment(&response_info.response_code());
933 self.response_flags.increment(response_info);
934 }
935}
936
937pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
938 message_bytes: &[u8],
939 src_addr: SocketAddr,
940 protocol: Protocol,
941 access: Arc<AccessControl>,
942 request_handler: Arc<T>,
943 response_handler: R,
944) {
945 let mut decoder = BinDecoder::new(message_bytes);
946
947 let inner_handle_request = |message: MessageRequest, response_handler: R| async move {
949 if message.message_type() == MessageType::Response {
950 return;
952 }
953
954 let id = message.id();
955 let qflags = message.header().flags();
956 let qop_code = message.op_code();
957 let message_type = message.message_type();
958 let is_dnssec = message.edns().is_some_and(|edns| edns.flags().dnssec_ok);
959
960 let request = Request::new(message, src_addr, protocol);
961
962 debug!(
963 "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op} qflags:{qflags}",
964 id = id,
965 proto = protocol,
966 addr = src_addr.ip(),
967 port = src_addr.port(),
968 message_type = message_type,
969 is_dnssec = is_dnssec,
970 op = qop_code,
971 qflags = qflags
972 );
973 for query in request.queries().iter() {
974 debug!(
975 "query:{query}:{qtype}:{class}",
976 query = query.name(),
977 qtype = query.query_type(),
978 class = query.query_class()
979 );
980 }
981
982 let queries = request.queries().to_vec();
984 let reporter = ReportingResponseHandler {
985 request_header: *request.header(),
986 queries,
987 protocol,
988 src_addr,
989 handler: response_handler,
990 #[cfg(feature = "metrics")]
991 metrics: ResponseHandlerMetrics::default(),
992 };
993
994 request_handler.handle_request(&request, reporter).await;
995 };
996
997 let error_response_handler = |protocol: Protocol,
999 src_addr: SocketAddr,
1000 header: Header,
1001 queries: Queries,
1002 response_code: ResponseCode,
1003 error: Box<ProtoError>,
1004 response_handler: R| async move {
1005 debug!(
1007 "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:{response_code}:{error}",
1008 id = header.id(),
1009 proto = protocol,
1010 addr = src_addr.ip(),
1011 port = src_addr.port(),
1012 message_type = header.message_type(),
1013 op = header.op_code(),
1014 response_code = response_code,
1015 error = error,
1016 );
1017
1018 let mut reporter = ReportingResponseHandler {
1020 request_header: header,
1021 queries: queries.queries().to_vec(),
1022 protocol,
1023 src_addr,
1024 handler: response_handler,
1025 #[cfg(feature = "metrics")]
1026 metrics: ResponseHandlerMetrics::default(),
1027 };
1028
1029 let response = MessageResponseBuilder::new(&queries);
1030 let result = reporter
1031 .send_response(response.error_msg(&header, response_code))
1032 .await;
1033
1034 if let Err(e) = result {
1035 warn!("failed to return FormError to client: {}", e);
1036 }
1037 };
1038
1039 if !access.allow(src_addr.ip()) {
1040 info!(
1041 "request:Refused src:{proto}://{addr}#{port}",
1042 proto = protocol,
1043 addr = src_addr.ip(),
1044 port = src_addr.port(),
1045 );
1046
1047 let Ok(header) = Header::read(&mut decoder) else {
1048 return;
1052 };
1053 let queries = match Queries::read(&mut decoder, header.query_count() as usize) {
1054 Ok(queries) => queries,
1055 Err(_) => Queries::empty(),
1056 };
1057 error_response_handler(
1058 protocol,
1059 src_addr,
1060 header,
1061 queries,
1062 ResponseCode::Refused,
1063 Box::new(ProtoErrorKind::RequestRefused.into()),
1064 response_handler,
1065 )
1066 .await;
1067
1068 return;
1069 }
1070
1071 match MessageRequest::read(&mut decoder) {
1073 Ok(message) => {
1074 inner_handle_request(message, response_handler).await;
1075 }
1076 Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
1077 let (header, error) = kind
1079 .into_form_error()
1080 .expect("as form_error already confirmed this is a FormError");
1081 let queries = Queries::empty();
1082
1083 error_response_handler(
1084 protocol,
1085 src_addr,
1086 header,
1087 queries,
1088 ResponseCode::FormErr,
1089 error,
1090 response_handler,
1091 )
1092 .await;
1093 }
1094 Err(error) => info!(
1095 "request:Failed src:{proto}://{addr}#{port} error:{error}",
1096 proto = protocol,
1097 addr = src_addr.ip(),
1098 port = src_addr.port(),
1099 ),
1100 }
1101}
1102
1103fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
1111 if src.port() == 0 {
1113 return Err(format!("cannot respond to src on port 0: {src}"));
1114 }
1115
1116 fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
1117 if src.is_unspecified() {
1118 return Err(format!("cannot respond to unspecified v4 addr: {src}"));
1119 }
1120
1121 if src.is_broadcast() {
1122 return Err(format!("cannot respond to broadcast v4 addr: {src}"));
1123 }
1124
1125 Ok(())
1128 }
1129
1130 fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
1131 if src.is_unspecified() {
1132 return Err(format!("cannot respond to unspecified v6 addr: {src}"));
1133 }
1134
1135 Ok(())
1136 }
1137
1138 match src.ip() {
1140 IpAddr::V4(v4) => verify_v4(v4),
1141 IpAddr::V6(v6) => verify_v6(v6),
1142 }
1143}
1144
1145fn is_unrecoverable_socket_error(err: &io::Error) -> bool {
1146 matches!(
1147 err.kind(),
1148 io::ErrorKind::NotConnected | io::ErrorKind::ConnectionAborted
1149 )
1150}
1151
1152#[cfg(test)]
1153mod tests {
1154 use super::*;
1155 use crate::authority::Catalog;
1156 use futures_util::future;
1157 #[cfg(feature = "__tls")]
1158 use rustls::{
1159 pki_types::{CertificateDer, PrivateKeyDer},
1160 sign::{CertifiedKey, SingleCertAndKey},
1161 };
1162 use std::net::SocketAddr;
1163 use test_support::subscribe;
1164 use tokio::net::{TcpListener, UdpSocket};
1165 use tokio::time::timeout;
1166
1167 #[tokio::test]
1168 async fn abort() {
1169 subscribe();
1170
1171 let endpoints = Endpoints::new().await;
1172
1173 let endpoints2 = endpoints.clone();
1174 let (abortable, abort_handle) = future::abortable(async move {
1175 let mut server_future = ServerFuture::new(Catalog::new());
1176 endpoints2.register(&mut server_future).await;
1177 server_future.block_until_done().await
1178 });
1179
1180 abort_handle.abort();
1181 abortable.await.expect_err("expected abort");
1182
1183 endpoints.rebind_all().await;
1184 }
1185
1186 #[tokio::test]
1187 async fn graceful_shutdown() {
1188 subscribe();
1189 let mut server_future = ServerFuture::new(Catalog::new());
1190 let endpoints = Endpoints::new().await;
1191 endpoints.register(&mut server_future).await;
1192
1193 timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1194 .await
1195 .expect("timed out waiting for the server to complete")
1196 .expect("error while awaiting tasks");
1197
1198 endpoints.rebind_all().await;
1199 }
1200
1201 #[test]
1202 fn test_sanitize_src_addr() {
1203 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4_096))).is_ok());
1205 assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1206
1207 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1208 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1209 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4_096))).is_err());
1210 assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4_096))).is_err());
1211
1212 assert!(
1214 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4_096))).is_ok()
1215 );
1216 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4_096))).is_ok());
1217
1218 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4_096))).is_err());
1219 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1220 assert!(
1221 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1222 );
1223 }
1224
1225 #[derive(Clone)]
1226 struct Endpoints {
1227 udp_addr: SocketAddr,
1228 udp_std_addr: SocketAddr,
1229 tcp_addr: SocketAddr,
1230 tcp_std_addr: SocketAddr,
1231 #[cfg(feature = "__tls")]
1232 rustls_addr: SocketAddr,
1233 #[cfg(feature = "__https")]
1234 https_rustls_addr: SocketAddr,
1235 #[cfg(feature = "__quic")]
1236 quic_addr: SocketAddr,
1237 #[cfg(feature = "__h3")]
1238 h3_addr: SocketAddr,
1239 }
1240
1241 impl Endpoints {
1242 async fn new() -> Self {
1243 let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1244 let udp_std = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1245 let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1246 let tcp_std = TcpListener::bind("127.0.0.1:0").await.unwrap();
1247 #[cfg(feature = "__tls")]
1248 let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1249 #[cfg(feature = "__https")]
1250 let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1251 #[cfg(feature = "__quic")]
1252 let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1253 #[cfg(feature = "__h3")]
1254 let h3 = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1255
1256 Self {
1257 udp_addr: udp.local_addr().unwrap(),
1258 udp_std_addr: udp_std.local_addr().unwrap(),
1259 tcp_addr: tcp.local_addr().unwrap(),
1260 tcp_std_addr: tcp_std.local_addr().unwrap(),
1261 #[cfg(feature = "__tls")]
1262 rustls_addr: rustls.local_addr().unwrap(),
1263 #[cfg(feature = "__https")]
1264 https_rustls_addr: https_rustls.local_addr().unwrap(),
1265 #[cfg(feature = "__quic")]
1266 quic_addr: quic.local_addr().unwrap(),
1267 #[cfg(feature = "__h3")]
1268 h3_addr: h3.local_addr().unwrap(),
1269 }
1270 }
1271
1272 async fn register<T: RequestHandler>(&self, server: &mut ServerFuture<T>) {
1273 server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1274 server
1275 .register_socket_std(std::net::UdpSocket::bind(self.udp_std_addr).unwrap())
1276 .unwrap();
1277 server.register_listener(
1278 TcpListener::bind(self.tcp_addr).await.unwrap(),
1279 Duration::from_secs(1),
1280 );
1281 server
1282 .register_listener_std(
1283 std::net::TcpListener::bind(self.tcp_std_addr).unwrap(),
1284 Duration::from_secs(1),
1285 )
1286 .unwrap();
1287
1288 #[cfg(feature = "__tls")]
1289 {
1290 let cert_key = rustls_cert_key();
1291 server
1292 .register_tls_listener(
1293 TcpListener::bind(self.rustls_addr).await.unwrap(),
1294 Duration::from_secs(30),
1295 cert_key,
1296 )
1297 .unwrap();
1298 }
1299
1300 #[cfg(feature = "__https")]
1301 {
1302 let cert_key = rustls_cert_key();
1303 server
1304 .register_https_listener(
1305 TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1306 Duration::from_secs(1),
1307 cert_key,
1308 None,
1309 "/dns-query".into(),
1310 )
1311 .unwrap();
1312 }
1313
1314 #[cfg(feature = "__quic")]
1315 {
1316 let cert_key = rustls_cert_key();
1317 server
1318 .register_quic_listener(
1319 UdpSocket::bind(self.quic_addr).await.unwrap(),
1320 Duration::from_secs(1),
1321 cert_key,
1322 None,
1323 )
1324 .unwrap();
1325 }
1326
1327 #[cfg(feature = "__h3")]
1328 {
1329 let cert_key = rustls_cert_key();
1330 server
1331 .register_h3_listener(
1332 UdpSocket::bind(self.h3_addr).await.unwrap(),
1333 Duration::from_secs(1),
1334 cert_key,
1335 None,
1336 )
1337 .unwrap();
1338 }
1339 }
1340
1341 async fn rebind_all(&self) {
1342 UdpSocket::bind(self.udp_addr).await.unwrap();
1343 UdpSocket::bind(self.udp_std_addr).await.unwrap();
1344 TcpListener::bind(self.tcp_addr).await.unwrap();
1345 TcpListener::bind(self.tcp_std_addr).await.unwrap();
1346 #[cfg(feature = "__tls")]
1347 TcpListener::bind(self.rustls_addr).await.unwrap();
1348 #[cfg(feature = "__https")]
1349 TcpListener::bind(self.https_rustls_addr).await.unwrap();
1350 #[cfg(feature = "__quic")]
1351 UdpSocket::bind(self.quic_addr).await.unwrap();
1352 #[cfg(feature = "__h3")]
1353 UdpSocket::bind(self.h3_addr).await.unwrap();
1354 }
1355 }
1356
1357 #[cfg(feature = "__tls")]
1358 fn rustls_cert_key() -> Arc<dyn ResolvesServerCert> {
1359 use rustls::pki_types::pem::PemObject;
1360 use std::env;
1361
1362 let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1363 let cert_chain =
1364 CertificateDer::pem_file_iter(format!("{}/tests/test-data/cert.pem", server_path))
1365 .unwrap()
1366 .collect::<Result<Vec<_>, _>>()
1367 .unwrap();
1368
1369 let key = PrivateKeyDer::from_pem_file(format!("{server_path}/tests/test-data/cert.key"))
1370 .unwrap();
1371
1372 let certified_key = CertifiedKey::from_der(cert_chain, key, &default_provider()).unwrap();
1373 Arc::new(SingleCertAndKey::from(certified_key))
1374 }
1375
1376 #[test]
1377 fn task_reap_on_empty_joinset() {
1378 let mut joinset = JoinSet::new();
1379
1380 reap_tasks(&mut joinset);
1382 }
1383
1384 #[tokio::test]
1385 async fn task_reap_on_nonempty_joinset() {
1386 let mut joinset = JoinSet::new();
1387 let t = joinset.spawn(tokio::time::sleep(Duration::from_secs(2)));
1388
1389 reap_tasks(&mut joinset);
1391 t.abort();
1392
1393 reap_tasks(&mut joinset);
1395 }
1396}