Skip to main content

hickory_resolver/
name_server_pool.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::collections::{HashMap, VecDeque};
9use std::net::IpAddr;
10use std::pin::Pin;
11use std::sync::atomic::AtomicU8;
12use std::sync::{
13    Arc,
14    atomic::{AtomicUsize, Ordering as AtomicOrdering},
15};
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant, SystemTime};
18
19use futures_util::lock::{Mutex as AsyncMutex, MutexGuard};
20use futures_util::stream::{FuturesUnordered, Stream, StreamExt, once};
21use futures_util::{
22    Future, FutureExt,
23    future::{BoxFuture, Shared},
24};
25use parking_lot::Mutex;
26#[cfg(feature = "serde")]
27use serde::{Deserialize, Serialize};
28use smallvec::SmallVec;
29#[cfg(all(feature = "toml", any(feature = "__tls", feature = "__quic")))]
30use tracing::info;
31use tracing::{debug, error};
32
33#[cfg(any(feature = "__tls", feature = "__quic"))]
34use crate::config::OpportunisticEncryptionConfig;
35use crate::{
36    config::{NameServerConfig, OpportunisticEncryption, ResolverOpts, ServerOrderingStrategy},
37    connection_provider::{ConnectionProvider, TlsConfig},
38    name_server::{ConnectionPolicy, NameServer},
39    net::{
40        DnsError, NetError, NoRecords,
41        runtime::{RuntimeProvider, Time},
42        xfer::{DnsHandle, Protocol},
43    },
44    proto::{
45        access_control::AccessControlSet,
46        op::{DnsRequest, DnsRequestOptions, DnsResponse, OpCode, Query, ResponseCode},
47        rr::{
48            Name, RData, Record,
49            rdata::{
50                A, AAAA,
51                opt::{ClientSubnet, EdnsCode, EdnsOption},
52            },
53        },
54    },
55};
56
57/// Abstract interface for mocking purpose
58#[derive(Clone)]
59pub struct NameServerPool<P: ConnectionProvider> {
60    state: Arc<PoolState<P>>,
61    active_requests: Arc<Mutex<HashMap<Arc<CacheKey>, SharedLookup>>>,
62    ttl: Option<TtlInstant>,
63    zone: Option<Name>,
64}
65
66impl<P: ConnectionProvider> NameServerPool<P> {
67    /// Construct a NameServerPool from a set of name server configs
68    pub fn from_config(
69        servers: impl IntoIterator<Item = NameServerConfig>,
70        cx: Arc<PoolContext>,
71        conn_provider: P,
72    ) -> Self {
73        Self::from_nameservers(
74            servers
75                .into_iter()
76                .map(|server| {
77                    Arc::new(NameServer::new(
78                        [],
79                        server,
80                        &cx.options,
81                        conn_provider.clone(),
82                    ))
83                })
84                .collect(),
85            cx,
86        )
87    }
88
89    #[doc(hidden)]
90    pub fn from_nameservers(servers: Vec<Arc<NameServer<P>>>, cx: Arc<PoolContext>) -> Self {
91        Self {
92            state: Arc::new(PoolState {
93                servers,
94                cx,
95                next: AtomicUsize::new(0),
96            }),
97            active_requests: Arc::new(Mutex::new(HashMap::new())),
98            ttl: None,
99            zone: None,
100        }
101    }
102
103    /// Set a TTL on the NameServerPool
104    pub fn with_ttl(mut self, ttl: Duration) -> Self {
105        self.ttl = Some(TtlInstant::now() + ttl);
106        self
107    }
108
109    /// Set the zone on the NameServerPool
110    pub fn with_zone(mut self, zone: Name) -> Self {
111        self.zone = Some(zone);
112        self
113    }
114
115    /// Check if the TTL on the NameServerPool (if set) has expired
116    pub fn ttl_expired(&self) -> bool {
117        match self.ttl {
118            Some(ttl) => TtlInstant::now() > ttl,
119            None => false,
120        }
121    }
122
123    /// Returns the pool's options.
124    pub fn context(&self) -> &Arc<PoolContext> {
125        &self.state.cx
126    }
127
128    /// Return the zone associated with the pool
129    pub fn zone(&self) -> Option<&Name> {
130        self.zone.as_ref()
131    }
132}
133
134// Type alias for TTL unit tests to use tokio's time pause/advance
135#[cfg(not(feature = "tokio"))]
136type TtlInstant = std::time::Instant;
137#[cfg(feature = "tokio")]
138type TtlInstant = tokio::time::Instant;
139
140impl<P: ConnectionProvider> DnsHandle for NameServerPool<P> {
141    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
142    type Runtime = P::RuntimeProvider;
143
144    fn lookup(&self, query: Query, mut options: DnsRequestOptions) -> Self::Response {
145        debug!("querying: {} {:?}", query.name(), query.query_type());
146        options.case_randomization = self.state.cx.options.case_randomization;
147        self.send(DnsRequest::from_query(query, options))
148    }
149
150    fn send(&self, request: DnsRequest) -> Self::Response {
151        let state = self.state.clone();
152        let acs = self.state.cx.answer_address_filter.clone();
153        let active_requests = self.active_requests.clone();
154
155        Box::pin(once(async move {
156            debug!("sending request: {:?}", request.queries);
157            let query = match request.queries.first() {
158                Some(q) => q.clone(),
159                None => return Err("no query in request".into()),
160            };
161
162            let key = Arc::new(CacheKey::from_request(&request));
163
164            let (lookup, is_creator) = {
165                let mut active = active_requests.lock();
166                if let Some(existing) = active.get(&key) {
167                    debug!(%query, "query currently in progress - returning shared lookup");
168                    (existing.clone(), false)
169                } else {
170                    debug!(%query, "creating new shared lookup");
171
172                    let lookup = async move {
173                        match state.try_send(request).await {
174                            Ok(response) => Some(Ok(response)),
175                            Err(e) => Some(Err(e)),
176                        }
177                    }
178                    .boxed()
179                    .shared();
180
181                    let shared_lookup = SharedLookup(lookup);
182                    active.insert(key.clone(), shared_lookup.clone());
183                    (shared_lookup, true)
184                }
185            };
186
187            // Only the creator removes the key so that the entry is not
188            // removed prematurely by a waiter task.  Using a guard ensures
189            // the entry is removed even if `lookup.await` panics, which
190            // would otherwise leave a poisoned `SharedLookup` in the map and
191            // cause every subsequent request for the same key to also panic.
192            let _cleanup = is_creator.then(|| ActiveRequestCleanup {
193                active_requests: active_requests.clone(),
194                key: key.clone(),
195            });
196
197            let response = lookup.await;
198            let mut response = response?;
199
200            if acs.allows_all() {
201                return Ok(response);
202            }
203
204            let answer_filter = |record: &Record| {
205                let ip = match &record.data {
206                    RData::A(A(ipv4)) => (*ipv4).into(),
207                    RData::AAAA(AAAA(ipv6)) => (*ipv6).into(),
208                    _ => return true,
209                };
210
211                if acs.denied(ip) {
212                    error!(
213                        %query,
214                        %ip,
215                        "removing ip from response: answer filter matched"
216                    );
217
218                    false
219                } else {
220                    true
221                }
222            };
223
224            let answers_len = response.answers.len();
225            let authorities_len = response.authorities.len();
226
227            response.additionals.retain(answer_filter);
228            response.answers.retain(answer_filter);
229            response.authorities.retain(answer_filter);
230
231            if response.answers.is_empty() && answers_len != 0
232                || (response.answers.is_empty()
233                    && response.authorities.is_empty()
234                    && authorities_len != 0)
235            {
236                return Err(NoRecords::new(Box::new(query.clone()), ResponseCode::NXDomain).into());
237            }
238
239            // Since the message might have changed, create a new response from
240            // the message to update the buffer.
241            DnsResponse::from_message(response.into_message()).map_err(NetError::from)
242        }))
243    }
244}
245
246struct PoolState<P: ConnectionProvider> {
247    servers: Vec<Arc<NameServer<P>>>,
248    cx: Arc<PoolContext>,
249    next: AtomicUsize,
250}
251
252impl<P: ConnectionProvider> PoolState<P> {
253    async fn try_send(&self, request: DnsRequest) -> Result<DnsResponse, NetError> {
254        let mut servers = self.servers.clone();
255        match self.cx.options.server_ordering_strategy {
256            // select the highest priority connection
257            //   reorder the connections based on current view...
258            //   this reorders the inner set
259            ServerOrderingStrategy::QueryStatistics => {
260                sort_servers_by_query_statistics(&mut servers);
261            }
262            ServerOrderingStrategy::UserProvidedOrder => {}
263            ServerOrderingStrategy::RoundRobin => {
264                let num_concurrent_reqs = if self.cx.options.num_concurrent_reqs > 1 {
265                    self.cx.options.num_concurrent_reqs
266                } else {
267                    1
268                };
269                if num_concurrent_reqs < servers.len() {
270                    let index = self
271                        .next
272                        .fetch_add(num_concurrent_reqs, AtomicOrdering::SeqCst)
273                        % servers.len();
274                    servers.rotate_left(index);
275                }
276            }
277        }
278
279        // If the name server we're trying is giving us backpressure by returning NetErrorKind::Busy,
280        // we will first try the other name servers (as for other error types). However, if the other
281        // servers are also busy, we're going to wait for a little while and then retry each server that
282        // returned Busy in the previous round. If the server is still Busy, this continues, while
283        // the backoff increases exponentially (by a factor of 2), until it hits 300ms, in which case we
284        // give up. The request might still be retried by the caller (likely the DnsRetryHandle).
285        //
286        // Enforce an end-to-end deadline so the total time spent in this loop never exceeds the
287        // configured timeout.  Without this, the pool can spend up to N × timeout (where N is the
288        // number of servers) before returning an error — well past the point where clients have
289        // given up and retransmitted the query.
290        let deadline = Instant::now() + self.cx.options.timeout;
291
292        let mut servers = VecDeque::from(servers);
293        let mut backoff = Duration::from_millis(20);
294        let mut busy = SmallVec::<[Arc<NameServer<P>>; 2]>::new();
295        let mut err = NetError::NoConnections;
296        let mut policy = ConnectionPolicy::default();
297
298        loop {
299            // Check the deadline before starting a new round of server attempts.
300            if Instant::now() >= deadline {
301                return Err(NetError::Timeout);
302            }
303
304            // construct the parallel requests, 2 is the default
305            let mut par_servers = SmallVec::<[_; 2]>::new();
306            while !servers.is_empty()
307                && par_servers.len() < Ord::max(self.cx.options.num_concurrent_reqs, 1)
308            {
309                if let Some(server) = servers.pop_front() {
310                    if policy.allows_server(&server) {
311                        par_servers.push(server);
312                    }
313                }
314            }
315
316            if par_servers.is_empty() {
317                if !busy.is_empty() && backoff < Duration::from_millis(300) {
318                    // Cap the backoff sleep so we don't sleep past the deadline.
319                    let remaining = deadline.saturating_duration_since(Instant::now());
320                    if remaining.is_zero() {
321                        return Err(NetError::Timeout);
322                    }
323                    <<P as ConnectionProvider>::RuntimeProvider as RuntimeProvider>::Timer::delay_for(
324                        backoff.min(remaining),
325                    ).await;
326                    servers.extend(busy.drain(..).filter(|ns| policy.allows_server(ns)));
327                    backoff *= 2;
328                    continue;
329                }
330                return Err(err);
331            }
332
333            // Track all servers in the parallel batch so we can penalize any
334            // that are still in-flight when a winner is found.
335            let in_flight = par_servers.iter().cloned().collect::<SmallVec<[_; 2]>>();
336
337            let batch_start = Instant::now();
338            let mut requests = par_servers
339                .into_iter()
340                .map(|server| {
341                    let mut request = request.clone();
342
343                    // Set the retry interval to 1.2 times the current decayed SRTT
344                    let retry_interval =
345                        Duration::from_micros((server.decayed_srtt() * 1.2) as u64);
346                    request.options_mut().retry_interval = retry_interval;
347                    debug!(?retry_interval, ip = ?server.ip(), "setting retry_interval");
348
349                    let future = server.clone().send(request, policy, &self.cx);
350                    async { (server, future.await) }
351                })
352                .collect::<FuturesUnordered<_>>();
353
354            // Servers that have already completed (successfully or with an
355            // error) — used to avoid double-penalizing them.
356            let mut completed = SmallVec::<[IpAddr; 2]>::new();
357
358            while let Some((server, result)) = requests.next().await {
359                completed.push(server.ip());
360                let e = match result {
361                    Ok(response) if response.truncation => {
362                        debug!("truncated response received, retrying over TCP");
363                        policy.disable_udp = true;
364                        err = NetError::from("received truncated response");
365                        servers.push_front(server);
366                        continue;
367                    }
368                    Ok(response) => {
369                        // Penalize servers still in-flight (see `record_cancelled`).
370                        let winner_rtt = batch_start.elapsed();
371                        for abandoned in &in_flight {
372                            if !completed.contains(&abandoned.ip()) {
373                                debug!(ip = ?abandoned.ip(), ?winner_rtt, "recording cancelled parallel server");
374                                abandoned.record_cancelled(winner_rtt);
375                            }
376                        }
377                        return Ok(response);
378                    }
379                    Err(e) => e,
380                };
381
382                match &e {
383                    // We assume the response is spoofed, so ignore it and avoid UDP server for this
384                    // request to try and avoid further spoofing.
385                    NetError::QueryCaseMismatch => {
386                        servers.push_front(server);
387                        policy.disable_udp = true;
388                        continue;
389                    }
390                    // If the server is busy, try it again later if necessary.
391                    NetError::Busy => busy.push(server),
392                    // If the connection failed or timed out, try another one.
393                    NetError::Io(_) | NetError::NoConnections | NetError::Timeout => {}
394                    // If we got an `NXDomain` response from a server whose negative responses we
395                    // don't trust, we should try another server.
396                    NetError::Dns(DnsError::NoRecordsFound(NoRecords {
397                        response_code: ResponseCode::NXDomain,
398                        ..
399                    })) if !server.trust_negative_responses() => {}
400                    _ => return Err(e),
401                }
402
403                err = most_specific(err, e);
404            }
405        }
406    }
407}
408
409/// Compare two errors to see if one contains a server response.
410fn most_specific(previous: NetError, current: NetError) -> NetError {
411    match (&previous, &current) {
412        (
413            NetError::Dns(DnsError::NoRecordsFound { .. }),
414            NetError::Dns(DnsError::NoRecordsFound { .. }),
415        ) => return previous,
416        (NetError::Dns(DnsError::NoRecordsFound { .. }), _) => return previous,
417        (_, NetError::Dns(DnsError::NoRecordsFound { .. })) => return current,
418        _ => (),
419    }
420
421    match (&previous, &current) {
422        (NetError::Io { .. }, NetError::Io { .. }) => return previous,
423        (NetError::Io { .. }, _) => return current,
424        (_, NetError::Io { .. }) => return previous,
425        _ => (),
426    }
427
428    match (&previous, &current) {
429        (NetError::Timeout, NetError::Timeout) => return previous,
430        (NetError::Timeout, _) => return previous,
431        (_, NetError::Timeout) => return current,
432        _ => (),
433    }
434
435    previous
436}
437
438/// Sorts servers by their decayed SRTT for query-statistics-based ordering.
439///
440/// Uses `sort_by_cached_key` to evaluate each server's decayed SRTT exactly
441/// once. This is critical because `decayed_srtt()` reads shared mutable state
442/// that can change between calls due to concurrent query completions, which
443/// would violate the total-order invariant required by `sort_by`.
444pub(crate) fn sort_servers_by_query_statistics<P: ConnectionProvider>(
445    servers: &mut [Arc<NameServer<P>>],
446) {
447    // Positive f64 bit patterns sort in the same order as their float values,
448    // so to_bits() is a valid u64 ordering key for non-negative SRTT values.
449    servers.sort_by_cached_key(|s| s.decayed_srtt().to_bits());
450}
451
452/// Context for a [`NameServerPool`]
453#[non_exhaustive]
454pub struct PoolContext {
455    /// Resolver options
456    pub options: ResolverOpts,
457    /// TLS configuration
458    #[cfg(feature = "__tls")]
459    pub tls: rustls::ClientConfig,
460    /// Opportunistic probe budget
461    pub opportunistic_probe_budget: AtomicU8,
462    /// Opportunistic encryption configuration
463    pub opportunistic_encryption: OpportunisticEncryption,
464    /// Opportunistic encryption name server transport state
465    pub transport_state: AsyncMutex<NameServerTransportState>,
466    /// Answer address filter
467    pub answer_address_filter: AccessControlSet,
468}
469
470impl PoolContext {
471    /// Creates a new PoolContext
472    #[cfg_attr(not(feature = "__tls"), expect(unused_variables))]
473    pub fn new(options: ResolverOpts, tls: TlsConfig) -> Self {
474        Self {
475            answer_address_filter: options.answer_address_filter(),
476            options,
477            #[cfg(feature = "__tls")]
478            tls: tls.config,
479            opportunistic_probe_budget: AtomicU8::default(),
480            opportunistic_encryption: OpportunisticEncryption::default(),
481            transport_state: AsyncMutex::new(NameServerTransportState::default()),
482        }
483    }
484
485    /// Set the opportunistic probe budget
486    pub fn with_probe_budget(self, budget: u8) -> Self {
487        self.opportunistic_probe_budget
488            .store(budget, AtomicOrdering::SeqCst);
489        self
490    }
491
492    /// Add an answer address filter
493    pub fn with_answer_filter(mut self, answer_filter: AccessControlSet) -> Self {
494        self.answer_address_filter = answer_filter;
495        self
496    }
497
498    /// Enables opportunistic encryption with default configuration
499    #[cfg(any(feature = "__tls", feature = "__quic"))]
500    pub fn with_opportunistic_encryption(mut self) -> Self {
501        self.opportunistic_encryption = OpportunisticEncryption::Enabled {
502            config: OpportunisticEncryptionConfig::default(),
503        };
504        self
505    }
506
507    /// Sets the transport state
508    pub fn with_transport_state(mut self, transport_state: NameServerTransportState) -> Self {
509        self.transport_state = AsyncMutex::new(transport_state);
510        self
511    }
512
513    pub(crate) async fn transport_state(&self) -> MutexGuard<'_, NameServerTransportState> {
514        self.transport_state.lock().await
515    }
516}
517
518/// A mapping from nameserver IP address and protocol to encrypted transport state.
519#[derive(Debug, Default, Clone)]
520#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
521#[repr(transparent)]
522pub struct NameServerTransportState(HashMap<IpAddr, ProtocolTransportState>);
523
524impl NameServerTransportState {
525    /// Return the count of nameservers with protocol transport state.
526    pub fn nameserver_count(&self) -> usize {
527        self.0.len()
528    }
529
530    /// Update the transport state for the given IP and protocol to record a connection initiation.
531    pub(crate) fn initiate_connection(&mut self, ip: IpAddr, protocol: Protocol) {
532        let protocol_state = self.0.entry(ip).or_default();
533        *protocol_state.get_mut(protocol) = TransportState::default();
534    }
535
536    /// Update the transport state for the given IP and protocol to record a connection completion.
537    pub(crate) fn complete_connection(&mut self, ip: IpAddr, protocol: Protocol) {
538        let protocol_state = self.0.entry(ip).or_default();
539        *protocol_state.get_mut(protocol) = TransportState::Success {
540            last_response: None,
541        };
542    }
543
544    /// Update the successful transport state for the given IP and protocol to record a response received.
545    pub(crate) fn response_received(&mut self, ip: IpAddr, protocol: Protocol) {
546        let Some(protocol_state) = self.0.get_mut(&ip) else {
547            return;
548        };
549        let TransportState::Success { last_response, .. } = protocol_state.get_mut(protocol) else {
550            return;
551        };
552        *last_response = Some(SystemTime::now());
553    }
554
555    /// Update the transport state for the given IP and protocol to record a received error.
556    pub(crate) fn error_received(&mut self, ip: IpAddr, protocol: Protocol, error: &NetError) {
557        let protocol_state = self.0.entry(ip).or_default();
558        *protocol_state.get_mut(protocol) = match &error {
559            NetError::Timeout => TransportState::TimedOut {
560                #[cfg(any(feature = "__tls", feature = "__quic"))]
561                completed_at: SystemTime::now(),
562            },
563            _ => TransportState::Failed {
564                #[cfg(any(feature = "__tls", feature = "__quic"))]
565                completed_at: SystemTime::now(),
566            },
567        };
568    }
569
570    /// Returns true if any supported encrypted protocol had a recent success for the given IP
571    /// within the damping period.
572    #[cfg(any(feature = "__tls", feature = "__quic"))]
573    pub(crate) fn any_recent_success(&self, ip: IpAddr, config: &OpportunisticEncryption) -> bool {
574        #[allow(unused_assignments, unused_mut)]
575        let mut tls_success = false;
576        #[allow(unused_assignments, unused_mut)]
577        let mut quic_success = false;
578
579        #[cfg(feature = "__tls")]
580        {
581            tls_success = self.recent_success(ip, Protocol::Tls, config);
582        }
583
584        #[cfg(feature = "__quic")]
585        {
586            quic_success = self.recent_success(ip, Protocol::Quic, config);
587        }
588
589        tls_success || quic_success
590    }
591
592    /// Returns true if any encrypted protocol had a recent success for the given IP within the damping period.
593    #[cfg(not(any(feature = "__tls", feature = "__quic")))]
594    pub(crate) fn any_recent_success(
595        &self,
596        _ip: IpAddr,
597        _config: &OpportunisticEncryption,
598    ) -> bool {
599        false
600    }
601
602    /// Returns true if there has been a successful response within the persistence period for the
603    /// IP/protocol.
604    ///
605    /// Returns false if opportunistic encryption is disabled, or if there has not been a successful
606    /// response read.
607    #[cfg(any(feature = "__tls", feature = "__quic"))]
608    pub(crate) fn recent_success(
609        &self,
610        ip: IpAddr,
611        protocol: Protocol,
612        config: &OpportunisticEncryption,
613    ) -> bool {
614        let OpportunisticEncryption::Enabled { config } = config else {
615            return false;
616        };
617
618        let Some(protocol_state) = self.0.get(&ip) else {
619            return false;
620        };
621
622        let TransportState::Success { last_response, .. } = protocol_state.get(protocol) else {
623            return false;
624        };
625
626        let Some(last_response) = last_response else {
627            return false;
628        };
629
630        last_response.elapsed().unwrap_or(Duration::MAX) <= config.persistence_period
631    }
632
633    /// Returns true if there has been a successful response within the persistence period.
634    ///
635    /// Returns false if opportunistic encryption is disabled, or if there has not been a successful
636    /// response read.
637    #[cfg(not(any(feature = "__tls", feature = "__quic")))]
638    pub(crate) fn recent_success(
639        &self,
640        _ip: IpAddr,
641        _protocol: Protocol,
642        _config: &OpportunisticEncryption,
643    ) -> bool {
644        false
645    }
646
647    /// Returns true if we should probe encrypted transport based on RFC 9539 damping logic.
648    #[cfg(any(feature = "__tls", feature = "__quic"))]
649    pub(crate) fn should_probe_encrypted(
650        &self,
651        ip: IpAddr,
652        protocol: Protocol,
653        config: &OpportunisticEncryption,
654    ) -> bool {
655        debug_assert!(protocol.is_encrypted());
656
657        let OpportunisticEncryption::Enabled { config, .. } = config else {
658            return false;
659        };
660
661        let Some(protocol_state) = self.0.get(&ip) else {
662            return true;
663        };
664
665        match protocol_state.get(protocol) {
666            TransportState::Initiated => false,
667            TransportState::Success { .. } => true,
668            TransportState::Failed { completed_at } | TransportState::TimedOut { completed_at } => {
669                completed_at.elapsed().unwrap_or(Duration::MAX) > config.damping_period
670            }
671        }
672    }
673
674    /// Returns true if we should probe encrypted transport based on RFC 9539 damping logic.
675    #[cfg(not(any(feature = "__tls", feature = "__quic")))]
676    pub(crate) fn should_probe_encrypted(
677        &self,
678        _ip: IpAddr,
679        _protocol: Protocol,
680        _config: &OpportunisticEncryption,
681    ) -> bool {
682        false
683    }
684
685    /// For testing, set the last response time for successful connections to the ip/protocol.
686    #[cfg(all(test, feature = "__tls"))]
687    pub(crate) fn set_last_response(&mut self, ip: IpAddr, protocol: Protocol, when: SystemTime) {
688        let Some(protocol_state) = self.0.get_mut(&ip) else {
689            return;
690        };
691
692        let TransportState::Success { last_response, .. } = protocol_state.get_mut(protocol) else {
693            return;
694        };
695
696        *last_response = Some(when);
697    }
698
699    /// For testing, set the completion time for failed connections to the ip/protocol.
700    #[cfg(all(test, feature = "__tls"))]
701    pub(crate) fn set_failure_time(&mut self, ip: IpAddr, protocol: Protocol, when: SystemTime) {
702        let protocol_state = self.0.entry(ip).or_default();
703        *protocol_state.get_mut(protocol) = TransportState::Failed { completed_at: when };
704    }
705}
706
707#[derive(Debug, Clone, Copy, Default)]
708#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
709struct ProtocolTransportState {
710    #[cfg(feature = "__tls")]
711    tls: TransportState,
712    #[cfg(feature = "__quic")]
713    quic: TransportState,
714}
715
716impl ProtocolTransportState {
717    #[cfg_attr(not(any(feature = "__tls", feature = "__quic")), allow(dead_code))]
718    fn get_mut(&mut self, protocol: Protocol) -> &mut TransportState {
719        match protocol {
720            #[cfg(feature = "__tls")]
721            Protocol::Tls => &mut self.tls,
722            #[cfg(feature = "__quic")]
723            Protocol::Quic => &mut self.quic,
724            _ => unreachable!("unsupported opportunistic encryption protocol: {protocol:?}"),
725        }
726    }
727
728    #[cfg_attr(not(any(feature = "__tls", feature = "__quic")), allow(dead_code))]
729    fn get(&self, protocol: Protocol) -> &TransportState {
730        match protocol {
731            #[cfg(feature = "__tls")]
732            Protocol::Tls => &self.tls,
733            #[cfg(feature = "__quic")]
734            Protocol::Quic => &self.quic,
735            _ => unreachable!("unsupported opportunistic encryption protocol: {protocol:?}"),
736        }
737    }
738}
739
740/// State tracked per nameserver IP/protocol to inform opportunistic encryption.
741#[derive(Debug, Clone, Copy, Default)]
742#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
743enum TransportState {
744    /// Connection attempt has been initiated.
745    #[default]
746    Initiated,
747    /// Connection completed successfully.
748    Success {
749        /// The last instant at which a response was read on the connection (if any).
750        last_response: Option<SystemTime>,
751    },
752    /// Connection failed with an error.
753    Failed {
754        /// The instant the connection attempt was completed at.
755        #[cfg(any(feature = "__tls", feature = "__quic"))]
756        completed_at: SystemTime,
757    },
758    /// Connection timed out.
759    TimedOut {
760        /// The instant the connection attempt was completed at.
761        #[cfg(any(feature = "__tls", feature = "__quic"))]
762        completed_at: SystemTime,
763    },
764}
765
766#[cfg(all(feature = "toml", any(feature = "__tls", feature = "__quic")))]
767pub use opportunistic_encryption_persistence::OpportunisticEncryptionStatePersistTask;
768
769#[cfg(all(feature = "toml", any(feature = "__tls", feature = "__quic")))]
770mod opportunistic_encryption_persistence {
771    #[cfg(unix)]
772    use std::fs::File;
773    use std::{
774        fs::{self, OpenOptions},
775        io::{self, Write},
776        marker::PhantomData,
777        path::{Path, PathBuf},
778    };
779
780    use tracing::trace;
781
782    use super::*;
783    use crate::config::OpportunisticEncryptionPersistence;
784    use crate::net::runtime::Spawn;
785
786    /// A background task for periodically saving opportunistic encryption state.
787    pub struct OpportunisticEncryptionStatePersistTask<T> {
788        cx: Arc<PoolContext>,
789        path: PathBuf,
790        save_interval: Duration,
791        _time: PhantomData<T>,
792    }
793
794    impl<T: Time> OpportunisticEncryptionStatePersistTask<T> {
795        /// Starts the persistence task based on the given configuration.
796        pub async fn start<P: RuntimeProvider>(
797            config: OpportunisticEncryptionPersistence,
798            pool_context: &Arc<PoolContext>,
799            conn_provider: P,
800        ) -> Result<Option<P::Handle>, String> {
801            info!(
802                path = %config.path.display(),
803                save_interval = ?config.save_interval,
804                "spawning encrypted transport state persistence task"
805            );
806
807            let new =
808                OpportunisticEncryptionStatePersistTask::<P::Timer>::new(config, pool_context);
809
810            // Try to save the state back immediately so we can surface write errors early
811            // instead of when the background task runs later on.
812            new.save(&*new.cx.transport_state.lock().await)
813                .map_err(|err| {
814                    format!(
815                        "failed to save opportunistic encryption state: {path}: {err}",
816                        path = new.path.display()
817                    )
818                })?;
819
820            let mut handle = conn_provider.create_handle();
821            handle.spawn_bg(new.run());
822            Ok(Some(handle))
823        }
824
825        fn new(config: OpportunisticEncryptionPersistence, cx: &Arc<PoolContext>) -> Self {
826            Self {
827                cx: cx.clone(),
828                path: config.path,
829                save_interval: config.save_interval,
830                _time: PhantomData,
831            }
832        }
833
834        async fn run(self) {
835            let Self {
836                save_interval,
837                path,
838                cx,
839                ..
840            } = &self;
841
842            loop {
843                T::delay_for(*save_interval).await;
844                trace!(path = %path.display(), ?save_interval, "persisting opportunistic encryption state");
845                if let Err(e) = self.save(&*cx.transport_state.lock().await) {
846                    error!("failed to save opportunistic encryption state: {e}");
847                }
848            }
849        }
850
851        fn save(&self, state: &NameServerTransportState) -> Result<(), io::Error> {
852            let toml_content = toml::to_string_pretty(state).map_err(|e| {
853                io::Error::new(
854                    io::ErrorKind::InvalidData,
855                    format!("failed to serialize state to TOML: {e}"),
856                )
857            })?;
858
859            if let Some(parent) = parent_directory(&self.path) {
860                fs::create_dir_all(parent)?;
861            }
862
863            let temp_path = {
864                let mut temp = self.path.as_os_str().to_os_string();
865                temp.push(".tmp");
866                PathBuf::from(temp)
867            };
868
869            {
870                let mut temp_file = OpenOptions::new()
871                    .write(true)
872                    .create(true)
873                    .truncate(true)
874                    .open(&temp_path)?;
875
876                temp_file.write_all(toml_content.as_bytes())?;
877                temp_file.sync_all()?;
878            }
879
880            #[cfg(unix)]
881            if let Some(parent) = parent_directory(&self.path) {
882                File::open(parent)?.sync_all()?;
883            }
884
885            fs::rename(&temp_path, &self.path)?;
886            debug!(state_file = %self.path.display(), "saved opportunistic encryption state");
887            Ok(())
888        }
889    }
890
891    /// Gets the parent directory of an absolute or relative path.
892    fn parent_directory(path: &Path) -> Option<&Path> {
893        let parent = path.parent()?;
894        // Special case: if the path has only one component, `parent()` will return an empty string. We
895        // should return "." instead, a relative path pointing at the current directory.
896        Some(match parent == Path::new("") {
897            true => Path::new("."),
898            false => parent,
899        })
900    }
901}
902
903/// RAII guard that removes a deduplication key from `active_requests` when dropped.
904///
905/// This is created only by the "creator" task (the one that inserted the key).
906/// Using `Drop` guarantees the entry is removed even if the inner future panics,
907/// preventing a poisoned [`SharedLookup`] from remaining in the map and causing
908/// every subsequent request for the same key to also panic.
909struct ActiveRequestCleanup {
910    active_requests: Arc<Mutex<HashMap<Arc<CacheKey>, SharedLookup>>>,
911    key: Arc<CacheKey>,
912}
913
914impl Drop for ActiveRequestCleanup {
915    fn drop(&mut self) {
916        self.active_requests.lock().remove(&self.key);
917    }
918}
919
920/// Fields of a [`DnsRequest`] that are used as a key when memoizing queries.
921#[derive(PartialEq, Eq, Hash)]
922struct CacheKey {
923    op_code: OpCode,
924    recursion_desired: bool,
925    checking_disabled: bool,
926    queries: Vec<Query>,
927    dnssec_ok: bool,
928    client_subnet: Option<ClientSubnet>,
929}
930
931impl CacheKey {
932    fn from_request(request: &DnsRequest) -> Self {
933        let dnssec_ok;
934        let client_subnet;
935        if let Some(edns) = &request.edns {
936            dnssec_ok = edns.flags().dnssec_ok;
937            if let Some(EdnsOption::Subnet(subnet)) = edns.option(EdnsCode::Subnet) {
938                client_subnet = Some(*subnet);
939            } else {
940                client_subnet = None;
941            }
942        } else {
943            dnssec_ok = false;
944            client_subnet = None;
945        }
946        Self {
947            op_code: request.op_code,
948            recursion_desired: request.recursion_desired,
949            checking_disabled: request.checking_disabled,
950            queries: request.queries.clone(),
951            dnssec_ok,
952            client_subnet,
953        }
954    }
955}
956
957#[derive(Clone)]
958pub(crate) struct SharedLookup(Shared<BoxFuture<'static, Option<Result<DnsResponse, NetError>>>>);
959
960impl Future for SharedLookup {
961    type Output = Result<DnsResponse, NetError>;
962
963    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
964        self.0.poll_unpin(cx).map(|o| match o {
965            Some(r) => r,
966            None => Err("no response from nameserver".into()),
967        })
968    }
969}
970
971#[cfg(test)]
972#[cfg(feature = "tokio")]
973mod tests {
974    use std::collections::HashSet;
975    use std::future::Future;
976    use std::io;
977    use std::net::{IpAddr, SocketAddr};
978    use std::pin::Pin;
979    use std::str::FromStr;
980    use std::sync::atomic::{AtomicBool, Ordering};
981    use std::thread;
982    use std::time::Duration;
983
984    use futures_util::future;
985    use test_support::{
986        MockNetworkHandler, MockProvider, MockRecord, MockTcpStream, MockUdpSocket, subscribe,
987    };
988    use tokio::runtime::Runtime;
989
990    use super::*;
991    use crate::config::{NameServerConfig, ResolverConfig, ServerOrderingStrategy};
992    use crate::net::runtime::{RuntimeProvider, TokioHandle, TokioRuntimeProvider, TokioTime};
993    use crate::net::xfer::{DnsHandle, FirstAnswer};
994    use crate::proto::op::{DnsRequestOptions, Query};
995    use crate::proto::rr::{Name, RecordType};
996
997    #[ignore]
998    // because of there is a real connection that needs a reasonable timeout
999    #[test]
1000    #[allow(clippy::uninlined_format_args)]
1001    fn test_failed_then_success_pool() {
1002        subscribe();
1003
1004        let mut config1 = NameServerConfig::udp(IpAddr::from([127, 0, 0, 252]));
1005        config1.trust_negative_responses = false;
1006        let config2 = NameServerConfig::udp(IpAddr::from([8, 8, 8, 8]));
1007
1008        let mut resolver_config = ResolverConfig::default();
1009        resolver_config.add_name_server(config1);
1010        resolver_config.add_name_server(config2);
1011
1012        let io_loop = Runtime::new().unwrap();
1013        let pool = NameServerPool::from_config(
1014            resolver_config.name_servers,
1015            Arc::new(PoolContext::new(
1016                ResolverOpts::default(),
1017                TlsConfig::new().unwrap(),
1018            )),
1019            TokioRuntimeProvider::new(),
1020        );
1021
1022        let name = Name::parse("www.example.com.", None).unwrap();
1023
1024        // TODO: it's not clear why there are two failures before the success
1025        for i in 0..2 {
1026            assert!(
1027                io_loop
1028                    .block_on(
1029                        pool.lookup(
1030                            Query::query(name.clone(), RecordType::A),
1031                            DnsRequestOptions::default()
1032                        )
1033                        .first_answer()
1034                    )
1035                    .is_err(),
1036                "iter: {}",
1037                i
1038            );
1039        }
1040
1041        for i in 0..10 {
1042            assert!(
1043                io_loop
1044                    .block_on(
1045                        pool.lookup(
1046                            Query::query(name.clone(), RecordType::A),
1047                            DnsRequestOptions::default()
1048                        )
1049                        .first_answer()
1050                    )
1051                    .is_ok(),
1052                "iter: {}",
1053                i
1054            );
1055        }
1056    }
1057
1058    #[tokio::test]
1059    async fn test_multi_use_conns() {
1060        subscribe();
1061
1062        let conn_provider = TokioRuntimeProvider::default();
1063        let opts = ResolverOpts {
1064            try_tcp_on_error: true,
1065            ..ResolverOpts::default()
1066        };
1067
1068        let tcp = NameServerConfig::tcp(IpAddr::from([8, 8, 8, 8]));
1069        let name_server = Arc::new(NameServer::new([], tcp, &opts, conn_provider));
1070        let name_servers = vec![name_server];
1071        let pool = NameServerPool::from_nameservers(
1072            name_servers.clone(),
1073            Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1074        );
1075
1076        let name = Name::from_str("www.example.com.").unwrap();
1077
1078        // first lookup
1079        let response = pool
1080            .lookup(
1081                Query::query(name.clone(), RecordType::A),
1082                DnsRequestOptions::default(),
1083            )
1084            .first_answer()
1085            .await
1086            .expect("lookup failed");
1087
1088        assert!(!response.answers.is_empty());
1089
1090        assert!(
1091            name_servers[0].is_connected(),
1092            "if this is failing then the NameServers aren't being properly shared."
1093        );
1094
1095        // first lookup
1096        let response = pool
1097            .lookup(
1098                Query::query(name, RecordType::AAAA),
1099                DnsRequestOptions::default(),
1100            )
1101            .first_answer()
1102            .await
1103            .expect("lookup failed");
1104
1105        assert!(!response.answers.is_empty());
1106
1107        assert!(
1108            name_servers[0].is_connected(),
1109            "if this is failing then the NameServers aren't being properly shared."
1110        );
1111    }
1112
1113    /// Regression test: when the first name server in the pool times out, the pool should
1114    /// try the remaining servers rather than returning the timeout error immediately.
1115    ///
1116    /// Before the fix (adding `NetError::Timeout` to the retry match arm in `try_send`),
1117    /// a timeout from one server would cause the entire lookup to fail even when other
1118    /// servers in the pool could have answered successfully.
1119    #[tokio::test]
1120    async fn test_pool_retries_on_timeout() {
1121        subscribe();
1122
1123        let timeout_ip = IpAddr::from([10, 0, 0, 1]);
1124        let good_ip = IpAddr::from([10, 0, 0, 2]);
1125        let query_name = Name::from_str("example.com.").unwrap();
1126
1127        // Set up a mock handler where the good server returns a valid A record.
1128        let responses = vec![MockRecord::a(good_ip, &query_name, good_ip)];
1129        let handler = MockNetworkHandler::new(responses);
1130        let mock_provider = MockProvider::new(handler);
1131
1132        // Wrap in TimeoutProvider so that the timeout_ip always fails with TimedOut.
1133        let provider = TimeoutProvider::new(mock_provider, vec![timeout_ip]);
1134
1135        let opts = ResolverOpts {
1136            num_concurrent_reqs: 1,
1137            server_ordering_strategy: ServerOrderingStrategy::UserProvidedOrder,
1138            ..ResolverOpts::default()
1139        };
1140
1141        let pool = NameServerPool::from_nameservers(
1142            vec![
1143                Arc::new(NameServer::new(
1144                    [].into_iter(),
1145                    NameServerConfig::udp(timeout_ip),
1146                    &opts,
1147                    provider.clone(),
1148                )),
1149                Arc::new(NameServer::new(
1150                    [].into_iter(),
1151                    NameServerConfig::udp(good_ip),
1152                    &opts,
1153                    provider.clone(),
1154                )),
1155            ],
1156            Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1157        );
1158
1159        // This should succeed: the pool should fall through the timeout from the first
1160        // server and get the answer from the second server.
1161        let response = pool
1162            .lookup(
1163                Query::query(query_name.clone(), RecordType::A),
1164                DnsRequestOptions::default(),
1165            )
1166            .first_answer()
1167            .await
1168            .expect("pool should retry on timeout and succeed with the second server");
1169
1170        assert!(
1171            !response.answers.is_empty(),
1172            "expected A record in response"
1173        );
1174    }
1175
1176    /// Regression test: when a server times out, its server-level SRTT should be penalized
1177    /// so that it gets deprioritized in future pool ordering.
1178    #[tokio::test]
1179    async fn test_timeout_penalizes_server_srtt() {
1180        subscribe();
1181
1182        let timeout_ip = IpAddr::from([10, 0, 0, 1]);
1183        let good_ip = IpAddr::from([10, 0, 0, 2]);
1184        let query_name = Name::from_str("example.com.").unwrap();
1185
1186        let responses = vec![MockRecord::a(good_ip, &query_name, good_ip)];
1187        let handler = MockNetworkHandler::new(responses);
1188        let mock_provider = MockProvider::new(handler);
1189        let provider = TimeoutProvider::new(mock_provider, vec![timeout_ip]);
1190
1191        let opts = ResolverOpts {
1192            num_concurrent_reqs: 1,
1193            server_ordering_strategy: ServerOrderingStrategy::UserProvidedOrder,
1194            ..ResolverOpts::default()
1195        };
1196
1197        let ns_timeout = Arc::new(NameServer::new(
1198            [].into_iter(),
1199            NameServerConfig::udp(timeout_ip),
1200            &opts,
1201            provider.clone(),
1202        ));
1203        let ns_good = Arc::new(NameServer::new(
1204            [].into_iter(),
1205            NameServerConfig::udp(good_ip),
1206            &opts,
1207            provider.clone(),
1208        ));
1209
1210        let initial_srtt_timeout = ns_timeout.decayed_srtt();
1211
1212        let pool = NameServerPool::from_nameservers(
1213            vec![ns_timeout.clone(), ns_good.clone()],
1214            Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1215        );
1216
1217        // Perform a lookup - the first server will timeout, second will succeed.
1218        let _response = pool
1219            .lookup(
1220                Query::query(query_name.clone(), RecordType::A),
1221                DnsRequestOptions::default(),
1222            )
1223            .first_answer()
1224            .await
1225            .expect("lookup should succeed via second server");
1226
1227        // The timeout server's SRTT should have been penalized (increased).
1228        assert!(
1229            ns_timeout.decayed_srtt() > initial_srtt_timeout,
1230            "timeout server SRTT should increase after failure: {} should be > {}",
1231            ns_timeout.decayed_srtt(),
1232            initial_srtt_timeout,
1233        );
1234
1235        // The good server's SRTT should not have been penalized.
1236        // It may have changed slightly due to recording a successful RTT, but should
1237        // not have jumped to the failure penalty value.
1238        let failure_penalty = 5_000_000.0_f64; // SRTT failure penalty
1239        assert!(
1240            ns_good.decayed_srtt() < failure_penalty,
1241            "good server SRTT should not be penalized: {}",
1242            ns_good.decayed_srtt(),
1243        );
1244    }
1245
1246    /// A [`RuntimeProvider`] wrapper that returns `io::ErrorKind::TimedOut` from `bind_udp`
1247    /// for a specified set of server IPs, simulating a connection-level timeout. All other
1248    /// IPs are delegated to the inner provider.
1249    #[derive(Clone)]
1250    struct TimeoutProvider {
1251        inner: MockProvider,
1252        timeout_ips: Arc<HashSet<IpAddr>>,
1253    }
1254
1255    impl TimeoutProvider {
1256        fn new(inner: MockProvider, timeout_ips: Vec<IpAddr>) -> Self {
1257            Self {
1258                inner,
1259                timeout_ips: Arc::new(timeout_ips.into_iter().collect()),
1260            }
1261        }
1262    }
1263
1264    impl RuntimeProvider for TimeoutProvider {
1265        type Handle = TokioHandle;
1266        type Timer = TokioTime;
1267        type Udp = MockUdpSocket;
1268        type Tcp = MockTcpStream;
1269
1270        fn create_handle(&self) -> Self::Handle {
1271            self.inner.create_handle()
1272        }
1273
1274        fn connect_tcp(
1275            &self,
1276            server_addr: SocketAddr,
1277            bind_addr: Option<SocketAddr>,
1278            timeout: Option<Duration>,
1279        ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
1280            if self.timeout_ips.contains(&server_addr.ip()) {
1281                Box::pin(future::ready(Err(io::Error::from(io::ErrorKind::TimedOut))))
1282            } else {
1283                self.inner.connect_tcp(server_addr, bind_addr, timeout)
1284            }
1285        }
1286
1287        fn bind_udp(
1288            &self,
1289            local_addr: SocketAddr,
1290            server_addr: SocketAddr,
1291        ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
1292            if self.timeout_ips.contains(&server_addr.ip()) {
1293                Box::pin(future::ready(Err(io::Error::from(io::ErrorKind::TimedOut))))
1294            } else {
1295                self.inner.bind_udp(local_addr, server_addr)
1296            }
1297        }
1298    }
1299
1300    /// Regression test: an unreachable server racing in parallel must be penalized.
1301    ///
1302    /// When `num_concurrent_reqs >= 2`, multiple servers are queried in parallel
1303    /// via `FuturesUnordered`. If a reachable server responds first, the
1304    /// unreachable server's future is dropped (cancelled). Before the fix, this
1305    /// meant `record_failure()` was never called for the unreachable server,
1306    /// leaving its SRTT unchanged so it would be retried on every subsequent
1307    /// query.
1308    #[tokio::test]
1309    async fn test_cancelled_parallel_server_is_penalized() {
1310        subscribe();
1311
1312        let unreachable_ip = IpAddr::from([10, 0, 0, 1]);
1313        let good_ip = IpAddr::from([10, 0, 0, 2]);
1314        let query_name = Name::from_str("example.com.").unwrap();
1315
1316        let responses = vec![MockRecord::a(good_ip, &query_name, good_ip)];
1317        let handler = MockNetworkHandler::new(responses);
1318        let mock_provider = MockProvider::new(handler);
1319        let provider = PendingProvider::new(mock_provider, vec![unreachable_ip]);
1320
1321        let opts = ResolverOpts {
1322            // Both servers are queried in parallel — the key condition for this bug.
1323            num_concurrent_reqs: 2,
1324            server_ordering_strategy: ServerOrderingStrategy::UserProvidedOrder,
1325            ..ResolverOpts::default()
1326        };
1327
1328        let ns_unreachable = Arc::new(NameServer::new(
1329            [].into_iter(),
1330            NameServerConfig::udp(unreachable_ip),
1331            &opts,
1332            provider.clone(),
1333        ));
1334        let ns_good = Arc::new(NameServer::new(
1335            [].into_iter(),
1336            NameServerConfig::udp(good_ip),
1337            &opts,
1338            provider.clone(),
1339        ));
1340
1341        let initial_srtt = ns_unreachable.decayed_srtt();
1342
1343        let pool = NameServerPool::from_nameservers(
1344            vec![ns_unreachable.clone(), ns_good.clone()],
1345            Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1346        );
1347
1348        // The good server wins the race; the unreachable server's future is cancelled.
1349        let _response = pool
1350            .lookup(
1351                Query::query(query_name.clone(), RecordType::A),
1352                DnsRequestOptions::default(),
1353            )
1354            .first_answer()
1355            .await
1356            .expect("lookup should succeed via good server");
1357
1358        // The unreachable server's SRTT must have increased despite its future
1359        // being cancelled (not completing with an error).
1360        assert!(
1361            ns_unreachable.decayed_srtt() > initial_srtt,
1362            "unreachable server SRTT should increase after being cancelled: {} should be > {}",
1363            ns_unreachable.decayed_srtt(),
1364            initial_srtt,
1365        );
1366
1367        // The good server should not have been penalized.
1368        let failure_penalty = 5_000_000.0_f64;
1369        assert!(
1370            ns_good.decayed_srtt() < failure_penalty,
1371            "good server SRTT should not be penalized: {}",
1372            ns_good.decayed_srtt(),
1373        );
1374    }
1375
1376    /// A [`RuntimeProvider`] wrapper where specified IPs never complete their
1377    /// connection — the future stays pending forever. This simulates an
1378    /// unreachable server (SYN sent, no SYN-ACK) where the OS TCP handshake
1379    /// hasn't timed out yet.
1380    #[derive(Clone)]
1381    struct PendingProvider {
1382        inner: MockProvider,
1383        pending_ips: Arc<HashSet<IpAddr>>,
1384    }
1385
1386    impl PendingProvider {
1387        fn new(inner: MockProvider, pending_ips: Vec<IpAddr>) -> Self {
1388            Self {
1389                inner,
1390                pending_ips: Arc::new(pending_ips.into_iter().collect()),
1391            }
1392        }
1393    }
1394
1395    impl RuntimeProvider for PendingProvider {
1396        type Handle = TokioHandle;
1397        type Timer = TokioTime;
1398        type Udp = MockUdpSocket;
1399        type Tcp = MockTcpStream;
1400
1401        fn create_handle(&self) -> Self::Handle {
1402            self.inner.create_handle()
1403        }
1404
1405        fn connect_tcp(
1406            &self,
1407            server_addr: SocketAddr,
1408            bind_addr: Option<SocketAddr>,
1409            timeout: Option<Duration>,
1410        ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
1411            if self.pending_ips.contains(&server_addr.ip()) {
1412                Box::pin(future::pending())
1413            } else {
1414                self.inner.connect_tcp(server_addr, bind_addr, timeout)
1415            }
1416        }
1417
1418        fn bind_udp(
1419            &self,
1420            local_addr: SocketAddr,
1421            server_addr: SocketAddr,
1422        ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
1423            if self.pending_ips.contains(&server_addr.ip()) {
1424                Box::pin(future::pending())
1425            } else {
1426                self.inner.bind_udp(local_addr, server_addr)
1427            }
1428        }
1429    }
1430
1431    /// Regression test: `sort_servers_by_query_statistics` must not panic when
1432    /// SRTT values are concurrently modified.
1433    ///
1434    /// `record()` and `record_failure()` can modify a server's SRTT while
1435    /// another thread sorts the server list. With `sort_by`, the comparator
1436    /// re-evaluates `decayed_srtt()` on every comparison, observing values
1437    /// that change between calls and violating the total-order invariant.
1438    /// The fix uses `sort_by_cached_key`, which evaluates each key exactly
1439    /// once before sorting.
1440    #[test]
1441    fn test_sort_by_decayed_srtt_does_not_panic() {
1442        let opts = ResolverOpts::default();
1443        let mock_provider = MockProvider::new(MockNetworkHandler::new(vec![]));
1444
1445        let mut servers = (1..=50)
1446            .map(|i| {
1447                let ns = Arc::new(NameServer::new(
1448                    [],
1449                    NameServerConfig::udp(IpAddr::from([10, 0, 0, i])),
1450                    &opts,
1451                    mock_provider.clone(),
1452                ));
1453                // Activate the time-based decay path by recording a failure,
1454                // which sets `last_update` to `Some(now)`.
1455                ns.test_record_failure();
1456                ns
1457            })
1458            .collect::<Vec<_>>();
1459
1460        // Spawn a thread that continuously modifies SRTT values, simulating
1461        // concurrent queries completing on other threads.
1462        let servers_writer = servers.clone();
1463        let stop = Arc::new(AtomicBool::new(false));
1464        let stop_writer = stop.clone();
1465        let writer = thread::spawn(move || {
1466            while !stop_writer.load(Ordering::Relaxed) {
1467                for s in &servers_writer {
1468                    s.test_record_failure();
1469                }
1470            }
1471        });
1472
1473        // Ensure the writer thread stops even if the test panics.
1474        struct StopGuard(Arc<AtomicBool>);
1475        impl Drop for StopGuard {
1476            fn drop(&mut self) {
1477                self.0.store(true, Ordering::Relaxed);
1478            }
1479        }
1480        let _guard = StopGuard(stop.clone());
1481
1482        // Call the production sort function many times while the writer
1483        // thread concurrently modifies SRTT values. With sort_by_cached_key
1484        // this is safe. With sort_by, the concurrent modifications cause
1485        // inconsistent comparisons that panic the sort.
1486        for _ in 0..100_000 {
1487            sort_servers_by_query_statistics(&mut servers);
1488        }
1489
1490        stop.store(true, Ordering::Relaxed);
1491        writer.join().unwrap();
1492    }
1493}