1#[cfg(not(test))]
9use std::time::{Duration, Instant};
10use std::{
11 cmp,
12 fmt::Debug,
13 marker::PhantomData,
14 net::IpAddr,
15 sync::{
16 Arc,
17 atomic::{AtomicU8, AtomicU32, Ordering},
18 },
19};
20
21use futures_util::lock::Mutex as AsyncMutex;
22use parking_lot::Mutex as SyncMutex;
23#[cfg(test)]
24use tokio::time::{Duration, Instant};
25use tracing::{debug, error, warn};
26
27#[cfg(feature = "metrics")]
28use crate::metrics::ResolverMetrics;
29#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
30use crate::metrics::opportunistic_encryption::ProbeMetrics;
31use crate::{
32 config::{
33 ConnectionConfig, NameServerConfig, OpportunisticEncryption, ResolverOpts,
34 ServerOrderingStrategy,
35 },
36 connection_provider::ConnectionProvider,
37 name_server_pool::{NameServerTransportState, PoolContext},
38 net::{
39 DnsError, NetError, NoRecords,
40 runtime::{RuntimeProvider, Spawn},
41 xfer::{DnsHandle, FirstAnswer, Protocol},
42 },
43 proto::{
44 op::{DnsRequest, DnsRequestOptions, DnsResponse, Query, ResponseCode},
45 rr::{Name, RecordType},
46 },
47};
48
49pub struct NameServer<P: ConnectionProvider> {
54 config: NameServerConfig,
55 connections: AsyncMutex<Vec<ConnectionState<P>>>,
56 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
58 opportunistic_probe_metrics: ProbeMetrics,
59 #[cfg(feature = "metrics")]
61 resolver_metrics: ResolverMetrics,
62 server_srtt: DecayingSrtt,
63 connection_provider: P,
64}
65
66impl<P: ConnectionProvider> NameServer<P> {
67 pub fn new(
71 connections: impl IntoIterator<Item = (Protocol, P::Conn)>,
72 config: NameServerConfig,
73 options: &ResolverOpts,
74 connection_provider: P,
75 ) -> Self {
76 let mut connections = connections
77 .into_iter()
78 .map(|(protocol, handle)| ConnectionState::new(handle, protocol))
79 .collect::<Vec<_>>();
80
81 if options.server_ordering_strategy != ServerOrderingStrategy::UserProvidedOrder {
84 connections.sort_by_key(|ns| ns.protocol != Protocol::Udp);
85 }
86
87 Self {
88 config,
89 connections: AsyncMutex::new(connections),
90 server_srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
91 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
92 opportunistic_probe_metrics: ProbeMetrics::default(),
93 #[cfg(feature = "metrics")]
94 resolver_metrics: ResolverMetrics::default(),
95 connection_provider,
96 }
97 }
98
99 pub(crate) async fn send(
101 self: Arc<Self>,
102 request: DnsRequest,
103 policy: ConnectionPolicy,
104 cx: &Arc<PoolContext>,
105 ) -> Result<DnsResponse, NetError> {
106 let (handle, meta, protocol) = self.connected_mut_client(policy, cx).await?;
107 #[cfg(feature = "metrics")]
108 self.resolver_metrics.increment_outgoing_query(&protocol);
109 let now = Instant::now();
110 let response = handle.send(request).first_answer().await;
111 let rtt = now.elapsed();
112
113 match response {
114 Ok(response) => {
115 meta.set_status(Status::Established);
116 let result = DnsError::from_response(response);
117 let error = match result {
118 Ok(response) => {
119 meta.srtt.record(rtt);
120 self.server_srtt.record(rtt);
121 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
122 cx.transport_state()
123 .await
124 .response_received(self.config.ip, protocol);
125 }
126 return Ok(response);
127 }
128 Err(error) => error,
129 };
130
131 let update = match error {
132 DnsError::NoRecordsFound(NoRecords {
133 response_code: ResponseCode::ServFail,
134 ..
135 }) => Some(true),
136 DnsError::NoRecordsFound(NoRecords { .. }) => Some(false),
137 _ => None,
138 };
139
140 match update {
141 Some(true) => {
142 meta.srtt.record(rtt);
143 self.server_srtt.record(rtt);
144 }
145 Some(false) => {
146 meta.srtt.record_failure();
148 self.server_srtt.record_failure();
149 }
150 None => {}
151 }
152
153 let err = NetError::from(error);
154 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
155 cx.transport_state()
156 .await
157 .error_received(self.config.ip, protocol, &err)
158 }
159 Err(err)
160 }
161 Err(error) => {
162 debug!(config = ?self.config, %error, "failed to connect to name server");
163
164 meta.set_status(Status::Failed);
166
167 match &error {
171 NetError::Busy | NetError::Io(_) | NetError::Timeout => {
172 meta.srtt.record_failure();
173 self.server_srtt.record_failure();
174 }
175 #[cfg(feature = "__quic")]
176 NetError::QuinnConfigError(_)
177 | NetError::QuinnConnect(_)
178 | NetError::QuinnConnection(_)
179 | NetError::QuinnTlsConfigError(_) => {
180 meta.srtt.record_failure();
181 self.server_srtt.record_failure();
182 }
183 #[cfg(feature = "__tls")]
184 NetError::RustlsError(_) => {
185 meta.srtt.record_failure();
186 self.server_srtt.record_failure();
187 }
188 _ => {}
189 }
190
191 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
192 cx.transport_state()
193 .await
194 .error_received(self.config.ip, protocol, &error);
195 }
196
197 Err(error)
199 }
200 }
201 }
202
203 async fn connected_mut_client(
207 &self,
208 policy: ConnectionPolicy,
209 cx: &Arc<PoolContext>,
210 ) -> Result<(P::Conn, Arc<ConnectionMeta>, Protocol), NetError> {
211 let mut connections = self.connections.lock().await;
212 connections.retain(|conn| matches!(conn.meta.status(), Status::Init | Status::Established));
213 if let Some(conn) = policy.select_connection(
214 self.config.ip,
215 &*cx.transport_state().await,
216 &cx.opportunistic_encryption,
217 &connections,
218 ) {
219 return Ok((conn.handle.clone(), conn.meta.clone(), conn.protocol));
220 }
221
222 debug!(config = ?self.config, "connecting");
223 let config = policy
224 .select_connection_config(
225 self.config.ip,
226 &*cx.transport_state().await,
227 &cx.opportunistic_encryption,
228 &self.config.connections,
229 )
230 .ok_or(NetError::NoConnections)?;
231
232 let protocol = config.protocol.to_protocol();
233 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
234 cx.transport_state()
235 .await
236 .initiate_connection(self.config.ip, protocol);
237 } else if cx.opportunistic_encryption.is_enabled() && !protocol.is_encrypted() {
238 self.consider_probe_encrypted_transport(&policy, cx).await;
239 }
240
241 let handle = Box::pin(self.connection_provider.new_connection(
242 self.config.ip,
243 config,
244 cx,
245 )?)
246 .await?;
247
248 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
249 cx.transport_state()
250 .await
251 .complete_connection(self.config.ip, protocol);
252 }
253
254 let state = ConnectionState::new(handle.clone(), protocol);
256 let meta = state.meta.clone();
257 connections.push(state);
258 Ok((handle, meta, protocol))
259 }
260
261 pub(super) fn protocols(&self) -> impl Iterator<Item = Protocol> + '_ {
262 self.config
263 .connections
264 .iter()
265 .map(|conn| conn.protocol.to_protocol())
266 }
267
268 pub(super) fn ip(&self) -> IpAddr {
269 self.config.ip
270 }
271
272 pub(crate) fn decayed_srtt(&self) -> f64 {
273 self.server_srtt.current()
274 }
275
276 pub(super) fn record_cancelled(&self, winner_rtt: Duration) {
290 const CANCEL_PENALTY: Duration = Duration::from_millis(5);
291 self.server_srtt.record(winner_rtt + CANCEL_PENALTY);
292 }
293
294 #[cfg(test)]
295 pub(crate) fn test_record_failure(&self) {
296 self.server_srtt.record_failure();
297 }
298
299 #[cfg(test)]
300 #[allow(dead_code)]
301 pub(crate) fn is_connected(&self) -> bool {
302 let Some(connections) = self.connections.try_lock() else {
303 return true;
305 };
306
307 connections.iter().any(|conn| match conn.meta.status() {
308 Status::Established | Status::Init => true,
309 Status::Failed => false,
310 })
311 }
312
313 pub(crate) fn trust_negative_responses(&self) -> bool {
314 self.config.trust_negative_responses
315 }
316
317 async fn consider_probe_encrypted_transport(
318 &self,
319 policy: &ConnectionPolicy,
320 cx: &Arc<PoolContext>,
321 ) {
322 let Some(probe_config) =
323 policy.select_encrypted_connection_config(&self.config.connections)
324 else {
325 warn!("no encrypted connection configs available for probing");
326 return;
327 };
328
329 let probe_protocol = probe_config.protocol.to_protocol();
330 let should_probe = {
331 let state = cx.transport_state().await;
332 state.should_probe_encrypted(
333 self.config.ip,
334 probe_protocol,
335 &cx.opportunistic_encryption,
336 )
337 };
338
339 if !should_probe {
340 return;
341 }
342
343 if let Err(err) = self.probe_encrypted_transport(cx, probe_config) {
344 error!(%err, "opportunistic encrypted probe attempt failed");
345 }
346 }
347
348 fn probe_encrypted_transport(
349 &self,
350 cx: &Arc<PoolContext>,
351 probe_config: &ConnectionConfig,
352 ) -> Result<(), NetError> {
353 let mut budget = cx.opportunistic_probe_budget.load(Ordering::Relaxed);
354 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
355 self.opportunistic_probe_metrics.probe_budget.set(budget);
356 loop {
357 if budget == 0 {
358 debug!("no remaining budget for opportunistic probing");
359 return Ok(());
360 }
361 match cx.opportunistic_probe_budget.compare_exchange_weak(
362 budget,
363 budget - 1,
364 Ordering::AcqRel,
365 Ordering::Relaxed,
366 ) {
367 Ok(_) => break,
368 Err(current) => budget = current,
369 }
370 }
371
372 let connect = ProbeRequest::new(
373 probe_config,
374 self,
375 cx,
376 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
377 self.opportunistic_probe_metrics.clone(),
378 )?;
379 self.connection_provider
380 .runtime_provider()
381 .create_handle()
382 .spawn_bg(connect.run());
383
384 Ok(())
385 }
386}
387
388struct ProbeRequest<P: ConnectionProvider> {
389 ip: IpAddr,
390 proto: Protocol,
391 connecting: P::FutureConn,
392 context: Arc<PoolContext>,
393 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
394 metrics: ProbeMetrics,
395 provider: PhantomData<P>,
396}
397
398impl<P: ConnectionProvider> ProbeRequest<P> {
399 fn new(
400 config: &ConnectionConfig,
401 ns: &NameServer<P>,
402 cx: &Arc<PoolContext>,
403 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
404 metrics: ProbeMetrics,
405 ) -> Result<Self, NetError> {
406 Ok(Self {
407 ip: ns.config.ip,
408 proto: config.protocol.to_protocol(),
409 connecting: ns
410 .connection_provider
411 .new_connection(ns.config.ip, config, cx)?,
412 context: cx.clone(),
413 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
414 metrics,
415 provider: PhantomData,
416 })
417 }
418
419 async fn run(self) {
420 let Self {
421 ip,
422 proto,
423 connecting,
424 context,
425 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
426 metrics,
427 provider: _,
428 } = self;
429
430 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
431 let start = Instant::now();
432
433 context
434 .transport_state()
435 .await
436 .initiate_connection(ip, proto);
437 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
438 metrics.increment_attempts(proto);
439
440 let conn = match connecting.await {
441 Ok(conn) => conn,
442 Err(err) => {
443 debug!(?proto, "probe connection failed");
444 let _prev = context
445 .opportunistic_probe_budget
446 .fetch_add(1, Ordering::Relaxed);
447 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
448 {
449 metrics.increment_errors(proto, &err);
450 metrics.probe_budget.set(_prev + 1);
451 metrics.record_probe_duration(proto, start.elapsed());
452 }
453 context
454 .transport_state()
455 .await
456 .error_received(ip, proto, &err);
457 return;
458 }
459 };
460
461 debug!(?proto, "probe connection succeeded");
462 context
463 .transport_state()
464 .await
465 .complete_connection(ip, proto);
466
467 match conn
468 .send(DnsRequest::from_query(
469 Query::query(Name::root(), RecordType::NS),
470 DnsRequestOptions::default(),
471 ))
472 .first_answer()
473 .await
474 {
475 Ok(_) => {
476 debug!(?proto, "probe query succeeded");
477 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
478 metrics.increment_successes(proto);
479 context.transport_state().await.response_received(ip, proto);
480 }
481 Err(err) => {
482 debug!(?proto, ?err, "probe query failed");
483 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
484 metrics.increment_errors(proto, &err);
485 context
486 .transport_state()
487 .await
488 .error_received(ip, proto, &err);
489 }
490 }
491
492 let _prev = context
493 .opportunistic_probe_budget
494 .fetch_add(1, Ordering::Relaxed);
495 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
496 {
497 metrics.probe_budget.set(_prev + 1);
498 metrics.record_probe_duration(proto, start.elapsed());
499 }
500 }
501}
502
503struct ConnectionState<P: ConnectionProvider> {
504 protocol: Protocol,
505 handle: P::Conn,
506 meta: Arc<ConnectionMeta>,
507}
508
509impl<P: ConnectionProvider> ConnectionState<P> {
510 fn new(handle: P::Conn, protocol: Protocol) -> Self {
511 Self {
512 protocol,
513 handle,
514 meta: Arc::new(ConnectionMeta::default()),
515 }
516 }
517}
518
519struct ConnectionMeta {
520 status: AtomicU8,
521 srtt: DecayingSrtt,
522}
523
524impl ConnectionMeta {
525 fn set_status(&self, status: Status) {
526 self.status.store(status.into(), Ordering::Release);
527 }
528
529 fn status(&self) -> Status {
530 Status::from(self.status.load(Ordering::Acquire))
531 }
532}
533
534impl Default for ConnectionMeta {
535 fn default() -> Self {
536 Self {
540 status: AtomicU8::new(Status::Init.into()),
541 srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
542 }
543 }
544}
545
546struct DecayingSrtt {
547 srtt_microseconds: AtomicU32,
575
576 last_update: SyncMutex<Option<Instant>>,
578}
579
580impl DecayingSrtt {
581 fn new(initial_srtt: Duration) -> Self {
582 Self {
583 srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
584 last_update: SyncMutex::new(None),
585 }
586 }
587
588 fn record(&self, rtt: Duration) {
589 self.update(
594 rtt.as_micros() as u32,
595 |cur_srtt_microseconds, last_update| {
596 let factor = compute_srtt_factor(last_update, 3);
600 let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
601 + factor * f64::from(cur_srtt_microseconds);
602 new_srtt.round() as u32
603 },
604 );
605 }
606
607 fn record_failure(&self) {
609 self.update(
610 Self::FAILURE_PENALTY,
611 |cur_srtt_microseconds, _last_update| {
612 cur_srtt_microseconds.saturating_add(Self::FAILURE_PENALTY)
613 },
614 );
615 }
616
617 fn current(&self) -> f64 {
626 let srtt = f64::from(self.srtt_microseconds.load(Ordering::Acquire));
627 self.last_update.lock().map_or(srtt, |last_update| {
628 srtt * compute_srtt_factor(last_update, 180)
639 })
640 }
641
642 fn update(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
648 let last_update = self.last_update.lock().replace(Instant::now());
649 let _ = self.srtt_microseconds.fetch_update(
650 Ordering::SeqCst,
651 Ordering::SeqCst,
652 move |cur_srtt_microseconds| {
653 Some(
654 last_update
655 .map_or(default, |last_update| {
656 update_fn(cur_srtt_microseconds, last_update)
657 })
658 .min(Self::MAX_SRTT_MICROS),
659 )
660 },
661 );
662 }
663
664 #[cfg(all(test, feature = "tokio"))]
668 fn as_duration(&self) -> Duration {
669 Duration::from_micros(u64::from(self.srtt_microseconds.load(Ordering::Acquire)))
670 }
671
672 const FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
673 const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
674}
675
676fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
685 let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
686 exponent.exp()
687}
688
689#[derive(Debug, Eq, PartialEq, Copy, Clone)]
691#[repr(u8)]
692enum Status {
693 Failed = 0,
699 Init = 1,
701 Established = 2,
704}
705
706impl From<Status> for u8 {
707 fn from(val: Status) -> Self {
709 val as Self
710 }
711}
712
713impl From<u8> for Status {
714 fn from(val: u8) -> Self {
715 match val {
716 2 => Self::Established,
717 1 => Self::Init,
718 _ => Self::Failed,
719 }
720 }
721}
722
723#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
724pub(crate) struct ConnectionPolicy {
725 pub(crate) disable_udp: bool,
726}
727
728impl ConnectionPolicy {
729 pub(crate) fn allows_server<P: ConnectionProvider>(&self, server: &NameServer<P>) -> bool {
731 server.protocols().any(|p| self.allows_protocol(p))
732 }
733
734 fn select_connection<'a, P: ConnectionProvider>(
739 &self,
740 ip: IpAddr,
741 encrypted_transport_state: &NameServerTransportState,
742 opportunistic_encryption: &OpportunisticEncryption,
743 connections: &'a [ConnectionState<P>],
744 ) -> Option<&'a ConnectionState<P>> {
745 let selected = connections
746 .iter()
747 .filter(|conn| self.allows_protocol(conn.protocol))
748 .min_by(|a, b| self.compare_connections(opportunistic_encryption.is_enabled(), a, b));
749
750 let selected = selected?;
751
752 match opportunistic_encryption.is_enabled()
758 && !selected.protocol.is_encrypted()
759 && encrypted_transport_state.any_recent_success(ip, opportunistic_encryption)
760 {
761 true => None,
762 false => Some(selected),
763 }
764 }
765
766 fn select_connection_config<'a>(
771 &self,
772 ip: IpAddr,
773 encrypted_transport_state: &NameServerTransportState,
774 opportunistic_encryption: &OpportunisticEncryption,
775 connection_configs: &'a [ConnectionConfig],
776 ) -> Option<&'a ConnectionConfig> {
777 connection_configs
778 .iter()
779 .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
780 .min_by(|a, b| {
781 self.compare_connection_configs(
782 ip,
783 encrypted_transport_state,
784 opportunistic_encryption,
785 a,
786 b,
787 )
788 })
789 }
790
791 fn select_encrypted_connection_config<'a>(
793 &self,
794 connection_config: &'a [ConnectionConfig],
795 ) -> Option<&'a ConnectionConfig> {
796 connection_config
797 .iter()
798 .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
799 .find(|c| c.protocol.to_protocol().is_encrypted())
800 }
801
802 fn allows_protocol(&self, protocol: Protocol) -> bool {
804 !(self.disable_udp && protocol == Protocol::Udp)
805 }
806
807 fn compare_connections<P: ConnectionProvider>(
810 &self,
811 opportunistic_encryption: bool,
812 a: &ConnectionState<P>,
813 b: &ConnectionState<P>,
814 ) -> cmp::Ordering {
815 if opportunistic_encryption {
818 match (a.protocol.is_encrypted(), b.protocol.is_encrypted()) {
819 (true, false) => return cmp::Ordering::Less,
820 (false, true) => return cmp::Ordering::Greater,
821 _ => {}
823 }
824 }
825
826 match (a.protocol, b.protocol) {
827 (ap, bp) if ap == bp => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
828 (Protocol::Udp, _) => cmp::Ordering::Less,
829 (_, Protocol::Udp) => cmp::Ordering::Greater,
830 _ => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
831 }
832 }
833
834 fn compare_connection_configs(
835 &self,
836 ip: IpAddr,
837 encrypted_transport_state: &NameServerTransportState,
838 opportunistic_encryption: &OpportunisticEncryption,
839 a: &ConnectionConfig,
840 b: &ConnectionConfig,
841 ) -> cmp::Ordering {
842 let a_protocol = a.protocol.to_protocol();
843 let b_protocol = b.protocol.to_protocol();
844
845 if opportunistic_encryption.is_enabled() {
848 let a_recent_enc_success = a_protocol.is_encrypted()
849 && encrypted_transport_state.recent_success(
850 ip,
851 a_protocol,
852 opportunistic_encryption,
853 );
854 let b_recent_enc_success = b_protocol.is_encrypted()
855 && encrypted_transport_state.recent_success(
856 ip,
857 b_protocol,
858 opportunistic_encryption,
859 );
860
861 match (a_recent_enc_success, b_recent_enc_success) {
862 (true, false) => return cmp::Ordering::Less,
863 (false, true) => return cmp::Ordering::Greater,
864 _ => {}
866 }
867 }
868
869 match (a_protocol, b_protocol) {
871 (ap, bp) if ap == bp => cmp::Ordering::Equal,
872 (Protocol::Udp, _) => cmp::Ordering::Less,
873 (_, Protocol::Udp) => cmp::Ordering::Greater,
874 _ => cmp::Ordering::Equal,
875 }
876 }
877}
878
879#[cfg(all(test, feature = "tokio"))]
880mod tests {
881 use std::cmp;
882 use std::net::{IpAddr, Ipv4Addr};
883 use std::str::FromStr;
884 use std::time::Duration;
885
886 use test_support::subscribe;
887 use tokio::net::UdpSocket;
888 use tokio::spawn;
889
890 use super::*;
891 use crate::config::{ConnectionConfig, ProtocolConfig};
892 use crate::connection_provider::TlsConfig;
893 use crate::net::runtime::TokioRuntimeProvider;
894 use crate::proto::op::{DnsRequest, DnsRequestOptions, Message, Query, ResponseCode};
895 use crate::proto::rr::rdata::NULL;
896 use crate::proto::rr::{Name, RData, Record, RecordType};
897
898 #[tokio::test]
899 async fn test_name_server() {
900 subscribe();
901
902 let options = ResolverOpts::default();
903 let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
904 let name_server = Arc::new(NameServer::new(
905 [].into_iter(),
906 config,
907 &options,
908 TokioRuntimeProvider::default(),
909 ));
910
911 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
912 let name = Name::parse("www.example.com.", None).unwrap();
913 let response = name_server
914 .send(
915 DnsRequest::from_query(
916 Query::query(name.clone(), RecordType::A),
917 DnsRequestOptions::default(),
918 ),
919 ConnectionPolicy::default(),
920 &cx,
921 )
922 .await
923 .expect("query failed");
924 assert_eq!(response.response_code, ResponseCode::NoError);
925 }
926
927 #[tokio::test]
928 async fn test_failed_name_server() {
929 subscribe();
930
931 let options = ResolverOpts {
932 timeout: Duration::from_millis(1), ..ResolverOpts::default()
934 };
935
936 let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)));
937 let name_server = Arc::new(NameServer::new(
938 [],
939 config,
940 &options,
941 TokioRuntimeProvider::default(),
942 ));
943
944 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
945 let name = Name::parse("www.example.com.", None).unwrap();
946 assert!(
947 name_server
948 .send(
949 DnsRequest::from_query(
950 Query::query(name.clone(), RecordType::A),
951 DnsRequestOptions::default(),
952 ),
953 ConnectionPolicy::default(),
954 &cx
955 )
956 .await
957 .is_err()
958 );
959 }
960
961 #[tokio::test]
962 async fn case_randomization_query_preserved() {
963 subscribe();
964
965 let provider = TokioRuntimeProvider::default();
966 let server = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
967 let server_addr = server.local_addr().unwrap();
968 let name = Name::from_str("dead.beef.").unwrap();
969 let data = b"DEADBEEF";
970
971 spawn({
972 let name = name.clone();
973 async move {
974 let mut buffer = [0_u8; 512];
975 let (len, addr) = server.recv_from(&mut buffer).await.unwrap();
976 let request = Message::from_vec(&buffer[0..len]).unwrap();
977 let mut response = Message::response(request.id, request.op_code);
978 response.add_queries(request.queries.to_vec());
979 response.add_answer(Record::from_rdata(
980 name,
981 0,
982 RData::NULL(NULL::with(data.to_vec())),
983 ));
984 let response_buffer = response.to_vec().unwrap();
985 server.send_to(&response_buffer, addr).await.unwrap();
986 }
987 });
988
989 let config = NameServerConfig {
990 ip: server_addr.ip(),
991 trust_negative_responses: true,
992 connections: vec![ConnectionConfig {
993 port: server_addr.port(),
994 protocol: ProtocolConfig::Udp,
995 bind_addr: None,
996 }],
997 };
998
999 let resolver_opts = ResolverOpts {
1000 case_randomization: true,
1001 ..Default::default()
1002 };
1003
1004 let cx = Arc::new(PoolContext::new(resolver_opts, TlsConfig::new().unwrap()));
1005 let mut request_options = DnsRequestOptions::default();
1006 request_options.case_randomization = true;
1007 let ns = Arc::new(NameServer::new([], config, &cx.options, provider));
1008 let response = ns
1009 .send(
1010 DnsRequest::from_query(
1011 Query::query(name.clone(), RecordType::NULL),
1012 request_options,
1013 ),
1014 ConnectionPolicy::default(),
1015 &cx,
1016 )
1017 .await
1018 .unwrap();
1019
1020 let response_query_name = response.queries.first().unwrap().name();
1021 assert!(response_query_name.eq_case(&name));
1022 }
1023
1024 #[allow(clippy::extra_unused_type_parameters)]
1025 fn is_send_sync<S: Sync + Send>() -> bool {
1026 true
1027 }
1028
1029 #[test]
1030 fn stats_are_sync() {
1031 assert!(is_send_sync::<ConnectionMeta>());
1032 }
1033
1034 #[tokio::test(start_paused = true)]
1035 async fn test_stats_cmp() {
1036 use std::cmp::Ordering;
1037 let srtt_a = DecayingSrtt::new(Duration::from_micros(10));
1038 let srtt_b = DecayingSrtt::new(Duration::from_micros(20));
1039
1040 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1043
1044 srtt_a.record(Duration::from_millis(30));
1046 tokio::time::advance(Duration::from_secs(5)).await;
1047 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1048
1049 srtt_b.record(Duration::from_millis(50));
1052 tokio::time::advance(Duration::from_secs(5)).await;
1053 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1054
1055 srtt_a.record_failure();
1058 tokio::time::advance(Duration::from_secs(5)).await;
1059 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1060
1061 while cmp(&srtt_a, &srtt_b) != Ordering::Less {
1063 srtt_b.record(Duration::from_millis(50));
1064 tokio::time::advance(Duration::from_secs(5)).await;
1065 }
1066
1067 srtt_a.record(Duration::from_millis(30));
1068 tokio::time::advance(Duration::from_secs(3)).await;
1069 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1070 }
1071
1072 fn cmp(a: &DecayingSrtt, b: &DecayingSrtt) -> cmp::Ordering {
1073 a.current().total_cmp(&b.current())
1074 }
1075
1076 #[tokio::test(start_paused = true)]
1077 async fn test_record_rtt() {
1078 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1079
1080 let first_rtt = Duration::from_millis(50);
1081 srtt.record(first_rtt);
1082
1083 assert_eq!(srtt.as_duration(), first_rtt);
1085
1086 tokio::time::advance(Duration::from_secs(3)).await;
1087
1088 srtt.record(Duration::from_millis(100));
1090 assert_eq!(srtt.as_duration(), Duration::from_micros(81606));
1091 }
1092
1093 #[test]
1094 fn test_record_rtt_maximum_value() {
1095 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1096
1097 srtt.record(Duration::MAX);
1098 assert_eq!(
1100 srtt.as_duration(),
1101 Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1102 );
1103 }
1104
1105 #[tokio::test(start_paused = true)]
1106 async fn test_record_connection_failure() {
1107 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1108
1109 for failure_count in 1..4 {
1112 srtt.record_failure();
1113 assert_eq!(
1114 srtt.as_duration(),
1115 Duration::from_micros(
1116 DecayingSrtt::FAILURE_PENALTY
1117 .checked_mul(failure_count)
1118 .expect("checked_mul overflow")
1119 .into()
1120 )
1121 );
1122 tokio::time::advance(Duration::from_secs(3)).await;
1123 }
1124
1125 srtt.record(Duration::from_millis(50));
1128 assert_eq!(srtt.as_duration(), Duration::from_micros(197152));
1129 }
1130
1131 #[test]
1132 fn test_record_connection_failure_maximum_value() {
1133 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1134
1135 let num_failures = (DecayingSrtt::MAX_SRTT_MICROS / DecayingSrtt::FAILURE_PENALTY) + 1;
1136 for _ in 0..num_failures {
1137 srtt.record_failure();
1138 }
1139
1140 assert_eq!(
1142 srtt.as_duration(),
1143 Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1144 );
1145 }
1146
1147 #[tokio::test(start_paused = true)]
1148 async fn test_decayed_srtt() {
1149 let initial_srtt = 10;
1150 let srtt = DecayingSrtt::new(Duration::from_micros(initial_srtt));
1151
1152 assert_eq!(srtt.current() as u32, initial_srtt as u32);
1154
1155 tokio::time::advance(Duration::from_secs(5)).await;
1156 srtt.record(Duration::from_millis(100));
1157
1158 tokio::time::advance(Duration::from_millis(500)).await;
1161 assert_eq!(srtt.current() as u32, 99445);
1162
1163 tokio::time::advance(Duration::from_secs(5)).await;
1164 assert_eq!(srtt.current() as u32, 96990);
1165 }
1166}
1167
1168#[cfg(all(test, feature = "__tls"))]
1169mod opportunistic_enc_tests {
1170 use std::io;
1171 use std::net::{IpAddr, Ipv4Addr};
1172 use std::sync::Arc;
1173 use std::time::{Duration, SystemTime};
1174
1175 #[cfg(feature = "metrics")]
1176 use metrics::{Label, Unit, with_local_recorder};
1177 #[cfg(feature = "metrics")]
1178 use metrics_util::debugging::DebuggingRecorder;
1179 use mock_provider::{MockClientHandle, MockProvider};
1180 use test_support::subscribe;
1181 #[cfg(feature = "metrics")]
1182 use test_support::{assert_counter_eq, assert_gauge_eq, assert_histogram_sample_count_eq};
1183
1184 use crate::config::{
1185 NameServerConfig, OpportunisticEncryption, OpportunisticEncryptionConfig, ProtocolConfig,
1186 ResolverOpts,
1187 };
1188 use crate::connection_provider::TlsConfig;
1189 #[cfg(feature = "metrics")]
1190 use crate::metrics::opportunistic_encryption::{
1191 PROBE_ATTEMPTS_TOTAL, PROBE_BUDGET_TOTAL, PROBE_DURATION_SECONDS, PROBE_ERRORS_TOTAL,
1192 PROBE_SUCCESSES_TOTAL, PROBE_TIMEOUTS_TOTAL,
1193 };
1194 use crate::name_server::{ConnectionPolicy, ConnectionState, NameServer, mock_provider};
1195 use crate::name_server_pool::{NameServerTransportState, PoolContext};
1196 use crate::net::NetError;
1197 use crate::net::xfer::Protocol;
1198
1199 #[tokio::test]
1200 async fn test_select_connection_opportunistic_enc_disabled() {
1201 let mut policy = ConnectionPolicy::default();
1202 let connections = vec![
1203 mock_connection(Protocol::Udp),
1204 mock_connection(Protocol::Tcp),
1205 ];
1206
1207 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1208 let state = NameServerTransportState::default();
1209 let opp_enc = OpportunisticEncryption::Disabled;
1210
1211 let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1214 assert!(selected.is_some());
1215 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1216
1217 policy.disable_udp = true;
1220 let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1221 assert!(selected.is_some());
1222 assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1223 }
1224
1225 #[tokio::test]
1226 async fn test_select_connection_opportunistic_enc_enabled() {
1227 let policy = ConnectionPolicy::default();
1228 let connections = [
1229 mock_connection(Protocol::Udp),
1230 mock_connection(Protocol::Tcp),
1231 mock_connection(Protocol::Tls),
1233 ];
1234
1235 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1236 let state = NameServerTransportState::default();
1237 let opp_enc = &OpportunisticEncryption::Enabled {
1238 config: OpportunisticEncryptionConfig::default(),
1239 };
1240
1241 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1244 assert!(selected.is_some());
1245 assert_eq!(selected.unwrap().protocol, Protocol::Tls);
1246 }
1247
1248 #[tokio::test]
1249 async fn test_select_connection_opportunistic_enc_enabled_no_state() {
1250 let mut policy = ConnectionPolicy::default();
1251 let connections = [
1252 mock_connection(Protocol::Udp),
1253 mock_connection(Protocol::Tcp),
1254 ];
1256
1257 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1258 let state = NameServerTransportState::default();
1259 let opp_enc = &OpportunisticEncryption::Enabled {
1260 config: OpportunisticEncryptionConfig::default(),
1261 };
1262
1263 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1266 assert!(selected.is_some());
1267 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1268
1269 policy.disable_udp = true;
1272 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1273 assert!(selected.is_some());
1274 assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1275 }
1276
1277 #[tokio::test]
1278 async fn test_select_connection_opportunistic_enc_enabled_failed_probe() {
1279 let policy = ConnectionPolicy::default();
1280 let connections = [
1281 mock_connection(Protocol::Udp),
1282 mock_connection(Protocol::Tcp),
1283 ];
1285
1286 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1287 let mut state = NameServerTransportState::default();
1288 let opp_enc = &OpportunisticEncryption::Enabled {
1289 config: OpportunisticEncryptionConfig::default(),
1290 };
1291
1292 state.error_received(
1294 ns_ip,
1295 Protocol::Tls,
1296 &NetError::from(io::Error::new(
1297 io::ErrorKind::ConnectionRefused,
1298 "nameserver refused TLS connection",
1299 )),
1300 );
1301
1302 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1305 assert!(selected.is_some());
1306 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1307 }
1308
1309 #[tokio::test]
1310 async fn test_select_connection_opportunistic_enc_enabled_in_progress_probe() {
1311 let policy = ConnectionPolicy::default();
1312 let connections = [
1313 mock_connection(Protocol::Udp),
1314 mock_connection(Protocol::Tcp),
1315 ];
1317
1318 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1319 let mut state = NameServerTransportState::default();
1320 let opp_enc = &OpportunisticEncryption::Enabled {
1321 config: OpportunisticEncryptionConfig::default(),
1322 };
1323
1324 state.initiate_connection(ns_ip, Protocol::Tls);
1326
1327 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1330 assert!(selected.is_some());
1331 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1332
1333 state.complete_connection(ns_ip, Protocol::Tls);
1336
1337 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1339 assert!(selected.is_some());
1340 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1341 }
1342
1343 #[tokio::test]
1344 async fn test_select_connection_opportunistic_enc_enabled_stale_probe() {
1345 let policy = ConnectionPolicy::default();
1346 let connections = [
1347 mock_connection(Protocol::Udp),
1348 mock_connection(Protocol::Tcp),
1349 ];
1351
1352 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1353 let mut state = NameServerTransportState::default();
1354 let opp_enc_config = OpportunisticEncryptionConfig {
1355 persistence_period: Duration::from_secs(10),
1356 ..OpportunisticEncryptionConfig::default()
1357 };
1358 let opp_enc = &OpportunisticEncryption::Enabled {
1359 config: opp_enc_config.clone(),
1360 };
1361
1362 state.complete_connection(ns_ip, Protocol::Tls);
1364 state.response_received(ns_ip, Protocol::Tls);
1365 let stale_time =
1367 SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1368 state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1369
1370 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1374 assert!(selected.is_some());
1375 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1376 }
1377
1378 #[tokio::test]
1379 async fn test_select_connection_opportunistic_enc_enabled_good_probe() {
1380 let policy = ConnectionPolicy::default();
1381 let connections = [
1382 mock_connection(Protocol::Udp),
1383 mock_connection(Protocol::Tcp),
1384 ];
1386
1387 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1388 let mut state = NameServerTransportState::default();
1389 let opp_enc = &OpportunisticEncryption::Enabled {
1390 config: OpportunisticEncryptionConfig::default(),
1391 };
1392
1393 state.complete_connection(ns_ip, Protocol::Tls);
1396 state.response_received(ns_ip, Protocol::Tls);
1397
1398 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1402 assert!(selected.is_none());
1403 }
1404
1405 #[tokio::test]
1406 async fn test_select_connection_config_opportunistic_enc_disabled() {
1407 let mut policy = ConnectionPolicy::default();
1408
1409 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1410 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1411
1412 let state = NameServerTransportState::default();
1413 let opp_enc = OpportunisticEncryption::Disabled;
1414
1415 let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1418 assert!(selected.is_some());
1419 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1420
1421 policy.disable_udp = true;
1424 let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1425 assert!(selected.is_some());
1426 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1427 }
1428
1429 #[tokio::test]
1430 async fn test_select_connection_config_opportunistic_enc_enabled_no_state() {
1431 let mut policy = ConnectionPolicy::default();
1432 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1433 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1434
1435 let state = NameServerTransportState::default();
1436 let opp_enc = &OpportunisticEncryption::Enabled {
1437 config: OpportunisticEncryptionConfig::default(),
1438 };
1439
1440 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1443 assert!(selected.is_some());
1444 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1445
1446 policy.disable_udp = true;
1449 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1450 assert!(selected.is_some());
1451 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1452 }
1453
1454 #[tokio::test]
1455 async fn test_select_connection_config_opportunistic_enc_enabled_failed_probe() {
1456 let policy = ConnectionPolicy::default();
1457 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1458 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1459
1460 let mut state = NameServerTransportState::default();
1461 let opp_enc = &OpportunisticEncryption::Enabled {
1462 config: OpportunisticEncryptionConfig::default(),
1463 };
1464
1465 state.error_received(
1467 ns_ip,
1468 Protocol::Tls,
1469 &NetError::from(io::Error::new(
1470 io::ErrorKind::ConnectionRefused,
1471 "nameserver refused TLS connection",
1472 )),
1473 );
1474
1475 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1478 assert!(selected.is_some());
1479 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1480 }
1481
1482 #[tokio::test]
1483 async fn test_select_connection_config_opportunistic_enc_enabled_stale_probe() {
1484 let policy = ConnectionPolicy::default();
1485 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1486 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1487
1488 let mut state = NameServerTransportState::default();
1489 let opp_enc_config = OpportunisticEncryptionConfig {
1490 persistence_period: Duration::from_secs(10),
1491 ..OpportunisticEncryptionConfig::default()
1492 };
1493 let opp_enc = &OpportunisticEncryption::Enabled {
1494 config: opp_enc_config.clone(),
1495 };
1496
1497 state.complete_connection(ns_ip, Protocol::Tls);
1499 state.response_received(ns_ip, Protocol::Tls);
1500 let stale_time =
1502 SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1503 state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1504
1505 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1508 assert!(selected.is_some());
1509 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1510 }
1511
1512 #[tokio::test]
1513 async fn test_select_connection_config_opportunistic_enc_enabled_good_probe() {
1514 let policy = ConnectionPolicy::default();
1515 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1516 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1517
1518 let mut state = NameServerTransportState::default();
1519 let opp_enc = &OpportunisticEncryption::Enabled {
1520 config: OpportunisticEncryptionConfig::default(),
1521 };
1522
1523 state.complete_connection(ns_ip, Protocol::Tls);
1526 state.response_received(ns_ip, Protocol::Tls);
1527
1528 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1531 assert!(selected.is_some());
1532 assert!(matches!(
1533 selected.unwrap().protocol,
1534 ProtocolConfig::Tls { .. }
1535 ));
1536 }
1537
1538 #[tokio::test]
1539 async fn test_opportunistic_probe() {
1540 subscribe();
1541
1542 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1544 .with_opportunistic_encryption()
1545 .with_probe_budget(10);
1546
1547 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1548 let mock_provider = MockProvider::default();
1549 assert!(
1550 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1551 .await
1552 .is_ok()
1553 );
1554
1555 let recorded_calls = mock_provider.new_connection_calls();
1556 assert_eq!(recorded_calls.len(), 2);
1558 let (ips, protocols): (Vec<IpAddr>, Vec<ProtocolConfig>) =
1559 recorded_calls.into_iter().unzip();
1560 assert!(ips.iter().all(|ip| *ip == ns_ip));
1562 let protocols = protocols
1564 .iter()
1565 .map(ProtocolConfig::to_protocol)
1566 .collect::<Vec<_>>();
1567 assert!(protocols.contains(&Protocol::Udp));
1568 assert!(protocols.contains(&Protocol::Tls));
1569 }
1570
1571 #[tokio::test]
1572 async fn test_opportunistic_probe_skip_in_progress() {
1573 subscribe();
1574
1575 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1576 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1577 .with_opportunistic_encryption()
1578 .with_probe_budget(10);
1579
1580 cx.transport_state()
1582 .await
1583 .initiate_connection(ns_ip, Protocol::Tls);
1584
1585 let mock_provider = MockProvider::default();
1586 assert!(
1587 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1588 .await
1589 .is_ok()
1590 );
1591
1592 let recorded_calls = mock_provider.new_connection_calls();
1593 assert_eq!(recorded_calls.len(), 1);
1595 let (ip, protocol) = &recorded_calls[0];
1596 assert_eq!(*ip, ns_ip);
1597 assert_eq!(protocol.to_protocol(), Protocol::Udp);
1598 }
1599
1600 #[tokio::test]
1601 async fn test_opportunistic_probe_skip_recent_failure() {
1602 subscribe();
1603
1604 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1605 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1606 .with_opportunistic_encryption()
1607 .with_probe_budget(10);
1608
1609 cx.transport_state().await.error_received(
1611 ns_ip,
1612 Protocol::Tls,
1613 &NetError::from(io::Error::new(
1614 io::ErrorKind::ConnectionRefused,
1615 "connection refused",
1616 )),
1617 );
1618
1619 let mock_provider = MockProvider::default();
1620 assert!(
1621 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1622 .await
1623 .is_ok()
1624 );
1625
1626 let recorded_calls = mock_provider.new_connection_calls();
1627 assert_eq!(recorded_calls.len(), 1);
1629 let (ip, protocol) = &recorded_calls[0];
1630 assert_eq!(*ip, ns_ip);
1631 assert_eq!(protocol.to_protocol(), Protocol::Udp);
1632 }
1633
1634 #[tokio::test]
1635 async fn test_opportunistic_probe_stale_failure() {
1636 subscribe();
1637
1638 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1639 let mut cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1640 .with_probe_budget(10);
1641 let opp_enc_config = OpportunisticEncryptionConfig {
1642 damping_period: Duration::from_secs(5),
1643 ..OpportunisticEncryptionConfig::default()
1644 };
1645 cx.opportunistic_encryption = OpportunisticEncryption::Enabled {
1646 config: opp_enc_config.clone(),
1647 };
1648
1649 {
1651 let mut state = cx.transport_state().await;
1652 let old_failure_time =
1653 SystemTime::now() - opp_enc_config.damping_period - Duration::from_secs(1);
1654 state.set_failure_time(ns_ip, Protocol::Tls, old_failure_time);
1655 }
1656
1657 let mock_provider = MockProvider::default();
1658 assert!(
1659 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1660 .await
1661 .is_ok()
1662 );
1663
1664 let recorded_calls = mock_provider.new_connection_calls();
1665 assert_eq!(recorded_calls.len(), 2);
1667 let protocols = recorded_calls
1668 .iter()
1669 .map(|(_, protocol)| protocol.to_protocol())
1670 .collect::<Vec<_>>();
1671 assert!(protocols.contains(&Protocol::Udp));
1672 assert!(protocols.contains(&Protocol::Tls));
1673 }
1674
1675 #[tokio::test]
1676 async fn test_opportunistic_probe_skip_no_budget() {
1677 subscribe();
1678
1679 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1680 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1681 .with_opportunistic_encryption();
1682 let mock_provider = MockProvider::default();
1683 assert!(
1685 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1686 .await
1687 .is_ok()
1688 );
1689
1690 let recorded_calls = mock_provider.new_connection_calls();
1691 assert_eq!(recorded_calls.len(), 1);
1693 let (ip, protocol) = &recorded_calls[0];
1694 assert_eq!(*ip, ns_ip);
1695 assert_eq!(protocol.to_protocol(), Protocol::Udp);
1696 }
1697
1698 fn mock_connection(protocol: Protocol) -> ConnectionState<MockProvider> {
1699 ConnectionState::new(MockClientHandle, protocol)
1700 }
1701
1702 #[cfg(feature = "metrics")]
1703 #[test]
1704 fn test_opportunistic_probe_metrics_success() {
1705 subscribe();
1706 let recorder = DebuggingRecorder::new();
1707 let snapshotter = recorder.snapshotter();
1708 let initial_budget = 10;
1709
1710 with_local_recorder(&recorder, || {
1711 let runtime = tokio::runtime::Builder::new_current_thread()
1712 .enable_all()
1713 .build()
1714 .unwrap();
1715
1716 runtime.block_on(async {
1717 assert!(
1718 test_connected_mut_client(
1719 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1720 Arc::new(
1721 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1722 .with_opportunistic_encryption()
1723 .with_probe_budget(initial_budget),
1724 ),
1725 &MockProvider::default(),
1726 )
1727 .await
1728 .is_ok()
1729 );
1730 });
1731 });
1732
1733 #[allow(clippy::mutable_key_type)]
1734 let map = snapshotter.snapshot().into_hashmap();
1735
1736 let protocol = vec![Label::new("protocol", "tls")];
1738 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1739 assert_histogram_sample_count_eq(
1741 &map,
1742 PROBE_DURATION_SECONDS,
1743 protocol.clone(),
1744 1,
1745 Unit::Seconds,
1746 );
1747
1748 assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol.clone(), 1);
1750
1751 assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol, 0);
1753
1754 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1756 }
1757
1758 #[cfg(feature = "metrics")]
1759 #[test]
1760 fn test_opportunistic_probe_metrics_budget_exhausted() {
1761 subscribe();
1762 let recorder = DebuggingRecorder::new();
1763 let snapshotter = recorder.snapshotter();
1764
1765 with_local_recorder(&recorder, || {
1766 let runtime = tokio::runtime::Builder::new_current_thread()
1767 .enable_all()
1768 .build()
1769 .unwrap();
1770
1771 runtime.block_on(async {
1772 assert!(
1773 test_connected_mut_client(
1774 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1775 Arc::new(
1776 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1777 .with_opportunistic_encryption(),
1778 ),
1779 &MockProvider::default(),
1780 )
1781 .await
1782 .is_ok()
1783 );
1784 });
1785 });
1786
1787 #[allow(clippy::mutable_key_type)]
1788 let map = snapshotter.snapshot().into_hashmap();
1789
1790 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], 0);
1792
1793 let protocol = vec![Label::new("protocol", "tls")];
1795 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 0);
1796 assert_histogram_sample_count_eq(&map, PROBE_DURATION_SECONDS, protocol, 0, Unit::Seconds);
1798 }
1799
1800 #[cfg(feature = "metrics")]
1801 #[test]
1802 fn test_opportunistic_probe_metrics_connection_error() {
1803 subscribe();
1804 let recorder = DebuggingRecorder::new();
1805 let snapshotter = recorder.snapshotter();
1806 let initial_budget = 10;
1807
1808 with_local_recorder(&recorder, || {
1809 let runtime = tokio::runtime::Builder::new_current_thread()
1810 .enable_all()
1811 .build()
1812 .unwrap();
1813
1814 runtime.block_on(async {
1815 let _ = test_connected_mut_client(
1816 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1817 Arc::new(
1818 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1819 .with_opportunistic_encryption()
1820 .with_probe_budget(initial_budget),
1821 ),
1822 &MockProvider {
1824 new_connection_error: Some(NetError::from(io::Error::new(
1825 io::ErrorKind::ConnectionRefused,
1826 "connection refused",
1827 ))),
1828 ..MockProvider::default()
1829 },
1830 )
1831 .await;
1832 });
1833 });
1834
1835 #[allow(clippy::mutable_key_type)]
1836 let map = snapshotter.snapshot().into_hashmap();
1837
1838 let protocol = vec![Label::new("protocol", "tls")];
1840 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1841 assert_histogram_sample_count_eq(
1843 &map,
1844 PROBE_DURATION_SECONDS,
1845 protocol.clone(),
1846 1,
1847 Unit::Seconds,
1848 );
1849
1850 assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 1);
1852
1853 assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1856
1857 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1859 }
1860
1861 #[cfg(feature = "metrics")]
1862 #[test]
1863 fn test_opportunistic_probe_metrics_connection_timeout_error() {
1864 subscribe();
1865 let recorder = DebuggingRecorder::new();
1866 let snapshotter = recorder.snapshotter();
1867 let initial_budget = 10;
1868
1869 with_local_recorder(&recorder, || {
1870 let runtime = tokio::runtime::Builder::new_current_thread()
1871 .enable_all()
1872 .build()
1873 .unwrap();
1874
1875 runtime.block_on(async {
1876 let _ = test_connected_mut_client(
1877 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1878 Arc::new(
1879 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1880 .with_opportunistic_encryption()
1881 .with_probe_budget(initial_budget),
1882 ),
1883 &MockProvider {
1885 new_connection_error: Some(NetError::Timeout),
1886 ..MockProvider::default()
1887 },
1888 )
1889 .await;
1890 });
1891 });
1892
1893 #[allow(clippy::mutable_key_type)]
1894 let map = snapshotter.snapshot().into_hashmap();
1895
1896 let protocol = vec![Label::new("protocol", "tls")];
1898 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1899 assert_histogram_sample_count_eq(
1901 &map,
1902 PROBE_DURATION_SECONDS,
1903 protocol.clone(),
1904 1,
1905 Unit::Seconds,
1906 );
1907
1908 assert_counter_eq(&map, PROBE_TIMEOUTS_TOTAL, protocol.clone(), 1);
1910
1911 assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 0);
1913
1914 assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1917
1918 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1920 }
1921
1922 async fn test_connected_mut_client(
1928 ns_ip: IpAddr,
1929 cx: Arc<PoolContext>,
1930 provider: &MockProvider,
1931 ) -> Result<(), NetError> {
1932 let name_server = NameServer::new(
1933 [].into_iter(),
1934 NameServerConfig::opportunistic_encryption(ns_ip),
1935 &ResolverOpts::default(),
1936 provider.clone(),
1937 );
1938
1939 name_server
1940 .connected_mut_client(ConnectionPolicy::default(), &cx)
1941 .await
1942 .map(|_| ())
1943 }
1944}
1945
1946#[cfg(all(test, feature = "metrics"))]
1947mod resolver_metrics_tests {
1948 use std::net::{IpAddr, Ipv4Addr};
1949
1950 use metrics::{Label, with_local_recorder};
1951 use metrics_util::debugging::DebuggingRecorder;
1952 use mock_provider::MockProvider;
1953 use test_support::assert_counter_eq;
1954 use test_support::subscribe;
1955
1956 use super::*;
1957 use crate::connection_provider::TlsConfig;
1958 use crate::metrics::OUTGOING_QUERIES_TOTAL;
1959
1960 #[test]
1961 fn test_outgoing_query_protocol_metrics_udp() {
1962 subscribe();
1963 let recorder = DebuggingRecorder::new();
1964 let snapshotter = recorder.snapshotter();
1965
1966 with_local_recorder(&recorder, || {
1967 let runtime = tokio::runtime::Builder::new_current_thread()
1968 .enable_all()
1969 .build()
1970 .unwrap();
1971
1972 runtime.block_on(async {
1973 let options = ResolverOpts::default();
1974 let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
1975 let name_server = Arc::new(NameServer::new(
1976 [],
1977 config,
1978 &options,
1979 MockProvider::default(),
1980 ));
1981
1982 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
1983 let name = Name::parse("www.example.com.", None).unwrap();
1984 let _ = name_server
1985 .send(
1986 DnsRequest::from_query(
1987 Query::query(name.clone(), RecordType::A),
1988 DnsRequestOptions::default(),
1989 ),
1990 ConnectionPolicy::default(),
1991 &cx,
1992 )
1993 .await;
1994 });
1995 });
1996
1997 #[allow(clippy::mutable_key_type)]
1998 let map = snapshotter.snapshot().into_hashmap();
1999
2000 let protocol = vec![Label::new("protocol", "udp")];
2002 assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2003 }
2004
2005 #[test]
2006 fn test_outgoing_query_protocol_metrics_tcp() {
2007 subscribe();
2008 let recorder = DebuggingRecorder::new();
2009 let snapshotter = recorder.snapshotter();
2010
2011 with_local_recorder(&recorder, || {
2012 let runtime = tokio::runtime::Builder::new_current_thread()
2013 .enable_all()
2014 .build()
2015 .unwrap();
2016
2017 runtime.block_on(async {
2018 let options = ResolverOpts::default();
2019 let config = NameServerConfig::tcp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
2020 let name_server = Arc::new(NameServer::new(
2021 [],
2022 config,
2023 &options,
2024 MockProvider::default(),
2025 ));
2026
2027 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2028 let name = Name::parse("www.example.com.", None).unwrap();
2029 let _ = name_server
2030 .send(
2031 DnsRequest::from_query(
2032 Query::query(name.clone(), RecordType::A),
2033 DnsRequestOptions::default(),
2034 ),
2035 ConnectionPolicy::default(),
2036 &cx,
2037 )
2038 .await;
2039 });
2040 });
2041
2042 #[allow(clippy::mutable_key_type)]
2043 let map = snapshotter.snapshot().into_hashmap();
2044
2045 let protocol = vec![Label::new("protocol", "tcp")];
2047 assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2048 }
2049
2050 #[cfg(feature = "__tls")]
2051 #[test]
2052 fn test_outgoing_query_protocol_metrics_tls() {
2053 subscribe();
2054 let recorder = DebuggingRecorder::new();
2055 let snapshotter = recorder.snapshotter();
2056
2057 with_local_recorder(&recorder, || {
2058 let runtime = tokio::runtime::Builder::new_current_thread()
2059 .enable_all()
2060 .build()
2061 .unwrap();
2062
2063 runtime.block_on(async {
2064 let options = ResolverOpts::default();
2065 let config = NameServerConfig::tls(
2066 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
2067 "dns.google".into(),
2068 );
2069 let name_server = Arc::new(NameServer::new(
2070 [],
2071 config,
2072 &options,
2073 MockProvider::default(),
2074 ));
2075
2076 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2077 let name = Name::parse("www.example.com.", None).unwrap();
2078 let _ = name_server
2079 .send(
2080 DnsRequest::from_query(
2081 Query::query(name.clone(), RecordType::A),
2082 DnsRequestOptions::default(),
2083 ),
2084 ConnectionPolicy::default(),
2085 &cx,
2086 )
2087 .await;
2088 });
2089 });
2090
2091 #[allow(clippy::mutable_key_type)]
2092 let map = snapshotter.snapshot().into_hashmap();
2093
2094 let protocol = vec![Label::new("protocol", "tls")];
2096 assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2097 }
2098}
2099
2100#[cfg(all(test, any(feature = "metrics", feature = "__tls")))]
2101mod mock_provider {
2102 use std::future::Future;
2103 use std::io;
2104 use std::pin::Pin;
2105 use std::task::{Context, Poll};
2106
2107 use futures_util::stream::once;
2108 use futures_util::{Stream, future};
2109 use tokio::net::UdpSocket;
2110
2111 use super::*;
2112 use crate::config::ProtocolConfig;
2113 use crate::net::runtime::TokioTime;
2114 use crate::net::runtime::iocompat::AsyncIoTokioAsStd;
2115 use crate::proto::op::Message;
2116
2117 #[derive(Clone)]
2124 pub(super) struct MockProvider {
2125 pub(super) runtime: MockSyncRuntimeProvider,
2126 pub(super) new_connection_calls: Arc<SyncMutex<Vec<(IpAddr, ProtocolConfig)>>>,
2127 pub(super) new_connection_error: Option<NetError>,
2128 }
2129
2130 impl MockProvider {
2131 pub(super) fn new_connection_calls(&self) -> Vec<(IpAddr, ProtocolConfig)> {
2132 self.new_connection_calls.lock().clone()
2133 }
2134 }
2135
2136 impl ConnectionProvider for MockProvider {
2137 type Conn = MockClientHandle;
2138 type FutureConn = Pin<Box<dyn Send + Future<Output = Result<Self::Conn, NetError>>>>;
2139 type RuntimeProvider = MockSyncRuntimeProvider;
2140
2141 fn new_connection(
2142 &self,
2143 ip: IpAddr,
2144 config: &ConnectionConfig,
2145 _cx: &PoolContext,
2146 ) -> Result<Self::FutureConn, NetError> {
2147 self.new_connection_calls
2148 .lock()
2149 .push((ip, config.protocol.clone()));
2150
2151 Ok(Box::pin(future::ready(match &self.new_connection_error {
2152 Some(err) => Err(err.clone()),
2153 None => Ok(MockClientHandle),
2154 })))
2155 }
2156
2157 fn runtime_provider(&self) -> &Self::RuntimeProvider {
2158 &self.runtime
2159 }
2160 }
2161
2162 impl Default for MockProvider {
2163 fn default() -> Self {
2164 Self {
2165 runtime: MockSyncRuntimeProvider,
2166 new_connection_calls: Arc::new(SyncMutex::new(Vec::new())),
2167 new_connection_error: None,
2168 }
2169 }
2170 }
2171
2172 #[derive(Clone, Default)]
2177 pub(super) struct MockClientHandle;
2178
2179 impl DnsHandle for MockClientHandle {
2180 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
2181 type Runtime = MockSyncRuntimeProvider;
2182
2183 fn send(&self, request: DnsRequest) -> Self::Response {
2184 let mut response = Message::response(request.id, request.op_code);
2185 response.metadata.response_code = ResponseCode::NoError;
2186 response.add_queries(request.queries.clone());
2187 Box::pin(once(future::ready(Ok(
2188 DnsResponse::from_message(response).unwrap()
2189 ))))
2190 }
2191 }
2192
2193 #[derive(Clone)]
2197 pub(super) struct MockSyncRuntimeProvider;
2198
2199 impl RuntimeProvider for MockSyncRuntimeProvider {
2200 type Handle = MockSyncHandle;
2201 type Timer = TokioTime;
2202 type Udp = UdpSocket;
2203 type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
2204
2205 fn create_handle(&self) -> Self::Handle {
2206 MockSyncHandle
2207 }
2208
2209 #[allow(clippy::unimplemented)]
2210 fn connect_tcp(
2211 &self,
2212 _server_addr: std::net::SocketAddr,
2213 _bind_addr: Option<std::net::SocketAddr>,
2214 _timeout: Option<Duration>,
2215 ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
2216 unimplemented!();
2217 }
2218
2219 #[allow(clippy::unimplemented)]
2220 fn bind_udp(
2221 &self,
2222 _local_addr: std::net::SocketAddr,
2223 _server_addr: std::net::SocketAddr,
2224 ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
2225 unimplemented!();
2226 }
2227 }
2228
2229 #[derive(Clone)]
2234 pub(super) struct MockSyncHandle;
2235
2236 impl Spawn for MockSyncHandle {
2237 fn spawn_bg(&mut self, future: impl Future<Output = ()> + Send + 'static) {
2238 let waker = futures_util::task::noop_waker();
2241 let mut context = Context::from_waker(&waker);
2242 let mut future = Box::pin(future);
2243
2244 loop {
2245 match future.as_mut().poll(&mut context) {
2246 Poll::Ready(_) => break,
2247 Poll::Pending => continue,
2248 }
2249 }
2250 }
2251 }
2252}