1#[cfg(not(test))]
9use std::time::{Duration, Instant};
10use std::{
11 cmp,
12 fmt::Debug,
13 marker::PhantomData,
14 net::IpAddr,
15 sync::{
16 Arc,
17 atomic::{AtomicU8, AtomicU32, Ordering},
18 },
19};
20
21use futures_util::lock::Mutex as AsyncMutex;
22use parking_lot::Mutex as SyncMutex;
23#[cfg(test)]
24use tokio::time::{Duration, Instant};
25use tracing::{debug, error, warn};
26
27#[cfg(feature = "metrics")]
28use crate::metrics::ResolverMetrics;
29#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
30use crate::metrics::opportunistic_encryption::ProbeMetrics;
31use crate::{
32 config::{
33 ConnectionConfig, NameServerConfig, OpportunisticEncryption, ResolverOpts,
34 ServerOrderingStrategy,
35 },
36 connection_provider::ConnectionProvider,
37 name_server_pool::{NameServerTransportState, PoolContext},
38 net::{
39 DnsError, NetError, NoRecords,
40 runtime::{RuntimeProvider, Spawn},
41 xfer::{DnsHandle, FirstAnswer, Protocol},
42 },
43 proto::{
44 op::{DnsRequest, DnsRequestOptions, DnsResponse, Query, ResponseCode},
45 rr::{Name, RecordType},
46 },
47};
48
49pub struct NameServer<P: ConnectionProvider> {
54 config: NameServerConfig,
55 connections: AsyncMutex<Vec<ConnectionState<P>>>,
56 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
58 opportunistic_probe_metrics: ProbeMetrics,
59 #[cfg(feature = "metrics")]
61 resolver_metrics: ResolverMetrics,
62 server_srtt: DecayingSrtt,
63 connection_provider: P,
64}
65
66impl<P: ConnectionProvider> NameServer<P> {
67 pub fn new(
71 connections: impl IntoIterator<Item = (Protocol, P::Conn)>,
72 config: NameServerConfig,
73 options: &ResolverOpts,
74 connection_provider: P,
75 ) -> Self {
76 let mut connections = connections
77 .into_iter()
78 .map(|(protocol, handle)| ConnectionState::new(handle, protocol))
79 .collect::<Vec<_>>();
80
81 if options.server_ordering_strategy != ServerOrderingStrategy::UserProvidedOrder {
84 connections.sort_by_key(|ns| ns.protocol != Protocol::Udp);
85 }
86
87 Self {
88 config,
89 connections: AsyncMutex::new(connections),
90 server_srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
91 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
92 opportunistic_probe_metrics: ProbeMetrics::default(),
93 #[cfg(feature = "metrics")]
94 resolver_metrics: ResolverMetrics::default(),
95 connection_provider,
96 }
97 }
98
99 pub(crate) async fn send(
101 self: Arc<Self>,
102 request: DnsRequest,
103 policy: ConnectionPolicy,
104 cx: &Arc<PoolContext>,
105 ) -> Result<DnsResponse, NetError> {
106 let (handle, meta, protocol) = self.connected_mut_client(policy, cx).await?;
107 #[cfg(feature = "metrics")]
108 self.resolver_metrics.increment_outgoing_query(&protocol);
109 let now = Instant::now();
110 let response = handle.send(request).first_answer().await;
111 let rtt = now.elapsed();
112
113 match response {
114 Ok(response) => {
115 meta.set_status(Status::Established);
116 let result = DnsError::from_response(response);
117 let error = match result {
118 Ok(response) => {
119 meta.srtt.record(rtt);
120 self.server_srtt.record(rtt);
121 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
122 cx.transport_state()
123 .await
124 .response_received(self.config.ip, protocol);
125 }
126 return Ok(response);
127 }
128 Err(error) => error,
129 };
130
131 let update = match error {
132 DnsError::NoRecordsFound(NoRecords {
133 response_code: ResponseCode::ServFail,
134 ..
135 }) => Some(true),
136 DnsError::NoRecordsFound(NoRecords { .. }) => Some(false),
137 _ => None,
138 };
139
140 match update {
141 Some(true) => {
142 meta.srtt.record(rtt);
143 self.server_srtt.record(rtt);
144 }
145 Some(false) => {
146 meta.srtt.record_failure();
148 self.server_srtt.record_failure();
149 }
150 None => {}
151 }
152
153 let err = NetError::from(error);
154 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
155 cx.transport_state()
156 .await
157 .error_received(self.config.ip, protocol, &err)
158 }
159 Err(err)
160 }
161 Err(error) => {
162 debug!(config = ?self.config, %error, "failed to connect to name server");
163
164 meta.set_status(Status::Failed);
166
167 match &error {
171 NetError::Busy | NetError::Io(_) | NetError::Timeout => {
172 meta.srtt.record_failure();
173 self.server_srtt.record_failure();
174 }
175 #[cfg(feature = "__quic")]
176 NetError::QuinnConfigError(_)
177 | NetError::QuinnConnect(_)
178 | NetError::QuinnConnection(_)
179 | NetError::QuinnTlsConfigError(_) => {
180 meta.srtt.record_failure();
181 self.server_srtt.record_failure();
182 }
183 #[cfg(feature = "__tls")]
184 NetError::RustlsError(_) => {
185 meta.srtt.record_failure();
186 self.server_srtt.record_failure();
187 }
188 _ => {}
189 }
190
191 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
192 cx.transport_state()
193 .await
194 .error_received(self.config.ip, protocol, &error);
195 }
196
197 Err(error)
199 }
200 }
201 }
202
203 async fn connected_mut_client(
207 &self,
208 policy: ConnectionPolicy,
209 cx: &Arc<PoolContext>,
210 ) -> Result<(P::Conn, Arc<ConnectionMeta>, Protocol), NetError> {
211 {
213 let mut connections = self.connections.lock().await;
214 connections
215 .retain(|conn| matches!(conn.meta.status(), Status::Init | Status::Established));
216 if let Some(conn) = policy.select_connection(
217 self.config.ip,
218 &*cx.transport_state().await,
219 &cx.opportunistic_encryption,
220 &connections,
221 ) {
222 return Ok((conn.handle.clone(), conn.meta.clone(), conn.protocol));
223 }
224 }
225
226 debug!(config = ?self.config, "connecting");
228 let config = policy
229 .select_connection_config(
230 self.config.ip,
231 &*cx.transport_state().await,
232 &cx.opportunistic_encryption,
233 &self.config.connections,
234 )
235 .ok_or(NetError::NoConnections)?;
236
237 let protocol = config.protocol.to_protocol();
238 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
239 cx.transport_state()
240 .await
241 .initiate_connection(self.config.ip, protocol);
242 } else if cx.opportunistic_encryption.is_enabled() && !protocol.is_encrypted() {
243 self.consider_probe_encrypted_transport(&policy, cx).await;
244 }
245
246 let handle = Box::pin(self.connection_provider.new_connection(
248 self.config.ip,
249 config,
250 cx,
251 )?)
252 .await?;
253
254 if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
255 cx.transport_state()
256 .await
257 .complete_connection(self.config.ip, protocol);
258 }
259
260 let state = ConnectionState::new(handle.clone(), protocol);
262 let meta = state.meta.clone();
263 self.connections.lock().await.push(state);
264 Ok((handle, meta, protocol))
265 }
266
267 pub(super) fn protocols(&self) -> impl Iterator<Item = Protocol> + '_ {
268 self.config
269 .connections
270 .iter()
271 .map(|conn| conn.protocol.to_protocol())
272 }
273
274 pub(super) fn ip(&self) -> IpAddr {
275 self.config.ip
276 }
277
278 pub(crate) fn decayed_srtt(&self) -> f64 {
279 self.server_srtt.current()
280 }
281
282 pub(super) fn record_cancelled(&self, winner_rtt: Duration) {
296 const CANCEL_PENALTY: Duration = Duration::from_millis(5);
297 self.server_srtt.record(winner_rtt + CANCEL_PENALTY);
298 }
299
300 #[cfg(test)]
301 pub(crate) fn test_record_failure(&self) {
302 self.server_srtt.record_failure();
303 }
304
305 #[cfg(test)]
306 #[allow(dead_code)]
307 pub(crate) fn is_connected(&self) -> bool {
308 let Some(connections) = self.connections.try_lock() else {
309 return true;
311 };
312
313 connections.iter().any(|conn| match conn.meta.status() {
314 Status::Established | Status::Init => true,
315 Status::Failed => false,
316 })
317 }
318
319 pub(crate) fn trust_negative_responses(&self) -> bool {
320 self.config.trust_negative_responses
321 }
322
323 async fn consider_probe_encrypted_transport(
324 &self,
325 policy: &ConnectionPolicy,
326 cx: &Arc<PoolContext>,
327 ) {
328 let Some(probe_config) =
329 policy.select_encrypted_connection_config(&self.config.connections)
330 else {
331 warn!("no encrypted connection configs available for probing");
332 return;
333 };
334
335 let probe_protocol = probe_config.protocol.to_protocol();
336 let should_probe = {
337 let state = cx.transport_state().await;
338 state.should_probe_encrypted(
339 self.config.ip,
340 probe_protocol,
341 &cx.opportunistic_encryption,
342 )
343 };
344
345 if !should_probe {
346 return;
347 }
348
349 if let Err(err) = self.probe_encrypted_transport(cx, probe_config) {
350 error!(%err, "opportunistic encrypted probe attempt failed");
351 }
352 }
353
354 fn probe_encrypted_transport(
355 &self,
356 cx: &Arc<PoolContext>,
357 probe_config: &ConnectionConfig,
358 ) -> Result<(), NetError> {
359 let mut budget = cx.opportunistic_probe_budget.load(Ordering::Relaxed);
360 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
361 self.opportunistic_probe_metrics.probe_budget.set(budget);
362 loop {
363 if budget == 0 {
364 debug!("no remaining budget for opportunistic probing");
365 return Ok(());
366 }
367 match cx.opportunistic_probe_budget.compare_exchange_weak(
368 budget,
369 budget - 1,
370 Ordering::AcqRel,
371 Ordering::Relaxed,
372 ) {
373 Ok(_) => break,
374 Err(current) => budget = current,
375 }
376 }
377
378 let connect = ProbeRequest::new(
379 probe_config,
380 self,
381 cx,
382 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
383 self.opportunistic_probe_metrics.clone(),
384 )?;
385 self.connection_provider
386 .runtime_provider()
387 .create_handle()
388 .spawn_bg(connect.run());
389
390 Ok(())
391 }
392}
393
394struct ProbeRequest<P: ConnectionProvider> {
395 ip: IpAddr,
396 proto: Protocol,
397 connecting: P::FutureConn,
398 context: Arc<PoolContext>,
399 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
400 metrics: ProbeMetrics,
401 provider: PhantomData<P>,
402}
403
404impl<P: ConnectionProvider> ProbeRequest<P> {
405 fn new(
406 config: &ConnectionConfig,
407 ns: &NameServer<P>,
408 cx: &Arc<PoolContext>,
409 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
410 metrics: ProbeMetrics,
411 ) -> Result<Self, NetError> {
412 Ok(Self {
413 ip: ns.config.ip,
414 proto: config.protocol.to_protocol(),
415 connecting: ns
416 .connection_provider
417 .new_connection(ns.config.ip, config, cx)?,
418 context: cx.clone(),
419 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
420 metrics,
421 provider: PhantomData,
422 })
423 }
424
425 async fn run(self) {
426 let Self {
427 ip,
428 proto,
429 connecting,
430 context,
431 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
432 metrics,
433 provider: _,
434 } = self;
435
436 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
437 let start = Instant::now();
438
439 context
440 .transport_state()
441 .await
442 .initiate_connection(ip, proto);
443 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
444 metrics.increment_attempts(proto);
445
446 let conn = match connecting.await {
447 Ok(conn) => conn,
448 Err(err) => {
449 debug!(?proto, "probe connection failed");
450 let _prev = context
451 .opportunistic_probe_budget
452 .fetch_add(1, Ordering::Relaxed);
453 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
454 {
455 metrics.increment_errors(proto, &err);
456 metrics.probe_budget.set(_prev + 1);
457 metrics.record_probe_duration(proto, start.elapsed());
458 }
459 context
460 .transport_state()
461 .await
462 .error_received(ip, proto, &err);
463 return;
464 }
465 };
466
467 debug!(?proto, "probe connection succeeded");
468 context
469 .transport_state()
470 .await
471 .complete_connection(ip, proto);
472
473 match conn
474 .send(DnsRequest::from_query(
475 Query::query(Name::root(), RecordType::NS),
476 DnsRequestOptions::default(),
477 ))
478 .first_answer()
479 .await
480 {
481 Ok(_) => {
482 debug!(?proto, "probe query succeeded");
483 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
484 metrics.increment_successes(proto);
485 context.transport_state().await.response_received(ip, proto);
486 }
487 Err(err) => {
488 debug!(?proto, ?err, "probe query failed");
489 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
490 metrics.increment_errors(proto, &err);
491 context
492 .transport_state()
493 .await
494 .error_received(ip, proto, &err);
495 }
496 }
497
498 let _prev = context
499 .opportunistic_probe_budget
500 .fetch_add(1, Ordering::Relaxed);
501 #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
502 {
503 metrics.probe_budget.set(_prev + 1);
504 metrics.record_probe_duration(proto, start.elapsed());
505 }
506 }
507}
508
509struct ConnectionState<P: ConnectionProvider> {
510 protocol: Protocol,
511 handle: P::Conn,
512 meta: Arc<ConnectionMeta>,
513}
514
515impl<P: ConnectionProvider> ConnectionState<P> {
516 fn new(handle: P::Conn, protocol: Protocol) -> Self {
517 Self {
518 protocol,
519 handle,
520 meta: Arc::new(ConnectionMeta::default()),
521 }
522 }
523}
524
525struct ConnectionMeta {
526 status: AtomicU8,
527 srtt: DecayingSrtt,
528}
529
530impl ConnectionMeta {
531 fn set_status(&self, status: Status) {
532 self.status.store(status.into(), Ordering::Release);
533 }
534
535 fn status(&self) -> Status {
536 Status::from(self.status.load(Ordering::Acquire))
537 }
538}
539
540impl Default for ConnectionMeta {
541 fn default() -> Self {
542 Self {
546 status: AtomicU8::new(Status::Init.into()),
547 srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
548 }
549 }
550}
551
552struct DecayingSrtt {
553 srtt_microseconds: AtomicU32,
581
582 last_update: SyncMutex<Option<Instant>>,
584}
585
586impl DecayingSrtt {
587 fn new(initial_srtt: Duration) -> Self {
588 Self {
589 srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
590 last_update: SyncMutex::new(None),
591 }
592 }
593
594 fn record(&self, rtt: Duration) {
595 self.update(
600 rtt.as_micros() as u32,
601 |cur_srtt_microseconds, last_update| {
602 let factor = compute_srtt_factor(last_update, 3);
606 let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
607 + factor * f64::from(cur_srtt_microseconds);
608 new_srtt.round() as u32
609 },
610 );
611 }
612
613 fn record_failure(&self) {
615 self.update(
616 Self::FAILURE_PENALTY,
617 |cur_srtt_microseconds, _last_update| {
618 cur_srtt_microseconds.saturating_add(Self::FAILURE_PENALTY)
619 },
620 );
621 }
622
623 fn current(&self) -> f64 {
632 let srtt = f64::from(self.srtt_microseconds.load(Ordering::Acquire));
633 self.last_update.lock().map_or(srtt, |last_update| {
634 srtt * compute_srtt_factor(last_update, 180)
645 })
646 }
647
648 fn update(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
654 let last_update = self.last_update.lock().replace(Instant::now());
655 let _ = self.srtt_microseconds.fetch_update(
656 Ordering::SeqCst,
657 Ordering::SeqCst,
658 move |cur_srtt_microseconds| {
659 Some(
660 last_update
661 .map_or(default, |last_update| {
662 update_fn(cur_srtt_microseconds, last_update)
663 })
664 .min(Self::MAX_SRTT_MICROS),
665 )
666 },
667 );
668 }
669
670 #[cfg(all(test, feature = "tokio"))]
674 fn as_duration(&self) -> Duration {
675 Duration::from_micros(u64::from(self.srtt_microseconds.load(Ordering::Acquire)))
676 }
677
678 const FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
679 const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
680}
681
682fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
691 let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
692 exponent.exp()
693}
694
695#[derive(Debug, Eq, PartialEq, Copy, Clone)]
697#[repr(u8)]
698enum Status {
699 Failed = 0,
705 Init = 1,
707 Established = 2,
710}
711
712impl From<Status> for u8 {
713 fn from(val: Status) -> Self {
715 val as Self
716 }
717}
718
719impl From<u8> for Status {
720 fn from(val: u8) -> Self {
721 match val {
722 2 => Self::Established,
723 1 => Self::Init,
724 _ => Self::Failed,
725 }
726 }
727}
728
729#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
730pub(crate) struct ConnectionPolicy {
731 pub(crate) disable_udp: bool,
732}
733
734impl ConnectionPolicy {
735 pub(crate) fn allows_server<P: ConnectionProvider>(&self, server: &NameServer<P>) -> bool {
737 server.protocols().any(|p| self.allows_protocol(p))
738 }
739
740 fn select_connection<'a, P: ConnectionProvider>(
745 &self,
746 ip: IpAddr,
747 encrypted_transport_state: &NameServerTransportState,
748 opportunistic_encryption: &OpportunisticEncryption,
749 connections: &'a [ConnectionState<P>],
750 ) -> Option<&'a ConnectionState<P>> {
751 let selected = connections
752 .iter()
753 .filter(|conn| self.allows_protocol(conn.protocol))
754 .min_by(|a, b| self.compare_connections(opportunistic_encryption.is_enabled(), a, b));
755
756 let selected = selected?;
757
758 match opportunistic_encryption.is_enabled()
764 && !selected.protocol.is_encrypted()
765 && encrypted_transport_state.any_recent_success(ip, opportunistic_encryption)
766 {
767 true => None,
768 false => Some(selected),
769 }
770 }
771
772 fn select_connection_config<'a>(
777 &self,
778 ip: IpAddr,
779 encrypted_transport_state: &NameServerTransportState,
780 opportunistic_encryption: &OpportunisticEncryption,
781 connection_configs: &'a [ConnectionConfig],
782 ) -> Option<&'a ConnectionConfig> {
783 connection_configs
784 .iter()
785 .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
786 .min_by(|a, b| {
787 self.compare_connection_configs(
788 ip,
789 encrypted_transport_state,
790 opportunistic_encryption,
791 a,
792 b,
793 )
794 })
795 }
796
797 fn select_encrypted_connection_config<'a>(
799 &self,
800 connection_config: &'a [ConnectionConfig],
801 ) -> Option<&'a ConnectionConfig> {
802 connection_config
803 .iter()
804 .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
805 .find(|c| c.protocol.to_protocol().is_encrypted())
806 }
807
808 fn allows_protocol(&self, protocol: Protocol) -> bool {
810 !(self.disable_udp && protocol == Protocol::Udp)
811 }
812
813 fn compare_connections<P: ConnectionProvider>(
816 &self,
817 opportunistic_encryption: bool,
818 a: &ConnectionState<P>,
819 b: &ConnectionState<P>,
820 ) -> cmp::Ordering {
821 if opportunistic_encryption {
824 match (a.protocol.is_encrypted(), b.protocol.is_encrypted()) {
825 (true, false) => return cmp::Ordering::Less,
826 (false, true) => return cmp::Ordering::Greater,
827 _ => {}
829 }
830 }
831
832 match (a.protocol, b.protocol) {
833 (ap, bp) if ap == bp => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
834 (Protocol::Udp, _) => cmp::Ordering::Less,
835 (_, Protocol::Udp) => cmp::Ordering::Greater,
836 _ => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
837 }
838 }
839
840 fn compare_connection_configs(
841 &self,
842 ip: IpAddr,
843 encrypted_transport_state: &NameServerTransportState,
844 opportunistic_encryption: &OpportunisticEncryption,
845 a: &ConnectionConfig,
846 b: &ConnectionConfig,
847 ) -> cmp::Ordering {
848 let a_protocol = a.protocol.to_protocol();
849 let b_protocol = b.protocol.to_protocol();
850
851 if opportunistic_encryption.is_enabled() {
854 let a_recent_enc_success = a_protocol.is_encrypted()
855 && encrypted_transport_state.recent_success(
856 ip,
857 a_protocol,
858 opportunistic_encryption,
859 );
860 let b_recent_enc_success = b_protocol.is_encrypted()
861 && encrypted_transport_state.recent_success(
862 ip,
863 b_protocol,
864 opportunistic_encryption,
865 );
866
867 match (a_recent_enc_success, b_recent_enc_success) {
868 (true, false) => return cmp::Ordering::Less,
869 (false, true) => return cmp::Ordering::Greater,
870 _ => {}
872 }
873 }
874
875 match (a_protocol, b_protocol) {
877 (ap, bp) if ap == bp => cmp::Ordering::Equal,
878 (Protocol::Udp, _) => cmp::Ordering::Less,
879 (_, Protocol::Udp) => cmp::Ordering::Greater,
880 _ => cmp::Ordering::Equal,
881 }
882 }
883}
884
885#[cfg(all(test, feature = "tokio"))]
886mod tests {
887 use std::cmp;
888 use std::net::{IpAddr, Ipv4Addr};
889 use std::str::FromStr;
890 use std::time::Duration;
891
892 use test_support::subscribe;
893 use tokio::net::UdpSocket;
894 use tokio::spawn;
895
896 use super::*;
897 use crate::config::{ConnectionConfig, ProtocolConfig};
898 use crate::connection_provider::TlsConfig;
899 use crate::net::runtime::TokioRuntimeProvider;
900 use crate::proto::op::{DnsRequest, DnsRequestOptions, Message, Query, ResponseCode};
901 use crate::proto::rr::rdata::NULL;
902 use crate::proto::rr::{Name, RData, Record, RecordType};
903
904 #[tokio::test]
905 async fn test_name_server() {
906 subscribe();
907
908 let options = ResolverOpts::default();
909 let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
910 let name_server = Arc::new(NameServer::new(
911 [].into_iter(),
912 config,
913 &options,
914 TokioRuntimeProvider::default(),
915 ));
916
917 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
918 let name = Name::parse("www.example.com.", None).unwrap();
919 let response = name_server
920 .send(
921 DnsRequest::from_query(
922 Query::query(name.clone(), RecordType::A),
923 DnsRequestOptions::default(),
924 ),
925 ConnectionPolicy::default(),
926 &cx,
927 )
928 .await
929 .expect("query failed");
930 assert_eq!(response.response_code, ResponseCode::NoError);
931 }
932
933 #[tokio::test]
934 async fn test_failed_name_server() {
935 subscribe();
936
937 let options = ResolverOpts {
938 timeout: Duration::from_millis(1), ..ResolverOpts::default()
940 };
941
942 let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)));
943 let name_server = Arc::new(NameServer::new(
944 [],
945 config,
946 &options,
947 TokioRuntimeProvider::default(),
948 ));
949
950 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
951 let name = Name::parse("www.example.com.", None).unwrap();
952 assert!(
953 name_server
954 .send(
955 DnsRequest::from_query(
956 Query::query(name.clone(), RecordType::A),
957 DnsRequestOptions::default(),
958 ),
959 ConnectionPolicy::default(),
960 &cx
961 )
962 .await
963 .is_err()
964 );
965 }
966
967 #[tokio::test]
968 async fn case_randomization_query_preserved() {
969 subscribe();
970
971 let provider = TokioRuntimeProvider::default();
972 let server = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
973 let server_addr = server.local_addr().unwrap();
974 let name = Name::from_str("dead.beef.").unwrap();
975 let data = b"DEADBEEF";
976
977 spawn({
978 let name = name.clone();
979 async move {
980 let mut buffer = [0_u8; 512];
981 let (len, addr) = server.recv_from(&mut buffer).await.unwrap();
982 let request = Message::from_vec(&buffer[0..len]).unwrap();
983 let mut response = Message::response(request.id, request.op_code);
984 response.add_queries(request.queries.to_vec());
985 response.add_answer(Record::from_rdata(
986 name,
987 0,
988 RData::NULL(NULL::with(data.to_vec())),
989 ));
990 let response_buffer = response.to_vec().unwrap();
991 server.send_to(&response_buffer, addr).await.unwrap();
992 }
993 });
994
995 let config = NameServerConfig {
996 ip: server_addr.ip(),
997 trust_negative_responses: true,
998 connections: vec![ConnectionConfig {
999 port: server_addr.port(),
1000 protocol: ProtocolConfig::Udp,
1001 bind_addr: None,
1002 }],
1003 };
1004
1005 let resolver_opts = ResolverOpts {
1006 case_randomization: true,
1007 ..Default::default()
1008 };
1009
1010 let cx = Arc::new(PoolContext::new(resolver_opts, TlsConfig::new().unwrap()));
1011 let mut request_options = DnsRequestOptions::default();
1012 request_options.case_randomization = true;
1013 let ns = Arc::new(NameServer::new([], config, &cx.options, provider));
1014 let response = ns
1015 .send(
1016 DnsRequest::from_query(
1017 Query::query(name.clone(), RecordType::NULL),
1018 request_options,
1019 ),
1020 ConnectionPolicy::default(),
1021 &cx,
1022 )
1023 .await
1024 .unwrap();
1025
1026 let response_query_name = response.queries.first().unwrap().name();
1027 assert!(response_query_name.eq_case(&name));
1028 }
1029
1030 #[allow(clippy::extra_unused_type_parameters)]
1031 fn is_send_sync<S: Sync + Send>() -> bool {
1032 true
1033 }
1034
1035 #[test]
1036 fn stats_are_sync() {
1037 assert!(is_send_sync::<ConnectionMeta>());
1038 }
1039
1040 #[tokio::test(start_paused = true)]
1041 async fn test_stats_cmp() {
1042 use std::cmp::Ordering;
1043 let srtt_a = DecayingSrtt::new(Duration::from_micros(10));
1044 let srtt_b = DecayingSrtt::new(Duration::from_micros(20));
1045
1046 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1049
1050 srtt_a.record(Duration::from_millis(30));
1052 tokio::time::advance(Duration::from_secs(5)).await;
1053 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1054
1055 srtt_b.record(Duration::from_millis(50));
1058 tokio::time::advance(Duration::from_secs(5)).await;
1059 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1060
1061 srtt_a.record_failure();
1064 tokio::time::advance(Duration::from_secs(5)).await;
1065 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1066
1067 while cmp(&srtt_a, &srtt_b) != Ordering::Less {
1069 srtt_b.record(Duration::from_millis(50));
1070 tokio::time::advance(Duration::from_secs(5)).await;
1071 }
1072
1073 srtt_a.record(Duration::from_millis(30));
1074 tokio::time::advance(Duration::from_secs(3)).await;
1075 assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1076 }
1077
1078 fn cmp(a: &DecayingSrtt, b: &DecayingSrtt) -> cmp::Ordering {
1079 a.current().total_cmp(&b.current())
1080 }
1081
1082 #[tokio::test(start_paused = true)]
1083 async fn test_record_rtt() {
1084 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1085
1086 let first_rtt = Duration::from_millis(50);
1087 srtt.record(first_rtt);
1088
1089 assert_eq!(srtt.as_duration(), first_rtt);
1091
1092 tokio::time::advance(Duration::from_secs(3)).await;
1093
1094 srtt.record(Duration::from_millis(100));
1096 assert_eq!(srtt.as_duration(), Duration::from_micros(81606));
1097 }
1098
1099 #[test]
1100 fn test_record_rtt_maximum_value() {
1101 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1102
1103 srtt.record(Duration::MAX);
1104 assert_eq!(
1106 srtt.as_duration(),
1107 Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1108 );
1109 }
1110
1111 #[tokio::test(start_paused = true)]
1112 async fn test_record_connection_failure() {
1113 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1114
1115 for failure_count in 1..4 {
1118 srtt.record_failure();
1119 assert_eq!(
1120 srtt.as_duration(),
1121 Duration::from_micros(
1122 DecayingSrtt::FAILURE_PENALTY
1123 .checked_mul(failure_count)
1124 .expect("checked_mul overflow")
1125 .into()
1126 )
1127 );
1128 tokio::time::advance(Duration::from_secs(3)).await;
1129 }
1130
1131 srtt.record(Duration::from_millis(50));
1134 assert_eq!(srtt.as_duration(), Duration::from_micros(197152));
1135 }
1136
1137 #[test]
1138 fn test_record_connection_failure_maximum_value() {
1139 let srtt = DecayingSrtt::new(Duration::from_micros(10));
1140
1141 let num_failures = (DecayingSrtt::MAX_SRTT_MICROS / DecayingSrtt::FAILURE_PENALTY) + 1;
1142 for _ in 0..num_failures {
1143 srtt.record_failure();
1144 }
1145
1146 assert_eq!(
1148 srtt.as_duration(),
1149 Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1150 );
1151 }
1152
1153 #[tokio::test(start_paused = true)]
1154 async fn test_decayed_srtt() {
1155 let initial_srtt = 10;
1156 let srtt = DecayingSrtt::new(Duration::from_micros(initial_srtt));
1157
1158 assert_eq!(srtt.current() as u32, initial_srtt as u32);
1160
1161 tokio::time::advance(Duration::from_secs(5)).await;
1162 srtt.record(Duration::from_millis(100));
1163
1164 tokio::time::advance(Duration::from_millis(500)).await;
1167 assert_eq!(srtt.current() as u32, 99445);
1168
1169 tokio::time::advance(Duration::from_secs(5)).await;
1170 assert_eq!(srtt.current() as u32, 96990);
1171 }
1172}
1173
1174#[cfg(all(test, feature = "__tls"))]
1175mod opportunistic_enc_tests {
1176 use std::io;
1177 use std::net::{IpAddr, Ipv4Addr};
1178 use std::sync::Arc;
1179 use std::time::{Duration, SystemTime};
1180
1181 #[cfg(feature = "metrics")]
1182 use metrics::{Label, Unit, with_local_recorder};
1183 #[cfg(feature = "metrics")]
1184 use metrics_util::debugging::DebuggingRecorder;
1185 use mock_provider::{MockClientHandle, MockProvider};
1186 use test_support::subscribe;
1187 #[cfg(feature = "metrics")]
1188 use test_support::{assert_counter_eq, assert_gauge_eq, assert_histogram_sample_count_eq};
1189
1190 use crate::config::{
1191 NameServerConfig, OpportunisticEncryption, OpportunisticEncryptionConfig, ProtocolConfig,
1192 ResolverOpts,
1193 };
1194 use crate::connection_provider::TlsConfig;
1195 #[cfg(feature = "metrics")]
1196 use crate::metrics::opportunistic_encryption::{
1197 PROBE_ATTEMPTS_TOTAL, PROBE_BUDGET_TOTAL, PROBE_DURATION_SECONDS, PROBE_ERRORS_TOTAL,
1198 PROBE_SUCCESSES_TOTAL, PROBE_TIMEOUTS_TOTAL,
1199 };
1200 use crate::name_server::{ConnectionPolicy, ConnectionState, NameServer, mock_provider};
1201 use crate::name_server_pool::{NameServerTransportState, PoolContext};
1202 use crate::net::NetError;
1203 use crate::net::xfer::Protocol;
1204
1205 #[tokio::test]
1206 async fn test_select_connection_opportunistic_enc_disabled() {
1207 let mut policy = ConnectionPolicy::default();
1208 let connections = vec![
1209 mock_connection(Protocol::Udp),
1210 mock_connection(Protocol::Tcp),
1211 ];
1212
1213 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1214 let state = NameServerTransportState::default();
1215 let opp_enc = OpportunisticEncryption::Disabled;
1216
1217 let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1220 assert!(selected.is_some());
1221 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1222
1223 policy.disable_udp = true;
1226 let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1227 assert!(selected.is_some());
1228 assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1229 }
1230
1231 #[tokio::test]
1232 async fn test_select_connection_opportunistic_enc_enabled() {
1233 let policy = ConnectionPolicy::default();
1234 let connections = [
1235 mock_connection(Protocol::Udp),
1236 mock_connection(Protocol::Tcp),
1237 mock_connection(Protocol::Tls),
1239 ];
1240
1241 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1242 let state = NameServerTransportState::default();
1243 let opp_enc = &OpportunisticEncryption::Enabled {
1244 config: OpportunisticEncryptionConfig::default(),
1245 };
1246
1247 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1250 assert!(selected.is_some());
1251 assert_eq!(selected.unwrap().protocol, Protocol::Tls);
1252 }
1253
1254 #[tokio::test]
1255 async fn test_select_connection_opportunistic_enc_enabled_no_state() {
1256 let mut policy = ConnectionPolicy::default();
1257 let connections = [
1258 mock_connection(Protocol::Udp),
1259 mock_connection(Protocol::Tcp),
1260 ];
1262
1263 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1264 let state = NameServerTransportState::default();
1265 let opp_enc = &OpportunisticEncryption::Enabled {
1266 config: OpportunisticEncryptionConfig::default(),
1267 };
1268
1269 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1272 assert!(selected.is_some());
1273 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1274
1275 policy.disable_udp = true;
1278 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1279 assert!(selected.is_some());
1280 assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1281 }
1282
1283 #[tokio::test]
1284 async fn test_select_connection_opportunistic_enc_enabled_failed_probe() {
1285 let policy = ConnectionPolicy::default();
1286 let connections = [
1287 mock_connection(Protocol::Udp),
1288 mock_connection(Protocol::Tcp),
1289 ];
1291
1292 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1293 let mut state = NameServerTransportState::default();
1294 let opp_enc = &OpportunisticEncryption::Enabled {
1295 config: OpportunisticEncryptionConfig::default(),
1296 };
1297
1298 state.error_received(
1300 ns_ip,
1301 Protocol::Tls,
1302 &NetError::from(io::Error::new(
1303 io::ErrorKind::ConnectionRefused,
1304 "nameserver refused TLS connection",
1305 )),
1306 );
1307
1308 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1311 assert!(selected.is_some());
1312 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1313 }
1314
1315 #[tokio::test]
1316 async fn test_select_connection_opportunistic_enc_enabled_in_progress_probe() {
1317 let policy = ConnectionPolicy::default();
1318 let connections = [
1319 mock_connection(Protocol::Udp),
1320 mock_connection(Protocol::Tcp),
1321 ];
1323
1324 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1325 let mut state = NameServerTransportState::default();
1326 let opp_enc = &OpportunisticEncryption::Enabled {
1327 config: OpportunisticEncryptionConfig::default(),
1328 };
1329
1330 state.initiate_connection(ns_ip, Protocol::Tls);
1332
1333 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1336 assert!(selected.is_some());
1337 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1338
1339 state.complete_connection(ns_ip, Protocol::Tls);
1342
1343 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1345 assert!(selected.is_some());
1346 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1347 }
1348
1349 #[tokio::test]
1350 async fn test_select_connection_opportunistic_enc_enabled_stale_probe() {
1351 let policy = ConnectionPolicy::default();
1352 let connections = [
1353 mock_connection(Protocol::Udp),
1354 mock_connection(Protocol::Tcp),
1355 ];
1357
1358 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1359 let mut state = NameServerTransportState::default();
1360 let opp_enc_config = OpportunisticEncryptionConfig {
1361 persistence_period: Duration::from_secs(10),
1362 ..OpportunisticEncryptionConfig::default()
1363 };
1364 let opp_enc = &OpportunisticEncryption::Enabled {
1365 config: opp_enc_config.clone(),
1366 };
1367
1368 state.complete_connection(ns_ip, Protocol::Tls);
1370 state.response_received(ns_ip, Protocol::Tls);
1371 let stale_time =
1373 SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1374 state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1375
1376 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1380 assert!(selected.is_some());
1381 assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1382 }
1383
1384 #[tokio::test]
1385 async fn test_select_connection_opportunistic_enc_enabled_good_probe() {
1386 let policy = ConnectionPolicy::default();
1387 let connections = [
1388 mock_connection(Protocol::Udp),
1389 mock_connection(Protocol::Tcp),
1390 ];
1392
1393 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1394 let mut state = NameServerTransportState::default();
1395 let opp_enc = &OpportunisticEncryption::Enabled {
1396 config: OpportunisticEncryptionConfig::default(),
1397 };
1398
1399 state.complete_connection(ns_ip, Protocol::Tls);
1402 state.response_received(ns_ip, Protocol::Tls);
1403
1404 let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1408 assert!(selected.is_none());
1409 }
1410
1411 #[tokio::test]
1412 async fn test_select_connection_config_opportunistic_enc_disabled() {
1413 let mut policy = ConnectionPolicy::default();
1414
1415 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1416 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1417
1418 let state = NameServerTransportState::default();
1419 let opp_enc = OpportunisticEncryption::Disabled;
1420
1421 let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1424 assert!(selected.is_some());
1425 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1426
1427 policy.disable_udp = true;
1430 let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1431 assert!(selected.is_some());
1432 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1433 }
1434
1435 #[tokio::test]
1436 async fn test_select_connection_config_opportunistic_enc_enabled_no_state() {
1437 let mut policy = ConnectionPolicy::default();
1438 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1439 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1440
1441 let state = NameServerTransportState::default();
1442 let opp_enc = &OpportunisticEncryption::Enabled {
1443 config: OpportunisticEncryptionConfig::default(),
1444 };
1445
1446 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1449 assert!(selected.is_some());
1450 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1451
1452 policy.disable_udp = true;
1455 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1456 assert!(selected.is_some());
1457 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1458 }
1459
1460 #[tokio::test]
1461 async fn test_select_connection_config_opportunistic_enc_enabled_failed_probe() {
1462 let policy = ConnectionPolicy::default();
1463 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1464 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1465
1466 let mut state = NameServerTransportState::default();
1467 let opp_enc = &OpportunisticEncryption::Enabled {
1468 config: OpportunisticEncryptionConfig::default(),
1469 };
1470
1471 state.error_received(
1473 ns_ip,
1474 Protocol::Tls,
1475 &NetError::from(io::Error::new(
1476 io::ErrorKind::ConnectionRefused,
1477 "nameserver refused TLS connection",
1478 )),
1479 );
1480
1481 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1484 assert!(selected.is_some());
1485 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1486 }
1487
1488 #[tokio::test]
1489 async fn test_select_connection_config_opportunistic_enc_enabled_stale_probe() {
1490 let policy = ConnectionPolicy::default();
1491 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1492 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1493
1494 let mut state = NameServerTransportState::default();
1495 let opp_enc_config = OpportunisticEncryptionConfig {
1496 persistence_period: Duration::from_secs(10),
1497 ..OpportunisticEncryptionConfig::default()
1498 };
1499 let opp_enc = &OpportunisticEncryption::Enabled {
1500 config: opp_enc_config.clone(),
1501 };
1502
1503 state.complete_connection(ns_ip, Protocol::Tls);
1505 state.response_received(ns_ip, Protocol::Tls);
1506 let stale_time =
1508 SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1509 state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1510
1511 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1514 assert!(selected.is_some());
1515 assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1516 }
1517
1518 #[tokio::test]
1519 async fn test_select_connection_config_opportunistic_enc_enabled_good_probe() {
1520 let policy = ConnectionPolicy::default();
1521 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1522 let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1523
1524 let mut state = NameServerTransportState::default();
1525 let opp_enc = &OpportunisticEncryption::Enabled {
1526 config: OpportunisticEncryptionConfig::default(),
1527 };
1528
1529 state.complete_connection(ns_ip, Protocol::Tls);
1532 state.response_received(ns_ip, Protocol::Tls);
1533
1534 let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1537 assert!(selected.is_some());
1538 assert!(matches!(
1539 selected.unwrap().protocol,
1540 ProtocolConfig::Tls { .. }
1541 ));
1542 }
1543
1544 #[tokio::test]
1545 async fn test_opportunistic_probe() {
1546 subscribe();
1547
1548 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1550 .with_opportunistic_encryption()
1551 .with_probe_budget(10);
1552
1553 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1554 let mock_provider = MockProvider::default();
1555 assert!(
1556 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1557 .await
1558 .is_ok()
1559 );
1560
1561 let recorded_calls = mock_provider.new_connection_calls();
1562 assert_eq!(recorded_calls.len(), 2);
1564 let (ips, protocols): (Vec<IpAddr>, Vec<ProtocolConfig>) =
1565 recorded_calls.into_iter().unzip();
1566 assert!(ips.iter().all(|ip| *ip == ns_ip));
1568 let protocols = protocols
1570 .iter()
1571 .map(ProtocolConfig::to_protocol)
1572 .collect::<Vec<_>>();
1573 assert!(protocols.contains(&Protocol::Udp));
1574 assert!(protocols.contains(&Protocol::Tls));
1575 }
1576
1577 #[tokio::test]
1578 async fn test_opportunistic_probe_skip_in_progress() {
1579 subscribe();
1580
1581 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1582 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1583 .with_opportunistic_encryption()
1584 .with_probe_budget(10);
1585
1586 cx.transport_state()
1588 .await
1589 .initiate_connection(ns_ip, Protocol::Tls);
1590
1591 let mock_provider = MockProvider::default();
1592 assert!(
1593 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1594 .await
1595 .is_ok()
1596 );
1597
1598 let recorded_calls = mock_provider.new_connection_calls();
1599 assert_eq!(recorded_calls.len(), 1);
1601 let (ip, protocol) = &recorded_calls[0];
1602 assert_eq!(*ip, ns_ip);
1603 assert_eq!(protocol.to_protocol(), Protocol::Udp);
1604 }
1605
1606 #[tokio::test]
1607 async fn test_opportunistic_probe_skip_recent_failure() {
1608 subscribe();
1609
1610 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1611 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1612 .with_opportunistic_encryption()
1613 .with_probe_budget(10);
1614
1615 cx.transport_state().await.error_received(
1617 ns_ip,
1618 Protocol::Tls,
1619 &NetError::from(io::Error::new(
1620 io::ErrorKind::ConnectionRefused,
1621 "connection refused",
1622 )),
1623 );
1624
1625 let mock_provider = MockProvider::default();
1626 assert!(
1627 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1628 .await
1629 .is_ok()
1630 );
1631
1632 let recorded_calls = mock_provider.new_connection_calls();
1633 assert_eq!(recorded_calls.len(), 1);
1635 let (ip, protocol) = &recorded_calls[0];
1636 assert_eq!(*ip, ns_ip);
1637 assert_eq!(protocol.to_protocol(), Protocol::Udp);
1638 }
1639
1640 #[tokio::test]
1641 async fn test_opportunistic_probe_stale_failure() {
1642 subscribe();
1643
1644 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1645 let mut cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1646 .with_probe_budget(10);
1647 let opp_enc_config = OpportunisticEncryptionConfig {
1648 damping_period: Duration::from_secs(5),
1649 ..OpportunisticEncryptionConfig::default()
1650 };
1651 cx.opportunistic_encryption = OpportunisticEncryption::Enabled {
1652 config: opp_enc_config.clone(),
1653 };
1654
1655 {
1657 let mut state = cx.transport_state().await;
1658 let old_failure_time =
1659 SystemTime::now() - opp_enc_config.damping_period - Duration::from_secs(1);
1660 state.set_failure_time(ns_ip, Protocol::Tls, old_failure_time);
1661 }
1662
1663 let mock_provider = MockProvider::default();
1664 assert!(
1665 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1666 .await
1667 .is_ok()
1668 );
1669
1670 let recorded_calls = mock_provider.new_connection_calls();
1671 assert_eq!(recorded_calls.len(), 2);
1673 let protocols = recorded_calls
1674 .iter()
1675 .map(|(_, protocol)| protocol.to_protocol())
1676 .collect::<Vec<_>>();
1677 assert!(protocols.contains(&Protocol::Udp));
1678 assert!(protocols.contains(&Protocol::Tls));
1679 }
1680
1681 #[tokio::test]
1682 async fn test_opportunistic_probe_skip_no_budget() {
1683 subscribe();
1684
1685 let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1686 let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1687 .with_opportunistic_encryption();
1688 let mock_provider = MockProvider::default();
1689 assert!(
1691 test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1692 .await
1693 .is_ok()
1694 );
1695
1696 let recorded_calls = mock_provider.new_connection_calls();
1697 assert_eq!(recorded_calls.len(), 1);
1699 let (ip, protocol) = &recorded_calls[0];
1700 assert_eq!(*ip, ns_ip);
1701 assert_eq!(protocol.to_protocol(), Protocol::Udp);
1702 }
1703
1704 fn mock_connection(protocol: Protocol) -> ConnectionState<MockProvider> {
1705 ConnectionState::new(MockClientHandle, protocol)
1706 }
1707
1708 #[cfg(feature = "metrics")]
1709 #[test]
1710 fn test_opportunistic_probe_metrics_success() {
1711 subscribe();
1712 let recorder = DebuggingRecorder::new();
1713 let snapshotter = recorder.snapshotter();
1714 let initial_budget = 10;
1715
1716 with_local_recorder(&recorder, || {
1717 let runtime = tokio::runtime::Builder::new_current_thread()
1718 .enable_all()
1719 .build()
1720 .unwrap();
1721
1722 runtime.block_on(async {
1723 assert!(
1724 test_connected_mut_client(
1725 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1726 Arc::new(
1727 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1728 .with_opportunistic_encryption()
1729 .with_probe_budget(initial_budget),
1730 ),
1731 &MockProvider::default(),
1732 )
1733 .await
1734 .is_ok()
1735 );
1736 });
1737 });
1738
1739 #[allow(clippy::mutable_key_type)]
1740 let map = snapshotter.snapshot().into_hashmap();
1741
1742 let protocol = vec![Label::new("protocol", "tls")];
1744 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1745 assert_histogram_sample_count_eq(
1747 &map,
1748 PROBE_DURATION_SECONDS,
1749 protocol.clone(),
1750 1,
1751 Unit::Seconds,
1752 );
1753
1754 assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol.clone(), 1);
1756
1757 assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol, 0);
1759
1760 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1762 }
1763
1764 #[cfg(feature = "metrics")]
1765 #[test]
1766 fn test_opportunistic_probe_metrics_budget_exhausted() {
1767 subscribe();
1768 let recorder = DebuggingRecorder::new();
1769 let snapshotter = recorder.snapshotter();
1770
1771 with_local_recorder(&recorder, || {
1772 let runtime = tokio::runtime::Builder::new_current_thread()
1773 .enable_all()
1774 .build()
1775 .unwrap();
1776
1777 runtime.block_on(async {
1778 assert!(
1779 test_connected_mut_client(
1780 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1781 Arc::new(
1782 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1783 .with_opportunistic_encryption(),
1784 ),
1785 &MockProvider::default(),
1786 )
1787 .await
1788 .is_ok()
1789 );
1790 });
1791 });
1792
1793 #[allow(clippy::mutable_key_type)]
1794 let map = snapshotter.snapshot().into_hashmap();
1795
1796 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], 0);
1798
1799 let protocol = vec![Label::new("protocol", "tls")];
1801 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 0);
1802 assert_histogram_sample_count_eq(&map, PROBE_DURATION_SECONDS, protocol, 0, Unit::Seconds);
1804 }
1805
1806 #[cfg(feature = "metrics")]
1807 #[test]
1808 fn test_opportunistic_probe_metrics_connection_error() {
1809 subscribe();
1810 let recorder = DebuggingRecorder::new();
1811 let snapshotter = recorder.snapshotter();
1812 let initial_budget = 10;
1813
1814 with_local_recorder(&recorder, || {
1815 let runtime = tokio::runtime::Builder::new_current_thread()
1816 .enable_all()
1817 .build()
1818 .unwrap();
1819
1820 runtime.block_on(async {
1821 let _ = test_connected_mut_client(
1822 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1823 Arc::new(
1824 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1825 .with_opportunistic_encryption()
1826 .with_probe_budget(initial_budget),
1827 ),
1828 &MockProvider {
1830 new_connection_error: Some(NetError::from(io::Error::new(
1831 io::ErrorKind::ConnectionRefused,
1832 "connection refused",
1833 ))),
1834 ..MockProvider::default()
1835 },
1836 )
1837 .await;
1838 });
1839 });
1840
1841 #[allow(clippy::mutable_key_type)]
1842 let map = snapshotter.snapshot().into_hashmap();
1843
1844 let protocol = vec![Label::new("protocol", "tls")];
1846 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1847 assert_histogram_sample_count_eq(
1849 &map,
1850 PROBE_DURATION_SECONDS,
1851 protocol.clone(),
1852 1,
1853 Unit::Seconds,
1854 );
1855
1856 assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 1);
1858
1859 assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1862
1863 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1865 }
1866
1867 #[cfg(feature = "metrics")]
1868 #[test]
1869 fn test_opportunistic_probe_metrics_connection_timeout_error() {
1870 subscribe();
1871 let recorder = DebuggingRecorder::new();
1872 let snapshotter = recorder.snapshotter();
1873 let initial_budget = 10;
1874
1875 with_local_recorder(&recorder, || {
1876 let runtime = tokio::runtime::Builder::new_current_thread()
1877 .enable_all()
1878 .build()
1879 .unwrap();
1880
1881 runtime.block_on(async {
1882 let _ = test_connected_mut_client(
1883 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1884 Arc::new(
1885 PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1886 .with_opportunistic_encryption()
1887 .with_probe_budget(initial_budget),
1888 ),
1889 &MockProvider {
1891 new_connection_error: Some(NetError::Timeout),
1892 ..MockProvider::default()
1893 },
1894 )
1895 .await;
1896 });
1897 });
1898
1899 #[allow(clippy::mutable_key_type)]
1900 let map = snapshotter.snapshot().into_hashmap();
1901
1902 let protocol = vec![Label::new("protocol", "tls")];
1904 assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1905 assert_histogram_sample_count_eq(
1907 &map,
1908 PROBE_DURATION_SECONDS,
1909 protocol.clone(),
1910 1,
1911 Unit::Seconds,
1912 );
1913
1914 assert_counter_eq(&map, PROBE_TIMEOUTS_TOTAL, protocol.clone(), 1);
1916
1917 assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 0);
1919
1920 assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1923
1924 assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1926 }
1927
1928 async fn test_connected_mut_client(
1934 ns_ip: IpAddr,
1935 cx: Arc<PoolContext>,
1936 provider: &MockProvider,
1937 ) -> Result<(), NetError> {
1938 let name_server = NameServer::new(
1939 [],
1940 NameServerConfig::opportunistic_encryption(ns_ip),
1941 &ResolverOpts::default(),
1942 provider.clone(),
1943 );
1944
1945 name_server
1946 .connected_mut_client(ConnectionPolicy::default(), &cx)
1947 .await
1948 .map(|_| ())
1949 }
1950}
1951
1952#[cfg(all(test, feature = "metrics"))]
1953mod resolver_metrics_tests {
1954 use std::net::{IpAddr, Ipv4Addr};
1955
1956 use metrics::{Label, with_local_recorder};
1957 use metrics_util::debugging::DebuggingRecorder;
1958 use mock_provider::MockProvider;
1959 use test_support::assert_counter_eq;
1960 use test_support::subscribe;
1961
1962 use super::*;
1963 use crate::connection_provider::TlsConfig;
1964 use crate::metrics::OUTGOING_QUERIES_TOTAL;
1965
1966 #[test]
1967 fn test_outgoing_query_protocol_metrics_udp() {
1968 subscribe();
1969 let recorder = DebuggingRecorder::new();
1970 let snapshotter = recorder.snapshotter();
1971
1972 with_local_recorder(&recorder, || {
1973 let runtime = tokio::runtime::Builder::new_current_thread()
1974 .enable_all()
1975 .build()
1976 .unwrap();
1977
1978 runtime.block_on(async {
1979 let options = ResolverOpts::default();
1980 let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
1981 let name_server = Arc::new(NameServer::new(
1982 [],
1983 config,
1984 &options,
1985 MockProvider::default(),
1986 ));
1987
1988 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
1989 let name = Name::parse("www.example.com.", None).unwrap();
1990 let _ = name_server
1991 .send(
1992 DnsRequest::from_query(
1993 Query::query(name.clone(), RecordType::A),
1994 DnsRequestOptions::default(),
1995 ),
1996 ConnectionPolicy::default(),
1997 &cx,
1998 )
1999 .await;
2000 });
2001 });
2002
2003 #[allow(clippy::mutable_key_type)]
2004 let map = snapshotter.snapshot().into_hashmap();
2005
2006 let protocol = vec![Label::new("protocol", "udp")];
2008 assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2009 }
2010
2011 #[test]
2012 fn test_outgoing_query_protocol_metrics_tcp() {
2013 subscribe();
2014 let recorder = DebuggingRecorder::new();
2015 let snapshotter = recorder.snapshotter();
2016
2017 with_local_recorder(&recorder, || {
2018 let runtime = tokio::runtime::Builder::new_current_thread()
2019 .enable_all()
2020 .build()
2021 .unwrap();
2022
2023 runtime.block_on(async {
2024 let options = ResolverOpts::default();
2025 let config = NameServerConfig::tcp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
2026 let name_server = Arc::new(NameServer::new(
2027 [],
2028 config,
2029 &options,
2030 MockProvider::default(),
2031 ));
2032
2033 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2034 let name = Name::parse("www.example.com.", None).unwrap();
2035 let _ = name_server
2036 .send(
2037 DnsRequest::from_query(
2038 Query::query(name.clone(), RecordType::A),
2039 DnsRequestOptions::default(),
2040 ),
2041 ConnectionPolicy::default(),
2042 &cx,
2043 )
2044 .await;
2045 });
2046 });
2047
2048 #[allow(clippy::mutable_key_type)]
2049 let map = snapshotter.snapshot().into_hashmap();
2050
2051 let protocol = vec![Label::new("protocol", "tcp")];
2053 assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2054 }
2055
2056 #[cfg(feature = "__tls")]
2057 #[test]
2058 fn test_outgoing_query_protocol_metrics_tls() {
2059 subscribe();
2060 let recorder = DebuggingRecorder::new();
2061 let snapshotter = recorder.snapshotter();
2062
2063 with_local_recorder(&recorder, || {
2064 let runtime = tokio::runtime::Builder::new_current_thread()
2065 .enable_all()
2066 .build()
2067 .unwrap();
2068
2069 runtime.block_on(async {
2070 let options = ResolverOpts::default();
2071 let config = NameServerConfig::tls(
2072 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
2073 "dns.google".into(),
2074 );
2075 let name_server = Arc::new(NameServer::new(
2076 [],
2077 config,
2078 &options,
2079 MockProvider::default(),
2080 ));
2081
2082 let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2083 let name = Name::parse("www.example.com.", None).unwrap();
2084 let _ = name_server
2085 .send(
2086 DnsRequest::from_query(
2087 Query::query(name.clone(), RecordType::A),
2088 DnsRequestOptions::default(),
2089 ),
2090 ConnectionPolicy::default(),
2091 &cx,
2092 )
2093 .await;
2094 });
2095 });
2096
2097 #[allow(clippy::mutable_key_type)]
2098 let map = snapshotter.snapshot().into_hashmap();
2099
2100 let protocol = vec![Label::new("protocol", "tls")];
2102 assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2103 }
2104}
2105
2106#[cfg(all(test, any(feature = "metrics", feature = "__tls")))]
2107mod mock_provider {
2108 use std::future::Future;
2109 use std::io;
2110 use std::pin::Pin;
2111 use std::task::{Context, Poll};
2112
2113 use futures_util::stream::once;
2114 use futures_util::{Stream, future};
2115 use tokio::net::UdpSocket;
2116
2117 use super::*;
2118 use crate::config::ProtocolConfig;
2119 use crate::net::runtime::TokioTime;
2120 use crate::net::runtime::iocompat::AsyncIoTokioAsStd;
2121 use crate::proto::op::Message;
2122
2123 #[derive(Clone)]
2130 pub(super) struct MockProvider {
2131 pub(super) runtime: MockSyncRuntimeProvider,
2132 pub(super) new_connection_calls: Arc<SyncMutex<Vec<(IpAddr, ProtocolConfig)>>>,
2133 pub(super) new_connection_error: Option<NetError>,
2134 }
2135
2136 impl MockProvider {
2137 pub(super) fn new_connection_calls(&self) -> Vec<(IpAddr, ProtocolConfig)> {
2138 self.new_connection_calls.lock().clone()
2139 }
2140 }
2141
2142 impl ConnectionProvider for MockProvider {
2143 type Conn = MockClientHandle;
2144 type FutureConn = Pin<Box<dyn Send + Future<Output = Result<Self::Conn, NetError>>>>;
2145 type RuntimeProvider = MockSyncRuntimeProvider;
2146
2147 fn new_connection(
2148 &self,
2149 ip: IpAddr,
2150 config: &ConnectionConfig,
2151 _cx: &PoolContext,
2152 ) -> Result<Self::FutureConn, NetError> {
2153 self.new_connection_calls
2154 .lock()
2155 .push((ip, config.protocol.clone()));
2156
2157 Ok(Box::pin(future::ready(match &self.new_connection_error {
2158 Some(err) => Err(err.clone()),
2159 None => Ok(MockClientHandle),
2160 })))
2161 }
2162
2163 fn runtime_provider(&self) -> &Self::RuntimeProvider {
2164 &self.runtime
2165 }
2166 }
2167
2168 impl Default for MockProvider {
2169 fn default() -> Self {
2170 Self {
2171 runtime: MockSyncRuntimeProvider,
2172 new_connection_calls: Arc::new(SyncMutex::new(Vec::new())),
2173 new_connection_error: None,
2174 }
2175 }
2176 }
2177
2178 #[derive(Clone, Default)]
2183 pub(super) struct MockClientHandle;
2184
2185 impl DnsHandle for MockClientHandle {
2186 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
2187 type Runtime = MockSyncRuntimeProvider;
2188
2189 fn send(&self, request: DnsRequest) -> Self::Response {
2190 let mut response = Message::response(request.id, request.op_code);
2191 response.metadata.response_code = ResponseCode::NoError;
2192 response.add_queries(request.queries.clone());
2193 Box::pin(once(future::ready(Ok(
2194 DnsResponse::from_message(response).unwrap()
2195 ))))
2196 }
2197 }
2198
2199 #[derive(Clone)]
2203 pub(super) struct MockSyncRuntimeProvider;
2204
2205 impl RuntimeProvider for MockSyncRuntimeProvider {
2206 type Handle = MockSyncHandle;
2207 type Timer = TokioTime;
2208 type Udp = UdpSocket;
2209 type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
2210
2211 fn create_handle(&self) -> Self::Handle {
2212 MockSyncHandle
2213 }
2214
2215 #[allow(clippy::unimplemented)]
2216 fn connect_tcp(
2217 &self,
2218 _server_addr: std::net::SocketAddr,
2219 _bind_addr: Option<std::net::SocketAddr>,
2220 _timeout: Option<Duration>,
2221 ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
2222 unimplemented!();
2223 }
2224
2225 #[allow(clippy::unimplemented)]
2226 fn bind_udp(
2227 &self,
2228 _local_addr: std::net::SocketAddr,
2229 _server_addr: std::net::SocketAddr,
2230 ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
2231 unimplemented!();
2232 }
2233 }
2234
2235 #[derive(Clone)]
2240 pub(super) struct MockSyncHandle;
2241
2242 impl Spawn for MockSyncHandle {
2243 fn spawn_bg(&mut self, future: impl Future<Output = ()> + Send + 'static) {
2244 let waker = futures_util::task::noop_waker();
2247 let mut context = Context::from_waker(&waker);
2248 let mut future = Box::pin(future);
2249
2250 loop {
2251 match future.as_mut().poll(&mut context) {
2252 Poll::Ready(_) => break,
2253 Poll::Pending => continue,
2254 }
2255 }
2256 }
2257 }
2258}