Skip to main content

hickory_resolver/
name_server.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8#[cfg(not(test))]
9use std::time::{Duration, Instant};
10use std::{
11    cmp,
12    fmt::Debug,
13    marker::PhantomData,
14    net::IpAddr,
15    sync::{
16        Arc,
17        atomic::{AtomicU8, AtomicU32, Ordering},
18    },
19};
20
21use futures_util::lock::Mutex as AsyncMutex;
22use parking_lot::Mutex as SyncMutex;
23#[cfg(test)]
24use tokio::time::{Duration, Instant};
25use tracing::{debug, error, warn};
26
27#[cfg(feature = "metrics")]
28use crate::metrics::ResolverMetrics;
29#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
30use crate::metrics::opportunistic_encryption::ProbeMetrics;
31use crate::{
32    config::{
33        ConnectionConfig, NameServerConfig, OpportunisticEncryption, ResolverOpts,
34        ServerOrderingStrategy,
35    },
36    connection_provider::ConnectionProvider,
37    name_server_pool::{NameServerTransportState, PoolContext},
38    net::{
39        DnsError, NetError, NoRecords,
40        runtime::{RuntimeProvider, Spawn},
41        xfer::{DnsHandle, FirstAnswer, Protocol},
42    },
43    proto::{
44        op::{DnsRequest, DnsRequestOptions, DnsResponse, Query, ResponseCode},
45        rr::{Name, RecordType},
46    },
47};
48
49/// A remote DNS server, identified by its IP address.
50///
51/// This potentially holds multiple open connections to the server, according to the
52/// configured protocols, and will make new connections as needed.
53pub struct NameServer<P: ConnectionProvider> {
54    config: NameServerConfig,
55    connections: AsyncMutex<Vec<ConnectionState<P>>>,
56    /// Metrics related to opportunistic encryption probes.
57    #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
58    opportunistic_probe_metrics: ProbeMetrics,
59    /// Metrics related to outgoing queries.
60    #[cfg(feature = "metrics")]
61    resolver_metrics: ResolverMetrics,
62    server_srtt: DecayingSrtt,
63    connection_provider: P,
64}
65
66impl<P: ConnectionProvider> NameServer<P> {
67    /// Create a new [`NameServer`] with the given connections and configuration.
68    ///
69    /// The `connections` will usually be empty.
70    pub fn new(
71        connections: impl IntoIterator<Item = (Protocol, P::Conn)>,
72        config: NameServerConfig,
73        options: &ResolverOpts,
74        connection_provider: P,
75    ) -> Self {
76        let mut connections = connections
77            .into_iter()
78            .map(|(protocol, handle)| ConnectionState::new(handle, protocol))
79            .collect::<Vec<_>>();
80
81        // Unless the user specified that we should follow the configured order,
82        // re-order the connections to prioritize UDP.
83        if options.server_ordering_strategy != ServerOrderingStrategy::UserProvidedOrder {
84            connections.sort_by_key(|ns| ns.protocol != Protocol::Udp);
85        }
86
87        Self {
88            config,
89            connections: AsyncMutex::new(connections),
90            server_srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
91            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
92            opportunistic_probe_metrics: ProbeMetrics::default(),
93            #[cfg(feature = "metrics")]
94            resolver_metrics: ResolverMetrics::default(),
95            connection_provider,
96        }
97    }
98
99    // TODO: there needs to be some way of customizing the connection based on EDNS options from the server side...
100    pub(crate) async fn send(
101        self: Arc<Self>,
102        request: DnsRequest,
103        policy: ConnectionPolicy,
104        cx: &Arc<PoolContext>,
105    ) -> Result<DnsResponse, NetError> {
106        let (handle, meta, protocol) = self.connected_mut_client(policy, cx).await?;
107        #[cfg(feature = "metrics")]
108        self.resolver_metrics.increment_outgoing_query(&protocol);
109        let now = Instant::now();
110        let response = handle.send(request).first_answer().await;
111        let rtt = now.elapsed();
112
113        match response {
114            Ok(response) => {
115                meta.set_status(Status::Established);
116                let result = DnsError::from_response(response);
117                let error = match result {
118                    Ok(response) => {
119                        meta.srtt.record(rtt);
120                        self.server_srtt.record(rtt);
121                        if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
122                            cx.transport_state()
123                                .await
124                                .response_received(self.config.ip, protocol);
125                        }
126                        return Ok(response);
127                    }
128                    Err(error) => error,
129                };
130
131                let update = match error {
132                    DnsError::NoRecordsFound(NoRecords {
133                        response_code: ResponseCode::ServFail,
134                        ..
135                    }) => Some(true),
136                    DnsError::NoRecordsFound(NoRecords { .. }) => Some(false),
137                    _ => None,
138                };
139
140                match update {
141                    Some(true) => {
142                        meta.srtt.record(rtt);
143                        self.server_srtt.record(rtt);
144                    }
145                    Some(false) => {
146                        // record the failure
147                        meta.srtt.record_failure();
148                        self.server_srtt.record_failure();
149                    }
150                    None => {}
151                }
152
153                let err = NetError::from(error);
154                if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
155                    cx.transport_state()
156                        .await
157                        .error_received(self.config.ip, protocol, &err)
158                }
159                Err(err)
160            }
161            Err(error) => {
162                debug!(config = ?self.config, %error, "failed to connect to name server");
163
164                // this transitions the state to failure
165                meta.set_status(Status::Failed);
166
167                // record the failure on both the per-connection and server-level SRTTs.
168                // updating server_srtt ensures the server is deprioritized in pool
169                // ordering (decayed_srtt) so other servers get a chance to be tried.
170                match &error {
171                    NetError::Busy | NetError::Io(_) | NetError::Timeout => {
172                        meta.srtt.record_failure();
173                        self.server_srtt.record_failure();
174                    }
175                    #[cfg(feature = "__quic")]
176                    NetError::QuinnConfigError(_)
177                    | NetError::QuinnConnect(_)
178                    | NetError::QuinnConnection(_)
179                    | NetError::QuinnTlsConfigError(_) => {
180                        meta.srtt.record_failure();
181                        self.server_srtt.record_failure();
182                    }
183                    #[cfg(feature = "__tls")]
184                    NetError::RustlsError(_) => {
185                        meta.srtt.record_failure();
186                        self.server_srtt.record_failure();
187                    }
188                    _ => {}
189                }
190
191                if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
192                    cx.transport_state()
193                        .await
194                        .error_received(self.config.ip, protocol, &error);
195                }
196
197                // These are connection failures, not lookup failures, that is handled in the resolver layer
198                Err(error)
199            }
200        }
201    }
202
203    /// This will return a mutable client to allows for sending messages.
204    ///
205    /// If the connection is in a failed state, then this will establish a new connection
206    async fn connected_mut_client(
207        &self,
208        policy: ConnectionPolicy,
209        cx: &Arc<PoolContext>,
210    ) -> Result<(P::Conn, Arc<ConnectionMeta>, Protocol), NetError> {
211        // Check for an existing usable connection (short lock)
212        {
213            let mut connections = self.connections.lock().await;
214            connections
215                .retain(|conn| matches!(conn.meta.status(), Status::Init | Status::Established));
216            if let Some(conn) = policy.select_connection(
217                self.config.ip,
218                &*cx.transport_state().await,
219                &cx.opportunistic_encryption,
220                &connections,
221            ) {
222                return Ok((conn.handle.clone(), conn.meta.clone(), conn.protocol));
223            }
224        }
225
226        // Select connection config and update transport state (no lock)
227        debug!(config = ?self.config, "connecting");
228        let config = policy
229            .select_connection_config(
230                self.config.ip,
231                &*cx.transport_state().await,
232                &cx.opportunistic_encryption,
233                &self.config.connections,
234            )
235            .ok_or(NetError::NoConnections)?;
236
237        let protocol = config.protocol.to_protocol();
238        if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
239            cx.transport_state()
240                .await
241                .initiate_connection(self.config.ip, protocol);
242        } else if cx.opportunistic_encryption.is_enabled() && !protocol.is_encrypted() {
243            self.consider_probe_encrypted_transport(&policy, cx).await;
244        }
245
246        // Establish connection
247        let handle = Box::pin(self.connection_provider.new_connection(
248            self.config.ip,
249            config,
250            cx,
251        )?)
252        .await?;
253
254        if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
255            cx.transport_state()
256                .await
257                .complete_connection(self.config.ip, protocol);
258        }
259
260        // Store the new connection (with lock)
261        let state = ConnectionState::new(handle.clone(), protocol);
262        let meta = state.meta.clone();
263        self.connections.lock().await.push(state);
264        Ok((handle, meta, protocol))
265    }
266
267    pub(super) fn protocols(&self) -> impl Iterator<Item = Protocol> + '_ {
268        self.config
269            .connections
270            .iter()
271            .map(|conn| conn.protocol.to_protocol())
272    }
273
274    pub(super) fn ip(&self) -> IpAddr {
275        self.config.ip
276    }
277
278    pub(crate) fn decayed_srtt(&self) -> f64 {
279        self.server_srtt.current()
280    }
281
282    /// Records an SRTT observation for a server whose in-flight request was
283    /// cancelled because a parallel request to another server succeeded first.
284    ///
285    /// Records the winner's RTT plus a small penalty (`CANCEL_PENALTY`) as the
286    /// observation: the cancelled server was *at least* that slow (it hadn't
287    /// responded yet), and the penalty ensures the winner retains a sorting
288    /// advantage in the next round. This avoids the full `FAILURE_PENALTY`
289    /// which would be too harsh for a server that's merely slightly slower.
290    ///
291    /// A truly unreachable server will be cancelled on every query and its SRTT
292    /// will ratchet up as the EWMA repeatedly incorporates the winner's RTT
293    /// without ever recording a real (successful) measurement to bring it back
294    /// down.
295    pub(super) fn record_cancelled(&self, winner_rtt: Duration) {
296        const CANCEL_PENALTY: Duration = Duration::from_millis(5);
297        self.server_srtt.record(winner_rtt + CANCEL_PENALTY);
298    }
299
300    #[cfg(test)]
301    pub(crate) fn test_record_failure(&self) {
302        self.server_srtt.record_failure();
303    }
304
305    #[cfg(test)]
306    #[allow(dead_code)]
307    pub(crate) fn is_connected(&self) -> bool {
308        let Some(connections) = self.connections.try_lock() else {
309            // assuming that if someone has it locked it will be or is connected
310            return true;
311        };
312
313        connections.iter().any(|conn| match conn.meta.status() {
314            Status::Established | Status::Init => true,
315            Status::Failed => false,
316        })
317    }
318
319    pub(crate) fn trust_negative_responses(&self) -> bool {
320        self.config.trust_negative_responses
321    }
322
323    async fn consider_probe_encrypted_transport(
324        &self,
325        policy: &ConnectionPolicy,
326        cx: &Arc<PoolContext>,
327    ) {
328        let Some(probe_config) =
329            policy.select_encrypted_connection_config(&self.config.connections)
330        else {
331            warn!("no encrypted connection configs available for probing");
332            return;
333        };
334
335        let probe_protocol = probe_config.protocol.to_protocol();
336        let should_probe = {
337            let state = cx.transport_state().await;
338            state.should_probe_encrypted(
339                self.config.ip,
340                probe_protocol,
341                &cx.opportunistic_encryption,
342            )
343        };
344
345        if !should_probe {
346            return;
347        }
348
349        if let Err(err) = self.probe_encrypted_transport(cx, probe_config) {
350            error!(%err, "opportunistic encrypted probe attempt failed");
351        }
352    }
353
354    fn probe_encrypted_transport(
355        &self,
356        cx: &Arc<PoolContext>,
357        probe_config: &ConnectionConfig,
358    ) -> Result<(), NetError> {
359        let mut budget = cx.opportunistic_probe_budget.load(Ordering::Relaxed);
360        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
361        self.opportunistic_probe_metrics.probe_budget.set(budget);
362        loop {
363            if budget == 0 {
364                debug!("no remaining budget for opportunistic probing");
365                return Ok(());
366            }
367            match cx.opportunistic_probe_budget.compare_exchange_weak(
368                budget,
369                budget - 1,
370                Ordering::AcqRel,
371                Ordering::Relaxed,
372            ) {
373                Ok(_) => break,
374                Err(current) => budget = current,
375            }
376        }
377
378        let connect = ProbeRequest::new(
379            probe_config,
380            self,
381            cx,
382            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
383            self.opportunistic_probe_metrics.clone(),
384        )?;
385        self.connection_provider
386            .runtime_provider()
387            .create_handle()
388            .spawn_bg(connect.run());
389
390        Ok(())
391    }
392}
393
394struct ProbeRequest<P: ConnectionProvider> {
395    ip: IpAddr,
396    proto: Protocol,
397    connecting: P::FutureConn,
398    context: Arc<PoolContext>,
399    #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
400    metrics: ProbeMetrics,
401    provider: PhantomData<P>,
402}
403
404impl<P: ConnectionProvider> ProbeRequest<P> {
405    fn new(
406        config: &ConnectionConfig,
407        ns: &NameServer<P>,
408        cx: &Arc<PoolContext>,
409        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
410        metrics: ProbeMetrics,
411    ) -> Result<Self, NetError> {
412        Ok(Self {
413            ip: ns.config.ip,
414            proto: config.protocol.to_protocol(),
415            connecting: ns
416                .connection_provider
417                .new_connection(ns.config.ip, config, cx)?,
418            context: cx.clone(),
419            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
420            metrics,
421            provider: PhantomData,
422        })
423    }
424
425    async fn run(self) {
426        let Self {
427            ip,
428            proto,
429            connecting,
430            context,
431            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
432            metrics,
433            provider: _,
434        } = self;
435
436        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
437        let start = Instant::now();
438
439        context
440            .transport_state()
441            .await
442            .initiate_connection(ip, proto);
443        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
444        metrics.increment_attempts(proto);
445
446        let conn = match connecting.await {
447            Ok(conn) => conn,
448            Err(err) => {
449                debug!(?proto, "probe connection failed");
450                let _prev = context
451                    .opportunistic_probe_budget
452                    .fetch_add(1, Ordering::Relaxed);
453                #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
454                {
455                    metrics.increment_errors(proto, &err);
456                    metrics.probe_budget.set(_prev + 1);
457                    metrics.record_probe_duration(proto, start.elapsed());
458                }
459                context
460                    .transport_state()
461                    .await
462                    .error_received(ip, proto, &err);
463                return;
464            }
465        };
466
467        debug!(?proto, "probe connection succeeded");
468        context
469            .transport_state()
470            .await
471            .complete_connection(ip, proto);
472
473        match conn
474            .send(DnsRequest::from_query(
475                Query::query(Name::root(), RecordType::NS),
476                DnsRequestOptions::default(),
477            ))
478            .first_answer()
479            .await
480        {
481            Ok(_) => {
482                debug!(?proto, "probe query succeeded");
483                #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
484                metrics.increment_successes(proto);
485                context.transport_state().await.response_received(ip, proto);
486            }
487            Err(err) => {
488                debug!(?proto, ?err, "probe query failed");
489                #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
490                metrics.increment_errors(proto, &err);
491                context
492                    .transport_state()
493                    .await
494                    .error_received(ip, proto, &err);
495            }
496        }
497
498        let _prev = context
499            .opportunistic_probe_budget
500            .fetch_add(1, Ordering::Relaxed);
501        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
502        {
503            metrics.probe_budget.set(_prev + 1);
504            metrics.record_probe_duration(proto, start.elapsed());
505        }
506    }
507}
508
509struct ConnectionState<P: ConnectionProvider> {
510    protocol: Protocol,
511    handle: P::Conn,
512    meta: Arc<ConnectionMeta>,
513}
514
515impl<P: ConnectionProvider> ConnectionState<P> {
516    fn new(handle: P::Conn, protocol: Protocol) -> Self {
517        Self {
518            protocol,
519            handle,
520            meta: Arc::new(ConnectionMeta::default()),
521        }
522    }
523}
524
525struct ConnectionMeta {
526    status: AtomicU8,
527    srtt: DecayingSrtt,
528}
529
530impl ConnectionMeta {
531    fn set_status(&self, status: Status) {
532        self.status.store(status.into(), Ordering::Release);
533    }
534
535    fn status(&self) -> Status {
536        Status::from(self.status.load(Ordering::Acquire))
537    }
538}
539
540impl Default for ConnectionMeta {
541    fn default() -> Self {
542        // Initialize the SRTT to a randomly generated value that represents a
543        // very low RTT. Such a value helps ensure that each server is attempted
544        // early.
545        Self {
546            status: AtomicU8::new(Status::Init.into()),
547            srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
548        }
549    }
550}
551
552struct DecayingSrtt {
553    /// The smoothed round-trip time (SRTT).
554    ///
555    /// This value represents an exponentially weighted moving average (EWMA) of
556    /// recorded latencies. The algorithm for computing this value is based on
557    /// the following:
558    ///
559    /// <https://en.wikipedia.org/wiki/Moving_average#Application_to_measuring_computer_performance>
560    ///
561    /// It is also partially inspired by the BIND and PowerDNS implementations:
562    ///
563    /// - <https://github.com/isc-projects/bind9/blob/7bf8a7ab1b280c1021bf1e762a239b07aac3c591/lib/dns/adb.c#L3487>
564    /// - <https://github.com/PowerDNS/pdns/blob/7c5f9ae6ae4fb17302d933eaeebc8d6f0249aab2/pdns/syncres.cc#L123>
565    ///
566    /// The algorithm for computing and using this value can be summarized as
567    /// follows:
568    ///
569    /// 1. The value is initialized to a random value that represents a very low
570    ///    latency.
571    /// 2. If the round-trip time (RTT) was successfully measured for a query,
572    ///    then it is incorporated into the EWMA using the formula linked above.
573    /// 3. If the RTT could not be measured (i.e. due to a connection failure),
574    ///    then a constant penalty factor is applied to the EWMA.
575    /// 4. When comparing EWMA values, a time-based decay is applied to each
576    ///    value. Note that this decay is only applied at read time.
577    ///
578    /// For the original discussion regarding this algorithm, see
579    /// <https://github.com/hickory-dns/hickory-dns/issues/1702>.
580    srtt_microseconds: AtomicU32,
581
582    /// The last time the `srtt_microseconds` value was updated.
583    last_update: SyncMutex<Option<Instant>>,
584}
585
586impl DecayingSrtt {
587    fn new(initial_srtt: Duration) -> Self {
588        Self {
589            srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
590            last_update: SyncMutex::new(None),
591        }
592    }
593
594    fn record(&self, rtt: Duration) {
595        // If the cast on the result does overflow (it shouldn't), then the
596        // value is saturated to u32::MAX, which is above the `MAX_SRTT_MICROS`
597        // limit (meaning that any potential overflow is inconsequential).
598        // See https://github.com/rust-lang/rust/issues/10184.
599        self.update(
600            rtt.as_micros() as u32,
601            |cur_srtt_microseconds, last_update| {
602                // An arbitrarily low weight is used when computing the factor
603                // to ensure that recent RTT measurements are weighted more
604                // heavily.
605                let factor = compute_srtt_factor(last_update, 3);
606                let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
607                    + factor * f64::from(cur_srtt_microseconds);
608                new_srtt.round() as u32
609            },
610        );
611    }
612
613    /// Records a connection failure for a particular query.
614    fn record_failure(&self) {
615        self.update(
616            Self::FAILURE_PENALTY,
617            |cur_srtt_microseconds, _last_update| {
618                cur_srtt_microseconds.saturating_add(Self::FAILURE_PENALTY)
619            },
620        );
621    }
622
623    /// Returns the SRTT value after applying a time based decay.
624    ///
625    /// The decay exponentially decreases the SRTT value. The primary reasons
626    /// for applying a downwards decay are twofold:
627    ///
628    /// 1. It helps distribute query load.
629    /// 2. It helps detect positive network changes. For example, decreases in
630    ///    latency or a server that has recovered from a failure.
631    fn current(&self) -> f64 {
632        let srtt = f64::from(self.srtt_microseconds.load(Ordering::Acquire));
633        self.last_update.lock().map_or(srtt, |last_update| {
634            // In general, if the time between queries is relatively short, then
635            // the server ordering algorithm will approximate a spike
636            // distribution where the servers with the lowest latencies are
637            // chosen much more frequently. Conversely, if the time between
638            // queries is relatively long, then the query distribution will be
639            // more uniform. A larger weight widens the window in which servers
640            // with historically lower latencies will be heavily preferred. On
641            // the other hand, a larger weight may also increase the time it
642            // takes to recover from a failure or to observe positive changes in
643            // latency.
644            srtt * compute_srtt_factor(last_update, 180)
645        })
646    }
647
648    /// Updates the SRTT value.
649    ///
650    /// If the `last_update` value has not been set, then uses the `default`
651    /// value to update the SRTT. Otherwise, invokes the `update_fn` with the
652    /// current SRTT value and the `last_update` timestamp.
653    fn update(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
654        let last_update = self.last_update.lock().replace(Instant::now());
655        let _ = self.srtt_microseconds.fetch_update(
656            Ordering::SeqCst,
657            Ordering::SeqCst,
658            move |cur_srtt_microseconds| {
659                Some(
660                    last_update
661                        .map_or(default, |last_update| {
662                            update_fn(cur_srtt_microseconds, last_update)
663                        })
664                        .min(Self::MAX_SRTT_MICROS),
665                )
666            },
667        );
668    }
669
670    /// Returns the raw SRTT value.
671    ///
672    /// Prefer to use `decayed_srtt` when ordering name servers.
673    #[cfg(all(test, feature = "tokio"))]
674    fn as_duration(&self) -> Duration {
675        Duration::from_micros(u64::from(self.srtt_microseconds.load(Ordering::Acquire)))
676    }
677
678    const FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
679    const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
680}
681
682/// Returns an exponentially weighted value in the range of 0.0 < x < 1.0
683///
684/// Computes the value using the following formula:
685///
686/// e<sup>(-t<sub>now</sub> - t<sub>last</sub>) / weight</sup>
687///
688/// As the duration since the `last_update` approaches the provided `weight`,
689/// the returned value decreases.
690fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
691    let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
692    exponent.exp()
693}
694
695/// State of a connection with a remote NameServer.
696#[derive(Debug, Eq, PartialEq, Copy, Clone)]
697#[repr(u8)]
698enum Status {
699    /// For some reason the connection failed. For UDP this would generally be a timeout
700    ///  for TCP this could be either Connection could never be established, or it
701    ///  failed at some point after. The Failed state should *not* be entered due to an
702    ///  error contained in a Message received from the server. In All cases to reestablish
703    ///  a new connection will need to be created.
704    Failed = 0,
705    /// Initial state, if Edns is not none, then Edns will be requested
706    Init = 1,
707    /// There has been successful communication with the remote.
708    ///  if no Edns is associated, then the remote does not support Edns
709    Established = 2,
710}
711
712impl From<Status> for u8 {
713    /// used for ordering purposes. The highest priority is placed on open connections
714    fn from(val: Status) -> Self {
715        val as Self
716    }
717}
718
719impl From<u8> for Status {
720    fn from(val: u8) -> Self {
721        match val {
722            2 => Self::Established,
723            1 => Self::Init,
724            _ => Self::Failed,
725        }
726    }
727}
728
729#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
730pub(crate) struct ConnectionPolicy {
731    pub(crate) disable_udp: bool,
732}
733
734impl ConnectionPolicy {
735    /// Checks if the given server has any protocols compatible with current policy.
736    pub(crate) fn allows_server<P: ConnectionProvider>(&self, server: &NameServer<P>) -> bool {
737        server.protocols().any(|p| self.allows_protocol(p))
738    }
739
740    /// Select the best pre-existing connection to use.
741    ///
742    /// This choice is made based on opportunistic encryption policy & probe history,
743    /// protocol policy, and the SRTT performance metrics.
744    fn select_connection<'a, P: ConnectionProvider>(
745        &self,
746        ip: IpAddr,
747        encrypted_transport_state: &NameServerTransportState,
748        opportunistic_encryption: &OpportunisticEncryption,
749        connections: &'a [ConnectionState<P>],
750    ) -> Option<&'a ConnectionState<P>> {
751        let selected = connections
752            .iter()
753            .filter(|conn| self.allows_protocol(conn.protocol))
754            .min_by(|a, b| self.compare_connections(opportunistic_encryption.is_enabled(), a, b));
755
756        let selected = selected?;
757
758        // If we're using opportunistic encryption and selected a pre-existing unencrypted connection,
759        // and have successfully probed on any supported encrypted protocol, we should _not_ reuse the
760        // existing connection and instead return `None`. This will result in a new encrypted connection
761        // being made to the successfully probed protocol and added to the connection list for future
762        // re-use.
763        match opportunistic_encryption.is_enabled()
764            && !selected.protocol.is_encrypted()
765            && encrypted_transport_state.any_recent_success(ip, opportunistic_encryption)
766        {
767            true => None,
768            false => Some(selected),
769        }
770    }
771
772    /// Select the best connection configuration to use for a new connection.
773    ///
774    /// This choice is made based on opportunistic encryption policy & probe history,
775    /// and protocol policy.
776    fn select_connection_config<'a>(
777        &self,
778        ip: IpAddr,
779        encrypted_transport_state: &NameServerTransportState,
780        opportunistic_encryption: &OpportunisticEncryption,
781        connection_configs: &'a [ConnectionConfig],
782    ) -> Option<&'a ConnectionConfig> {
783        connection_configs
784            .iter()
785            .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
786            .min_by(|a, b| {
787                self.compare_connection_configs(
788                    ip,
789                    encrypted_transport_state,
790                    opportunistic_encryption,
791                    a,
792                    b,
793                )
794            })
795    }
796
797    /// Select the first protocol allowed by current policy that uses an encrypted transport.
798    fn select_encrypted_connection_config<'a>(
799        &self,
800        connection_config: &'a [ConnectionConfig],
801    ) -> Option<&'a ConnectionConfig> {
802        connection_config
803            .iter()
804            .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
805            .find(|c| c.protocol.to_protocol().is_encrypted())
806    }
807
808    /// Checks if the given protocol is allowed by current policy.
809    fn allows_protocol(&self, protocol: Protocol) -> bool {
810        !(self.disable_udp && protocol == Protocol::Udp)
811    }
812
813    /// Compare two connections according to policy, protocol, and performance.
814    /// If opportunistic encryption is enabled we make an effort to select an encrypted connection.
815    fn compare_connections<P: ConnectionProvider>(
816        &self,
817        opportunistic_encryption: bool,
818        a: &ConnectionState<P>,
819        b: &ConnectionState<P>,
820    ) -> cmp::Ordering {
821        // When opportunistic encryption is in-play, we want to consider encrypted
822        // connections with the greatest priority.
823        if opportunistic_encryption {
824            match (a.protocol.is_encrypted(), b.protocol.is_encrypted()) {
825                (true, false) => return cmp::Ordering::Less,
826                (false, true) => return cmp::Ordering::Greater,
827                // When _both_ are encrypted, then decide on ordering based on other properties (like SRTT).
828                _ => {}
829            }
830        }
831
832        match (a.protocol, b.protocol) {
833            (ap, bp) if ap == bp => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
834            (Protocol::Udp, _) => cmp::Ordering::Less,
835            (_, Protocol::Udp) => cmp::Ordering::Greater,
836            _ => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
837        }
838    }
839
840    fn compare_connection_configs(
841        &self,
842        ip: IpAddr,
843        encrypted_transport_state: &NameServerTransportState,
844        opportunistic_encryption: &OpportunisticEncryption,
845        a: &ConnectionConfig,
846        b: &ConnectionConfig,
847    ) -> cmp::Ordering {
848        let a_protocol = a.protocol.to_protocol();
849        let b_protocol = b.protocol.to_protocol();
850
851        // When opportunistic encryption is in-play, prioritize encrypted protocols
852        // that have recent successful connections
853        if opportunistic_encryption.is_enabled() {
854            let a_recent_enc_success = a_protocol.is_encrypted()
855                && encrypted_transport_state.recent_success(
856                    ip,
857                    a_protocol,
858                    opportunistic_encryption,
859                );
860            let b_recent_enc_success = b_protocol.is_encrypted()
861                && encrypted_transport_state.recent_success(
862                    ip,
863                    b_protocol,
864                    opportunistic_encryption,
865                );
866
867            match (a_recent_enc_success, b_recent_enc_success) {
868                (true, false) => return cmp::Ordering::Less,
869                (false, true) => return cmp::Ordering::Greater,
870                // When both have recent success or neither do, continue with normal ordering
871                _ => {}
872            }
873        }
874
875        // Default protocol ordering: UDP first, then others
876        match (a_protocol, b_protocol) {
877            (ap, bp) if ap == bp => cmp::Ordering::Equal,
878            (Protocol::Udp, _) => cmp::Ordering::Less,
879            (_, Protocol::Udp) => cmp::Ordering::Greater,
880            _ => cmp::Ordering::Equal,
881        }
882    }
883}
884
885#[cfg(all(test, feature = "tokio"))]
886mod tests {
887    use std::cmp;
888    use std::net::{IpAddr, Ipv4Addr};
889    use std::str::FromStr;
890    use std::time::Duration;
891
892    use test_support::subscribe;
893    use tokio::net::UdpSocket;
894    use tokio::spawn;
895
896    use super::*;
897    use crate::config::{ConnectionConfig, ProtocolConfig};
898    use crate::connection_provider::TlsConfig;
899    use crate::net::runtime::TokioRuntimeProvider;
900    use crate::proto::op::{DnsRequest, DnsRequestOptions, Message, Query, ResponseCode};
901    use crate::proto::rr::rdata::NULL;
902    use crate::proto::rr::{Name, RData, Record, RecordType};
903
904    #[tokio::test]
905    async fn test_name_server() {
906        subscribe();
907
908        let options = ResolverOpts::default();
909        let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
910        let name_server = Arc::new(NameServer::new(
911            [].into_iter(),
912            config,
913            &options,
914            TokioRuntimeProvider::default(),
915        ));
916
917        let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
918        let name = Name::parse("www.example.com.", None).unwrap();
919        let response = name_server
920            .send(
921                DnsRequest::from_query(
922                    Query::query(name.clone(), RecordType::A),
923                    DnsRequestOptions::default(),
924                ),
925                ConnectionPolicy::default(),
926                &cx,
927            )
928            .await
929            .expect("query failed");
930        assert_eq!(response.response_code, ResponseCode::NoError);
931    }
932
933    #[tokio::test]
934    async fn test_failed_name_server() {
935        subscribe();
936
937        let options = ResolverOpts {
938            timeout: Duration::from_millis(1), // this is going to fail, make it fail fast...
939            ..ResolverOpts::default()
940        };
941
942        let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)));
943        let name_server = Arc::new(NameServer::new(
944            [],
945            config,
946            &options,
947            TokioRuntimeProvider::default(),
948        ));
949
950        let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
951        let name = Name::parse("www.example.com.", None).unwrap();
952        assert!(
953            name_server
954                .send(
955                    DnsRequest::from_query(
956                        Query::query(name.clone(), RecordType::A),
957                        DnsRequestOptions::default(),
958                    ),
959                    ConnectionPolicy::default(),
960                    &cx
961                )
962                .await
963                .is_err()
964        );
965    }
966
967    #[tokio::test]
968    async fn case_randomization_query_preserved() {
969        subscribe();
970
971        let provider = TokioRuntimeProvider::default();
972        let server = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
973        let server_addr = server.local_addr().unwrap();
974        let name = Name::from_str("dead.beef.").unwrap();
975        let data = b"DEADBEEF";
976
977        spawn({
978            let name = name.clone();
979            async move {
980                let mut buffer = [0_u8; 512];
981                let (len, addr) = server.recv_from(&mut buffer).await.unwrap();
982                let request = Message::from_vec(&buffer[0..len]).unwrap();
983                let mut response = Message::response(request.id, request.op_code);
984                response.add_queries(request.queries.to_vec());
985                response.add_answer(Record::from_rdata(
986                    name,
987                    0,
988                    RData::NULL(NULL::with(data.to_vec())),
989                ));
990                let response_buffer = response.to_vec().unwrap();
991                server.send_to(&response_buffer, addr).await.unwrap();
992            }
993        });
994
995        let config = NameServerConfig {
996            ip: server_addr.ip(),
997            trust_negative_responses: true,
998            connections: vec![ConnectionConfig {
999                port: server_addr.port(),
1000                protocol: ProtocolConfig::Udp,
1001                bind_addr: None,
1002            }],
1003        };
1004
1005        let resolver_opts = ResolverOpts {
1006            case_randomization: true,
1007            ..Default::default()
1008        };
1009
1010        let cx = Arc::new(PoolContext::new(resolver_opts, TlsConfig::new().unwrap()));
1011        let mut request_options = DnsRequestOptions::default();
1012        request_options.case_randomization = true;
1013        let ns = Arc::new(NameServer::new([], config, &cx.options, provider));
1014        let response = ns
1015            .send(
1016                DnsRequest::from_query(
1017                    Query::query(name.clone(), RecordType::NULL),
1018                    request_options,
1019                ),
1020                ConnectionPolicy::default(),
1021                &cx,
1022            )
1023            .await
1024            .unwrap();
1025
1026        let response_query_name = response.queries.first().unwrap().name();
1027        assert!(response_query_name.eq_case(&name));
1028    }
1029
1030    #[allow(clippy::extra_unused_type_parameters)]
1031    fn is_send_sync<S: Sync + Send>() -> bool {
1032        true
1033    }
1034
1035    #[test]
1036    fn stats_are_sync() {
1037        assert!(is_send_sync::<ConnectionMeta>());
1038    }
1039
1040    #[tokio::test(start_paused = true)]
1041    async fn test_stats_cmp() {
1042        use std::cmp::Ordering;
1043        let srtt_a = DecayingSrtt::new(Duration::from_micros(10));
1044        let srtt_b = DecayingSrtt::new(Duration::from_micros(20));
1045
1046        // No RTTs or failures have been recorded. The initial SRTTs should be
1047        // compared.
1048        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1049
1050        // Server A was used. Unused server B should now be preferred.
1051        srtt_a.record(Duration::from_millis(30));
1052        tokio::time::advance(Duration::from_secs(5)).await;
1053        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1054
1055        // Both servers have been used. Server A has a lower SRTT and should be
1056        // preferred.
1057        srtt_b.record(Duration::from_millis(50));
1058        tokio::time::advance(Duration::from_secs(5)).await;
1059        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1060
1061        // Server A experiences a connection failure, which results in Server B
1062        // being preferred.
1063        srtt_a.record_failure();
1064        tokio::time::advance(Duration::from_secs(5)).await;
1065        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1066
1067        // Server A should eventually recover and once again be preferred.
1068        while cmp(&srtt_a, &srtt_b) != Ordering::Less {
1069            srtt_b.record(Duration::from_millis(50));
1070            tokio::time::advance(Duration::from_secs(5)).await;
1071        }
1072
1073        srtt_a.record(Duration::from_millis(30));
1074        tokio::time::advance(Duration::from_secs(3)).await;
1075        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1076    }
1077
1078    fn cmp(a: &DecayingSrtt, b: &DecayingSrtt) -> cmp::Ordering {
1079        a.current().total_cmp(&b.current())
1080    }
1081
1082    #[tokio::test(start_paused = true)]
1083    async fn test_record_rtt() {
1084        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1085
1086        let first_rtt = Duration::from_millis(50);
1087        srtt.record(first_rtt);
1088
1089        // The first recorded RTT should replace the initial value.
1090        assert_eq!(srtt.as_duration(), first_rtt);
1091
1092        tokio::time::advance(Duration::from_secs(3)).await;
1093
1094        // Subsequent RTTs should factor in previously recorded values.
1095        srtt.record(Duration::from_millis(100));
1096        assert_eq!(srtt.as_duration(), Duration::from_micros(81606));
1097    }
1098
1099    #[test]
1100    fn test_record_rtt_maximum_value() {
1101        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1102
1103        srtt.record(Duration::MAX);
1104        // Updates to the SRTT are capped at a maximum value.
1105        assert_eq!(
1106            srtt.as_duration(),
1107            Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1108        );
1109    }
1110
1111    #[tokio::test(start_paused = true)]
1112    async fn test_record_connection_failure() {
1113        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1114
1115        // Verify that the SRTT value is initially replaced with the penalty and
1116        // subsequent failures result in the penalty being added.
1117        for failure_count in 1..4 {
1118            srtt.record_failure();
1119            assert_eq!(
1120                srtt.as_duration(),
1121                Duration::from_micros(
1122                    DecayingSrtt::FAILURE_PENALTY
1123                        .checked_mul(failure_count)
1124                        .expect("checked_mul overflow")
1125                        .into()
1126                )
1127            );
1128            tokio::time::advance(Duration::from_secs(3)).await;
1129        }
1130
1131        // Verify that the `last_update` timestamp was updated for a connection
1132        // failure and is used in subsequent calculations.
1133        srtt.record(Duration::from_millis(50));
1134        assert_eq!(srtt.as_duration(), Duration::from_micros(197152));
1135    }
1136
1137    #[test]
1138    fn test_record_connection_failure_maximum_value() {
1139        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1140
1141        let num_failures = (DecayingSrtt::MAX_SRTT_MICROS / DecayingSrtt::FAILURE_PENALTY) + 1;
1142        for _ in 0..num_failures {
1143            srtt.record_failure();
1144        }
1145
1146        // Updates to the SRTT are capped at a maximum value.
1147        assert_eq!(
1148            srtt.as_duration(),
1149            Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1150        );
1151    }
1152
1153    #[tokio::test(start_paused = true)]
1154    async fn test_decayed_srtt() {
1155        let initial_srtt = 10;
1156        let srtt = DecayingSrtt::new(Duration::from_micros(initial_srtt));
1157
1158        // No decay should be applied to the initial value.
1159        assert_eq!(srtt.current() as u32, initial_srtt as u32);
1160
1161        tokio::time::advance(Duration::from_secs(5)).await;
1162        srtt.record(Duration::from_millis(100));
1163
1164        // The decay function should assume a minimum of one second has elapsed
1165        // since the last update.
1166        tokio::time::advance(Duration::from_millis(500)).await;
1167        assert_eq!(srtt.current() as u32, 99445);
1168
1169        tokio::time::advance(Duration::from_secs(5)).await;
1170        assert_eq!(srtt.current() as u32, 96990);
1171    }
1172}
1173
1174#[cfg(all(test, feature = "__tls"))]
1175mod opportunistic_enc_tests {
1176    use std::io;
1177    use std::net::{IpAddr, Ipv4Addr};
1178    use std::sync::Arc;
1179    use std::time::{Duration, SystemTime};
1180
1181    #[cfg(feature = "metrics")]
1182    use metrics::{Label, Unit, with_local_recorder};
1183    #[cfg(feature = "metrics")]
1184    use metrics_util::debugging::DebuggingRecorder;
1185    use mock_provider::{MockClientHandle, MockProvider};
1186    use test_support::subscribe;
1187    #[cfg(feature = "metrics")]
1188    use test_support::{assert_counter_eq, assert_gauge_eq, assert_histogram_sample_count_eq};
1189
1190    use crate::config::{
1191        NameServerConfig, OpportunisticEncryption, OpportunisticEncryptionConfig, ProtocolConfig,
1192        ResolverOpts,
1193    };
1194    use crate::connection_provider::TlsConfig;
1195    #[cfg(feature = "metrics")]
1196    use crate::metrics::opportunistic_encryption::{
1197        PROBE_ATTEMPTS_TOTAL, PROBE_BUDGET_TOTAL, PROBE_DURATION_SECONDS, PROBE_ERRORS_TOTAL,
1198        PROBE_SUCCESSES_TOTAL, PROBE_TIMEOUTS_TOTAL,
1199    };
1200    use crate::name_server::{ConnectionPolicy, ConnectionState, NameServer, mock_provider};
1201    use crate::name_server_pool::{NameServerTransportState, PoolContext};
1202    use crate::net::NetError;
1203    use crate::net::xfer::Protocol;
1204
1205    #[tokio::test]
1206    async fn test_select_connection_opportunistic_enc_disabled() {
1207        let mut policy = ConnectionPolicy::default();
1208        let connections = vec![
1209            mock_connection(Protocol::Udp),
1210            mock_connection(Protocol::Tcp),
1211        ];
1212
1213        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1214        let state = NameServerTransportState::default();
1215        let opp_enc = OpportunisticEncryption::Disabled;
1216
1217        // When opportunistic encryption is disabled, and disable_udp isn't active,
1218        // we should select the UDP conn.
1219        let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1220        assert!(selected.is_some());
1221        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1222
1223        // When opportunistic encryption is disabled, and disable_udp is active,
1224        // we should select the TCP conn.
1225        policy.disable_udp = true;
1226        let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1227        assert!(selected.is_some());
1228        assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1229    }
1230
1231    #[tokio::test]
1232    async fn test_select_connection_opportunistic_enc_enabled() {
1233        let policy = ConnectionPolicy::default();
1234        let connections = [
1235            mock_connection(Protocol::Udp),
1236            mock_connection(Protocol::Tcp),
1237            // Include a pre-existing encrypted protocol connection.
1238            mock_connection(Protocol::Tls),
1239        ];
1240
1241        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1242        let state = NameServerTransportState::default();
1243        let opp_enc = &OpportunisticEncryption::Enabled {
1244            config: OpportunisticEncryptionConfig::default(),
1245        };
1246
1247        // When opportunistic encryption is enabled, and there is an encrypted connection available,
1248        // we should always choose it as the most preferred.
1249        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1250        assert!(selected.is_some());
1251        assert_eq!(selected.unwrap().protocol, Protocol::Tls);
1252    }
1253
1254    #[tokio::test]
1255    async fn test_select_connection_opportunistic_enc_enabled_no_state() {
1256        let mut policy = ConnectionPolicy::default();
1257        let connections = [
1258            mock_connection(Protocol::Udp),
1259            mock_connection(Protocol::Tcp),
1260            // No pre-existing encrypted protocol connection is available.
1261        ];
1262
1263        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1264        let state = NameServerTransportState::default();
1265        let opp_enc = &OpportunisticEncryption::Enabled {
1266            config: OpportunisticEncryptionConfig::default(),
1267        };
1268
1269        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1270        // and we have no probe state, we should select the UDP conn.
1271        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1272        assert!(selected.is_some());
1273        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1274
1275        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1276        // and we have no probe state, we should select the TCP conn.
1277        policy.disable_udp = true;
1278        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1279        assert!(selected.is_some());
1280        assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1281    }
1282
1283    #[tokio::test]
1284    async fn test_select_connection_opportunistic_enc_enabled_failed_probe() {
1285        let policy = ConnectionPolicy::default();
1286        let connections = [
1287            mock_connection(Protocol::Udp),
1288            mock_connection(Protocol::Tcp),
1289            // No pre-existing encrypted protocol connection is available.
1290        ];
1291
1292        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1293        let mut state = NameServerTransportState::default();
1294        let opp_enc = &OpportunisticEncryption::Enabled {
1295            config: OpportunisticEncryptionConfig::default(),
1296        };
1297
1298        // Update the state to reflect that we failed a previous probe attempt.
1299        state.error_received(
1300            ns_ip,
1301            Protocol::Tls,
1302            &NetError::from(io::Error::new(
1303                io::ErrorKind::ConnectionRefused,
1304                "nameserver refused TLS connection",
1305            )),
1306        );
1307
1308        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1309        // and our probe state indicates a failure, we should select the UDP conn.
1310        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1311        assert!(selected.is_some());
1312        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1313    }
1314
1315    #[tokio::test]
1316    async fn test_select_connection_opportunistic_enc_enabled_in_progress_probe() {
1317        let policy = ConnectionPolicy::default();
1318        let connections = [
1319            mock_connection(Protocol::Udp),
1320            mock_connection(Protocol::Tcp),
1321            // No pre-existing encrypted protocol connection is available.
1322        ];
1323
1324        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1325        let mut state = NameServerTransportState::default();
1326        let opp_enc = &OpportunisticEncryption::Enabled {
1327            config: OpportunisticEncryptionConfig::default(),
1328        };
1329
1330        // Update the state to reflect that we have an in-progress probe in-flight.
1331        state.initiate_connection(ns_ip, Protocol::Tls);
1332
1333        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1334        // and our probe state indicates an in-flight probe, we should select the UDP conn.
1335        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1336        assert!(selected.is_some());
1337        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1338
1339        // Update the state to reflect that we completed the connection, but haven't
1340        // received a response.
1341        state.complete_connection(ns_ip, Protocol::Tls);
1342
1343        // In this case we should still select the UDP conn.
1344        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1345        assert!(selected.is_some());
1346        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1347    }
1348
1349    #[tokio::test]
1350    async fn test_select_connection_opportunistic_enc_enabled_stale_probe() {
1351        let policy = ConnectionPolicy::default();
1352        let connections = [
1353            mock_connection(Protocol::Udp),
1354            mock_connection(Protocol::Tcp),
1355            // No pre-existing encrypted protocol connection is available.
1356        ];
1357
1358        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1359        let mut state = NameServerTransportState::default();
1360        let opp_enc_config = OpportunisticEncryptionConfig {
1361            persistence_period: Duration::from_secs(10),
1362            ..OpportunisticEncryptionConfig::default()
1363        };
1364        let opp_enc = &OpportunisticEncryption::Enabled {
1365            config: opp_enc_config.clone(),
1366        };
1367
1368        // Update the state to reflect that we have successfully probed this NS.
1369        state.complete_connection(ns_ip, Protocol::Tls);
1370        state.response_received(ns_ip, Protocol::Tls);
1371        // And then update the last response time to be too stale for consideration.
1372        let stale_time =
1373            SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1374        state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1375
1376        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1377        // and our probe state indicates success that is too stale, we should select an unencrypted
1378        // connection since the probe is no longer considered recent.
1379        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1380        assert!(selected.is_some());
1381        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1382    }
1383
1384    #[tokio::test]
1385    async fn test_select_connection_opportunistic_enc_enabled_good_probe() {
1386        let policy = ConnectionPolicy::default();
1387        let connections = [
1388            mock_connection(Protocol::Udp),
1389            mock_connection(Protocol::Tcp),
1390            // No pre-existing encrypted protocol connection is available.
1391        ];
1392
1393        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1394        let mut state = NameServerTransportState::default();
1395        let opp_enc = &OpportunisticEncryption::Enabled {
1396            config: OpportunisticEncryptionConfig::default(),
1397        };
1398
1399        // Update the state to reflect that we have successfully probed this NS within
1400        // the persistence period and received a response.
1401        state.complete_connection(ns_ip, Protocol::Tls);
1402        state.response_received(ns_ip, Protocol::Tls);
1403
1404        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1405        // and our probe state indicates a recent enough success, we should return `None` so that
1406        // we make a new encrypted connection.
1407        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1408        assert!(selected.is_none());
1409    }
1410
1411    #[tokio::test]
1412    async fn test_select_connection_config_opportunistic_enc_disabled() {
1413        let mut policy = ConnectionPolicy::default();
1414
1415        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1416        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1417
1418        let state = NameServerTransportState::default();
1419        let opp_enc = OpportunisticEncryption::Disabled;
1420
1421        // When opportunistic encryption is disabled, and disable_udp isn't active,
1422        // we should select the UDP config.
1423        let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1424        assert!(selected.is_some());
1425        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1426
1427        // When opportunistic encryption is disabled, and disable_udp is active,
1428        // we should select the TCP config.
1429        policy.disable_udp = true;
1430        let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1431        assert!(selected.is_some());
1432        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1433    }
1434
1435    #[tokio::test]
1436    async fn test_select_connection_config_opportunistic_enc_enabled_no_state() {
1437        let mut policy = ConnectionPolicy::default();
1438        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1439        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1440
1441        let state = NameServerTransportState::default();
1442        let opp_enc = &OpportunisticEncryption::Enabled {
1443            config: OpportunisticEncryptionConfig::default(),
1444        };
1445
1446        // When opportunistic encryption is enabled, but we have no probe state,
1447        // we should select the UDP config (default protocol ordering).
1448        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1449        assert!(selected.is_some());
1450        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1451
1452        // When opportunistic encryption is enabled, but we have no probe state,
1453        // and disable_udp is active, we should select the TCP config.
1454        policy.disable_udp = true;
1455        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1456        assert!(selected.is_some());
1457        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1458    }
1459
1460    #[tokio::test]
1461    async fn test_select_connection_config_opportunistic_enc_enabled_failed_probe() {
1462        let policy = ConnectionPolicy::default();
1463        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1464        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1465
1466        let mut state = NameServerTransportState::default();
1467        let opp_enc = &OpportunisticEncryption::Enabled {
1468            config: OpportunisticEncryptionConfig::default(),
1469        };
1470
1471        // Update the state to reflect that we failed a previous probe attempt.
1472        state.error_received(
1473            ns_ip,
1474            Protocol::Tls,
1475            &NetError::from(io::Error::new(
1476                io::ErrorKind::ConnectionRefused,
1477                "nameserver refused TLS connection",
1478            )),
1479        );
1480
1481        // When opportunistic encryption is enabled, but our probe state indicates a failure,
1482        // we should select the UDP config.
1483        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1484        assert!(selected.is_some());
1485        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1486    }
1487
1488    #[tokio::test]
1489    async fn test_select_connection_config_opportunistic_enc_enabled_stale_probe() {
1490        let policy = ConnectionPolicy::default();
1491        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1492        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1493
1494        let mut state = NameServerTransportState::default();
1495        let opp_enc_config = OpportunisticEncryptionConfig {
1496            persistence_period: Duration::from_secs(10),
1497            ..OpportunisticEncryptionConfig::default()
1498        };
1499        let opp_enc = &OpportunisticEncryption::Enabled {
1500            config: opp_enc_config.clone(),
1501        };
1502
1503        // Update the state to reflect that we have successfully probed this NS.
1504        state.complete_connection(ns_ip, Protocol::Tls);
1505        state.response_received(ns_ip, Protocol::Tls);
1506        // And then update the last response time to be too stale for consideration.
1507        let stale_time =
1508            SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1509        state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1510
1511        // When opportunistic encryption is enabled, but our probe state indicates success that is too stale,
1512        // we should select an unencrypted config since the probe is no longer considered recent.
1513        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1514        assert!(selected.is_some());
1515        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1516    }
1517
1518    #[tokio::test]
1519    async fn test_select_connection_config_opportunistic_enc_enabled_good_probe() {
1520        let policy = ConnectionPolicy::default();
1521        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1522        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1523
1524        let mut state = NameServerTransportState::default();
1525        let opp_enc = &OpportunisticEncryption::Enabled {
1526            config: OpportunisticEncryptionConfig::default(),
1527        };
1528
1529        // Update the state to reflect that we have successfully probed this NS within
1530        // the persistence period and received a response.
1531        state.complete_connection(ns_ip, Protocol::Tls);
1532        state.response_received(ns_ip, Protocol::Tls);
1533
1534        // When opportunistic encryption is enabled, and our probe state indicates a recent enough success,
1535        // we should select the encrypted config with highest priority.
1536        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1537        assert!(selected.is_some());
1538        assert!(matches!(
1539            selected.unwrap().protocol,
1540            ProtocolConfig::Tls { .. }
1541        ));
1542    }
1543
1544    #[tokio::test]
1545    async fn test_opportunistic_probe() {
1546        subscribe();
1547
1548        // Enable opportunistic encryption
1549        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1550            .with_opportunistic_encryption()
1551            .with_probe_budget(10);
1552
1553        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1554        let mock_provider = MockProvider::default();
1555        assert!(
1556            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1557                .await
1558                .is_ok()
1559        );
1560
1561        let recorded_calls = mock_provider.new_connection_calls();
1562        // We should have made two new connection calls.
1563        assert_eq!(recorded_calls.len(), 2);
1564        let (ips, protocols): (Vec<IpAddr>, Vec<ProtocolConfig>) =
1565            recorded_calls.into_iter().unzip();
1566        // All connections should be to the expected NS IP.
1567        assert!(ips.iter().all(|ip| *ip == ns_ip));
1568        // We should have made connections for both the UDP protocol, and the encrypted probe protocol.
1569        let protocols = protocols
1570            .iter()
1571            .map(ProtocolConfig::to_protocol)
1572            .collect::<Vec<_>>();
1573        assert!(protocols.contains(&Protocol::Udp));
1574        assert!(protocols.contains(&Protocol::Tls));
1575    }
1576
1577    #[tokio::test]
1578    async fn test_opportunistic_probe_skip_in_progress() {
1579        subscribe();
1580
1581        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1582        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1583            .with_opportunistic_encryption()
1584            .with_probe_budget(10);
1585
1586        // Set up state to show an in-flight connection already initiated
1587        cx.transport_state()
1588            .await
1589            .initiate_connection(ns_ip, Protocol::Tls);
1590
1591        let mock_provider = MockProvider::default();
1592        assert!(
1593            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1594                .await
1595                .is_ok()
1596        );
1597
1598        let recorded_calls = mock_provider.new_connection_calls();
1599        // We should have made only one connection call (UDP), no probe because one is already in-flight
1600        assert_eq!(recorded_calls.len(), 1);
1601        let (ip, protocol) = &recorded_calls[0];
1602        assert_eq!(*ip, ns_ip);
1603        assert_eq!(protocol.to_protocol(), Protocol::Udp);
1604    }
1605
1606    #[tokio::test]
1607    async fn test_opportunistic_probe_skip_recent_failure() {
1608        subscribe();
1609
1610        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1611        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1612            .with_opportunistic_encryption()
1613            .with_probe_budget(10);
1614
1615        // Set up state to show a recent failure within the damping period
1616        cx.transport_state().await.error_received(
1617            ns_ip,
1618            Protocol::Tls,
1619            &NetError::from(io::Error::new(
1620                io::ErrorKind::ConnectionRefused,
1621                "connection refused",
1622            )),
1623        );
1624
1625        let mock_provider = MockProvider::default();
1626        assert!(
1627            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1628                .await
1629                .is_ok()
1630        );
1631
1632        let recorded_calls = mock_provider.new_connection_calls();
1633        // We should have made only one connection call (UDP), no probe due to recent failure
1634        assert_eq!(recorded_calls.len(), 1);
1635        let (ip, protocol) = &recorded_calls[0];
1636        assert_eq!(*ip, ns_ip);
1637        assert_eq!(protocol.to_protocol(), Protocol::Udp);
1638    }
1639
1640    #[tokio::test]
1641    async fn test_opportunistic_probe_stale_failure() {
1642        subscribe();
1643
1644        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1645        let mut cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1646            .with_probe_budget(10);
1647        let opp_enc_config = OpportunisticEncryptionConfig {
1648            damping_period: Duration::from_secs(5),
1649            ..OpportunisticEncryptionConfig::default()
1650        };
1651        cx.opportunistic_encryption = OpportunisticEncryption::Enabled {
1652            config: opp_enc_config.clone(),
1653        };
1654
1655        // Set up state to show an old failure outside the damping period.
1656        {
1657            let mut state = cx.transport_state().await;
1658            let old_failure_time =
1659                SystemTime::now() - opp_enc_config.damping_period - Duration::from_secs(1);
1660            state.set_failure_time(ns_ip, Protocol::Tls, old_failure_time);
1661        }
1662
1663        let mock_provider = MockProvider::default();
1664        assert!(
1665            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1666                .await
1667                .is_ok()
1668        );
1669
1670        let recorded_calls = mock_provider.new_connection_calls();
1671        // We should have made two connection calls (UDP + TLS probe) because the failure is old
1672        assert_eq!(recorded_calls.len(), 2);
1673        let protocols = recorded_calls
1674            .iter()
1675            .map(|(_, protocol)| protocol.to_protocol())
1676            .collect::<Vec<_>>();
1677        assert!(protocols.contains(&Protocol::Udp));
1678        assert!(protocols.contains(&Protocol::Tls));
1679    }
1680
1681    #[tokio::test]
1682    async fn test_opportunistic_probe_skip_no_budget() {
1683        subscribe();
1684
1685        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1686        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1687            .with_opportunistic_encryption();
1688        let mock_provider = MockProvider::default();
1689        // Set budget to 0 to simulate exhausted probe budget
1690        assert!(
1691            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1692                .await
1693                .is_ok()
1694        );
1695
1696        let recorded_calls = mock_provider.new_connection_calls();
1697        // We should have made only one connection call (UDP), no probe due to exhausted budget
1698        assert_eq!(recorded_calls.len(), 1);
1699        let (ip, protocol) = &recorded_calls[0];
1700        assert_eq!(*ip, ns_ip);
1701        assert_eq!(protocol.to_protocol(), Protocol::Udp);
1702    }
1703
1704    fn mock_connection(protocol: Protocol) -> ConnectionState<MockProvider> {
1705        ConnectionState::new(MockClientHandle, protocol)
1706    }
1707
1708    #[cfg(feature = "metrics")]
1709    #[test]
1710    fn test_opportunistic_probe_metrics_success() {
1711        subscribe();
1712        let recorder = DebuggingRecorder::new();
1713        let snapshotter = recorder.snapshotter();
1714        let initial_budget = 10;
1715
1716        with_local_recorder(&recorder, || {
1717            let runtime = tokio::runtime::Builder::new_current_thread()
1718                .enable_all()
1719                .build()
1720                .unwrap();
1721
1722            runtime.block_on(async {
1723                assert!(
1724                    test_connected_mut_client(
1725                        IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1726                        Arc::new(
1727                            PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1728                                .with_opportunistic_encryption()
1729                                .with_probe_budget(initial_budget),
1730                        ),
1731                        &MockProvider::default(),
1732                    )
1733                    .await
1734                    .is_ok()
1735                );
1736            });
1737        });
1738
1739        #[allow(clippy::mutable_key_type)]
1740        let map = snapshotter.snapshot().into_hashmap();
1741
1742        // We should have registered 1 TLS protocol probe attempt.
1743        let protocol = vec![Label::new("protocol", "tls")];
1744        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1745        // And seen 1 probe duration observation.
1746        assert_histogram_sample_count_eq(
1747            &map,
1748            PROBE_DURATION_SECONDS,
1749            protocol.clone(),
1750            1,
1751            Unit::Seconds,
1752        );
1753
1754        // We should have registered 1 TLS protocol probe success.
1755        assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol.clone(), 1);
1756
1757        // We should have registered 0 TLS protocol probe errors.
1758        assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol, 0);
1759
1760        // The budget should be back to the initial value now that the probe completed.
1761        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1762    }
1763
1764    #[cfg(feature = "metrics")]
1765    #[test]
1766    fn test_opportunistic_probe_metrics_budget_exhausted() {
1767        subscribe();
1768        let recorder = DebuggingRecorder::new();
1769        let snapshotter = recorder.snapshotter();
1770
1771        with_local_recorder(&recorder, || {
1772            let runtime = tokio::runtime::Builder::new_current_thread()
1773                .enable_all()
1774                .build()
1775                .unwrap();
1776
1777            runtime.block_on(async {
1778                assert!(
1779                    test_connected_mut_client(
1780                        IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1781                        Arc::new(
1782                            PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1783                                .with_opportunistic_encryption(),
1784                        ),
1785                        &MockProvider::default(),
1786                    )
1787                    .await
1788                    .is_ok()
1789                );
1790            });
1791        });
1792
1793        #[allow(clippy::mutable_key_type)]
1794        let map = snapshotter.snapshot().into_hashmap();
1795
1796        // The budget metric should confirm that there's no budget.
1797        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], 0);
1798
1799        // We should not have registered a probe attempt.
1800        let protocol = vec![Label::new("protocol", "tls")];
1801        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 0);
1802        // Or seen a probe duration observation.
1803        assert_histogram_sample_count_eq(&map, PROBE_DURATION_SECONDS, protocol, 0, Unit::Seconds);
1804    }
1805
1806    #[cfg(feature = "metrics")]
1807    #[test]
1808    fn test_opportunistic_probe_metrics_connection_error() {
1809        subscribe();
1810        let recorder = DebuggingRecorder::new();
1811        let snapshotter = recorder.snapshotter();
1812        let initial_budget = 10;
1813
1814        with_local_recorder(&recorder, || {
1815            let runtime = tokio::runtime::Builder::new_current_thread()
1816                .enable_all()
1817                .build()
1818                .unwrap();
1819
1820            runtime.block_on(async {
1821                let _ = test_connected_mut_client(
1822                    IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1823                    Arc::new(
1824                        PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1825                            .with_opportunistic_encryption()
1826                            .with_probe_budget(initial_budget),
1827                    ),
1828                    // Configure a mock provider that always produces an error when new connections are requested.
1829                    &MockProvider {
1830                        new_connection_error: Some(NetError::from(io::Error::new(
1831                            io::ErrorKind::ConnectionRefused,
1832                            "connection refused",
1833                        ))),
1834                        ..MockProvider::default()
1835                    },
1836                )
1837                .await;
1838            });
1839        });
1840
1841        #[allow(clippy::mutable_key_type)]
1842        let map = snapshotter.snapshot().into_hashmap();
1843
1844        // We should have registered 1 TLS protocol probe attempt.
1845        let protocol = vec![Label::new("protocol", "tls")];
1846        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1847        // And seen 1 probe duration observation.
1848        assert_histogram_sample_count_eq(
1849            &map,
1850            PROBE_DURATION_SECONDS,
1851            protocol.clone(),
1852            1,
1853            Unit::Seconds,
1854        );
1855
1856        // We should have registered 1 TLS protocol probe error.
1857        assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 1);
1858
1859        // We shouldn't have registered any TLS protocol probe successes due to the
1860        // mock new connection error.
1861        assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1862
1863        // The budget should be back to the initial value now that the probe completed.
1864        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1865    }
1866
1867    #[cfg(feature = "metrics")]
1868    #[test]
1869    fn test_opportunistic_probe_metrics_connection_timeout_error() {
1870        subscribe();
1871        let recorder = DebuggingRecorder::new();
1872        let snapshotter = recorder.snapshotter();
1873        let initial_budget = 10;
1874
1875        with_local_recorder(&recorder, || {
1876            let runtime = tokio::runtime::Builder::new_current_thread()
1877                .enable_all()
1878                .build()
1879                .unwrap();
1880
1881            runtime.block_on(async {
1882                let _ = test_connected_mut_client(
1883                    IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1884                    Arc::new(
1885                        PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1886                            .with_opportunistic_encryption()
1887                            .with_probe_budget(initial_budget),
1888                    ),
1889                    // Configure a mock provider that always produces a Timeout error when new connections are requested.
1890                    &MockProvider {
1891                        new_connection_error: Some(NetError::Timeout),
1892                        ..MockProvider::default()
1893                    },
1894                )
1895                .await;
1896            });
1897        });
1898
1899        #[allow(clippy::mutable_key_type)]
1900        let map = snapshotter.snapshot().into_hashmap();
1901
1902        // We should have registered 1 TLS protocol probe attempt.
1903        let protocol = vec![Label::new("protocol", "tls")];
1904        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1905        // And seen 1 probe duration observation.
1906        assert_histogram_sample_count_eq(
1907            &map,
1908            PROBE_DURATION_SECONDS,
1909            protocol.clone(),
1910            1,
1911            Unit::Seconds,
1912        );
1913
1914        // We should have registered 1 TLS protocol probe timeout.
1915        assert_counter_eq(&map, PROBE_TIMEOUTS_TOTAL, protocol.clone(), 1);
1916
1917        // We shouldn't have registered a more general probe error.
1918        assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 0);
1919
1920        // We shouldn't have registered any TLS protocol probe successes due to the
1921        // mock new connection error.
1922        assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1923
1924        // The budget should be back to the initial value now that the probe completed.
1925        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1926    }
1927
1928    /// Construct a nameserver appropriate for opportunistic encryption and assert connected_mut_client
1929    /// returns Ok.
1930    ///
1931    /// Behind the scenes this may provoke probing behaviour that the calling test can observe via
1932    /// the `MockProvider`'s recorded calls.
1933    async fn test_connected_mut_client(
1934        ns_ip: IpAddr,
1935        cx: Arc<PoolContext>,
1936        provider: &MockProvider,
1937    ) -> Result<(), NetError> {
1938        let name_server = NameServer::new(
1939            [],
1940            NameServerConfig::opportunistic_encryption(ns_ip),
1941            &ResolverOpts::default(),
1942            provider.clone(),
1943        );
1944
1945        name_server
1946            .connected_mut_client(ConnectionPolicy::default(), &cx)
1947            .await
1948            .map(|_| ())
1949    }
1950}
1951
1952#[cfg(all(test, feature = "metrics"))]
1953mod resolver_metrics_tests {
1954    use std::net::{IpAddr, Ipv4Addr};
1955
1956    use metrics::{Label, with_local_recorder};
1957    use metrics_util::debugging::DebuggingRecorder;
1958    use mock_provider::MockProvider;
1959    use test_support::assert_counter_eq;
1960    use test_support::subscribe;
1961
1962    use super::*;
1963    use crate::connection_provider::TlsConfig;
1964    use crate::metrics::OUTGOING_QUERIES_TOTAL;
1965
1966    #[test]
1967    fn test_outgoing_query_protocol_metrics_udp() {
1968        subscribe();
1969        let recorder = DebuggingRecorder::new();
1970        let snapshotter = recorder.snapshotter();
1971
1972        with_local_recorder(&recorder, || {
1973            let runtime = tokio::runtime::Builder::new_current_thread()
1974                .enable_all()
1975                .build()
1976                .unwrap();
1977
1978            runtime.block_on(async {
1979                let options = ResolverOpts::default();
1980                let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
1981                let name_server = Arc::new(NameServer::new(
1982                    [],
1983                    config,
1984                    &options,
1985                    MockProvider::default(),
1986                ));
1987
1988                let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
1989                let name = Name::parse("www.example.com.", None).unwrap();
1990                let _ = name_server
1991                    .send(
1992                        DnsRequest::from_query(
1993                            Query::query(name.clone(), RecordType::A),
1994                            DnsRequestOptions::default(),
1995                        ),
1996                        ConnectionPolicy::default(),
1997                        &cx,
1998                    )
1999                    .await;
2000            });
2001        });
2002
2003        #[allow(clippy::mutable_key_type)]
2004        let map = snapshotter.snapshot().into_hashmap();
2005
2006        // We should have registered 1 UDP protocol query.
2007        let protocol = vec![Label::new("protocol", "udp")];
2008        assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2009    }
2010
2011    #[test]
2012    fn test_outgoing_query_protocol_metrics_tcp() {
2013        subscribe();
2014        let recorder = DebuggingRecorder::new();
2015        let snapshotter = recorder.snapshotter();
2016
2017        with_local_recorder(&recorder, || {
2018            let runtime = tokio::runtime::Builder::new_current_thread()
2019                .enable_all()
2020                .build()
2021                .unwrap();
2022
2023            runtime.block_on(async {
2024                let options = ResolverOpts::default();
2025                let config = NameServerConfig::tcp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
2026                let name_server = Arc::new(NameServer::new(
2027                    [],
2028                    config,
2029                    &options,
2030                    MockProvider::default(),
2031                ));
2032
2033                let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2034                let name = Name::parse("www.example.com.", None).unwrap();
2035                let _ = name_server
2036                    .send(
2037                        DnsRequest::from_query(
2038                            Query::query(name.clone(), RecordType::A),
2039                            DnsRequestOptions::default(),
2040                        ),
2041                        ConnectionPolicy::default(),
2042                        &cx,
2043                    )
2044                    .await;
2045            });
2046        });
2047
2048        #[allow(clippy::mutable_key_type)]
2049        let map = snapshotter.snapshot().into_hashmap();
2050
2051        // We should have registered 1 TCP protocol query.
2052        let protocol = vec![Label::new("protocol", "tcp")];
2053        assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2054    }
2055
2056    #[cfg(feature = "__tls")]
2057    #[test]
2058    fn test_outgoing_query_protocol_metrics_tls() {
2059        subscribe();
2060        let recorder = DebuggingRecorder::new();
2061        let snapshotter = recorder.snapshotter();
2062
2063        with_local_recorder(&recorder, || {
2064            let runtime = tokio::runtime::Builder::new_current_thread()
2065                .enable_all()
2066                .build()
2067                .unwrap();
2068
2069            runtime.block_on(async {
2070                let options = ResolverOpts::default();
2071                let config = NameServerConfig::tls(
2072                    IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
2073                    "dns.google".into(),
2074                );
2075                let name_server = Arc::new(NameServer::new(
2076                    [],
2077                    config,
2078                    &options,
2079                    MockProvider::default(),
2080                ));
2081
2082                let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2083                let name = Name::parse("www.example.com.", None).unwrap();
2084                let _ = name_server
2085                    .send(
2086                        DnsRequest::from_query(
2087                            Query::query(name.clone(), RecordType::A),
2088                            DnsRequestOptions::default(),
2089                        ),
2090                        ConnectionPolicy::default(),
2091                        &cx,
2092                    )
2093                    .await;
2094            });
2095        });
2096
2097        #[allow(clippy::mutable_key_type)]
2098        let map = snapshotter.snapshot().into_hashmap();
2099
2100        // We should have registered 1 TLS protocol query.
2101        let protocol = vec![Label::new("protocol", "tls")];
2102        assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2103    }
2104}
2105
2106#[cfg(all(test, any(feature = "metrics", feature = "__tls")))]
2107mod mock_provider {
2108    use std::future::Future;
2109    use std::io;
2110    use std::pin::Pin;
2111    use std::task::{Context, Poll};
2112
2113    use futures_util::stream::once;
2114    use futures_util::{Stream, future};
2115    use tokio::net::UdpSocket;
2116
2117    use super::*;
2118    use crate::config::ProtocolConfig;
2119    use crate::net::runtime::TokioTime;
2120    use crate::net::runtime::iocompat::AsyncIoTokioAsStd;
2121    use crate::proto::op::Message;
2122
2123    /// `MockProvider` is a `ConnectionProvider` that uses a synchronous runtime provider.
2124    ///
2125    /// It also tracks calls to `new_connection`, exposing the arguments provided as
2126    /// `new_connection_calls` for test to interrogate. The optional `new_connection_error`
2127    /// `ProtoError` can be set to have `new_connection()` return a future that will error
2128    /// when polled, mocking a connection failure.
2129    #[derive(Clone)]
2130    pub(super) struct MockProvider {
2131        pub(super) runtime: MockSyncRuntimeProvider,
2132        pub(super) new_connection_calls: Arc<SyncMutex<Vec<(IpAddr, ProtocolConfig)>>>,
2133        pub(super) new_connection_error: Option<NetError>,
2134    }
2135
2136    impl MockProvider {
2137        pub(super) fn new_connection_calls(&self) -> Vec<(IpAddr, ProtocolConfig)> {
2138            self.new_connection_calls.lock().clone()
2139        }
2140    }
2141
2142    impl ConnectionProvider for MockProvider {
2143        type Conn = MockClientHandle;
2144        type FutureConn = Pin<Box<dyn Send + Future<Output = Result<Self::Conn, NetError>>>>;
2145        type RuntimeProvider = MockSyncRuntimeProvider;
2146
2147        fn new_connection(
2148            &self,
2149            ip: IpAddr,
2150            config: &ConnectionConfig,
2151            _cx: &PoolContext,
2152        ) -> Result<Self::FutureConn, NetError> {
2153            self.new_connection_calls
2154                .lock()
2155                .push((ip, config.protocol.clone()));
2156
2157            Ok(Box::pin(future::ready(match &self.new_connection_error {
2158                Some(err) => Err(err.clone()),
2159                None => Ok(MockClientHandle),
2160            })))
2161        }
2162
2163        fn runtime_provider(&self) -> &Self::RuntimeProvider {
2164            &self.runtime
2165        }
2166    }
2167
2168    impl Default for MockProvider {
2169        fn default() -> Self {
2170            Self {
2171                runtime: MockSyncRuntimeProvider,
2172                new_connection_calls: Arc::new(SyncMutex::new(Vec::new())),
2173                new_connection_error: None,
2174            }
2175        }
2176    }
2177
2178    /// `MockClientHandle` is a `DnsHandle` that uses a synchronous runtime provider.
2179    ///
2180    /// It's `send` method always returns a `NoError` response when polled, simulating a
2181    /// successful DNS request exchange.
2182    #[derive(Clone, Default)]
2183    pub(super) struct MockClientHandle;
2184
2185    impl DnsHandle for MockClientHandle {
2186        type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
2187        type Runtime = MockSyncRuntimeProvider;
2188
2189        fn send(&self, request: DnsRequest) -> Self::Response {
2190            let mut response = Message::response(request.id, request.op_code);
2191            response.metadata.response_code = ResponseCode::NoError;
2192            response.add_queries(request.queries.clone());
2193            Box::pin(once(future::ready(Ok(
2194                DnsResponse::from_message(response).unwrap()
2195            ))))
2196        }
2197    }
2198
2199    /// `MockSyncRuntimeProvider` is a `RuntimeProvider` that creates `MockSyncHandle` instances.
2200    ///
2201    /// Trait methods other than `create_handle` are not implemented.
2202    #[derive(Clone)]
2203    pub(super) struct MockSyncRuntimeProvider;
2204
2205    impl RuntimeProvider for MockSyncRuntimeProvider {
2206        type Handle = MockSyncHandle;
2207        type Timer = TokioTime;
2208        type Udp = UdpSocket;
2209        type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
2210
2211        fn create_handle(&self) -> Self::Handle {
2212            MockSyncHandle
2213        }
2214
2215        #[allow(clippy::unimplemented)]
2216        fn connect_tcp(
2217            &self,
2218            _server_addr: std::net::SocketAddr,
2219            _bind_addr: Option<std::net::SocketAddr>,
2220            _timeout: Option<Duration>,
2221        ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
2222            unimplemented!();
2223        }
2224
2225        #[allow(clippy::unimplemented)]
2226        fn bind_udp(
2227            &self,
2228            _local_addr: std::net::SocketAddr,
2229            _server_addr: std::net::SocketAddr,
2230        ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
2231            unimplemented!();
2232        }
2233    }
2234
2235    /// `MockSyncHandle` is a `Spawn` implementation that polls task futures synchronously.
2236    ///
2237    /// Provided futures will be polled until completion, allowing tests to avoid needing to
2238    /// coordinate with background tasks to determine their completion state.
2239    #[derive(Clone)]
2240    pub(super) struct MockSyncHandle;
2241
2242    impl Spawn for MockSyncHandle {
2243        fn spawn_bg(&mut self, future: impl Future<Output = ()> + Send + 'static) {
2244            // Instead of spawning the future as a background task, poll it synchronously
2245            // until completion.
2246            let waker = futures_util::task::noop_waker();
2247            let mut context = Context::from_waker(&waker);
2248            let mut future = Box::pin(future);
2249
2250            loop {
2251                match future.as_mut().poll(&mut context) {
2252                    Poll::Ready(_) => break,
2253                    Poll::Pending => continue,
2254                }
2255            }
2256        }
2257    }
2258}