1use std::collections::{HashMap, VecDeque};
9use std::net::IpAddr;
10use std::pin::Pin;
11use std::sync::atomic::AtomicU8;
12use std::sync::{
13 Arc,
14 atomic::{AtomicUsize, Ordering as AtomicOrdering},
15};
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant, SystemTime};
18
19use futures_util::lock::{Mutex as AsyncMutex, MutexGuard};
20use futures_util::stream::{FuturesUnordered, Stream, StreamExt, once};
21use futures_util::{
22 Future, FutureExt,
23 future::{BoxFuture, Shared},
24};
25use parking_lot::Mutex;
26#[cfg(feature = "serde")]
27use serde::{Deserialize, Serialize};
28use smallvec::SmallVec;
29#[cfg(all(feature = "toml", any(feature = "__tls", feature = "__quic")))]
30use tracing::info;
31use tracing::{debug, error};
32
33#[cfg(any(feature = "__tls", feature = "__quic"))]
34use crate::config::OpportunisticEncryptionConfig;
35use crate::{
36 config::{NameServerConfig, OpportunisticEncryption, ResolverOpts, ServerOrderingStrategy},
37 connection_provider::{ConnectionProvider, TlsConfig},
38 name_server::{ConnectionPolicy, NameServer},
39 net::{
40 DnsError, NetError, NoRecords,
41 runtime::{RuntimeProvider, Time},
42 xfer::{DnsHandle, Protocol},
43 },
44 proto::{
45 access_control::AccessControlSet,
46 op::{DnsRequest, DnsRequestOptions, DnsResponse, OpCode, Query, ResponseCode},
47 rr::{
48 Name, RData, Record,
49 rdata::{
50 A, AAAA,
51 opt::{ClientSubnet, EdnsCode, EdnsOption},
52 },
53 },
54 },
55};
56
57#[derive(Clone)]
59pub struct NameServerPool<P: ConnectionProvider> {
60 state: Arc<PoolState<P>>,
61 active_requests: Arc<Mutex<HashMap<Arc<CacheKey>, SharedLookup>>>,
62 ttl: Option<TtlInstant>,
63 zone: Option<Name>,
64}
65
66impl<P: ConnectionProvider> NameServerPool<P> {
67 pub fn from_config(
69 servers: impl IntoIterator<Item = NameServerConfig>,
70 cx: Arc<PoolContext>,
71 conn_provider: P,
72 ) -> Self {
73 Self::from_nameservers(
74 servers
75 .into_iter()
76 .map(|server| {
77 Arc::new(NameServer::new(
78 [],
79 server,
80 &cx.options,
81 conn_provider.clone(),
82 ))
83 })
84 .collect(),
85 cx,
86 )
87 }
88
89 #[doc(hidden)]
90 pub fn from_nameservers(servers: Vec<Arc<NameServer<P>>>, cx: Arc<PoolContext>) -> Self {
91 Self {
92 state: Arc::new(PoolState {
93 servers,
94 cx,
95 next: AtomicUsize::new(0),
96 }),
97 active_requests: Arc::new(Mutex::new(HashMap::new())),
98 ttl: None,
99 zone: None,
100 }
101 }
102
103 pub fn with_ttl(mut self, ttl: Duration) -> Self {
105 self.ttl = Some(TtlInstant::now() + ttl);
106 self
107 }
108
109 pub fn with_zone(mut self, zone: Name) -> Self {
111 self.zone = Some(zone);
112 self
113 }
114
115 pub fn ttl_expired(&self) -> bool {
117 match self.ttl {
118 Some(ttl) => TtlInstant::now() > ttl,
119 None => false,
120 }
121 }
122
123 pub fn context(&self) -> &Arc<PoolContext> {
125 &self.state.cx
126 }
127
128 pub fn zone(&self) -> Option<&Name> {
130 self.zone.as_ref()
131 }
132}
133
134#[cfg(not(feature = "tokio"))]
136type TtlInstant = std::time::Instant;
137#[cfg(feature = "tokio")]
138type TtlInstant = tokio::time::Instant;
139
140impl<P: ConnectionProvider> DnsHandle for NameServerPool<P> {
141 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
142 type Runtime = P::RuntimeProvider;
143
144 fn lookup(&self, query: Query, mut options: DnsRequestOptions) -> Self::Response {
145 debug!("querying: {} {:?}", query.name(), query.query_type());
146 options.case_randomization = self.state.cx.options.case_randomization;
147 self.send(DnsRequest::from_query(query, options))
148 }
149
150 fn send(&self, request: DnsRequest) -> Self::Response {
151 let state = self.state.clone();
152 let acs = self.state.cx.answer_address_filter.clone();
153 let active_requests = self.active_requests.clone();
154
155 Box::pin(once(async move {
156 debug!("sending request: {:?}", request.queries);
157 let query = match request.queries.first() {
158 Some(q) => q.clone(),
159 None => return Err("no query in request".into()),
160 };
161
162 let key = Arc::new(CacheKey::from_request(&request));
163
164 let (lookup, is_creator) = {
165 let mut active = active_requests.lock();
166 if let Some(existing) = active.get(&key) {
167 debug!(%query, "query currently in progress - returning shared lookup");
168 (existing.clone(), false)
169 } else {
170 debug!(%query, "creating new shared lookup");
171
172 let lookup = async move {
173 match state.try_send(request).await {
174 Ok(response) => Some(Ok(response)),
175 Err(e) => Some(Err(e)),
176 }
177 }
178 .boxed()
179 .shared();
180
181 let shared_lookup = SharedLookup(lookup);
182 active.insert(key.clone(), shared_lookup.clone());
183 (shared_lookup, true)
184 }
185 };
186
187 let _cleanup = is_creator.then(|| ActiveRequestCleanup {
193 active_requests: active_requests.clone(),
194 key: key.clone(),
195 });
196
197 let response = lookup.await;
198 let mut response = response?;
199
200 if acs.allows_all() {
201 return Ok(response);
202 }
203
204 let answer_filter = |record: &Record| {
205 let ip = match &record.data {
206 RData::A(A(ipv4)) => (*ipv4).into(),
207 RData::AAAA(AAAA(ipv6)) => (*ipv6).into(),
208 _ => return true,
209 };
210
211 if acs.denied(ip) {
212 error!(
213 %query,
214 %ip,
215 "removing ip from response: answer filter matched"
216 );
217
218 false
219 } else {
220 true
221 }
222 };
223
224 let answers_len = response.answers.len();
225 let authorities_len = response.authorities.len();
226
227 response.additionals.retain(answer_filter);
228 response.answers.retain(answer_filter);
229 response.authorities.retain(answer_filter);
230
231 if response.answers.is_empty() && answers_len != 0
232 || (response.answers.is_empty()
233 && response.authorities.is_empty()
234 && authorities_len != 0)
235 {
236 return Err(NoRecords::new(Box::new(query.clone()), ResponseCode::NXDomain).into());
237 }
238
239 DnsResponse::from_message(response.into_message()).map_err(NetError::from)
242 }))
243 }
244}
245
246struct PoolState<P: ConnectionProvider> {
247 servers: Vec<Arc<NameServer<P>>>,
248 cx: Arc<PoolContext>,
249 next: AtomicUsize,
250}
251
252impl<P: ConnectionProvider> PoolState<P> {
253 async fn try_send(&self, request: DnsRequest) -> Result<DnsResponse, NetError> {
254 let mut servers = self.servers.clone();
255 match self.cx.options.server_ordering_strategy {
256 ServerOrderingStrategy::QueryStatistics => {
260 sort_servers_by_query_statistics(&mut servers);
261 }
262 ServerOrderingStrategy::UserProvidedOrder => {}
263 ServerOrderingStrategy::RoundRobin => {
264 let num_concurrent_reqs = if self.cx.options.num_concurrent_reqs > 1 {
265 self.cx.options.num_concurrent_reqs
266 } else {
267 1
268 };
269 if num_concurrent_reqs < servers.len() {
270 let index = self
271 .next
272 .fetch_add(num_concurrent_reqs, AtomicOrdering::SeqCst)
273 % servers.len();
274 servers.rotate_left(index);
275 }
276 }
277 }
278
279 let deadline = Instant::now() + self.cx.options.timeout;
291
292 let mut servers = VecDeque::from(servers);
293 let mut backoff = Duration::from_millis(20);
294 let mut busy = SmallVec::<[Arc<NameServer<P>>; 2]>::new();
295 let mut err = NetError::NoConnections;
296 let mut policy = ConnectionPolicy::default();
297
298 loop {
299 if Instant::now() >= deadline {
301 return Err(NetError::Timeout);
302 }
303
304 let mut par_servers = SmallVec::<[_; 2]>::new();
306 while !servers.is_empty()
307 && par_servers.len() < Ord::max(self.cx.options.num_concurrent_reqs, 1)
308 {
309 if let Some(server) = servers.pop_front() {
310 if policy.allows_server(&server) {
311 par_servers.push(server);
312 }
313 }
314 }
315
316 if par_servers.is_empty() {
317 if !busy.is_empty() && backoff < Duration::from_millis(300) {
318 let remaining = deadline.saturating_duration_since(Instant::now());
320 if remaining.is_zero() {
321 return Err(NetError::Timeout);
322 }
323 <<P as ConnectionProvider>::RuntimeProvider as RuntimeProvider>::Timer::delay_for(
324 backoff.min(remaining),
325 ).await;
326 servers.extend(busy.drain(..).filter(|ns| policy.allows_server(ns)));
327 backoff *= 2;
328 continue;
329 }
330 return Err(err);
331 }
332
333 let in_flight = par_servers.iter().cloned().collect::<SmallVec<[_; 2]>>();
336
337 let batch_start = Instant::now();
338 let mut requests = par_servers
339 .into_iter()
340 .map(|server| {
341 let mut request = request.clone();
342
343 let retry_interval =
345 Duration::from_micros((server.decayed_srtt() * 1.2) as u64);
346 request.options_mut().retry_interval = retry_interval;
347 debug!(?retry_interval, ip = ?server.ip(), "setting retry_interval");
348
349 let future = server.clone().send(request, policy, &self.cx);
350 async { (server, future.await) }
351 })
352 .collect::<FuturesUnordered<_>>();
353
354 let mut completed = SmallVec::<[IpAddr; 2]>::new();
357
358 while let Some((server, result)) = requests.next().await {
359 completed.push(server.ip());
360 let e = match result {
361 Ok(response) if response.truncation => {
362 debug!("truncated response received, retrying over TCP");
363 policy.disable_udp = true;
364 err = NetError::from("received truncated response");
365 servers.push_front(server);
366 continue;
367 }
368 Ok(response) => {
369 let winner_rtt = batch_start.elapsed();
371 for abandoned in &in_flight {
372 if !completed.contains(&abandoned.ip()) {
373 debug!(ip = ?abandoned.ip(), ?winner_rtt, "recording cancelled parallel server");
374 abandoned.record_cancelled(winner_rtt);
375 }
376 }
377 return Ok(response);
378 }
379 Err(e) => e,
380 };
381
382 match &e {
383 NetError::QueryCaseMismatch => {
386 servers.push_front(server);
387 policy.disable_udp = true;
388 continue;
389 }
390 NetError::Busy => busy.push(server),
392 NetError::Io(_) | NetError::NoConnections | NetError::Timeout => {}
394 NetError::Dns(DnsError::NoRecordsFound(NoRecords {
397 response_code: ResponseCode::NXDomain,
398 ..
399 })) if !server.trust_negative_responses() => {}
400 _ => return Err(e),
401 }
402
403 err = most_specific(err, e);
404 }
405 }
406 }
407}
408
409fn most_specific(previous: NetError, current: NetError) -> NetError {
411 match (&previous, ¤t) {
412 (
413 NetError::Dns(DnsError::NoRecordsFound { .. }),
414 NetError::Dns(DnsError::NoRecordsFound { .. }),
415 ) => return previous,
416 (NetError::Dns(DnsError::NoRecordsFound { .. }), _) => return previous,
417 (_, NetError::Dns(DnsError::NoRecordsFound { .. })) => return current,
418 _ => (),
419 }
420
421 match (&previous, ¤t) {
422 (NetError::Io { .. }, NetError::Io { .. }) => return previous,
423 (NetError::Io { .. }, _) => return current,
424 (_, NetError::Io { .. }) => return previous,
425 _ => (),
426 }
427
428 match (&previous, ¤t) {
429 (NetError::Timeout, NetError::Timeout) => return previous,
430 (NetError::Timeout, _) => return previous,
431 (_, NetError::Timeout) => return current,
432 _ => (),
433 }
434
435 previous
436}
437
438pub(crate) fn sort_servers_by_query_statistics<P: ConnectionProvider>(
445 servers: &mut [Arc<NameServer<P>>],
446) {
447 servers.sort_by_cached_key(|s| s.decayed_srtt().to_bits());
450}
451
452#[non_exhaustive]
454pub struct PoolContext {
455 pub options: ResolverOpts,
457 #[cfg(feature = "__tls")]
459 pub tls: rustls::ClientConfig,
460 pub opportunistic_probe_budget: AtomicU8,
462 pub opportunistic_encryption: OpportunisticEncryption,
464 pub transport_state: AsyncMutex<NameServerTransportState>,
466 pub answer_address_filter: AccessControlSet,
468}
469
470impl PoolContext {
471 #[cfg_attr(not(feature = "__tls"), expect(unused_variables))]
473 pub fn new(options: ResolverOpts, tls: TlsConfig) -> Self {
474 Self {
475 answer_address_filter: options.answer_address_filter(),
476 options,
477 #[cfg(feature = "__tls")]
478 tls: tls.config,
479 opportunistic_probe_budget: AtomicU8::default(),
480 opportunistic_encryption: OpportunisticEncryption::default(),
481 transport_state: AsyncMutex::new(NameServerTransportState::default()),
482 }
483 }
484
485 pub fn with_probe_budget(self, budget: u8) -> Self {
487 self.opportunistic_probe_budget
488 .store(budget, AtomicOrdering::SeqCst);
489 self
490 }
491
492 pub fn with_answer_filter(mut self, answer_filter: AccessControlSet) -> Self {
494 self.answer_address_filter = answer_filter;
495 self
496 }
497
498 #[cfg(any(feature = "__tls", feature = "__quic"))]
500 pub fn with_opportunistic_encryption(mut self) -> Self {
501 self.opportunistic_encryption = OpportunisticEncryption::Enabled {
502 config: OpportunisticEncryptionConfig::default(),
503 };
504 self
505 }
506
507 pub fn with_transport_state(mut self, transport_state: NameServerTransportState) -> Self {
509 self.transport_state = AsyncMutex::new(transport_state);
510 self
511 }
512
513 pub(crate) async fn transport_state(&self) -> MutexGuard<'_, NameServerTransportState> {
514 self.transport_state.lock().await
515 }
516}
517
518#[derive(Debug, Default, Clone)]
520#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
521#[repr(transparent)]
522pub struct NameServerTransportState(HashMap<IpAddr, ProtocolTransportState>);
523
524impl NameServerTransportState {
525 pub fn nameserver_count(&self) -> usize {
527 self.0.len()
528 }
529
530 pub(crate) fn initiate_connection(&mut self, ip: IpAddr, protocol: Protocol) {
532 let protocol_state = self.0.entry(ip).or_default();
533 *protocol_state.get_mut(protocol) = TransportState::default();
534 }
535
536 pub(crate) fn complete_connection(&mut self, ip: IpAddr, protocol: Protocol) {
538 let protocol_state = self.0.entry(ip).or_default();
539 *protocol_state.get_mut(protocol) = TransportState::Success {
540 last_response: None,
541 };
542 }
543
544 pub(crate) fn response_received(&mut self, ip: IpAddr, protocol: Protocol) {
546 let Some(protocol_state) = self.0.get_mut(&ip) else {
547 return;
548 };
549 let TransportState::Success { last_response, .. } = protocol_state.get_mut(protocol) else {
550 return;
551 };
552 *last_response = Some(SystemTime::now());
553 }
554
555 pub(crate) fn error_received(&mut self, ip: IpAddr, protocol: Protocol, error: &NetError) {
557 let protocol_state = self.0.entry(ip).or_default();
558 *protocol_state.get_mut(protocol) = match &error {
559 NetError::Timeout => TransportState::TimedOut {
560 #[cfg(any(feature = "__tls", feature = "__quic"))]
561 completed_at: SystemTime::now(),
562 },
563 _ => TransportState::Failed {
564 #[cfg(any(feature = "__tls", feature = "__quic"))]
565 completed_at: SystemTime::now(),
566 },
567 };
568 }
569
570 #[cfg(any(feature = "__tls", feature = "__quic"))]
573 pub(crate) fn any_recent_success(&self, ip: IpAddr, config: &OpportunisticEncryption) -> bool {
574 #[allow(unused_assignments, unused_mut)]
575 let mut tls_success = false;
576 #[allow(unused_assignments, unused_mut)]
577 let mut quic_success = false;
578
579 #[cfg(feature = "__tls")]
580 {
581 tls_success = self.recent_success(ip, Protocol::Tls, config);
582 }
583
584 #[cfg(feature = "__quic")]
585 {
586 quic_success = self.recent_success(ip, Protocol::Quic, config);
587 }
588
589 tls_success || quic_success
590 }
591
592 #[cfg(not(any(feature = "__tls", feature = "__quic")))]
594 pub(crate) fn any_recent_success(
595 &self,
596 _ip: IpAddr,
597 _config: &OpportunisticEncryption,
598 ) -> bool {
599 false
600 }
601
602 #[cfg(any(feature = "__tls", feature = "__quic"))]
608 pub(crate) fn recent_success(
609 &self,
610 ip: IpAddr,
611 protocol: Protocol,
612 config: &OpportunisticEncryption,
613 ) -> bool {
614 let OpportunisticEncryption::Enabled { config } = config else {
615 return false;
616 };
617
618 let Some(protocol_state) = self.0.get(&ip) else {
619 return false;
620 };
621
622 let TransportState::Success { last_response, .. } = protocol_state.get(protocol) else {
623 return false;
624 };
625
626 let Some(last_response) = last_response else {
627 return false;
628 };
629
630 last_response.elapsed().unwrap_or(Duration::MAX) <= config.persistence_period
631 }
632
633 #[cfg(not(any(feature = "__tls", feature = "__quic")))]
638 pub(crate) fn recent_success(
639 &self,
640 _ip: IpAddr,
641 _protocol: Protocol,
642 _config: &OpportunisticEncryption,
643 ) -> bool {
644 false
645 }
646
647 #[cfg(any(feature = "__tls", feature = "__quic"))]
649 pub(crate) fn should_probe_encrypted(
650 &self,
651 ip: IpAddr,
652 protocol: Protocol,
653 config: &OpportunisticEncryption,
654 ) -> bool {
655 debug_assert!(protocol.is_encrypted());
656
657 let OpportunisticEncryption::Enabled { config, .. } = config else {
658 return false;
659 };
660
661 let Some(protocol_state) = self.0.get(&ip) else {
662 return true;
663 };
664
665 match protocol_state.get(protocol) {
666 TransportState::Initiated => false,
667 TransportState::Success { .. } => true,
668 TransportState::Failed { completed_at } | TransportState::TimedOut { completed_at } => {
669 completed_at.elapsed().unwrap_or(Duration::MAX) > config.damping_period
670 }
671 }
672 }
673
674 #[cfg(not(any(feature = "__tls", feature = "__quic")))]
676 pub(crate) fn should_probe_encrypted(
677 &self,
678 _ip: IpAddr,
679 _protocol: Protocol,
680 _config: &OpportunisticEncryption,
681 ) -> bool {
682 false
683 }
684
685 #[cfg(all(test, feature = "__tls"))]
687 pub(crate) fn set_last_response(&mut self, ip: IpAddr, protocol: Protocol, when: SystemTime) {
688 let Some(protocol_state) = self.0.get_mut(&ip) else {
689 return;
690 };
691
692 let TransportState::Success { last_response, .. } = protocol_state.get_mut(protocol) else {
693 return;
694 };
695
696 *last_response = Some(when);
697 }
698
699 #[cfg(all(test, feature = "__tls"))]
701 pub(crate) fn set_failure_time(&mut self, ip: IpAddr, protocol: Protocol, when: SystemTime) {
702 let protocol_state = self.0.entry(ip).or_default();
703 *protocol_state.get_mut(protocol) = TransportState::Failed { completed_at: when };
704 }
705}
706
707#[derive(Debug, Clone, Copy, Default)]
708#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
709struct ProtocolTransportState {
710 #[cfg(feature = "__tls")]
711 tls: TransportState,
712 #[cfg(feature = "__quic")]
713 quic: TransportState,
714}
715
716impl ProtocolTransportState {
717 #[cfg_attr(not(any(feature = "__tls", feature = "__quic")), allow(dead_code))]
718 fn get_mut(&mut self, protocol: Protocol) -> &mut TransportState {
719 match protocol {
720 #[cfg(feature = "__tls")]
721 Protocol::Tls => &mut self.tls,
722 #[cfg(feature = "__quic")]
723 Protocol::Quic => &mut self.quic,
724 _ => unreachable!("unsupported opportunistic encryption protocol: {protocol:?}"),
725 }
726 }
727
728 #[cfg_attr(not(any(feature = "__tls", feature = "__quic")), allow(dead_code))]
729 fn get(&self, protocol: Protocol) -> &TransportState {
730 match protocol {
731 #[cfg(feature = "__tls")]
732 Protocol::Tls => &self.tls,
733 #[cfg(feature = "__quic")]
734 Protocol::Quic => &self.quic,
735 _ => unreachable!("unsupported opportunistic encryption protocol: {protocol:?}"),
736 }
737 }
738}
739
740#[derive(Debug, Clone, Copy, Default)]
742#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
743enum TransportState {
744 #[default]
746 Initiated,
747 Success {
749 last_response: Option<SystemTime>,
751 },
752 Failed {
754 #[cfg(any(feature = "__tls", feature = "__quic"))]
756 completed_at: SystemTime,
757 },
758 TimedOut {
760 #[cfg(any(feature = "__tls", feature = "__quic"))]
762 completed_at: SystemTime,
763 },
764}
765
766#[cfg(all(feature = "toml", any(feature = "__tls", feature = "__quic")))]
767pub use opportunistic_encryption_persistence::OpportunisticEncryptionStatePersistTask;
768
769#[cfg(all(feature = "toml", any(feature = "__tls", feature = "__quic")))]
770mod opportunistic_encryption_persistence {
771 #[cfg(unix)]
772 use std::fs::File;
773 use std::{
774 fs::{self, OpenOptions},
775 io::{self, Write},
776 marker::PhantomData,
777 path::{Path, PathBuf},
778 };
779
780 use tracing::trace;
781
782 use super::*;
783 use crate::config::OpportunisticEncryptionPersistence;
784 use crate::net::runtime::Spawn;
785
786 pub struct OpportunisticEncryptionStatePersistTask<T> {
788 cx: Arc<PoolContext>,
789 path: PathBuf,
790 save_interval: Duration,
791 _time: PhantomData<T>,
792 }
793
794 impl<T: Time> OpportunisticEncryptionStatePersistTask<T> {
795 pub async fn start<P: RuntimeProvider>(
797 config: OpportunisticEncryptionPersistence,
798 pool_context: &Arc<PoolContext>,
799 conn_provider: P,
800 ) -> Result<Option<P::Handle>, String> {
801 info!(
802 path = %config.path.display(),
803 save_interval = ?config.save_interval,
804 "spawning encrypted transport state persistence task"
805 );
806
807 let new =
808 OpportunisticEncryptionStatePersistTask::<P::Timer>::new(config, pool_context);
809
810 new.save(&*new.cx.transport_state.lock().await)
813 .map_err(|err| {
814 format!(
815 "failed to save opportunistic encryption state: {path}: {err}",
816 path = new.path.display()
817 )
818 })?;
819
820 let mut handle = conn_provider.create_handle();
821 handle.spawn_bg(new.run());
822 Ok(Some(handle))
823 }
824
825 fn new(config: OpportunisticEncryptionPersistence, cx: &Arc<PoolContext>) -> Self {
826 Self {
827 cx: cx.clone(),
828 path: config.path,
829 save_interval: config.save_interval,
830 _time: PhantomData,
831 }
832 }
833
834 async fn run(self) {
835 let Self {
836 save_interval,
837 path,
838 cx,
839 ..
840 } = &self;
841
842 loop {
843 T::delay_for(*save_interval).await;
844 trace!(path = %path.display(), ?save_interval, "persisting opportunistic encryption state");
845 if let Err(e) = self.save(&*cx.transport_state.lock().await) {
846 error!("failed to save opportunistic encryption state: {e}");
847 }
848 }
849 }
850
851 fn save(&self, state: &NameServerTransportState) -> Result<(), io::Error> {
852 let toml_content = toml::to_string_pretty(state).map_err(|e| {
853 io::Error::new(
854 io::ErrorKind::InvalidData,
855 format!("failed to serialize state to TOML: {e}"),
856 )
857 })?;
858
859 if let Some(parent) = parent_directory(&self.path) {
860 fs::create_dir_all(parent)?;
861 }
862
863 let temp_path = {
864 let mut temp = self.path.as_os_str().to_os_string();
865 temp.push(".tmp");
866 PathBuf::from(temp)
867 };
868
869 {
870 let mut temp_file = OpenOptions::new()
871 .write(true)
872 .create(true)
873 .truncate(true)
874 .open(&temp_path)?;
875
876 temp_file.write_all(toml_content.as_bytes())?;
877 temp_file.sync_all()?;
878 }
879
880 #[cfg(unix)]
881 if let Some(parent) = parent_directory(&self.path) {
882 File::open(parent)?.sync_all()?;
883 }
884
885 fs::rename(&temp_path, &self.path)?;
886 debug!(state_file = %self.path.display(), "saved opportunistic encryption state");
887 Ok(())
888 }
889 }
890
891 fn parent_directory(path: &Path) -> Option<&Path> {
893 let parent = path.parent()?;
894 Some(match parent == Path::new("") {
897 true => Path::new("."),
898 false => parent,
899 })
900 }
901}
902
903struct ActiveRequestCleanup {
910 active_requests: Arc<Mutex<HashMap<Arc<CacheKey>, SharedLookup>>>,
911 key: Arc<CacheKey>,
912}
913
914impl Drop for ActiveRequestCleanup {
915 fn drop(&mut self) {
916 self.active_requests.lock().remove(&self.key);
917 }
918}
919
920#[derive(PartialEq, Eq, Hash)]
922struct CacheKey {
923 op_code: OpCode,
924 recursion_desired: bool,
925 checking_disabled: bool,
926 queries: Vec<Query>,
927 dnssec_ok: bool,
928 client_subnet: Option<ClientSubnet>,
929}
930
931impl CacheKey {
932 fn from_request(request: &DnsRequest) -> Self {
933 let dnssec_ok;
934 let client_subnet;
935 if let Some(edns) = &request.edns {
936 dnssec_ok = edns.flags().dnssec_ok;
937 if let Some(EdnsOption::Subnet(subnet)) = edns.option(EdnsCode::Subnet) {
938 client_subnet = Some(*subnet);
939 } else {
940 client_subnet = None;
941 }
942 } else {
943 dnssec_ok = false;
944 client_subnet = None;
945 }
946 Self {
947 op_code: request.op_code,
948 recursion_desired: request.recursion_desired,
949 checking_disabled: request.checking_disabled,
950 queries: request.queries.clone(),
951 dnssec_ok,
952 client_subnet,
953 }
954 }
955}
956
957#[derive(Clone)]
958pub(crate) struct SharedLookup(Shared<BoxFuture<'static, Option<Result<DnsResponse, NetError>>>>);
959
960impl Future for SharedLookup {
961 type Output = Result<DnsResponse, NetError>;
962
963 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
964 self.0.poll_unpin(cx).map(|o| match o {
965 Some(r) => r,
966 None => Err("no response from nameserver".into()),
967 })
968 }
969}
970
971#[cfg(test)]
972#[cfg(feature = "tokio")]
973mod tests {
974 use std::collections::HashSet;
975 use std::future::Future;
976 use std::io;
977 use std::net::{IpAddr, SocketAddr};
978 use std::pin::Pin;
979 use std::str::FromStr;
980 use std::sync::atomic::{AtomicBool, Ordering};
981 use std::thread;
982 use std::time::Duration;
983
984 use futures_util::future;
985 use test_support::{
986 MockNetworkHandler, MockProvider, MockRecord, MockTcpStream, MockUdpSocket, subscribe,
987 };
988 use tokio::runtime::Runtime;
989
990 use super::*;
991 use crate::config::{NameServerConfig, ResolverConfig, ServerOrderingStrategy};
992 use crate::net::runtime::{RuntimeProvider, TokioHandle, TokioRuntimeProvider, TokioTime};
993 use crate::net::xfer::{DnsHandle, FirstAnswer};
994 use crate::proto::op::{DnsRequestOptions, Query};
995 use crate::proto::rr::{Name, RecordType};
996
997 #[ignore]
998 #[test]
1000 #[allow(clippy::uninlined_format_args)]
1001 fn test_failed_then_success_pool() {
1002 subscribe();
1003
1004 let mut config1 = NameServerConfig::udp(IpAddr::from([127, 0, 0, 252]));
1005 config1.trust_negative_responses = false;
1006 let config2 = NameServerConfig::udp(IpAddr::from([8, 8, 8, 8]));
1007
1008 let mut resolver_config = ResolverConfig::default();
1009 resolver_config.add_name_server(config1);
1010 resolver_config.add_name_server(config2);
1011
1012 let io_loop = Runtime::new().unwrap();
1013 let pool = NameServerPool::from_config(
1014 resolver_config.name_servers,
1015 Arc::new(PoolContext::new(
1016 ResolverOpts::default(),
1017 TlsConfig::new().unwrap(),
1018 )),
1019 TokioRuntimeProvider::new(),
1020 );
1021
1022 let name = Name::parse("www.example.com.", None).unwrap();
1023
1024 for i in 0..2 {
1026 assert!(
1027 io_loop
1028 .block_on(
1029 pool.lookup(
1030 Query::query(name.clone(), RecordType::A),
1031 DnsRequestOptions::default()
1032 )
1033 .first_answer()
1034 )
1035 .is_err(),
1036 "iter: {}",
1037 i
1038 );
1039 }
1040
1041 for i in 0..10 {
1042 assert!(
1043 io_loop
1044 .block_on(
1045 pool.lookup(
1046 Query::query(name.clone(), RecordType::A),
1047 DnsRequestOptions::default()
1048 )
1049 .first_answer()
1050 )
1051 .is_ok(),
1052 "iter: {}",
1053 i
1054 );
1055 }
1056 }
1057
1058 #[tokio::test]
1059 async fn test_multi_use_conns() {
1060 subscribe();
1061
1062 let conn_provider = TokioRuntimeProvider::default();
1063 let opts = ResolverOpts {
1064 try_tcp_on_error: true,
1065 ..ResolverOpts::default()
1066 };
1067
1068 let tcp = NameServerConfig::tcp(IpAddr::from([8, 8, 8, 8]));
1069 let name_server = Arc::new(NameServer::new([], tcp, &opts, conn_provider));
1070 let name_servers = vec![name_server];
1071 let pool = NameServerPool::from_nameservers(
1072 name_servers.clone(),
1073 Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1074 );
1075
1076 let name = Name::from_str("www.example.com.").unwrap();
1077
1078 let response = pool
1080 .lookup(
1081 Query::query(name.clone(), RecordType::A),
1082 DnsRequestOptions::default(),
1083 )
1084 .first_answer()
1085 .await
1086 .expect("lookup failed");
1087
1088 assert!(!response.answers.is_empty());
1089
1090 assert!(
1091 name_servers[0].is_connected(),
1092 "if this is failing then the NameServers aren't being properly shared."
1093 );
1094
1095 let response = pool
1097 .lookup(
1098 Query::query(name, RecordType::AAAA),
1099 DnsRequestOptions::default(),
1100 )
1101 .first_answer()
1102 .await
1103 .expect("lookup failed");
1104
1105 assert!(!response.answers.is_empty());
1106
1107 assert!(
1108 name_servers[0].is_connected(),
1109 "if this is failing then the NameServers aren't being properly shared."
1110 );
1111 }
1112
1113 #[tokio::test]
1120 async fn test_pool_retries_on_timeout() {
1121 subscribe();
1122
1123 let timeout_ip = IpAddr::from([10, 0, 0, 1]);
1124 let good_ip = IpAddr::from([10, 0, 0, 2]);
1125 let query_name = Name::from_str("example.com.").unwrap();
1126
1127 let responses = vec![MockRecord::a(good_ip, &query_name, good_ip)];
1129 let handler = MockNetworkHandler::new(responses);
1130 let mock_provider = MockProvider::new(handler);
1131
1132 let provider = TimeoutProvider::new(mock_provider, vec![timeout_ip]);
1134
1135 let opts = ResolverOpts {
1136 num_concurrent_reqs: 1,
1137 server_ordering_strategy: ServerOrderingStrategy::UserProvidedOrder,
1138 ..ResolverOpts::default()
1139 };
1140
1141 let pool = NameServerPool::from_nameservers(
1142 vec![
1143 Arc::new(NameServer::new(
1144 [].into_iter(),
1145 NameServerConfig::udp(timeout_ip),
1146 &opts,
1147 provider.clone(),
1148 )),
1149 Arc::new(NameServer::new(
1150 [].into_iter(),
1151 NameServerConfig::udp(good_ip),
1152 &opts,
1153 provider.clone(),
1154 )),
1155 ],
1156 Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1157 );
1158
1159 let response = pool
1162 .lookup(
1163 Query::query(query_name.clone(), RecordType::A),
1164 DnsRequestOptions::default(),
1165 )
1166 .first_answer()
1167 .await
1168 .expect("pool should retry on timeout and succeed with the second server");
1169
1170 assert!(
1171 !response.answers.is_empty(),
1172 "expected A record in response"
1173 );
1174 }
1175
1176 #[tokio::test]
1179 async fn test_timeout_penalizes_server_srtt() {
1180 subscribe();
1181
1182 let timeout_ip = IpAddr::from([10, 0, 0, 1]);
1183 let good_ip = IpAddr::from([10, 0, 0, 2]);
1184 let query_name = Name::from_str("example.com.").unwrap();
1185
1186 let responses = vec![MockRecord::a(good_ip, &query_name, good_ip)];
1187 let handler = MockNetworkHandler::new(responses);
1188 let mock_provider = MockProvider::new(handler);
1189 let provider = TimeoutProvider::new(mock_provider, vec![timeout_ip]);
1190
1191 let opts = ResolverOpts {
1192 num_concurrent_reqs: 1,
1193 server_ordering_strategy: ServerOrderingStrategy::UserProvidedOrder,
1194 ..ResolverOpts::default()
1195 };
1196
1197 let ns_timeout = Arc::new(NameServer::new(
1198 [].into_iter(),
1199 NameServerConfig::udp(timeout_ip),
1200 &opts,
1201 provider.clone(),
1202 ));
1203 let ns_good = Arc::new(NameServer::new(
1204 [].into_iter(),
1205 NameServerConfig::udp(good_ip),
1206 &opts,
1207 provider.clone(),
1208 ));
1209
1210 let initial_srtt_timeout = ns_timeout.decayed_srtt();
1211
1212 let pool = NameServerPool::from_nameservers(
1213 vec![ns_timeout.clone(), ns_good.clone()],
1214 Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1215 );
1216
1217 let _response = pool
1219 .lookup(
1220 Query::query(query_name.clone(), RecordType::A),
1221 DnsRequestOptions::default(),
1222 )
1223 .first_answer()
1224 .await
1225 .expect("lookup should succeed via second server");
1226
1227 assert!(
1229 ns_timeout.decayed_srtt() > initial_srtt_timeout,
1230 "timeout server SRTT should increase after failure: {} should be > {}",
1231 ns_timeout.decayed_srtt(),
1232 initial_srtt_timeout,
1233 );
1234
1235 let failure_penalty = 5_000_000.0_f64; assert!(
1240 ns_good.decayed_srtt() < failure_penalty,
1241 "good server SRTT should not be penalized: {}",
1242 ns_good.decayed_srtt(),
1243 );
1244 }
1245
1246 #[derive(Clone)]
1250 struct TimeoutProvider {
1251 inner: MockProvider,
1252 timeout_ips: Arc<HashSet<IpAddr>>,
1253 }
1254
1255 impl TimeoutProvider {
1256 fn new(inner: MockProvider, timeout_ips: Vec<IpAddr>) -> Self {
1257 Self {
1258 inner,
1259 timeout_ips: Arc::new(timeout_ips.into_iter().collect()),
1260 }
1261 }
1262 }
1263
1264 impl RuntimeProvider for TimeoutProvider {
1265 type Handle = TokioHandle;
1266 type Timer = TokioTime;
1267 type Udp = MockUdpSocket;
1268 type Tcp = MockTcpStream;
1269
1270 fn create_handle(&self) -> Self::Handle {
1271 self.inner.create_handle()
1272 }
1273
1274 fn connect_tcp(
1275 &self,
1276 server_addr: SocketAddr,
1277 bind_addr: Option<SocketAddr>,
1278 timeout: Option<Duration>,
1279 ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
1280 if self.timeout_ips.contains(&server_addr.ip()) {
1281 Box::pin(future::ready(Err(io::Error::from(io::ErrorKind::TimedOut))))
1282 } else {
1283 self.inner.connect_tcp(server_addr, bind_addr, timeout)
1284 }
1285 }
1286
1287 fn bind_udp(
1288 &self,
1289 local_addr: SocketAddr,
1290 server_addr: SocketAddr,
1291 ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
1292 if self.timeout_ips.contains(&server_addr.ip()) {
1293 Box::pin(future::ready(Err(io::Error::from(io::ErrorKind::TimedOut))))
1294 } else {
1295 self.inner.bind_udp(local_addr, server_addr)
1296 }
1297 }
1298 }
1299
1300 #[tokio::test]
1309 async fn test_cancelled_parallel_server_is_penalized() {
1310 subscribe();
1311
1312 let unreachable_ip = IpAddr::from([10, 0, 0, 1]);
1313 let good_ip = IpAddr::from([10, 0, 0, 2]);
1314 let query_name = Name::from_str("example.com.").unwrap();
1315
1316 let responses = vec![MockRecord::a(good_ip, &query_name, good_ip)];
1317 let handler = MockNetworkHandler::new(responses);
1318 let mock_provider = MockProvider::new(handler);
1319 let provider = PendingProvider::new(mock_provider, vec![unreachable_ip]);
1320
1321 let opts = ResolverOpts {
1322 num_concurrent_reqs: 2,
1324 server_ordering_strategy: ServerOrderingStrategy::UserProvidedOrder,
1325 ..ResolverOpts::default()
1326 };
1327
1328 let ns_unreachable = Arc::new(NameServer::new(
1329 [].into_iter(),
1330 NameServerConfig::udp(unreachable_ip),
1331 &opts,
1332 provider.clone(),
1333 ));
1334 let ns_good = Arc::new(NameServer::new(
1335 [].into_iter(),
1336 NameServerConfig::udp(good_ip),
1337 &opts,
1338 provider.clone(),
1339 ));
1340
1341 let initial_srtt = ns_unreachable.decayed_srtt();
1342
1343 let pool = NameServerPool::from_nameservers(
1344 vec![ns_unreachable.clone(), ns_good.clone()],
1345 Arc::new(PoolContext::new(opts, TlsConfig::new().unwrap())),
1346 );
1347
1348 let _response = pool
1350 .lookup(
1351 Query::query(query_name.clone(), RecordType::A),
1352 DnsRequestOptions::default(),
1353 )
1354 .first_answer()
1355 .await
1356 .expect("lookup should succeed via good server");
1357
1358 assert!(
1361 ns_unreachable.decayed_srtt() > initial_srtt,
1362 "unreachable server SRTT should increase after being cancelled: {} should be > {}",
1363 ns_unreachable.decayed_srtt(),
1364 initial_srtt,
1365 );
1366
1367 let failure_penalty = 5_000_000.0_f64;
1369 assert!(
1370 ns_good.decayed_srtt() < failure_penalty,
1371 "good server SRTT should not be penalized: {}",
1372 ns_good.decayed_srtt(),
1373 );
1374 }
1375
1376 #[derive(Clone)]
1381 struct PendingProvider {
1382 inner: MockProvider,
1383 pending_ips: Arc<HashSet<IpAddr>>,
1384 }
1385
1386 impl PendingProvider {
1387 fn new(inner: MockProvider, pending_ips: Vec<IpAddr>) -> Self {
1388 Self {
1389 inner,
1390 pending_ips: Arc::new(pending_ips.into_iter().collect()),
1391 }
1392 }
1393 }
1394
1395 impl RuntimeProvider for PendingProvider {
1396 type Handle = TokioHandle;
1397 type Timer = TokioTime;
1398 type Udp = MockUdpSocket;
1399 type Tcp = MockTcpStream;
1400
1401 fn create_handle(&self) -> Self::Handle {
1402 self.inner.create_handle()
1403 }
1404
1405 fn connect_tcp(
1406 &self,
1407 server_addr: SocketAddr,
1408 bind_addr: Option<SocketAddr>,
1409 timeout: Option<Duration>,
1410 ) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
1411 if self.pending_ips.contains(&server_addr.ip()) {
1412 Box::pin(future::pending())
1413 } else {
1414 self.inner.connect_tcp(server_addr, bind_addr, timeout)
1415 }
1416 }
1417
1418 fn bind_udp(
1419 &self,
1420 local_addr: SocketAddr,
1421 server_addr: SocketAddr,
1422 ) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
1423 if self.pending_ips.contains(&server_addr.ip()) {
1424 Box::pin(future::pending())
1425 } else {
1426 self.inner.bind_udp(local_addr, server_addr)
1427 }
1428 }
1429 }
1430
1431 #[test]
1441 fn test_sort_by_decayed_srtt_does_not_panic() {
1442 let opts = ResolverOpts::default();
1443 let mock_provider = MockProvider::new(MockNetworkHandler::new(vec![]));
1444
1445 let mut servers = (1..=50)
1446 .map(|i| {
1447 let ns = Arc::new(NameServer::new(
1448 [],
1449 NameServerConfig::udp(IpAddr::from([10, 0, 0, i])),
1450 &opts,
1451 mock_provider.clone(),
1452 ));
1453 ns.test_record_failure();
1456 ns
1457 })
1458 .collect::<Vec<_>>();
1459
1460 let servers_writer = servers.clone();
1463 let stop = Arc::new(AtomicBool::new(false));
1464 let stop_writer = stop.clone();
1465 let writer = thread::spawn(move || {
1466 while !stop_writer.load(Ordering::Relaxed) {
1467 for s in &servers_writer {
1468 s.test_record_failure();
1469 }
1470 }
1471 });
1472
1473 struct StopGuard(Arc<AtomicBool>);
1475 impl Drop for StopGuard {
1476 fn drop(&mut self) {
1477 self.0.store(true, Ordering::Relaxed);
1478 }
1479 }
1480 let _guard = StopGuard(stop.clone());
1481
1482 for _ in 0..100_000 {
1487 sort_servers_by_query_statistics(&mut servers);
1488 }
1489
1490 stop.store(true, Ordering::Relaxed);
1491 writer.join().unwrap();
1492 }
1493}