Skip to main content

hickory_resolver/
name_server.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8#[cfg(not(test))]
9use std::time::{Duration, Instant};
10use std::{
11    cmp,
12    fmt::Debug,
13    marker::PhantomData,
14    net::IpAddr,
15    sync::{
16        Arc,
17        atomic::{AtomicU8, AtomicU32, Ordering},
18    },
19};
20
21use futures_util::lock::Mutex as AsyncMutex;
22use parking_lot::Mutex as SyncMutex;
23#[cfg(test)]
24use tokio::time::{Duration, Instant};
25use tracing::{debug, error, warn};
26
27#[cfg(feature = "metrics")]
28use crate::metrics::ResolverMetrics;
29#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
30use crate::metrics::opportunistic_encryption::ProbeMetrics;
31use crate::{
32    config::{
33        ConnectionConfig, NameServerConfig, OpportunisticEncryption, ResolverOpts,
34        ServerOrderingStrategy,
35    },
36    connection_provider::ConnectionProvider,
37    name_server_pool::{NameServerTransportState, PoolContext},
38    net::{
39        DnsError, NetError, NoRecords,
40        runtime::{RuntimeProvider, Spawn},
41        xfer::{DnsHandle, FirstAnswer, Protocol},
42    },
43    proto::{
44        op::{DnsRequest, DnsRequestOptions, DnsResponse, Query, ResponseCode},
45        rr::{Name, RecordType},
46    },
47};
48
49/// A remote DNS server, identified by its IP address.
50///
51/// This potentially holds multiple open connections to the server, according to the
52/// configured protocols, and will make new connections as needed.
53pub struct NameServer<P: ConnectionProvider> {
54    config: NameServerConfig,
55    connections: AsyncMutex<Vec<ConnectionState<P>>>,
56    /// Metrics related to opportunistic encryption probes.
57    #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
58    opportunistic_probe_metrics: ProbeMetrics,
59    /// Metrics related to outgoing queries.
60    #[cfg(feature = "metrics")]
61    resolver_metrics: ResolverMetrics,
62    server_srtt: DecayingSrtt,
63    connection_provider: P,
64}
65
66impl<P: ConnectionProvider> NameServer<P> {
67    /// Create a new [`NameServer`] with the given connections and configuration.
68    ///
69    /// The `connections` will usually be empty.
70    pub fn new(
71        connections: impl IntoIterator<Item = (Protocol, P::Conn)>,
72        config: NameServerConfig,
73        options: &ResolverOpts,
74        connection_provider: P,
75    ) -> Self {
76        let mut connections = connections
77            .into_iter()
78            .map(|(protocol, handle)| ConnectionState::new(handle, protocol))
79            .collect::<Vec<_>>();
80
81        // Unless the user specified that we should follow the configured order,
82        // re-order the connections to prioritize UDP.
83        if options.server_ordering_strategy != ServerOrderingStrategy::UserProvidedOrder {
84            connections.sort_by_key(|ns| ns.protocol != Protocol::Udp);
85        }
86
87        Self {
88            config,
89            connections: AsyncMutex::new(connections),
90            server_srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
91            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
92            opportunistic_probe_metrics: ProbeMetrics::default(),
93            #[cfg(feature = "metrics")]
94            resolver_metrics: ResolverMetrics::default(),
95            connection_provider,
96        }
97    }
98
99    // TODO: there needs to be some way of customizing the connection based on EDNS options from the server side...
100    pub(crate) async fn send(
101        self: Arc<Self>,
102        request: DnsRequest,
103        policy: ConnectionPolicy,
104        cx: &Arc<PoolContext>,
105    ) -> Result<DnsResponse, NetError> {
106        let (handle, meta, protocol) = self.connected_mut_client(policy, cx).await?;
107        #[cfg(feature = "metrics")]
108        self.resolver_metrics.increment_outgoing_query(&protocol);
109        let now = Instant::now();
110        let response = handle.send(request).first_answer().await;
111        let rtt = now.elapsed();
112
113        match response {
114            Ok(response) => {
115                meta.set_status(Status::Established);
116                let result = DnsError::from_response(response);
117                let error = match result {
118                    Ok(response) => {
119                        meta.srtt.record(rtt);
120                        self.server_srtt.record(rtt);
121                        if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
122                            cx.transport_state()
123                                .await
124                                .response_received(self.config.ip, protocol);
125                        }
126                        return Ok(response);
127                    }
128                    Err(error) => error,
129                };
130
131                let update = match error {
132                    DnsError::NoRecordsFound(NoRecords {
133                        response_code: ResponseCode::ServFail,
134                        ..
135                    }) => Some(true),
136                    DnsError::NoRecordsFound(NoRecords { .. }) => Some(false),
137                    _ => None,
138                };
139
140                match update {
141                    Some(true) => {
142                        meta.srtt.record(rtt);
143                        self.server_srtt.record(rtt);
144                    }
145                    Some(false) => {
146                        // record the failure
147                        meta.srtt.record_failure();
148                        self.server_srtt.record_failure();
149                    }
150                    None => {}
151                }
152
153                let err = NetError::from(error);
154                if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
155                    cx.transport_state()
156                        .await
157                        .error_received(self.config.ip, protocol, &err)
158                }
159                Err(err)
160            }
161            Err(error) => {
162                debug!(config = ?self.config, %error, "failed to connect to name server");
163
164                // this transitions the state to failure
165                meta.set_status(Status::Failed);
166
167                // record the failure on both the per-connection and server-level SRTTs.
168                // updating server_srtt ensures the server is deprioritized in pool
169                // ordering (decayed_srtt) so other servers get a chance to be tried.
170                match &error {
171                    NetError::Busy | NetError::Io(_) | NetError::Timeout => {
172                        meta.srtt.record_failure();
173                        self.server_srtt.record_failure();
174                    }
175                    #[cfg(feature = "__quic")]
176                    NetError::QuinnConfigError(_)
177                    | NetError::QuinnConnect(_)
178                    | NetError::QuinnConnection(_)
179                    | NetError::QuinnTlsConfigError(_) => {
180                        meta.srtt.record_failure();
181                        self.server_srtt.record_failure();
182                    }
183                    #[cfg(feature = "__tls")]
184                    NetError::RustlsError(_) => {
185                        meta.srtt.record_failure();
186                        self.server_srtt.record_failure();
187                    }
188                    _ => {}
189                }
190
191                if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
192                    cx.transport_state()
193                        .await
194                        .error_received(self.config.ip, protocol, &error);
195                }
196
197                // These are connection failures, not lookup failures, that is handled in the resolver layer
198                Err(error)
199            }
200        }
201    }
202
203    /// This will return a mutable client to allows for sending messages.
204    ///
205    /// If the connection is in a failed state, then this will establish a new connection
206    async fn connected_mut_client(
207        &self,
208        policy: ConnectionPolicy,
209        cx: &Arc<PoolContext>,
210    ) -> Result<(P::Conn, Arc<ConnectionMeta>, Protocol), NetError> {
211        let mut connections = self.connections.lock().await;
212        connections.retain(|conn| matches!(conn.meta.status(), Status::Init | Status::Established));
213        if let Some(conn) = policy.select_connection(
214            self.config.ip,
215            &*cx.transport_state().await,
216            &cx.opportunistic_encryption,
217            &connections,
218        ) {
219            return Ok((conn.handle.clone(), conn.meta.clone(), conn.protocol));
220        }
221
222        debug!(config = ?self.config, "connecting");
223        let config = policy
224            .select_connection_config(
225                self.config.ip,
226                &*cx.transport_state().await,
227                &cx.opportunistic_encryption,
228                &self.config.connections,
229            )
230            .ok_or(NetError::NoConnections)?;
231
232        let protocol = config.protocol.to_protocol();
233        if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
234            cx.transport_state()
235                .await
236                .initiate_connection(self.config.ip, protocol);
237        } else if cx.opportunistic_encryption.is_enabled() && !protocol.is_encrypted() {
238            self.consider_probe_encrypted_transport(&policy, cx).await;
239        }
240
241        let handle = Box::pin(self.connection_provider.new_connection(
242            self.config.ip,
243            config,
244            cx,
245        )?)
246        .await?;
247
248        if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
249            cx.transport_state()
250                .await
251                .complete_connection(self.config.ip, protocol);
252        }
253
254        // establish a new connection
255        let state = ConnectionState::new(handle.clone(), protocol);
256        let meta = state.meta.clone();
257        connections.push(state);
258        Ok((handle, meta, protocol))
259    }
260
261    pub(super) fn protocols(&self) -> impl Iterator<Item = Protocol> + '_ {
262        self.config
263            .connections
264            .iter()
265            .map(|conn| conn.protocol.to_protocol())
266    }
267
268    pub(super) fn ip(&self) -> IpAddr {
269        self.config.ip
270    }
271
272    pub(crate) fn decayed_srtt(&self) -> f64 {
273        self.server_srtt.current()
274    }
275
276    /// Records an SRTT observation for a server whose in-flight request was
277    /// cancelled because a parallel request to another server succeeded first.
278    ///
279    /// Records the winner's RTT plus a small penalty (`CANCEL_PENALTY`) as the
280    /// observation: the cancelled server was *at least* that slow (it hadn't
281    /// responded yet), and the penalty ensures the winner retains a sorting
282    /// advantage in the next round. This avoids the full `FAILURE_PENALTY`
283    /// which would be too harsh for a server that's merely slightly slower.
284    ///
285    /// A truly unreachable server will be cancelled on every query and its SRTT
286    /// will ratchet up as the EWMA repeatedly incorporates the winner's RTT
287    /// without ever recording a real (successful) measurement to bring it back
288    /// down.
289    pub(super) fn record_cancelled(&self, winner_rtt: Duration) {
290        const CANCEL_PENALTY: Duration = Duration::from_millis(5);
291        self.server_srtt.record(winner_rtt + CANCEL_PENALTY);
292    }
293
294    #[cfg(test)]
295    pub(crate) fn test_record_failure(&self) {
296        self.server_srtt.record_failure();
297    }
298
299    #[cfg(test)]
300    #[allow(dead_code)]
301    pub(crate) fn is_connected(&self) -> bool {
302        let Some(connections) = self.connections.try_lock() else {
303            // assuming that if someone has it locked it will be or is connected
304            return true;
305        };
306
307        connections.iter().any(|conn| match conn.meta.status() {
308            Status::Established | Status::Init => true,
309            Status::Failed => false,
310        })
311    }
312
313    pub(crate) fn trust_negative_responses(&self) -> bool {
314        self.config.trust_negative_responses
315    }
316
317    async fn consider_probe_encrypted_transport(
318        &self,
319        policy: &ConnectionPolicy,
320        cx: &Arc<PoolContext>,
321    ) {
322        let Some(probe_config) =
323            policy.select_encrypted_connection_config(&self.config.connections)
324        else {
325            warn!("no encrypted connection configs available for probing");
326            return;
327        };
328
329        let probe_protocol = probe_config.protocol.to_protocol();
330        let should_probe = {
331            let state = cx.transport_state().await;
332            state.should_probe_encrypted(
333                self.config.ip,
334                probe_protocol,
335                &cx.opportunistic_encryption,
336            )
337        };
338
339        if !should_probe {
340            return;
341        }
342
343        if let Err(err) = self.probe_encrypted_transport(cx, probe_config) {
344            error!(%err, "opportunistic encrypted probe attempt failed");
345        }
346    }
347
348    fn probe_encrypted_transport(
349        &self,
350        cx: &Arc<PoolContext>,
351        probe_config: &ConnectionConfig,
352    ) -> Result<(), NetError> {
353        let mut budget = cx.opportunistic_probe_budget.load(Ordering::Relaxed);
354        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
355        self.opportunistic_probe_metrics.probe_budget.set(budget);
356        loop {
357            if budget == 0 {
358                debug!("no remaining budget for opportunistic probing");
359                return Ok(());
360            }
361            match cx.opportunistic_probe_budget.compare_exchange_weak(
362                budget,
363                budget - 1,
364                Ordering::AcqRel,
365                Ordering::Relaxed,
366            ) {
367                Ok(_) => break,
368                Err(current) => budget = current,
369            }
370        }
371
372        let connect = ProbeRequest::new(
373            probe_config,
374            self,
375            cx,
376            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
377            self.opportunistic_probe_metrics.clone(),
378        )?;
379        self.connection_provider
380            .runtime_provider()
381            .create_handle()
382            .spawn_bg(connect.run());
383
384        Ok(())
385    }
386}
387
388struct ProbeRequest<P: ConnectionProvider> {
389    ip: IpAddr,
390    proto: Protocol,
391    connecting: P::FutureConn,
392    context: Arc<PoolContext>,
393    #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
394    metrics: ProbeMetrics,
395    provider: PhantomData<P>,
396}
397
398impl<P: ConnectionProvider> ProbeRequest<P> {
399    fn new(
400        config: &ConnectionConfig,
401        ns: &NameServer<P>,
402        cx: &Arc<PoolContext>,
403        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
404        metrics: ProbeMetrics,
405    ) -> Result<Self, NetError> {
406        Ok(Self {
407            ip: ns.config.ip,
408            proto: config.protocol.to_protocol(),
409            connecting: ns
410                .connection_provider
411                .new_connection(ns.config.ip, config, cx)?,
412            context: cx.clone(),
413            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
414            metrics,
415            provider: PhantomData,
416        })
417    }
418
419    async fn run(self) {
420        let Self {
421            ip,
422            proto,
423            connecting,
424            context,
425            #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
426            metrics,
427            provider: _,
428        } = self;
429
430        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
431        let start = Instant::now();
432
433        context
434            .transport_state()
435            .await
436            .initiate_connection(ip, proto);
437        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
438        metrics.increment_attempts(proto);
439
440        let conn = match connecting.await {
441            Ok(conn) => conn,
442            Err(err) => {
443                debug!(?proto, "probe connection failed");
444                let _prev = context
445                    .opportunistic_probe_budget
446                    .fetch_add(1, Ordering::Relaxed);
447                #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
448                {
449                    metrics.increment_errors(proto, &err);
450                    metrics.probe_budget.set(_prev + 1);
451                    metrics.record_probe_duration(proto, start.elapsed());
452                }
453                context
454                    .transport_state()
455                    .await
456                    .error_received(ip, proto, &err);
457                return;
458            }
459        };
460
461        debug!(?proto, "probe connection succeeded");
462        context
463            .transport_state()
464            .await
465            .complete_connection(ip, proto);
466
467        match conn
468            .send(DnsRequest::from_query(
469                Query::query(Name::root(), RecordType::NS),
470                DnsRequestOptions::default(),
471            ))
472            .first_answer()
473            .await
474        {
475            Ok(_) => {
476                debug!(?proto, "probe query succeeded");
477                #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
478                metrics.increment_successes(proto);
479                context.transport_state().await.response_received(ip, proto);
480            }
481            Err(err) => {
482                debug!(?proto, ?err, "probe query failed");
483                #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
484                metrics.increment_errors(proto, &err);
485                context
486                    .transport_state()
487                    .await
488                    .error_received(ip, proto, &err);
489            }
490        }
491
492        let _prev = context
493            .opportunistic_probe_budget
494            .fetch_add(1, Ordering::Relaxed);
495        #[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
496        {
497            metrics.probe_budget.set(_prev + 1);
498            metrics.record_probe_duration(proto, start.elapsed());
499        }
500    }
501}
502
503struct ConnectionState<P: ConnectionProvider> {
504    protocol: Protocol,
505    handle: P::Conn,
506    meta: Arc<ConnectionMeta>,
507}
508
509impl<P: ConnectionProvider> ConnectionState<P> {
510    fn new(handle: P::Conn, protocol: Protocol) -> Self {
511        Self {
512            protocol,
513            handle,
514            meta: Arc::new(ConnectionMeta::default()),
515        }
516    }
517}
518
519struct ConnectionMeta {
520    status: AtomicU8,
521    srtt: DecayingSrtt,
522}
523
524impl ConnectionMeta {
525    fn set_status(&self, status: Status) {
526        self.status.store(status.into(), Ordering::Release);
527    }
528
529    fn status(&self) -> Status {
530        Status::from(self.status.load(Ordering::Acquire))
531    }
532}
533
534impl Default for ConnectionMeta {
535    fn default() -> Self {
536        // Initialize the SRTT to a randomly generated value that represents a
537        // very low RTT. Such a value helps ensure that each server is attempted
538        // early.
539        Self {
540            status: AtomicU8::new(Status::Init.into()),
541            srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
542        }
543    }
544}
545
546struct DecayingSrtt {
547    /// The smoothed round-trip time (SRTT).
548    ///
549    /// This value represents an exponentially weighted moving average (EWMA) of
550    /// recorded latencies. The algorithm for computing this value is based on
551    /// the following:
552    ///
553    /// <https://en.wikipedia.org/wiki/Moving_average#Application_to_measuring_computer_performance>
554    ///
555    /// It is also partially inspired by the BIND and PowerDNS implementations:
556    ///
557    /// - <https://github.com/isc-projects/bind9/blob/7bf8a7ab1b280c1021bf1e762a239b07aac3c591/lib/dns/adb.c#L3487>
558    /// - <https://github.com/PowerDNS/pdns/blob/7c5f9ae6ae4fb17302d933eaeebc8d6f0249aab2/pdns/syncres.cc#L123>
559    ///
560    /// The algorithm for computing and using this value can be summarized as
561    /// follows:
562    ///
563    /// 1. The value is initialized to a random value that represents a very low
564    ///    latency.
565    /// 2. If the round-trip time (RTT) was successfully measured for a query,
566    ///    then it is incorporated into the EWMA using the formula linked above.
567    /// 3. If the RTT could not be measured (i.e. due to a connection failure),
568    ///    then a constant penalty factor is applied to the EWMA.
569    /// 4. When comparing EWMA values, a time-based decay is applied to each
570    ///    value. Note that this decay is only applied at read time.
571    ///
572    /// For the original discussion regarding this algorithm, see
573    /// <https://github.com/hickory-dns/hickory-dns/issues/1702>.
574    srtt_microseconds: AtomicU32,
575
576    /// The last time the `srtt_microseconds` value was updated.
577    last_update: SyncMutex<Option<Instant>>,
578}
579
580impl DecayingSrtt {
581    fn new(initial_srtt: Duration) -> Self {
582        Self {
583            srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
584            last_update: SyncMutex::new(None),
585        }
586    }
587
588    fn record(&self, rtt: Duration) {
589        // If the cast on the result does overflow (it shouldn't), then the
590        // value is saturated to u32::MAX, which is above the `MAX_SRTT_MICROS`
591        // limit (meaning that any potential overflow is inconsequential).
592        // See https://github.com/rust-lang/rust/issues/10184.
593        self.update(
594            rtt.as_micros() as u32,
595            |cur_srtt_microseconds, last_update| {
596                // An arbitrarily low weight is used when computing the factor
597                // to ensure that recent RTT measurements are weighted more
598                // heavily.
599                let factor = compute_srtt_factor(last_update, 3);
600                let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
601                    + factor * f64::from(cur_srtt_microseconds);
602                new_srtt.round() as u32
603            },
604        );
605    }
606
607    /// Records a connection failure for a particular query.
608    fn record_failure(&self) {
609        self.update(
610            Self::FAILURE_PENALTY,
611            |cur_srtt_microseconds, _last_update| {
612                cur_srtt_microseconds.saturating_add(Self::FAILURE_PENALTY)
613            },
614        );
615    }
616
617    /// Returns the SRTT value after applying a time based decay.
618    ///
619    /// The decay exponentially decreases the SRTT value. The primary reasons
620    /// for applying a downwards decay are twofold:
621    ///
622    /// 1. It helps distribute query load.
623    /// 2. It helps detect positive network changes. For example, decreases in
624    ///    latency or a server that has recovered from a failure.
625    fn current(&self) -> f64 {
626        let srtt = f64::from(self.srtt_microseconds.load(Ordering::Acquire));
627        self.last_update.lock().map_or(srtt, |last_update| {
628            // In general, if the time between queries is relatively short, then
629            // the server ordering algorithm will approximate a spike
630            // distribution where the servers with the lowest latencies are
631            // chosen much more frequently. Conversely, if the time between
632            // queries is relatively long, then the query distribution will be
633            // more uniform. A larger weight widens the window in which servers
634            // with historically lower latencies will be heavily preferred. On
635            // the other hand, a larger weight may also increase the time it
636            // takes to recover from a failure or to observe positive changes in
637            // latency.
638            srtt * compute_srtt_factor(last_update, 180)
639        })
640    }
641
642    /// Updates the SRTT value.
643    ///
644    /// If the `last_update` value has not been set, then uses the `default`
645    /// value to update the SRTT. Otherwise, invokes the `update_fn` with the
646    /// current SRTT value and the `last_update` timestamp.
647    fn update(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
648        let last_update = self.last_update.lock().replace(Instant::now());
649        let _ = self.srtt_microseconds.fetch_update(
650            Ordering::SeqCst,
651            Ordering::SeqCst,
652            move |cur_srtt_microseconds| {
653                Some(
654                    last_update
655                        .map_or(default, |last_update| {
656                            update_fn(cur_srtt_microseconds, last_update)
657                        })
658                        .min(Self::MAX_SRTT_MICROS),
659                )
660            },
661        );
662    }
663
664    /// Returns the raw SRTT value.
665    ///
666    /// Prefer to use `decayed_srtt` when ordering name servers.
667    #[cfg(all(test, feature = "tokio"))]
668    fn as_duration(&self) -> Duration {
669        Duration::from_micros(u64::from(self.srtt_microseconds.load(Ordering::Acquire)))
670    }
671
672    const FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
673    const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
674}
675
676/// Returns an exponentially weighted value in the range of 0.0 < x < 1.0
677///
678/// Computes the value using the following formula:
679///
680/// e<sup>(-t<sub>now</sub> - t<sub>last</sub>) / weight</sup>
681///
682/// As the duration since the `last_update` approaches the provided `weight`,
683/// the returned value decreases.
684fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
685    let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
686    exponent.exp()
687}
688
689/// State of a connection with a remote NameServer.
690#[derive(Debug, Eq, PartialEq, Copy, Clone)]
691#[repr(u8)]
692enum Status {
693    /// For some reason the connection failed. For UDP this would generally be a timeout
694    ///  for TCP this could be either Connection could never be established, or it
695    ///  failed at some point after. The Failed state should *not* be entered due to an
696    ///  error contained in a Message received from the server. In All cases to reestablish
697    ///  a new connection will need to be created.
698    Failed = 0,
699    /// Initial state, if Edns is not none, then Edns will be requested
700    Init = 1,
701    /// There has been successful communication with the remote.
702    ///  if no Edns is associated, then the remote does not support Edns
703    Established = 2,
704}
705
706impl From<Status> for u8 {
707    /// used for ordering purposes. The highest priority is placed on open connections
708    fn from(val: Status) -> Self {
709        val as Self
710    }
711}
712
713impl From<u8> for Status {
714    fn from(val: u8) -> Self {
715        match val {
716            2 => Self::Established,
717            1 => Self::Init,
718            _ => Self::Failed,
719        }
720    }
721}
722
723#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
724pub(crate) struct ConnectionPolicy {
725    pub(crate) disable_udp: bool,
726}
727
728impl ConnectionPolicy {
729    /// Checks if the given server has any protocols compatible with current policy.
730    pub(crate) fn allows_server<P: ConnectionProvider>(&self, server: &NameServer<P>) -> bool {
731        server.protocols().any(|p| self.allows_protocol(p))
732    }
733
734    /// Select the best pre-existing connection to use.
735    ///
736    /// This choice is made based on opportunistic encryption policy & probe history,
737    /// protocol policy, and the SRTT performance metrics.
738    fn select_connection<'a, P: ConnectionProvider>(
739        &self,
740        ip: IpAddr,
741        encrypted_transport_state: &NameServerTransportState,
742        opportunistic_encryption: &OpportunisticEncryption,
743        connections: &'a [ConnectionState<P>],
744    ) -> Option<&'a ConnectionState<P>> {
745        let selected = connections
746            .iter()
747            .filter(|conn| self.allows_protocol(conn.protocol))
748            .min_by(|a, b| self.compare_connections(opportunistic_encryption.is_enabled(), a, b));
749
750        let selected = selected?;
751
752        // If we're using opportunistic encryption and selected a pre-existing unencrypted connection,
753        // and have successfully probed on any supported encrypted protocol, we should _not_ reuse the
754        // existing connection and instead return `None`. This will result in a new encrypted connection
755        // being made to the successfully probed protocol and added to the connection list for future
756        // re-use.
757        match opportunistic_encryption.is_enabled()
758            && !selected.protocol.is_encrypted()
759            && encrypted_transport_state.any_recent_success(ip, opportunistic_encryption)
760        {
761            true => None,
762            false => Some(selected),
763        }
764    }
765
766    /// Select the best connection configuration to use for a new connection.
767    ///
768    /// This choice is made based on opportunistic encryption policy & probe history,
769    /// and protocol policy.
770    fn select_connection_config<'a>(
771        &self,
772        ip: IpAddr,
773        encrypted_transport_state: &NameServerTransportState,
774        opportunistic_encryption: &OpportunisticEncryption,
775        connection_configs: &'a [ConnectionConfig],
776    ) -> Option<&'a ConnectionConfig> {
777        connection_configs
778            .iter()
779            .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
780            .min_by(|a, b| {
781                self.compare_connection_configs(
782                    ip,
783                    encrypted_transport_state,
784                    opportunistic_encryption,
785                    a,
786                    b,
787                )
788            })
789    }
790
791    /// Select the first protocol allowed by current policy that uses an encrypted transport.
792    fn select_encrypted_connection_config<'a>(
793        &self,
794        connection_config: &'a [ConnectionConfig],
795    ) -> Option<&'a ConnectionConfig> {
796        connection_config
797            .iter()
798            .filter(|c| self.allows_protocol(c.protocol.to_protocol()))
799            .find(|c| c.protocol.to_protocol().is_encrypted())
800    }
801
802    /// Checks if the given protocol is allowed by current policy.
803    fn allows_protocol(&self, protocol: Protocol) -> bool {
804        !(self.disable_udp && protocol == Protocol::Udp)
805    }
806
807    /// Compare two connections according to policy, protocol, and performance.
808    /// If opportunistic encryption is enabled we make an effort to select an encrypted connection.
809    fn compare_connections<P: ConnectionProvider>(
810        &self,
811        opportunistic_encryption: bool,
812        a: &ConnectionState<P>,
813        b: &ConnectionState<P>,
814    ) -> cmp::Ordering {
815        // When opportunistic encryption is in-play, we want to consider encrypted
816        // connections with the greatest priority.
817        if opportunistic_encryption {
818            match (a.protocol.is_encrypted(), b.protocol.is_encrypted()) {
819                (true, false) => return cmp::Ordering::Less,
820                (false, true) => return cmp::Ordering::Greater,
821                // When _both_ are encrypted, then decide on ordering based on other properties (like SRTT).
822                _ => {}
823            }
824        }
825
826        match (a.protocol, b.protocol) {
827            (ap, bp) if ap == bp => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
828            (Protocol::Udp, _) => cmp::Ordering::Less,
829            (_, Protocol::Udp) => cmp::Ordering::Greater,
830            _ => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
831        }
832    }
833
834    fn compare_connection_configs(
835        &self,
836        ip: IpAddr,
837        encrypted_transport_state: &NameServerTransportState,
838        opportunistic_encryption: &OpportunisticEncryption,
839        a: &ConnectionConfig,
840        b: &ConnectionConfig,
841    ) -> cmp::Ordering {
842        let a_protocol = a.protocol.to_protocol();
843        let b_protocol = b.protocol.to_protocol();
844
845        // When opportunistic encryption is in-play, prioritize encrypted protocols
846        // that have recent successful connections
847        if opportunistic_encryption.is_enabled() {
848            let a_recent_enc_success = a_protocol.is_encrypted()
849                && encrypted_transport_state.recent_success(
850                    ip,
851                    a_protocol,
852                    opportunistic_encryption,
853                );
854            let b_recent_enc_success = b_protocol.is_encrypted()
855                && encrypted_transport_state.recent_success(
856                    ip,
857                    b_protocol,
858                    opportunistic_encryption,
859                );
860
861            match (a_recent_enc_success, b_recent_enc_success) {
862                (true, false) => return cmp::Ordering::Less,
863                (false, true) => return cmp::Ordering::Greater,
864                // When both have recent success or neither do, continue with normal ordering
865                _ => {}
866            }
867        }
868
869        // Default protocol ordering: UDP first, then others
870        match (a_protocol, b_protocol) {
871            (ap, bp) if ap == bp => cmp::Ordering::Equal,
872            (Protocol::Udp, _) => cmp::Ordering::Less,
873            (_, Protocol::Udp) => cmp::Ordering::Greater,
874            _ => cmp::Ordering::Equal,
875        }
876    }
877}
878
879#[cfg(all(test, feature = "tokio"))]
880mod tests {
881    use std::cmp;
882    use std::net::{IpAddr, Ipv4Addr};
883    use std::str::FromStr;
884    use std::time::Duration;
885
886    use test_support::subscribe;
887    use tokio::net::UdpSocket;
888    use tokio::spawn;
889
890    use super::*;
891    use crate::config::{ConnectionConfig, ProtocolConfig};
892    use crate::connection_provider::TlsConfig;
893    use crate::net::runtime::TokioRuntimeProvider;
894    use crate::proto::op::{DnsRequest, DnsRequestOptions, Message, Query, ResponseCode};
895    use crate::proto::rr::rdata::NULL;
896    use crate::proto::rr::{Name, RData, Record, RecordType};
897
898    #[tokio::test]
899    async fn test_name_server() {
900        subscribe();
901
902        let options = ResolverOpts::default();
903        let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
904        let name_server = Arc::new(NameServer::new(
905            [].into_iter(),
906            config,
907            &options,
908            TokioRuntimeProvider::default(),
909        ));
910
911        let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
912        let name = Name::parse("www.example.com.", None).unwrap();
913        let response = name_server
914            .send(
915                DnsRequest::from_query(
916                    Query::query(name.clone(), RecordType::A),
917                    DnsRequestOptions::default(),
918                ),
919                ConnectionPolicy::default(),
920                &cx,
921            )
922            .await
923            .expect("query failed");
924        assert_eq!(response.response_code, ResponseCode::NoError);
925    }
926
927    #[tokio::test]
928    async fn test_failed_name_server() {
929        subscribe();
930
931        let options = ResolverOpts {
932            timeout: Duration::from_millis(1), // this is going to fail, make it fail fast...
933            ..ResolverOpts::default()
934        };
935
936        let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)));
937        let name_server = Arc::new(NameServer::new(
938            [],
939            config,
940            &options,
941            TokioRuntimeProvider::default(),
942        ));
943
944        let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
945        let name = Name::parse("www.example.com.", None).unwrap();
946        assert!(
947            name_server
948                .send(
949                    DnsRequest::from_query(
950                        Query::query(name.clone(), RecordType::A),
951                        DnsRequestOptions::default(),
952                    ),
953                    ConnectionPolicy::default(),
954                    &cx
955                )
956                .await
957                .is_err()
958        );
959    }
960
961    #[tokio::test]
962    async fn case_randomization_query_preserved() {
963        subscribe();
964
965        let provider = TokioRuntimeProvider::default();
966        let server = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
967        let server_addr = server.local_addr().unwrap();
968        let name = Name::from_str("dead.beef.").unwrap();
969        let data = b"DEADBEEF";
970
971        spawn({
972            let name = name.clone();
973            async move {
974                let mut buffer = [0_u8; 512];
975                let (len, addr) = server.recv_from(&mut buffer).await.unwrap();
976                let request = Message::from_vec(&buffer[0..len]).unwrap();
977                let mut response = Message::response(request.id, request.op_code);
978                response.add_queries(request.queries.to_vec());
979                response.add_answer(Record::from_rdata(
980                    name,
981                    0,
982                    RData::NULL(NULL::with(data.to_vec())),
983                ));
984                let response_buffer = response.to_vec().unwrap();
985                server.send_to(&response_buffer, addr).await.unwrap();
986            }
987        });
988
989        let config = NameServerConfig {
990            ip: server_addr.ip(),
991            trust_negative_responses: true,
992            connections: vec![ConnectionConfig {
993                port: server_addr.port(),
994                protocol: ProtocolConfig::Udp,
995                bind_addr: None,
996            }],
997        };
998
999        let resolver_opts = ResolverOpts {
1000            case_randomization: true,
1001            ..Default::default()
1002        };
1003
1004        let cx = Arc::new(PoolContext::new(resolver_opts, TlsConfig::new().unwrap()));
1005        let mut request_options = DnsRequestOptions::default();
1006        request_options.case_randomization = true;
1007        let ns = Arc::new(NameServer::new([], config, &cx.options, provider));
1008        let response = ns
1009            .send(
1010                DnsRequest::from_query(
1011                    Query::query(name.clone(), RecordType::NULL),
1012                    request_options,
1013                ),
1014                ConnectionPolicy::default(),
1015                &cx,
1016            )
1017            .await
1018            .unwrap();
1019
1020        let response_query_name = response.queries.first().unwrap().name();
1021        assert!(response_query_name.eq_case(&name));
1022    }
1023
1024    #[allow(clippy::extra_unused_type_parameters)]
1025    fn is_send_sync<S: Sync + Send>() -> bool {
1026        true
1027    }
1028
1029    #[test]
1030    fn stats_are_sync() {
1031        assert!(is_send_sync::<ConnectionMeta>());
1032    }
1033
1034    #[tokio::test(start_paused = true)]
1035    async fn test_stats_cmp() {
1036        use std::cmp::Ordering;
1037        let srtt_a = DecayingSrtt::new(Duration::from_micros(10));
1038        let srtt_b = DecayingSrtt::new(Duration::from_micros(20));
1039
1040        // No RTTs or failures have been recorded. The initial SRTTs should be
1041        // compared.
1042        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1043
1044        // Server A was used. Unused server B should now be preferred.
1045        srtt_a.record(Duration::from_millis(30));
1046        tokio::time::advance(Duration::from_secs(5)).await;
1047        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1048
1049        // Both servers have been used. Server A has a lower SRTT and should be
1050        // preferred.
1051        srtt_b.record(Duration::from_millis(50));
1052        tokio::time::advance(Duration::from_secs(5)).await;
1053        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1054
1055        // Server A experiences a connection failure, which results in Server B
1056        // being preferred.
1057        srtt_a.record_failure();
1058        tokio::time::advance(Duration::from_secs(5)).await;
1059        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
1060
1061        // Server A should eventually recover and once again be preferred.
1062        while cmp(&srtt_a, &srtt_b) != Ordering::Less {
1063            srtt_b.record(Duration::from_millis(50));
1064            tokio::time::advance(Duration::from_secs(5)).await;
1065        }
1066
1067        srtt_a.record(Duration::from_millis(30));
1068        tokio::time::advance(Duration::from_secs(3)).await;
1069        assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
1070    }
1071
1072    fn cmp(a: &DecayingSrtt, b: &DecayingSrtt) -> cmp::Ordering {
1073        a.current().total_cmp(&b.current())
1074    }
1075
1076    #[tokio::test(start_paused = true)]
1077    async fn test_record_rtt() {
1078        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1079
1080        let first_rtt = Duration::from_millis(50);
1081        srtt.record(first_rtt);
1082
1083        // The first recorded RTT should replace the initial value.
1084        assert_eq!(srtt.as_duration(), first_rtt);
1085
1086        tokio::time::advance(Duration::from_secs(3)).await;
1087
1088        // Subsequent RTTs should factor in previously recorded values.
1089        srtt.record(Duration::from_millis(100));
1090        assert_eq!(srtt.as_duration(), Duration::from_micros(81606));
1091    }
1092
1093    #[test]
1094    fn test_record_rtt_maximum_value() {
1095        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1096
1097        srtt.record(Duration::MAX);
1098        // Updates to the SRTT are capped at a maximum value.
1099        assert_eq!(
1100            srtt.as_duration(),
1101            Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1102        );
1103    }
1104
1105    #[tokio::test(start_paused = true)]
1106    async fn test_record_connection_failure() {
1107        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1108
1109        // Verify that the SRTT value is initially replaced with the penalty and
1110        // subsequent failures result in the penalty being added.
1111        for failure_count in 1..4 {
1112            srtt.record_failure();
1113            assert_eq!(
1114                srtt.as_duration(),
1115                Duration::from_micros(
1116                    DecayingSrtt::FAILURE_PENALTY
1117                        .checked_mul(failure_count)
1118                        .expect("checked_mul overflow")
1119                        .into()
1120                )
1121            );
1122            tokio::time::advance(Duration::from_secs(3)).await;
1123        }
1124
1125        // Verify that the `last_update` timestamp was updated for a connection
1126        // failure and is used in subsequent calculations.
1127        srtt.record(Duration::from_millis(50));
1128        assert_eq!(srtt.as_duration(), Duration::from_micros(197152));
1129    }
1130
1131    #[test]
1132    fn test_record_connection_failure_maximum_value() {
1133        let srtt = DecayingSrtt::new(Duration::from_micros(10));
1134
1135        let num_failures = (DecayingSrtt::MAX_SRTT_MICROS / DecayingSrtt::FAILURE_PENALTY) + 1;
1136        for _ in 0..num_failures {
1137            srtt.record_failure();
1138        }
1139
1140        // Updates to the SRTT are capped at a maximum value.
1141        assert_eq!(
1142            srtt.as_duration(),
1143            Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
1144        );
1145    }
1146
1147    #[tokio::test(start_paused = true)]
1148    async fn test_decayed_srtt() {
1149        let initial_srtt = 10;
1150        let srtt = DecayingSrtt::new(Duration::from_micros(initial_srtt));
1151
1152        // No decay should be applied to the initial value.
1153        assert_eq!(srtt.current() as u32, initial_srtt as u32);
1154
1155        tokio::time::advance(Duration::from_secs(5)).await;
1156        srtt.record(Duration::from_millis(100));
1157
1158        // The decay function should assume a minimum of one second has elapsed
1159        // since the last update.
1160        tokio::time::advance(Duration::from_millis(500)).await;
1161        assert_eq!(srtt.current() as u32, 99445);
1162
1163        tokio::time::advance(Duration::from_secs(5)).await;
1164        assert_eq!(srtt.current() as u32, 96990);
1165    }
1166}
1167
1168#[cfg(all(test, feature = "__tls"))]
1169mod opportunistic_enc_tests {
1170    use std::io;
1171    use std::net::{IpAddr, Ipv4Addr};
1172    use std::sync::Arc;
1173    use std::time::{Duration, SystemTime};
1174
1175    #[cfg(feature = "metrics")]
1176    use metrics::{Label, Unit, with_local_recorder};
1177    #[cfg(feature = "metrics")]
1178    use metrics_util::debugging::DebuggingRecorder;
1179    use mock_provider::{MockClientHandle, MockProvider};
1180    use test_support::subscribe;
1181    #[cfg(feature = "metrics")]
1182    use test_support::{assert_counter_eq, assert_gauge_eq, assert_histogram_sample_count_eq};
1183
1184    use crate::config::{
1185        NameServerConfig, OpportunisticEncryption, OpportunisticEncryptionConfig, ProtocolConfig,
1186        ResolverOpts,
1187    };
1188    use crate::connection_provider::TlsConfig;
1189    #[cfg(feature = "metrics")]
1190    use crate::metrics::opportunistic_encryption::{
1191        PROBE_ATTEMPTS_TOTAL, PROBE_BUDGET_TOTAL, PROBE_DURATION_SECONDS, PROBE_ERRORS_TOTAL,
1192        PROBE_SUCCESSES_TOTAL, PROBE_TIMEOUTS_TOTAL,
1193    };
1194    use crate::name_server::{ConnectionPolicy, ConnectionState, NameServer, mock_provider};
1195    use crate::name_server_pool::{NameServerTransportState, PoolContext};
1196    use crate::net::NetError;
1197    use crate::net::xfer::Protocol;
1198
1199    #[tokio::test]
1200    async fn test_select_connection_opportunistic_enc_disabled() {
1201        let mut policy = ConnectionPolicy::default();
1202        let connections = vec![
1203            mock_connection(Protocol::Udp),
1204            mock_connection(Protocol::Tcp),
1205        ];
1206
1207        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1208        let state = NameServerTransportState::default();
1209        let opp_enc = OpportunisticEncryption::Disabled;
1210
1211        // When opportunistic encryption is disabled, and disable_udp isn't active,
1212        // we should select the UDP conn.
1213        let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1214        assert!(selected.is_some());
1215        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1216
1217        // When opportunistic encryption is disabled, and disable_udp is active,
1218        // we should select the TCP conn.
1219        policy.disable_udp = true;
1220        let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
1221        assert!(selected.is_some());
1222        assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1223    }
1224
1225    #[tokio::test]
1226    async fn test_select_connection_opportunistic_enc_enabled() {
1227        let policy = ConnectionPolicy::default();
1228        let connections = [
1229            mock_connection(Protocol::Udp),
1230            mock_connection(Protocol::Tcp),
1231            // Include a pre-existing encrypted protocol connection.
1232            mock_connection(Protocol::Tls),
1233        ];
1234
1235        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1236        let state = NameServerTransportState::default();
1237        let opp_enc = &OpportunisticEncryption::Enabled {
1238            config: OpportunisticEncryptionConfig::default(),
1239        };
1240
1241        // When opportunistic encryption is enabled, and there is an encrypted connection available,
1242        // we should always choose it as the most preferred.
1243        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1244        assert!(selected.is_some());
1245        assert_eq!(selected.unwrap().protocol, Protocol::Tls);
1246    }
1247
1248    #[tokio::test]
1249    async fn test_select_connection_opportunistic_enc_enabled_no_state() {
1250        let mut policy = ConnectionPolicy::default();
1251        let connections = [
1252            mock_connection(Protocol::Udp),
1253            mock_connection(Protocol::Tcp),
1254            // No pre-existing encrypted protocol connection is available.
1255        ];
1256
1257        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1258        let state = NameServerTransportState::default();
1259        let opp_enc = &OpportunisticEncryption::Enabled {
1260            config: OpportunisticEncryptionConfig::default(),
1261        };
1262
1263        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1264        // and we have no probe state, we should select the UDP conn.
1265        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1266        assert!(selected.is_some());
1267        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1268
1269        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1270        // and we have no probe state, we should select the TCP conn.
1271        policy.disable_udp = true;
1272        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1273        assert!(selected.is_some());
1274        assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
1275    }
1276
1277    #[tokio::test]
1278    async fn test_select_connection_opportunistic_enc_enabled_failed_probe() {
1279        let policy = ConnectionPolicy::default();
1280        let connections = [
1281            mock_connection(Protocol::Udp),
1282            mock_connection(Protocol::Tcp),
1283            // No pre-existing encrypted protocol connection is available.
1284        ];
1285
1286        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1287        let mut state = NameServerTransportState::default();
1288        let opp_enc = &OpportunisticEncryption::Enabled {
1289            config: OpportunisticEncryptionConfig::default(),
1290        };
1291
1292        // Update the state to reflect that we failed a previous probe attempt.
1293        state.error_received(
1294            ns_ip,
1295            Protocol::Tls,
1296            &NetError::from(io::Error::new(
1297                io::ErrorKind::ConnectionRefused,
1298                "nameserver refused TLS connection",
1299            )),
1300        );
1301
1302        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1303        // and our probe state indicates a failure, we should select the UDP conn.
1304        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1305        assert!(selected.is_some());
1306        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1307    }
1308
1309    #[tokio::test]
1310    async fn test_select_connection_opportunistic_enc_enabled_in_progress_probe() {
1311        let policy = ConnectionPolicy::default();
1312        let connections = [
1313            mock_connection(Protocol::Udp),
1314            mock_connection(Protocol::Tcp),
1315            // No pre-existing encrypted protocol connection is available.
1316        ];
1317
1318        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1319        let mut state = NameServerTransportState::default();
1320        let opp_enc = &OpportunisticEncryption::Enabled {
1321            config: OpportunisticEncryptionConfig::default(),
1322        };
1323
1324        // Update the state to reflect that we have an in-progress probe in-flight.
1325        state.initiate_connection(ns_ip, Protocol::Tls);
1326
1327        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1328        // and our probe state indicates an in-flight probe, we should select the UDP conn.
1329        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1330        assert!(selected.is_some());
1331        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1332
1333        // Update the state to reflect that we completed the connection, but haven't
1334        // received a response.
1335        state.complete_connection(ns_ip, Protocol::Tls);
1336
1337        // In this case we should still select the UDP conn.
1338        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1339        assert!(selected.is_some());
1340        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1341    }
1342
1343    #[tokio::test]
1344    async fn test_select_connection_opportunistic_enc_enabled_stale_probe() {
1345        let policy = ConnectionPolicy::default();
1346        let connections = [
1347            mock_connection(Protocol::Udp),
1348            mock_connection(Protocol::Tcp),
1349            // No pre-existing encrypted protocol connection is available.
1350        ];
1351
1352        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1353        let mut state = NameServerTransportState::default();
1354        let opp_enc_config = OpportunisticEncryptionConfig {
1355            persistence_period: Duration::from_secs(10),
1356            ..OpportunisticEncryptionConfig::default()
1357        };
1358        let opp_enc = &OpportunisticEncryption::Enabled {
1359            config: opp_enc_config.clone(),
1360        };
1361
1362        // Update the state to reflect that we have successfully probed this NS.
1363        state.complete_connection(ns_ip, Protocol::Tls);
1364        state.response_received(ns_ip, Protocol::Tls);
1365        // And then update the last response time to be too stale for consideration.
1366        let stale_time =
1367            SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1368        state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1369
1370        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1371        // and our probe state indicates success that is too stale, we should select an unencrypted
1372        // connection since the probe is no longer considered recent.
1373        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1374        assert!(selected.is_some());
1375        assert_eq!(selected.unwrap().protocol, Protocol::Udp);
1376    }
1377
1378    #[tokio::test]
1379    async fn test_select_connection_opportunistic_enc_enabled_good_probe() {
1380        let policy = ConnectionPolicy::default();
1381        let connections = [
1382            mock_connection(Protocol::Udp),
1383            mock_connection(Protocol::Tcp),
1384            // No pre-existing encrypted protocol connection is available.
1385        ];
1386
1387        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1388        let mut state = NameServerTransportState::default();
1389        let opp_enc = &OpportunisticEncryption::Enabled {
1390            config: OpportunisticEncryptionConfig::default(),
1391        };
1392
1393        // Update the state to reflect that we have successfully probed this NS within
1394        // the persistence period and received a response.
1395        state.complete_connection(ns_ip, Protocol::Tls);
1396        state.response_received(ns_ip, Protocol::Tls);
1397
1398        // When opportunistic encryption is enabled, but there are no encrypted connections available,
1399        // and our probe state indicates a recent enough success, we should return `None` so that
1400        // we make a new encrypted connection.
1401        let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
1402        assert!(selected.is_none());
1403    }
1404
1405    #[tokio::test]
1406    async fn test_select_connection_config_opportunistic_enc_disabled() {
1407        let mut policy = ConnectionPolicy::default();
1408
1409        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1410        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1411
1412        let state = NameServerTransportState::default();
1413        let opp_enc = OpportunisticEncryption::Disabled;
1414
1415        // When opportunistic encryption is disabled, and disable_udp isn't active,
1416        // we should select the UDP config.
1417        let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1418        assert!(selected.is_some());
1419        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1420
1421        // When opportunistic encryption is disabled, and disable_udp is active,
1422        // we should select the TCP config.
1423        policy.disable_udp = true;
1424        let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
1425        assert!(selected.is_some());
1426        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1427    }
1428
1429    #[tokio::test]
1430    async fn test_select_connection_config_opportunistic_enc_enabled_no_state() {
1431        let mut policy = ConnectionPolicy::default();
1432        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1433        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1434
1435        let state = NameServerTransportState::default();
1436        let opp_enc = &OpportunisticEncryption::Enabled {
1437            config: OpportunisticEncryptionConfig::default(),
1438        };
1439
1440        // When opportunistic encryption is enabled, but we have no probe state,
1441        // we should select the UDP config (default protocol ordering).
1442        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1443        assert!(selected.is_some());
1444        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1445
1446        // When opportunistic encryption is enabled, but we have no probe state,
1447        // and disable_udp is active, we should select the TCP config.
1448        policy.disable_udp = true;
1449        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1450        assert!(selected.is_some());
1451        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
1452    }
1453
1454    #[tokio::test]
1455    async fn test_select_connection_config_opportunistic_enc_enabled_failed_probe() {
1456        let policy = ConnectionPolicy::default();
1457        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1458        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1459
1460        let mut state = NameServerTransportState::default();
1461        let opp_enc = &OpportunisticEncryption::Enabled {
1462            config: OpportunisticEncryptionConfig::default(),
1463        };
1464
1465        // Update the state to reflect that we failed a previous probe attempt.
1466        state.error_received(
1467            ns_ip,
1468            Protocol::Tls,
1469            &NetError::from(io::Error::new(
1470                io::ErrorKind::ConnectionRefused,
1471                "nameserver refused TLS connection",
1472            )),
1473        );
1474
1475        // When opportunistic encryption is enabled, but our probe state indicates a failure,
1476        // we should select the UDP config.
1477        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1478        assert!(selected.is_some());
1479        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1480    }
1481
1482    #[tokio::test]
1483    async fn test_select_connection_config_opportunistic_enc_enabled_stale_probe() {
1484        let policy = ConnectionPolicy::default();
1485        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1486        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1487
1488        let mut state = NameServerTransportState::default();
1489        let opp_enc_config = OpportunisticEncryptionConfig {
1490            persistence_period: Duration::from_secs(10),
1491            ..OpportunisticEncryptionConfig::default()
1492        };
1493        let opp_enc = &OpportunisticEncryption::Enabled {
1494            config: opp_enc_config.clone(),
1495        };
1496
1497        // Update the state to reflect that we have successfully probed this NS.
1498        state.complete_connection(ns_ip, Protocol::Tls);
1499        state.response_received(ns_ip, Protocol::Tls);
1500        // And then update the last response time to be too stale for consideration.
1501        let stale_time =
1502            SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
1503        state.set_last_response(ns_ip, Protocol::Tls, stale_time);
1504
1505        // When opportunistic encryption is enabled, but our probe state indicates success that is too stale,
1506        // we should select an unencrypted config since the probe is no longer considered recent.
1507        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1508        assert!(selected.is_some());
1509        assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
1510    }
1511
1512    #[tokio::test]
1513    async fn test_select_connection_config_opportunistic_enc_enabled_good_probe() {
1514        let policy = ConnectionPolicy::default();
1515        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1516        let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
1517
1518        let mut state = NameServerTransportState::default();
1519        let opp_enc = &OpportunisticEncryption::Enabled {
1520            config: OpportunisticEncryptionConfig::default(),
1521        };
1522
1523        // Update the state to reflect that we have successfully probed this NS within
1524        // the persistence period and received a response.
1525        state.complete_connection(ns_ip, Protocol::Tls);
1526        state.response_received(ns_ip, Protocol::Tls);
1527
1528        // When opportunistic encryption is enabled, and our probe state indicates a recent enough success,
1529        // we should select the encrypted config with highest priority.
1530        let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
1531        assert!(selected.is_some());
1532        assert!(matches!(
1533            selected.unwrap().protocol,
1534            ProtocolConfig::Tls { .. }
1535        ));
1536    }
1537
1538    #[tokio::test]
1539    async fn test_opportunistic_probe() {
1540        subscribe();
1541
1542        // Enable opportunistic encryption
1543        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1544            .with_opportunistic_encryption()
1545            .with_probe_budget(10);
1546
1547        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1548        let mock_provider = MockProvider::default();
1549        assert!(
1550            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1551                .await
1552                .is_ok()
1553        );
1554
1555        let recorded_calls = mock_provider.new_connection_calls();
1556        // We should have made two new connection calls.
1557        assert_eq!(recorded_calls.len(), 2);
1558        let (ips, protocols): (Vec<IpAddr>, Vec<ProtocolConfig>) =
1559            recorded_calls.into_iter().unzip();
1560        // All connections should be to the expected NS IP.
1561        assert!(ips.iter().all(|ip| *ip == ns_ip));
1562        // We should have made connections for both the UDP protocol, and the encrypted probe protocol.
1563        let protocols = protocols
1564            .iter()
1565            .map(ProtocolConfig::to_protocol)
1566            .collect::<Vec<_>>();
1567        assert!(protocols.contains(&Protocol::Udp));
1568        assert!(protocols.contains(&Protocol::Tls));
1569    }
1570
1571    #[tokio::test]
1572    async fn test_opportunistic_probe_skip_in_progress() {
1573        subscribe();
1574
1575        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1576        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1577            .with_opportunistic_encryption()
1578            .with_probe_budget(10);
1579
1580        // Set up state to show an in-flight connection already initiated
1581        cx.transport_state()
1582            .await
1583            .initiate_connection(ns_ip, Protocol::Tls);
1584
1585        let mock_provider = MockProvider::default();
1586        assert!(
1587            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1588                .await
1589                .is_ok()
1590        );
1591
1592        let recorded_calls = mock_provider.new_connection_calls();
1593        // We should have made only one connection call (UDP), no probe because one is already in-flight
1594        assert_eq!(recorded_calls.len(), 1);
1595        let (ip, protocol) = &recorded_calls[0];
1596        assert_eq!(*ip, ns_ip);
1597        assert_eq!(protocol.to_protocol(), Protocol::Udp);
1598    }
1599
1600    #[tokio::test]
1601    async fn test_opportunistic_probe_skip_recent_failure() {
1602        subscribe();
1603
1604        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1605        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1606            .with_opportunistic_encryption()
1607            .with_probe_budget(10);
1608
1609        // Set up state to show a recent failure within the damping period
1610        cx.transport_state().await.error_received(
1611            ns_ip,
1612            Protocol::Tls,
1613            &NetError::from(io::Error::new(
1614                io::ErrorKind::ConnectionRefused,
1615                "connection refused",
1616            )),
1617        );
1618
1619        let mock_provider = MockProvider::default();
1620        assert!(
1621            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1622                .await
1623                .is_ok()
1624        );
1625
1626        let recorded_calls = mock_provider.new_connection_calls();
1627        // We should have made only one connection call (UDP), no probe due to recent failure
1628        assert_eq!(recorded_calls.len(), 1);
1629        let (ip, protocol) = &recorded_calls[0];
1630        assert_eq!(*ip, ns_ip);
1631        assert_eq!(protocol.to_protocol(), Protocol::Udp);
1632    }
1633
1634    #[tokio::test]
1635    async fn test_opportunistic_probe_stale_failure() {
1636        subscribe();
1637
1638        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1639        let mut cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1640            .with_probe_budget(10);
1641        let opp_enc_config = OpportunisticEncryptionConfig {
1642            damping_period: Duration::from_secs(5),
1643            ..OpportunisticEncryptionConfig::default()
1644        };
1645        cx.opportunistic_encryption = OpportunisticEncryption::Enabled {
1646            config: opp_enc_config.clone(),
1647        };
1648
1649        // Set up state to show an old failure outside the damping period.
1650        {
1651            let mut state = cx.transport_state().await;
1652            let old_failure_time =
1653                SystemTime::now() - opp_enc_config.damping_period - Duration::from_secs(1);
1654            state.set_failure_time(ns_ip, Protocol::Tls, old_failure_time);
1655        }
1656
1657        let mock_provider = MockProvider::default();
1658        assert!(
1659            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1660                .await
1661                .is_ok()
1662        );
1663
1664        let recorded_calls = mock_provider.new_connection_calls();
1665        // We should have made two connection calls (UDP + TLS probe) because the failure is old
1666        assert_eq!(recorded_calls.len(), 2);
1667        let protocols = recorded_calls
1668            .iter()
1669            .map(|(_, protocol)| protocol.to_protocol())
1670            .collect::<Vec<_>>();
1671        assert!(protocols.contains(&Protocol::Udp));
1672        assert!(protocols.contains(&Protocol::Tls));
1673    }
1674
1675    #[tokio::test]
1676    async fn test_opportunistic_probe_skip_no_budget() {
1677        subscribe();
1678
1679        let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
1680        let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1681            .with_opportunistic_encryption();
1682        let mock_provider = MockProvider::default();
1683        // Set budget to 0 to simulate exhausted probe budget
1684        assert!(
1685            test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
1686                .await
1687                .is_ok()
1688        );
1689
1690        let recorded_calls = mock_provider.new_connection_calls();
1691        // We should have made only one connection call (UDP), no probe due to exhausted budget
1692        assert_eq!(recorded_calls.len(), 1);
1693        let (ip, protocol) = &recorded_calls[0];
1694        assert_eq!(*ip, ns_ip);
1695        assert_eq!(protocol.to_protocol(), Protocol::Udp);
1696    }
1697
1698    fn mock_connection(protocol: Protocol) -> ConnectionState<MockProvider> {
1699        ConnectionState::new(MockClientHandle, protocol)
1700    }
1701
1702    #[cfg(feature = "metrics")]
1703    #[test]
1704    fn test_opportunistic_probe_metrics_success() {
1705        subscribe();
1706        let recorder = DebuggingRecorder::new();
1707        let snapshotter = recorder.snapshotter();
1708        let initial_budget = 10;
1709
1710        with_local_recorder(&recorder, || {
1711            let runtime = tokio::runtime::Builder::new_current_thread()
1712                .enable_all()
1713                .build()
1714                .unwrap();
1715
1716            runtime.block_on(async {
1717                assert!(
1718                    test_connected_mut_client(
1719                        IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1720                        Arc::new(
1721                            PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1722                                .with_opportunistic_encryption()
1723                                .with_probe_budget(initial_budget),
1724                        ),
1725                        &MockProvider::default(),
1726                    )
1727                    .await
1728                    .is_ok()
1729                );
1730            });
1731        });
1732
1733        #[allow(clippy::mutable_key_type)]
1734        let map = snapshotter.snapshot().into_hashmap();
1735
1736        // We should have registered 1 TLS protocol probe attempt.
1737        let protocol = vec![Label::new("protocol", "tls")];
1738        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1739        // And seen 1 probe duration observation.
1740        assert_histogram_sample_count_eq(
1741            &map,
1742            PROBE_DURATION_SECONDS,
1743            protocol.clone(),
1744            1,
1745            Unit::Seconds,
1746        );
1747
1748        // We should have registered 1 TLS protocol probe success.
1749        assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol.clone(), 1);
1750
1751        // We should have registered 0 TLS protocol probe errors.
1752        assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol, 0);
1753
1754        // The budget should be back to the initial value now that the probe completed.
1755        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1756    }
1757
1758    #[cfg(feature = "metrics")]
1759    #[test]
1760    fn test_opportunistic_probe_metrics_budget_exhausted() {
1761        subscribe();
1762        let recorder = DebuggingRecorder::new();
1763        let snapshotter = recorder.snapshotter();
1764
1765        with_local_recorder(&recorder, || {
1766            let runtime = tokio::runtime::Builder::new_current_thread()
1767                .enable_all()
1768                .build()
1769                .unwrap();
1770
1771            runtime.block_on(async {
1772                assert!(
1773                    test_connected_mut_client(
1774                        IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1775                        Arc::new(
1776                            PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1777                                .with_opportunistic_encryption(),
1778                        ),
1779                        &MockProvider::default(),
1780                    )
1781                    .await
1782                    .is_ok()
1783                );
1784            });
1785        });
1786
1787        #[allow(clippy::mutable_key_type)]
1788        let map = snapshotter.snapshot().into_hashmap();
1789
1790        // The budget metric should confirm that there's no budget.
1791        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], 0);
1792
1793        // We should not have registered a probe attempt.
1794        let protocol = vec![Label::new("protocol", "tls")];
1795        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 0);
1796        // Or seen a probe duration observation.
1797        assert_histogram_sample_count_eq(&map, PROBE_DURATION_SECONDS, protocol, 0, Unit::Seconds);
1798    }
1799
1800    #[cfg(feature = "metrics")]
1801    #[test]
1802    fn test_opportunistic_probe_metrics_connection_error() {
1803        subscribe();
1804        let recorder = DebuggingRecorder::new();
1805        let snapshotter = recorder.snapshotter();
1806        let initial_budget = 10;
1807
1808        with_local_recorder(&recorder, || {
1809            let runtime = tokio::runtime::Builder::new_current_thread()
1810                .enable_all()
1811                .build()
1812                .unwrap();
1813
1814            runtime.block_on(async {
1815                let _ = test_connected_mut_client(
1816                    IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1817                    Arc::new(
1818                        PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1819                            .with_opportunistic_encryption()
1820                            .with_probe_budget(initial_budget),
1821                    ),
1822                    // Configure a mock provider that always produces an error when new connections are requested.
1823                    &MockProvider {
1824                        new_connection_error: Some(NetError::from(io::Error::new(
1825                            io::ErrorKind::ConnectionRefused,
1826                            "connection refused",
1827                        ))),
1828                        ..MockProvider::default()
1829                    },
1830                )
1831                .await;
1832            });
1833        });
1834
1835        #[allow(clippy::mutable_key_type)]
1836        let map = snapshotter.snapshot().into_hashmap();
1837
1838        // We should have registered 1 TLS protocol probe attempt.
1839        let protocol = vec![Label::new("protocol", "tls")];
1840        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1841        // And seen 1 probe duration observation.
1842        assert_histogram_sample_count_eq(
1843            &map,
1844            PROBE_DURATION_SECONDS,
1845            protocol.clone(),
1846            1,
1847            Unit::Seconds,
1848        );
1849
1850        // We should have registered 1 TLS protocol probe error.
1851        assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 1);
1852
1853        // We shouldn't have registered any TLS protocol probe successes due to the
1854        // mock new connection error.
1855        assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1856
1857        // The budget should be back to the initial value now that the probe completed.
1858        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1859    }
1860
1861    #[cfg(feature = "metrics")]
1862    #[test]
1863    fn test_opportunistic_probe_metrics_connection_timeout_error() {
1864        subscribe();
1865        let recorder = DebuggingRecorder::new();
1866        let snapshotter = recorder.snapshotter();
1867        let initial_budget = 10;
1868
1869        with_local_recorder(&recorder, || {
1870            let runtime = tokio::runtime::Builder::new_current_thread()
1871                .enable_all()
1872                .build()
1873                .unwrap();
1874
1875            runtime.block_on(async {
1876                let _ = test_connected_mut_client(
1877                    IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
1878                    Arc::new(
1879                        PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
1880                            .with_opportunistic_encryption()
1881                            .with_probe_budget(initial_budget),
1882                    ),
1883                    // Configure a mock provider that always produces a Timeout error when new connections are requested.
1884                    &MockProvider {
1885                        new_connection_error: Some(NetError::Timeout),
1886                        ..MockProvider::default()
1887                    },
1888                )
1889                .await;
1890            });
1891        });
1892
1893        #[allow(clippy::mutable_key_type)]
1894        let map = snapshotter.snapshot().into_hashmap();
1895
1896        // We should have registered 1 TLS protocol probe attempt.
1897        let protocol = vec![Label::new("protocol", "tls")];
1898        assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
1899        // And seen 1 probe duration observation.
1900        assert_histogram_sample_count_eq(
1901            &map,
1902            PROBE_DURATION_SECONDS,
1903            protocol.clone(),
1904            1,
1905            Unit::Seconds,
1906        );
1907
1908        // We should have registered 1 TLS protocol probe timeout.
1909        assert_counter_eq(&map, PROBE_TIMEOUTS_TOTAL, protocol.clone(), 1);
1910
1911        // We shouldn't have registered a more general probe error.
1912        assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 0);
1913
1914        // We shouldn't have registered any TLS protocol probe successes due to the
1915        // mock new connection error.
1916        assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
1917
1918        // The budget should be back to the initial value now that the probe completed.
1919        assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
1920    }
1921
1922    /// Construct a nameserver appropriate for opportunistic encryption and assert connected_mut_client
1923    /// returns Ok.
1924    ///
1925    /// Behind the scenes this may provoke probing behaviour that the calling test can observe via
1926    /// the `MockProvider`'s recorded calls.
1927    async fn test_connected_mut_client(
1928        ns_ip: IpAddr,
1929        cx: Arc<PoolContext>,
1930        provider: &MockProvider,
1931    ) -> Result<(), NetError> {
1932        let name_server = NameServer::new(
1933            [].into_iter(),
1934            NameServerConfig::opportunistic_encryption(ns_ip),
1935            &ResolverOpts::default(),
1936            provider.clone(),
1937        );
1938
1939        name_server
1940            .connected_mut_client(ConnectionPolicy::default(), &cx)
1941            .await
1942            .map(|_| ())
1943    }
1944}
1945
1946#[cfg(all(test, feature = "metrics"))]
1947mod resolver_metrics_tests {
1948    use std::net::{IpAddr, Ipv4Addr};
1949
1950    use metrics::{Label, with_local_recorder};
1951    use metrics_util::debugging::DebuggingRecorder;
1952    use mock_provider::MockProvider;
1953    use test_support::assert_counter_eq;
1954    use test_support::subscribe;
1955
1956    use super::*;
1957    use crate::connection_provider::TlsConfig;
1958    use crate::metrics::OUTGOING_QUERIES_TOTAL;
1959
1960    #[test]
1961    fn test_outgoing_query_protocol_metrics_udp() {
1962        subscribe();
1963        let recorder = DebuggingRecorder::new();
1964        let snapshotter = recorder.snapshotter();
1965
1966        with_local_recorder(&recorder, || {
1967            let runtime = tokio::runtime::Builder::new_current_thread()
1968                .enable_all()
1969                .build()
1970                .unwrap();
1971
1972            runtime.block_on(async {
1973                let options = ResolverOpts::default();
1974                let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
1975                let name_server = Arc::new(NameServer::new(
1976                    [],
1977                    config,
1978                    &options,
1979                    MockProvider::default(),
1980                ));
1981
1982                let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
1983                let name = Name::parse("www.example.com.", None).unwrap();
1984                let _ = name_server
1985                    .send(
1986                        DnsRequest::from_query(
1987                            Query::query(name.clone(), RecordType::A),
1988                            DnsRequestOptions::default(),
1989                        ),
1990                        ConnectionPolicy::default(),
1991                        &cx,
1992                    )
1993                    .await;
1994            });
1995        });
1996
1997        #[allow(clippy::mutable_key_type)]
1998        let map = snapshotter.snapshot().into_hashmap();
1999
2000        // We should have registered 1 UDP protocol query.
2001        let protocol = vec![Label::new("protocol", "udp")];
2002        assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2003    }
2004
2005    #[test]
2006    fn test_outgoing_query_protocol_metrics_tcp() {
2007        subscribe();
2008        let recorder = DebuggingRecorder::new();
2009        let snapshotter = recorder.snapshotter();
2010
2011        with_local_recorder(&recorder, || {
2012            let runtime = tokio::runtime::Builder::new_current_thread()
2013                .enable_all()
2014                .build()
2015                .unwrap();
2016
2017            runtime.block_on(async {
2018                let options = ResolverOpts::default();
2019                let config = NameServerConfig::tcp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
2020                let name_server = Arc::new(NameServer::new(
2021                    [],
2022                    config,
2023                    &options,
2024                    MockProvider::default(),
2025                ));
2026
2027                let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2028                let name = Name::parse("www.example.com.", None).unwrap();
2029                let _ = name_server
2030                    .send(
2031                        DnsRequest::from_query(
2032                            Query::query(name.clone(), RecordType::A),
2033                            DnsRequestOptions::default(),
2034                        ),
2035                        ConnectionPolicy::default(),
2036                        &cx,
2037                    )
2038                    .await;
2039            });
2040        });
2041
2042        #[allow(clippy::mutable_key_type)]
2043        let map = snapshotter.snapshot().into_hashmap();
2044
2045        // We should have registered 1 TCP protocol query.
2046        let protocol = vec![Label::new("protocol", "tcp")];
2047        assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2048    }
2049
2050    #[cfg(feature = "__tls")]
2051    #[test]
2052    fn test_outgoing_query_protocol_metrics_tls() {
2053        subscribe();
2054        let recorder = DebuggingRecorder::new();
2055        let snapshotter = recorder.snapshotter();
2056
2057        with_local_recorder(&recorder, || {
2058            let runtime = tokio::runtime::Builder::new_current_thread()
2059                .enable_all()
2060                .build()
2061                .unwrap();
2062
2063            runtime.block_on(async {
2064                let options = ResolverOpts::default();
2065                let config = NameServerConfig::tls(
2066                    IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
2067                    "dns.google".into(),
2068                );
2069                let name_server = Arc::new(NameServer::new(
2070                    [],
2071                    config,
2072                    &options,
2073                    MockProvider::default(),
2074                ));
2075
2076                let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
2077                let name = Name::parse("www.example.com.", None).unwrap();
2078                let _ = name_server
2079                    .send(
2080                        DnsRequest::from_query(
2081                            Query::query(name.clone(), RecordType::A),
2082                            DnsRequestOptions::default(),
2083                        ),
2084                        ConnectionPolicy::default(),
2085                        &cx,
2086                    )
2087                    .await;
2088            });
2089        });
2090
2091        #[allow(clippy::mutable_key_type)]
2092        let map = snapshotter.snapshot().into_hashmap();
2093
2094        // We should have registered 1 TLS protocol query.
2095        let protocol = vec![Label::new("protocol", "tls")];
2096        assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
2097    }
2098}
2099
2100#[cfg(all(test, any(feature = "metrics", feature = "__tls")))]
2101mod mock_provider {
2102    use std::future::Future;
2103    use std::io;
2104    use std::pin::Pin;
2105    use std::task::{Context, Poll};
2106
2107    use futures_util::stream::once;
2108    use futures_util::{Stream, future};
2109    use tokio::net::UdpSocket;
2110
2111    use super::*;
2112    use crate::config::ProtocolConfig;
2113    use crate::net::runtime::TokioTime;
2114    use crate::net::runtime::iocompat::AsyncIoTokioAsStd;
2115    use crate::proto::op::Message;
2116
2117    /// `MockProvider` is a `ConnectionProvider` that uses a synchronous runtime provider.
2118    ///
2119    /// It also tracks calls to `new_connection`, exposing the arguments provided as
2120    /// `new_connection_calls` for test to interrogate. The optional `new_connection_error`
2121    /// `ProtoError` can be set to have `new_connection()` return a future that will error
2122    /// when polled, mocking a connection failure.
2123    #[derive(Clone)]
2124    pub(super) struct MockProvider {
2125        pub(super) runtime: MockSyncRuntimeProvider,
2126        pub(super) new_connection_calls: Arc<SyncMutex<Vec<(IpAddr, ProtocolConfig)>>>,
2127        pub(super) new_connection_error: Option<NetError>,
2128    }
2129
2130    impl MockProvider {
2131        pub(super) fn new_connection_calls(&self) -> Vec<(IpAddr, ProtocolConfig)> {
2132            self.new_connection_calls.lock().clone()
2133        }
2134    }
2135
2136    impl ConnectionProvider for MockProvider {
2137        type Conn = MockClientHandle;
2138        type FutureConn = Pin<Box<dyn Send + Future<Output = Result<Self::Conn, NetError>>>>;
2139        type RuntimeProvider = MockSyncRuntimeProvider;
2140
2141        fn new_connection(
2142            &self,
2143            ip: IpAddr,
2144            config: &ConnectionConfig,
2145            _cx: &PoolContext,
2146        ) -> Result<Self::FutureConn, NetError> {
2147            self.new_connection_calls
2148                .lock()
2149                .push((ip, config.protocol.clone()));
2150
2151            Ok(Box::pin(future::ready(match &self.new_connection_error {
2152                Some(err) => Err(err.clone()),
2153                None => Ok(MockClientHandle),
2154            })))
2155        }
2156
2157        fn runtime_provider(&self) -> &Self::RuntimeProvider {
2158            &self.runtime
2159        }
2160    }
2161
2162    impl Default for MockProvider {
2163        fn default() -> Self {
2164            Self {
2165                runtime: MockSyncRuntimeProvider,
2166                new_connection_calls: Arc::new(SyncMutex::new(Vec::new())),
2167                new_connection_error: None,
2168            }
2169        }
2170    }
2171
2172    /// `MockClientHandle` is a `DnsHandle` that uses a synchronous runtime provider.
2173    ///
2174    /// It's `send` method always returns a `NoError` response when polled, simulating a
2175    /// successful DNS request exchange.
2176    #[derive(Clone, Default)]
2177    pub(super) struct MockClientHandle;
2178
2179    impl DnsHandle for MockClientHandle {
2180        type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
2181        type Runtime = MockSyncRuntimeProvider;
2182
2183        fn send(&self, request: DnsRequest) -> Self::Response {
2184            let mut response = Message::response(request.id, request.op_code);
2185            response.metadata.response_code = ResponseCode::NoError;
2186            response.add_queries(request.queries.clone());
2187            Box::pin(once(future::ready(Ok(
2188                DnsResponse::from_message(response).unwrap()
2189            ))))
2190        }
2191    }
2192
2193    /// `MockSyncRuntimeProvider` is a `RuntimeProvider` that creates `MockSyncHandle` instances.
2194    ///
2195    /// Trait methods other than `create_handle` are not implemented.
2196    #[derive(Clone)]
2197    pub(super) struct MockSyncRuntimeProvider;
2198
2199    impl RuntimeProvider for MockSyncRuntimeProvider {
2200        type Handle = MockSyncHandle;
2201        type Timer = TokioTime;
2202        type Udp = UdpSocket;
2203        type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
2204
2205        fn create_handle(&self) -> Self::Handle {
2206            MockSyncHandle
2207        }
2208
2209        #[allow(clippy::unimplemented)]
2210        fn connect_tcp(
2211            &self,
2212            _server_addr: std::net::SocketAddr,
2213            _bind_addr: Option<std::net::SocketAddr>,
2214            _timeout: Option<Duration>,
2215        ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
2216            unimplemented!();
2217        }
2218
2219        #[allow(clippy::unimplemented)]
2220        fn bind_udp(
2221            &self,
2222            _local_addr: std::net::SocketAddr,
2223            _server_addr: std::net::SocketAddr,
2224        ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
2225            unimplemented!();
2226        }
2227    }
2228
2229    /// `MockSyncHandle` is a `Spawn` implementation that polls task futures synchronously.
2230    ///
2231    /// Provided futures will be polled until completion, allowing tests to avoid needing to
2232    /// coordinate with background tasks to determine their completion state.
2233    #[derive(Clone)]
2234    pub(super) struct MockSyncHandle;
2235
2236    impl Spawn for MockSyncHandle {
2237        fn spawn_bg(&mut self, future: impl Future<Output = ()> + Send + 'static) {
2238            // Instead of spawning the future as a background task, poll it synchronously
2239            // until completion.
2240            let waker = futures_util::task::noop_waker();
2241            let mut context = Context::from_waker(&waker);
2242            let mut future = Box::pin(future);
2243
2244            loop {
2245                match future.as_mut().poll(&mut context) {
2246                    Poll::Ready(_) => break,
2247                    Poll::Pending => continue,
2248                }
2249            }
2250        }
2251    }
2252}