#[cfg(not(test))]
use std::time::{Duration, Instant};
use std::{
cmp,
fmt::Debug,
marker::PhantomData,
net::IpAddr,
sync::{
Arc,
atomic::{AtomicU8, AtomicU32, Ordering},
},
};
use futures_util::lock::Mutex as AsyncMutex;
use parking_lot::Mutex as SyncMutex;
#[cfg(test)]
use tokio::time::{Duration, Instant};
use tracing::{debug, error, warn};
#[cfg(feature = "metrics")]
use crate::metrics::ResolverMetrics;
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
use crate::metrics::opportunistic_encryption::ProbeMetrics;
use crate::{
config::{
ConnectionConfig, NameServerConfig, OpportunisticEncryption, ResolverOpts,
ServerOrderingStrategy,
},
connection_provider::ConnectionProvider,
name_server_pool::{NameServerTransportState, PoolContext},
net::{
DnsError, NetError, NoRecords,
runtime::{RuntimeProvider, Spawn},
xfer::{DnsHandle, FirstAnswer, Protocol},
},
proto::{
op::{DnsRequest, DnsRequestOptions, DnsResponse, Query, ResponseCode},
rr::{Name, RecordType},
},
};
pub struct NameServer<P: ConnectionProvider> {
config: NameServerConfig,
connections: AsyncMutex<Vec<ConnectionState<P>>>,
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
opportunistic_probe_metrics: ProbeMetrics,
#[cfg(feature = "metrics")]
resolver_metrics: ResolverMetrics,
server_srtt: DecayingSrtt,
connection_provider: P,
}
impl<P: ConnectionProvider> NameServer<P> {
pub fn new(
connections: impl IntoIterator<Item = (Protocol, P::Conn)>,
config: NameServerConfig,
options: &ResolverOpts,
connection_provider: P,
) -> Self {
let mut connections = connections
.into_iter()
.map(|(protocol, handle)| ConnectionState::new(handle, protocol))
.collect::<Vec<_>>();
if options.server_ordering_strategy != ServerOrderingStrategy::UserProvidedOrder {
connections.sort_by_key(|ns| ns.protocol != Protocol::Udp);
}
Self {
config,
connections: AsyncMutex::new(connections),
server_srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
opportunistic_probe_metrics: ProbeMetrics::default(),
#[cfg(feature = "metrics")]
resolver_metrics: ResolverMetrics::default(),
connection_provider,
}
}
pub(crate) async fn send(
self: Arc<Self>,
request: DnsRequest,
policy: ConnectionPolicy,
cx: &Arc<PoolContext>,
) -> Result<DnsResponse, NetError> {
let (handle, meta, protocol) = self.connected_mut_client(policy, cx).await?;
#[cfg(feature = "metrics")]
self.resolver_metrics.increment_outgoing_query(&protocol);
let now = Instant::now();
let response = handle.send(request).first_answer().await;
let rtt = now.elapsed();
match response {
Ok(response) => {
meta.set_status(Status::Established);
let result = DnsError::from_response(response);
let error = match result {
Ok(response) => {
meta.srtt.record(rtt);
self.server_srtt.record(rtt);
if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
cx.transport_state()
.await
.response_received(self.config.ip, protocol);
}
return Ok(response);
}
Err(error) => error,
};
let update = match error {
DnsError::NoRecordsFound(NoRecords {
response_code: ResponseCode::ServFail,
..
}) => Some(true),
DnsError::NoRecordsFound(NoRecords { .. }) => Some(false),
_ => None,
};
match update {
Some(true) => {
meta.srtt.record(rtt);
self.server_srtt.record(rtt);
}
Some(false) => {
meta.srtt.record_failure();
self.server_srtt.record_failure();
}
None => {}
}
let err = NetError::from(error);
if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
cx.transport_state()
.await
.error_received(self.config.ip, protocol, &err)
}
Err(err)
}
Err(error) => {
debug!(config = ?self.config, %error, "failed to connect to name server");
meta.set_status(Status::Failed);
match &error {
NetError::Busy | NetError::Io(_) | NetError::Timeout => {
meta.srtt.record_failure();
self.server_srtt.record_failure();
}
#[cfg(feature = "__quic")]
NetError::QuinnConfigError(_)
| NetError::QuinnConnect(_)
| NetError::QuinnConnection(_)
| NetError::QuinnTlsConfigError(_) => {
meta.srtt.record_failure();
self.server_srtt.record_failure();
}
#[cfg(feature = "__tls")]
NetError::RustlsError(_) => {
meta.srtt.record_failure();
self.server_srtt.record_failure();
}
_ => {}
}
if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
cx.transport_state()
.await
.error_received(self.config.ip, protocol, &error);
}
Err(error)
}
}
}
async fn connected_mut_client(
&self,
policy: ConnectionPolicy,
cx: &Arc<PoolContext>,
) -> Result<(P::Conn, Arc<ConnectionMeta>, Protocol), NetError> {
let mut connections = self.connections.lock().await;
connections.retain(|conn| matches!(conn.meta.status(), Status::Init | Status::Established));
if let Some(conn) = policy.select_connection(
self.config.ip,
&*cx.transport_state().await,
&cx.opportunistic_encryption,
&connections,
) {
return Ok((conn.handle.clone(), conn.meta.clone(), conn.protocol));
}
debug!(config = ?self.config, "connecting");
let config = policy
.select_connection_config(
self.config.ip,
&*cx.transport_state().await,
&cx.opportunistic_encryption,
&self.config.connections,
)
.ok_or(NetError::NoConnections)?;
let protocol = config.protocol.to_protocol();
if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
cx.transport_state()
.await
.initiate_connection(self.config.ip, protocol);
} else if cx.opportunistic_encryption.is_enabled() && !protocol.is_encrypted() {
self.consider_probe_encrypted_transport(&policy, cx).await;
}
let handle = Box::pin(self.connection_provider.new_connection(
self.config.ip,
config,
cx,
)?)
.await?;
if cx.opportunistic_encryption.is_enabled() && protocol.is_encrypted() {
cx.transport_state()
.await
.complete_connection(self.config.ip, protocol);
}
let state = ConnectionState::new(handle.clone(), protocol);
let meta = state.meta.clone();
connections.push(state);
Ok((handle, meta, protocol))
}
pub(super) fn protocols(&self) -> impl Iterator<Item = Protocol> + '_ {
self.config
.connections
.iter()
.map(|conn| conn.protocol.to_protocol())
}
pub(super) fn ip(&self) -> IpAddr {
self.config.ip
}
pub(crate) fn decayed_srtt(&self) -> f64 {
self.server_srtt.current()
}
pub(super) fn record_cancelled(&self, winner_rtt: Duration) {
const CANCEL_PENALTY: Duration = Duration::from_millis(5);
self.server_srtt.record(winner_rtt + CANCEL_PENALTY);
}
#[cfg(test)]
pub(crate) fn test_record_failure(&self) {
self.server_srtt.record_failure();
}
#[cfg(test)]
#[allow(dead_code)]
pub(crate) fn is_connected(&self) -> bool {
let Some(connections) = self.connections.try_lock() else {
return true;
};
connections.iter().any(|conn| match conn.meta.status() {
Status::Established | Status::Init => true,
Status::Failed => false,
})
}
pub(crate) fn trust_negative_responses(&self) -> bool {
self.config.trust_negative_responses
}
async fn consider_probe_encrypted_transport(
&self,
policy: &ConnectionPolicy,
cx: &Arc<PoolContext>,
) {
let Some(probe_config) =
policy.select_encrypted_connection_config(&self.config.connections)
else {
warn!("no encrypted connection configs available for probing");
return;
};
let probe_protocol = probe_config.protocol.to_protocol();
let should_probe = {
let state = cx.transport_state().await;
state.should_probe_encrypted(
self.config.ip,
probe_protocol,
&cx.opportunistic_encryption,
)
};
if !should_probe {
return;
}
if let Err(err) = self.probe_encrypted_transport(cx, probe_config) {
error!(%err, "opportunistic encrypted probe attempt failed");
}
}
fn probe_encrypted_transport(
&self,
cx: &Arc<PoolContext>,
probe_config: &ConnectionConfig,
) -> Result<(), NetError> {
let mut budget = cx.opportunistic_probe_budget.load(Ordering::Relaxed);
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
self.opportunistic_probe_metrics.probe_budget.set(budget);
loop {
if budget == 0 {
debug!("no remaining budget for opportunistic probing");
return Ok(());
}
match cx.opportunistic_probe_budget.compare_exchange_weak(
budget,
budget - 1,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(current) => budget = current,
}
}
let connect = ProbeRequest::new(
probe_config,
self,
cx,
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
self.opportunistic_probe_metrics.clone(),
)?;
self.connection_provider
.runtime_provider()
.create_handle()
.spawn_bg(connect.run());
Ok(())
}
}
struct ProbeRequest<P: ConnectionProvider> {
ip: IpAddr,
proto: Protocol,
connecting: P::FutureConn,
context: Arc<PoolContext>,
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
metrics: ProbeMetrics,
provider: PhantomData<P>,
}
impl<P: ConnectionProvider> ProbeRequest<P> {
fn new(
config: &ConnectionConfig,
ns: &NameServer<P>,
cx: &Arc<PoolContext>,
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
metrics: ProbeMetrics,
) -> Result<Self, NetError> {
Ok(Self {
ip: ns.config.ip,
proto: config.protocol.to_protocol(),
connecting: ns
.connection_provider
.new_connection(ns.config.ip, config, cx)?,
context: cx.clone(),
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
metrics,
provider: PhantomData,
})
}
async fn run(self) {
let Self {
ip,
proto,
connecting,
context,
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
metrics,
provider: _,
} = self;
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
let start = Instant::now();
context
.transport_state()
.await
.initiate_connection(ip, proto);
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
metrics.increment_attempts(proto);
let conn = match connecting.await {
Ok(conn) => conn,
Err(err) => {
debug!(?proto, "probe connection failed");
let _prev = context
.opportunistic_probe_budget
.fetch_add(1, Ordering::Relaxed);
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
{
metrics.increment_errors(proto, &err);
metrics.probe_budget.set(_prev + 1);
metrics.record_probe_duration(proto, start.elapsed());
}
context
.transport_state()
.await
.error_received(ip, proto, &err);
return;
}
};
debug!(?proto, "probe connection succeeded");
context
.transport_state()
.await
.complete_connection(ip, proto);
match conn
.send(DnsRequest::from_query(
Query::query(Name::root(), RecordType::NS),
DnsRequestOptions::default(),
))
.first_answer()
.await
{
Ok(_) => {
debug!(?proto, "probe query succeeded");
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
metrics.increment_successes(proto);
context.transport_state().await.response_received(ip, proto);
}
Err(err) => {
debug!(?proto, ?err, "probe query failed");
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
metrics.increment_errors(proto, &err);
context
.transport_state()
.await
.error_received(ip, proto, &err);
}
}
let _prev = context
.opportunistic_probe_budget
.fetch_add(1, Ordering::Relaxed);
#[cfg(all(feature = "metrics", any(feature = "__tls", feature = "__quic")))]
{
metrics.probe_budget.set(_prev + 1);
metrics.record_probe_duration(proto, start.elapsed());
}
}
}
struct ConnectionState<P: ConnectionProvider> {
protocol: Protocol,
handle: P::Conn,
meta: Arc<ConnectionMeta>,
}
impl<P: ConnectionProvider> ConnectionState<P> {
fn new(handle: P::Conn, protocol: Protocol) -> Self {
Self {
protocol,
handle,
meta: Arc::new(ConnectionMeta::default()),
}
}
}
struct ConnectionMeta {
status: AtomicU8,
srtt: DecayingSrtt,
}
impl ConnectionMeta {
fn set_status(&self, status: Status) {
self.status.store(status.into(), Ordering::Release);
}
fn status(&self) -> Status {
Status::from(self.status.load(Ordering::Acquire))
}
}
impl Default for ConnectionMeta {
fn default() -> Self {
Self {
status: AtomicU8::new(Status::Init.into()),
srtt: DecayingSrtt::new(Duration::from_micros(rand::random_range(1..32))),
}
}
}
struct DecayingSrtt {
srtt_microseconds: AtomicU32,
last_update: SyncMutex<Option<Instant>>,
}
impl DecayingSrtt {
fn new(initial_srtt: Duration) -> Self {
Self {
srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
last_update: SyncMutex::new(None),
}
}
fn record(&self, rtt: Duration) {
self.update(
rtt.as_micros() as u32,
|cur_srtt_microseconds, last_update| {
let factor = compute_srtt_factor(last_update, 3);
let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
+ factor * f64::from(cur_srtt_microseconds);
new_srtt.round() as u32
},
);
}
fn record_failure(&self) {
self.update(
Self::FAILURE_PENALTY,
|cur_srtt_microseconds, _last_update| {
cur_srtt_microseconds.saturating_add(Self::FAILURE_PENALTY)
},
);
}
fn current(&self) -> f64 {
let srtt = f64::from(self.srtt_microseconds.load(Ordering::Acquire));
self.last_update.lock().map_or(srtt, |last_update| {
srtt * compute_srtt_factor(last_update, 180)
})
}
fn update(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
let last_update = self.last_update.lock().replace(Instant::now());
let _ = self.srtt_microseconds.fetch_update(
Ordering::SeqCst,
Ordering::SeqCst,
move |cur_srtt_microseconds| {
Some(
last_update
.map_or(default, |last_update| {
update_fn(cur_srtt_microseconds, last_update)
})
.min(Self::MAX_SRTT_MICROS),
)
},
);
}
#[cfg(all(test, feature = "tokio"))]
fn as_duration(&self) -> Duration {
Duration::from_micros(u64::from(self.srtt_microseconds.load(Ordering::Acquire)))
}
const FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
}
fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
exponent.exp()
}
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
#[repr(u8)]
enum Status {
Failed = 0,
Init = 1,
Established = 2,
}
impl From<Status> for u8 {
fn from(val: Status) -> Self {
val as Self
}
}
impl From<u8> for Status {
fn from(val: u8) -> Self {
match val {
2 => Self::Established,
1 => Self::Init,
_ => Self::Failed,
}
}
}
#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
pub(crate) struct ConnectionPolicy {
pub(crate) disable_udp: bool,
}
impl ConnectionPolicy {
pub(crate) fn allows_server<P: ConnectionProvider>(&self, server: &NameServer<P>) -> bool {
server.protocols().any(|p| self.allows_protocol(p))
}
fn select_connection<'a, P: ConnectionProvider>(
&self,
ip: IpAddr,
encrypted_transport_state: &NameServerTransportState,
opportunistic_encryption: &OpportunisticEncryption,
connections: &'a [ConnectionState<P>],
) -> Option<&'a ConnectionState<P>> {
let selected = connections
.iter()
.filter(|conn| self.allows_protocol(conn.protocol))
.min_by(|a, b| self.compare_connections(opportunistic_encryption.is_enabled(), a, b));
let selected = selected?;
match opportunistic_encryption.is_enabled()
&& !selected.protocol.is_encrypted()
&& encrypted_transport_state.any_recent_success(ip, opportunistic_encryption)
{
true => None,
false => Some(selected),
}
}
fn select_connection_config<'a>(
&self,
ip: IpAddr,
encrypted_transport_state: &NameServerTransportState,
opportunistic_encryption: &OpportunisticEncryption,
connection_configs: &'a [ConnectionConfig],
) -> Option<&'a ConnectionConfig> {
connection_configs
.iter()
.filter(|c| self.allows_protocol(c.protocol.to_protocol()))
.min_by(|a, b| {
self.compare_connection_configs(
ip,
encrypted_transport_state,
opportunistic_encryption,
a,
b,
)
})
}
fn select_encrypted_connection_config<'a>(
&self,
connection_config: &'a [ConnectionConfig],
) -> Option<&'a ConnectionConfig> {
connection_config
.iter()
.filter(|c| self.allows_protocol(c.protocol.to_protocol()))
.find(|c| c.protocol.to_protocol().is_encrypted())
}
fn allows_protocol(&self, protocol: Protocol) -> bool {
!(self.disable_udp && protocol == Protocol::Udp)
}
fn compare_connections<P: ConnectionProvider>(
&self,
opportunistic_encryption: bool,
a: &ConnectionState<P>,
b: &ConnectionState<P>,
) -> cmp::Ordering {
if opportunistic_encryption {
match (a.protocol.is_encrypted(), b.protocol.is_encrypted()) {
(true, false) => return cmp::Ordering::Less,
(false, true) => return cmp::Ordering::Greater,
_ => {}
}
}
match (a.protocol, b.protocol) {
(ap, bp) if ap == bp => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
(Protocol::Udp, _) => cmp::Ordering::Less,
(_, Protocol::Udp) => cmp::Ordering::Greater,
_ => a.meta.srtt.current().total_cmp(&b.meta.srtt.current()),
}
}
fn compare_connection_configs(
&self,
ip: IpAddr,
encrypted_transport_state: &NameServerTransportState,
opportunistic_encryption: &OpportunisticEncryption,
a: &ConnectionConfig,
b: &ConnectionConfig,
) -> cmp::Ordering {
let a_protocol = a.protocol.to_protocol();
let b_protocol = b.protocol.to_protocol();
if opportunistic_encryption.is_enabled() {
let a_recent_enc_success = a_protocol.is_encrypted()
&& encrypted_transport_state.recent_success(
ip,
a_protocol,
opportunistic_encryption,
);
let b_recent_enc_success = b_protocol.is_encrypted()
&& encrypted_transport_state.recent_success(
ip,
b_protocol,
opportunistic_encryption,
);
match (a_recent_enc_success, b_recent_enc_success) {
(true, false) => return cmp::Ordering::Less,
(false, true) => return cmp::Ordering::Greater,
_ => {}
}
}
match (a_protocol, b_protocol) {
(ap, bp) if ap == bp => cmp::Ordering::Equal,
(Protocol::Udp, _) => cmp::Ordering::Less,
(_, Protocol::Udp) => cmp::Ordering::Greater,
_ => cmp::Ordering::Equal,
}
}
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
use std::cmp;
use std::net::{IpAddr, Ipv4Addr};
use std::str::FromStr;
use std::time::Duration;
use test_support::subscribe;
use tokio::net::UdpSocket;
use tokio::spawn;
use super::*;
use crate::config::{ConnectionConfig, ProtocolConfig};
use crate::connection_provider::TlsConfig;
use crate::net::runtime::TokioRuntimeProvider;
use crate::proto::op::{DnsRequest, DnsRequestOptions, Message, Query, ResponseCode};
use crate::proto::rr::rdata::NULL;
use crate::proto::rr::{Name, RData, Record, RecordType};
#[tokio::test]
async fn test_name_server() {
subscribe();
let options = ResolverOpts::default();
let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
let name_server = Arc::new(NameServer::new(
[].into_iter(),
config,
&options,
TokioRuntimeProvider::default(),
));
let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
let name = Name::parse("www.example.com.", None).unwrap();
let response = name_server
.send(
DnsRequest::from_query(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
),
ConnectionPolicy::default(),
&cx,
)
.await
.expect("query failed");
assert_eq!(response.response_code, ResponseCode::NoError);
}
#[tokio::test]
async fn test_failed_name_server() {
subscribe();
let options = ResolverOpts {
timeout: Duration::from_millis(1), ..ResolverOpts::default()
};
let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)));
let name_server = Arc::new(NameServer::new(
[],
config,
&options,
TokioRuntimeProvider::default(),
));
let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
let name = Name::parse("www.example.com.", None).unwrap();
assert!(
name_server
.send(
DnsRequest::from_query(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
),
ConnectionPolicy::default(),
&cx
)
.await
.is_err()
);
}
#[tokio::test]
async fn case_randomization_query_preserved() {
subscribe();
let provider = TokioRuntimeProvider::default();
let server = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let server_addr = server.local_addr().unwrap();
let name = Name::from_str("dead.beef.").unwrap();
let data = b"DEADBEEF";
spawn({
let name = name.clone();
async move {
let mut buffer = [0_u8; 512];
let (len, addr) = server.recv_from(&mut buffer).await.unwrap();
let request = Message::from_vec(&buffer[0..len]).unwrap();
let mut response = Message::response(request.id, request.op_code);
response.add_queries(request.queries.to_vec());
response.add_answer(Record::from_rdata(
name,
0,
RData::NULL(NULL::with(data.to_vec())),
));
let response_buffer = response.to_vec().unwrap();
server.send_to(&response_buffer, addr).await.unwrap();
}
});
let config = NameServerConfig {
ip: server_addr.ip(),
trust_negative_responses: true,
connections: vec![ConnectionConfig {
port: server_addr.port(),
protocol: ProtocolConfig::Udp,
bind_addr: None,
}],
};
let resolver_opts = ResolverOpts {
case_randomization: true,
..Default::default()
};
let cx = Arc::new(PoolContext::new(resolver_opts, TlsConfig::new().unwrap()));
let mut request_options = DnsRequestOptions::default();
request_options.case_randomization = true;
let ns = Arc::new(NameServer::new([], config, &cx.options, provider));
let response = ns
.send(
DnsRequest::from_query(
Query::query(name.clone(), RecordType::NULL),
request_options,
),
ConnectionPolicy::default(),
&cx,
)
.await
.unwrap();
let response_query_name = response.queries.first().unwrap().name();
assert!(response_query_name.eq_case(&name));
}
#[allow(clippy::extra_unused_type_parameters)]
fn is_send_sync<S: Sync + Send>() -> bool {
true
}
#[test]
fn stats_are_sync() {
assert!(is_send_sync::<ConnectionMeta>());
}
#[tokio::test(start_paused = true)]
async fn test_stats_cmp() {
use std::cmp::Ordering;
let srtt_a = DecayingSrtt::new(Duration::from_micros(10));
let srtt_b = DecayingSrtt::new(Duration::from_micros(20));
assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
srtt_a.record(Duration::from_millis(30));
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
srtt_b.record(Duration::from_millis(50));
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
srtt_a.record_failure();
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Greater);
while cmp(&srtt_a, &srtt_b) != Ordering::Less {
srtt_b.record(Duration::from_millis(50));
tokio::time::advance(Duration::from_secs(5)).await;
}
srtt_a.record(Duration::from_millis(30));
tokio::time::advance(Duration::from_secs(3)).await;
assert_eq!(cmp(&srtt_a, &srtt_b), Ordering::Less);
}
fn cmp(a: &DecayingSrtt, b: &DecayingSrtt) -> cmp::Ordering {
a.current().total_cmp(&b.current())
}
#[tokio::test(start_paused = true)]
async fn test_record_rtt() {
let srtt = DecayingSrtt::new(Duration::from_micros(10));
let first_rtt = Duration::from_millis(50);
srtt.record(first_rtt);
assert_eq!(srtt.as_duration(), first_rtt);
tokio::time::advance(Duration::from_secs(3)).await;
srtt.record(Duration::from_millis(100));
assert_eq!(srtt.as_duration(), Duration::from_micros(81606));
}
#[test]
fn test_record_rtt_maximum_value() {
let srtt = DecayingSrtt::new(Duration::from_micros(10));
srtt.record(Duration::MAX);
assert_eq!(
srtt.as_duration(),
Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
);
}
#[tokio::test(start_paused = true)]
async fn test_record_connection_failure() {
let srtt = DecayingSrtt::new(Duration::from_micros(10));
for failure_count in 1..4 {
srtt.record_failure();
assert_eq!(
srtt.as_duration(),
Duration::from_micros(
DecayingSrtt::FAILURE_PENALTY
.checked_mul(failure_count)
.expect("checked_mul overflow")
.into()
)
);
tokio::time::advance(Duration::from_secs(3)).await;
}
srtt.record(Duration::from_millis(50));
assert_eq!(srtt.as_duration(), Duration::from_micros(197152));
}
#[test]
fn test_record_connection_failure_maximum_value() {
let srtt = DecayingSrtt::new(Duration::from_micros(10));
let num_failures = (DecayingSrtt::MAX_SRTT_MICROS / DecayingSrtt::FAILURE_PENALTY) + 1;
for _ in 0..num_failures {
srtt.record_failure();
}
assert_eq!(
srtt.as_duration(),
Duration::from_micros(DecayingSrtt::MAX_SRTT_MICROS.into())
);
}
#[tokio::test(start_paused = true)]
async fn test_decayed_srtt() {
let initial_srtt = 10;
let srtt = DecayingSrtt::new(Duration::from_micros(initial_srtt));
assert_eq!(srtt.current() as u32, initial_srtt as u32);
tokio::time::advance(Duration::from_secs(5)).await;
srtt.record(Duration::from_millis(100));
tokio::time::advance(Duration::from_millis(500)).await;
assert_eq!(srtt.current() as u32, 99445);
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(srtt.current() as u32, 96990);
}
}
#[cfg(all(test, feature = "__tls"))]
mod opportunistic_enc_tests {
use std::io;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
#[cfg(feature = "metrics")]
use metrics::{Label, Unit, with_local_recorder};
#[cfg(feature = "metrics")]
use metrics_util::debugging::DebuggingRecorder;
use mock_provider::{MockClientHandle, MockProvider};
use test_support::subscribe;
#[cfg(feature = "metrics")]
use test_support::{assert_counter_eq, assert_gauge_eq, assert_histogram_sample_count_eq};
use crate::config::{
NameServerConfig, OpportunisticEncryption, OpportunisticEncryptionConfig, ProtocolConfig,
ResolverOpts,
};
use crate::connection_provider::TlsConfig;
#[cfg(feature = "metrics")]
use crate::metrics::opportunistic_encryption::{
PROBE_ATTEMPTS_TOTAL, PROBE_BUDGET_TOTAL, PROBE_DURATION_SECONDS, PROBE_ERRORS_TOTAL,
PROBE_SUCCESSES_TOTAL, PROBE_TIMEOUTS_TOTAL,
};
use crate::name_server::{ConnectionPolicy, ConnectionState, NameServer, mock_provider};
use crate::name_server_pool::{NameServerTransportState, PoolContext};
use crate::net::NetError;
use crate::net::xfer::Protocol;
#[tokio::test]
async fn test_select_connection_opportunistic_enc_disabled() {
let mut policy = ConnectionPolicy::default();
let connections = vec![
mock_connection(Protocol::Udp),
mock_connection(Protocol::Tcp),
];
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let state = NameServerTransportState::default();
let opp_enc = OpportunisticEncryption::Disabled;
let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Udp);
policy.disable_udp = true;
let selected = policy.select_connection(ns_ip, &state, &opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
}
#[tokio::test]
async fn test_select_connection_opportunistic_enc_enabled() {
let policy = ConnectionPolicy::default();
let connections = [
mock_connection(Protocol::Udp),
mock_connection(Protocol::Tcp),
mock_connection(Protocol::Tls),
];
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Tls);
}
#[tokio::test]
async fn test_select_connection_opportunistic_enc_enabled_no_state() {
let mut policy = ConnectionPolicy::default();
let connections = [
mock_connection(Protocol::Udp),
mock_connection(Protocol::Tcp),
];
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Udp);
policy.disable_udp = true;
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Tcp);
}
#[tokio::test]
async fn test_select_connection_opportunistic_enc_enabled_failed_probe() {
let policy = ConnectionPolicy::default();
let connections = [
mock_connection(Protocol::Udp),
mock_connection(Protocol::Tcp),
];
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let mut state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
state.error_received(
ns_ip,
Protocol::Tls,
&NetError::from(io::Error::new(
io::ErrorKind::ConnectionRefused,
"nameserver refused TLS connection",
)),
);
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Udp);
}
#[tokio::test]
async fn test_select_connection_opportunistic_enc_enabled_in_progress_probe() {
let policy = ConnectionPolicy::default();
let connections = [
mock_connection(Protocol::Udp),
mock_connection(Protocol::Tcp),
];
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let mut state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
state.initiate_connection(ns_ip, Protocol::Tls);
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Udp);
state.complete_connection(ns_ip, Protocol::Tls);
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Udp);
}
#[tokio::test]
async fn test_select_connection_opportunistic_enc_enabled_stale_probe() {
let policy = ConnectionPolicy::default();
let connections = [
mock_connection(Protocol::Udp),
mock_connection(Protocol::Tcp),
];
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let mut state = NameServerTransportState::default();
let opp_enc_config = OpportunisticEncryptionConfig {
persistence_period: Duration::from_secs(10),
..OpportunisticEncryptionConfig::default()
};
let opp_enc = &OpportunisticEncryption::Enabled {
config: opp_enc_config.clone(),
};
state.complete_connection(ns_ip, Protocol::Tls);
state.response_received(ns_ip, Protocol::Tls);
let stale_time =
SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
state.set_last_response(ns_ip, Protocol::Tls, stale_time);
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, Protocol::Udp);
}
#[tokio::test]
async fn test_select_connection_opportunistic_enc_enabled_good_probe() {
let policy = ConnectionPolicy::default();
let connections = [
mock_connection(Protocol::Udp),
mock_connection(Protocol::Tcp),
];
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let mut state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
state.complete_connection(ns_ip, Protocol::Tls);
state.response_received(ns_ip, Protocol::Tls);
let selected = policy.select_connection(ns_ip, &state, opp_enc, &connections);
assert!(selected.is_none());
}
#[tokio::test]
async fn test_select_connection_config_opportunistic_enc_disabled() {
let mut policy = ConnectionPolicy::default();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
let state = NameServerTransportState::default();
let opp_enc = OpportunisticEncryption::Disabled;
let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
policy.disable_udp = true;
let selected = policy.select_connection_config(ns_ip, &state, &opp_enc, &configs);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
}
#[tokio::test]
async fn test_select_connection_config_opportunistic_enc_enabled_no_state() {
let mut policy = ConnectionPolicy::default();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
let state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
policy.disable_udp = true;
let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, ProtocolConfig::Tcp);
}
#[tokio::test]
async fn test_select_connection_config_opportunistic_enc_enabled_failed_probe() {
let policy = ConnectionPolicy::default();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
let mut state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
state.error_received(
ns_ip,
Protocol::Tls,
&NetError::from(io::Error::new(
io::ErrorKind::ConnectionRefused,
"nameserver refused TLS connection",
)),
);
let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
}
#[tokio::test]
async fn test_select_connection_config_opportunistic_enc_enabled_stale_probe() {
let policy = ConnectionPolicy::default();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
let mut state = NameServerTransportState::default();
let opp_enc_config = OpportunisticEncryptionConfig {
persistence_period: Duration::from_secs(10),
..OpportunisticEncryptionConfig::default()
};
let opp_enc = &OpportunisticEncryption::Enabled {
config: opp_enc_config.clone(),
};
state.complete_connection(ns_ip, Protocol::Tls);
state.response_received(ns_ip, Protocol::Tls);
let stale_time =
SystemTime::now() - opp_enc_config.persistence_period - Duration::from_secs(1);
state.set_last_response(ns_ip, Protocol::Tls, stale_time);
let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
assert!(selected.is_some());
assert_eq!(selected.unwrap().protocol, ProtocolConfig::Udp);
}
#[tokio::test]
async fn test_select_connection_config_opportunistic_enc_enabled_good_probe() {
let policy = ConnectionPolicy::default();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let configs = NameServerConfig::opportunistic_encryption(ns_ip).connections;
let mut state = NameServerTransportState::default();
let opp_enc = &OpportunisticEncryption::Enabled {
config: OpportunisticEncryptionConfig::default(),
};
state.complete_connection(ns_ip, Protocol::Tls);
state.response_received(ns_ip, Protocol::Tls);
let selected = policy.select_connection_config(ns_ip, &state, opp_enc, &configs);
assert!(selected.is_some());
assert!(matches!(
selected.unwrap().protocol,
ProtocolConfig::Tls { .. }
));
}
#[tokio::test]
async fn test_opportunistic_probe() {
subscribe();
let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption()
.with_probe_budget(10);
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let mock_provider = MockProvider::default();
assert!(
test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
.await
.is_ok()
);
let recorded_calls = mock_provider.new_connection_calls();
assert_eq!(recorded_calls.len(), 2);
let (ips, protocols): (Vec<IpAddr>, Vec<ProtocolConfig>) =
recorded_calls.into_iter().unzip();
assert!(ips.iter().all(|ip| *ip == ns_ip));
let protocols = protocols
.iter()
.map(ProtocolConfig::to_protocol)
.collect::<Vec<_>>();
assert!(protocols.contains(&Protocol::Udp));
assert!(protocols.contains(&Protocol::Tls));
}
#[tokio::test]
async fn test_opportunistic_probe_skip_in_progress() {
subscribe();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption()
.with_probe_budget(10);
cx.transport_state()
.await
.initiate_connection(ns_ip, Protocol::Tls);
let mock_provider = MockProvider::default();
assert!(
test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
.await
.is_ok()
);
let recorded_calls = mock_provider.new_connection_calls();
assert_eq!(recorded_calls.len(), 1);
let (ip, protocol) = &recorded_calls[0];
assert_eq!(*ip, ns_ip);
assert_eq!(protocol.to_protocol(), Protocol::Udp);
}
#[tokio::test]
async fn test_opportunistic_probe_skip_recent_failure() {
subscribe();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption()
.with_probe_budget(10);
cx.transport_state().await.error_received(
ns_ip,
Protocol::Tls,
&NetError::from(io::Error::new(
io::ErrorKind::ConnectionRefused,
"connection refused",
)),
);
let mock_provider = MockProvider::default();
assert!(
test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
.await
.is_ok()
);
let recorded_calls = mock_provider.new_connection_calls();
assert_eq!(recorded_calls.len(), 1);
let (ip, protocol) = &recorded_calls[0];
assert_eq!(*ip, ns_ip);
assert_eq!(protocol.to_protocol(), Protocol::Udp);
}
#[tokio::test]
async fn test_opportunistic_probe_stale_failure() {
subscribe();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let mut cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_probe_budget(10);
let opp_enc_config = OpportunisticEncryptionConfig {
damping_period: Duration::from_secs(5),
..OpportunisticEncryptionConfig::default()
};
cx.opportunistic_encryption = OpportunisticEncryption::Enabled {
config: opp_enc_config.clone(),
};
{
let mut state = cx.transport_state().await;
let old_failure_time =
SystemTime::now() - opp_enc_config.damping_period - Duration::from_secs(1);
state.set_failure_time(ns_ip, Protocol::Tls, old_failure_time);
}
let mock_provider = MockProvider::default();
assert!(
test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
.await
.is_ok()
);
let recorded_calls = mock_provider.new_connection_calls();
assert_eq!(recorded_calls.len(), 2);
let protocols = recorded_calls
.iter()
.map(|(_, protocol)| protocol.to_protocol())
.collect::<Vec<_>>();
assert!(protocols.contains(&Protocol::Udp));
assert!(protocols.contains(&Protocol::Tls));
}
#[tokio::test]
async fn test_opportunistic_probe_skip_no_budget() {
subscribe();
let ns_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let cx = PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption();
let mock_provider = MockProvider::default();
assert!(
test_connected_mut_client(ns_ip, Arc::new(cx), &mock_provider)
.await
.is_ok()
);
let recorded_calls = mock_provider.new_connection_calls();
assert_eq!(recorded_calls.len(), 1);
let (ip, protocol) = &recorded_calls[0];
assert_eq!(*ip, ns_ip);
assert_eq!(protocol.to_protocol(), Protocol::Udp);
}
fn mock_connection(protocol: Protocol) -> ConnectionState<MockProvider> {
ConnectionState::new(MockClientHandle, protocol)
}
#[cfg(feature = "metrics")]
#[test]
fn test_opportunistic_probe_metrics_success() {
subscribe();
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
let initial_budget = 10;
with_local_recorder(&recorder, || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
assert!(
test_connected_mut_client(
IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
Arc::new(
PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption()
.with_probe_budget(initial_budget),
),
&MockProvider::default(),
)
.await
.is_ok()
);
});
});
#[allow(clippy::mutable_key_type)]
let map = snapshotter.snapshot().into_hashmap();
let protocol = vec![Label::new("protocol", "tls")];
assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
assert_histogram_sample_count_eq(
&map,
PROBE_DURATION_SECONDS,
protocol.clone(),
1,
Unit::Seconds,
);
assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol.clone(), 1);
assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol, 0);
assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
}
#[cfg(feature = "metrics")]
#[test]
fn test_opportunistic_probe_metrics_budget_exhausted() {
subscribe();
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
with_local_recorder(&recorder, || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
assert!(
test_connected_mut_client(
IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
Arc::new(
PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption(),
),
&MockProvider::default(),
)
.await
.is_ok()
);
});
});
#[allow(clippy::mutable_key_type)]
let map = snapshotter.snapshot().into_hashmap();
assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], 0);
let protocol = vec![Label::new("protocol", "tls")];
assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 0);
assert_histogram_sample_count_eq(&map, PROBE_DURATION_SECONDS, protocol, 0, Unit::Seconds);
}
#[cfg(feature = "metrics")]
#[test]
fn test_opportunistic_probe_metrics_connection_error() {
subscribe();
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
let initial_budget = 10;
with_local_recorder(&recorder, || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
let _ = test_connected_mut_client(
IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
Arc::new(
PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption()
.with_probe_budget(initial_budget),
),
&MockProvider {
new_connection_error: Some(NetError::from(io::Error::new(
io::ErrorKind::ConnectionRefused,
"connection refused",
))),
..MockProvider::default()
},
)
.await;
});
});
#[allow(clippy::mutable_key_type)]
let map = snapshotter.snapshot().into_hashmap();
let protocol = vec![Label::new("protocol", "tls")];
assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
assert_histogram_sample_count_eq(
&map,
PROBE_DURATION_SECONDS,
protocol.clone(),
1,
Unit::Seconds,
);
assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 1);
assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
}
#[cfg(feature = "metrics")]
#[test]
fn test_opportunistic_probe_metrics_connection_timeout_error() {
subscribe();
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
let initial_budget = 10;
with_local_recorder(&recorder, || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
let _ = test_connected_mut_client(
IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
Arc::new(
PoolContext::new(ResolverOpts::default(), TlsConfig::new().unwrap())
.with_opportunistic_encryption()
.with_probe_budget(initial_budget),
),
&MockProvider {
new_connection_error: Some(NetError::Timeout),
..MockProvider::default()
},
)
.await;
});
});
#[allow(clippy::mutable_key_type)]
let map = snapshotter.snapshot().into_hashmap();
let protocol = vec![Label::new("protocol", "tls")];
assert_counter_eq(&map, PROBE_ATTEMPTS_TOTAL, protocol.clone(), 1);
assert_histogram_sample_count_eq(
&map,
PROBE_DURATION_SECONDS,
protocol.clone(),
1,
Unit::Seconds,
);
assert_counter_eq(&map, PROBE_TIMEOUTS_TOTAL, protocol.clone(), 1);
assert_counter_eq(&map, PROBE_ERRORS_TOTAL, protocol.clone(), 0);
assert_counter_eq(&map, PROBE_SUCCESSES_TOTAL, protocol, 0);
assert_gauge_eq(&map, PROBE_BUDGET_TOTAL, vec![], initial_budget);
}
async fn test_connected_mut_client(
ns_ip: IpAddr,
cx: Arc<PoolContext>,
provider: &MockProvider,
) -> Result<(), NetError> {
let name_server = NameServer::new(
[].into_iter(),
NameServerConfig::opportunistic_encryption(ns_ip),
&ResolverOpts::default(),
provider.clone(),
);
name_server
.connected_mut_client(ConnectionPolicy::default(), &cx)
.await
.map(|_| ())
}
}
#[cfg(all(test, feature = "metrics"))]
mod resolver_metrics_tests {
use std::net::{IpAddr, Ipv4Addr};
use metrics::{Label, with_local_recorder};
use metrics_util::debugging::DebuggingRecorder;
use mock_provider::MockProvider;
use test_support::assert_counter_eq;
use test_support::subscribe;
use super::*;
use crate::connection_provider::TlsConfig;
use crate::metrics::OUTGOING_QUERIES_TOTAL;
#[test]
fn test_outgoing_query_protocol_metrics_udp() {
subscribe();
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
with_local_recorder(&recorder, || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
let options = ResolverOpts::default();
let config = NameServerConfig::udp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
let name_server = Arc::new(NameServer::new(
[],
config,
&options,
MockProvider::default(),
));
let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
let name = Name::parse("www.example.com.", None).unwrap();
let _ = name_server
.send(
DnsRequest::from_query(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
),
ConnectionPolicy::default(),
&cx,
)
.await;
});
});
#[allow(clippy::mutable_key_type)]
let map = snapshotter.snapshot().into_hashmap();
let protocol = vec![Label::new("protocol", "udp")];
assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
}
#[test]
fn test_outgoing_query_protocol_metrics_tcp() {
subscribe();
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
with_local_recorder(&recorder, || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
let options = ResolverOpts::default();
let config = NameServerConfig::tcp(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
let name_server = Arc::new(NameServer::new(
[],
config,
&options,
MockProvider::default(),
));
let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
let name = Name::parse("www.example.com.", None).unwrap();
let _ = name_server
.send(
DnsRequest::from_query(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
),
ConnectionPolicy::default(),
&cx,
)
.await;
});
});
#[allow(clippy::mutable_key_type)]
let map = snapshotter.snapshot().into_hashmap();
let protocol = vec![Label::new("protocol", "tcp")];
assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
}
#[cfg(feature = "__tls")]
#[test]
fn test_outgoing_query_protocol_metrics_tls() {
subscribe();
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
with_local_recorder(&recorder, || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
let options = ResolverOpts::default();
let config = NameServerConfig::tls(
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
"dns.google".into(),
);
let name_server = Arc::new(NameServer::new(
[],
config,
&options,
MockProvider::default(),
));
let cx = Arc::new(PoolContext::new(options, TlsConfig::new().unwrap()));
let name = Name::parse("www.example.com.", None).unwrap();
let _ = name_server
.send(
DnsRequest::from_query(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
),
ConnectionPolicy::default(),
&cx,
)
.await;
});
});
#[allow(clippy::mutable_key_type)]
let map = snapshotter.snapshot().into_hashmap();
let protocol = vec![Label::new("protocol", "tls")];
assert_counter_eq(&map, OUTGOING_QUERIES_TOTAL, protocol, 1);
}
}
#[cfg(all(test, any(feature = "metrics", feature = "__tls")))]
mod mock_provider {
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::stream::once;
use futures_util::{Stream, future};
use tokio::net::UdpSocket;
use super::*;
use crate::config::ProtocolConfig;
use crate::net::runtime::TokioTime;
use crate::net::runtime::iocompat::AsyncIoTokioAsStd;
use crate::proto::op::Message;
#[derive(Clone)]
pub(super) struct MockProvider {
pub(super) runtime: MockSyncRuntimeProvider,
pub(super) new_connection_calls: Arc<SyncMutex<Vec<(IpAddr, ProtocolConfig)>>>,
pub(super) new_connection_error: Option<NetError>,
}
impl MockProvider {
pub(super) fn new_connection_calls(&self) -> Vec<(IpAddr, ProtocolConfig)> {
self.new_connection_calls.lock().clone()
}
}
impl ConnectionProvider for MockProvider {
type Conn = MockClientHandle;
type FutureConn = Pin<Box<dyn Send + Future<Output = Result<Self::Conn, NetError>>>>;
type RuntimeProvider = MockSyncRuntimeProvider;
fn new_connection(
&self,
ip: IpAddr,
config: &ConnectionConfig,
_cx: &PoolContext,
) -> Result<Self::FutureConn, NetError> {
self.new_connection_calls
.lock()
.push((ip, config.protocol.clone()));
Ok(Box::pin(future::ready(match &self.new_connection_error {
Some(err) => Err(err.clone()),
None => Ok(MockClientHandle),
})))
}
fn runtime_provider(&self) -> &Self::RuntimeProvider {
&self.runtime
}
}
impl Default for MockProvider {
fn default() -> Self {
Self {
runtime: MockSyncRuntimeProvider,
new_connection_calls: Arc::new(SyncMutex::new(Vec::new())),
new_connection_error: None,
}
}
}
#[derive(Clone, Default)]
pub(super) struct MockClientHandle;
impl DnsHandle for MockClientHandle {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send>>;
type Runtime = MockSyncRuntimeProvider;
fn send(&self, request: DnsRequest) -> Self::Response {
let mut response = Message::response(request.id, request.op_code);
response.metadata.response_code = ResponseCode::NoError;
response.add_queries(request.queries.clone());
Box::pin(once(future::ready(Ok(
DnsResponse::from_message(response).unwrap()
))))
}
}
#[derive(Clone)]
pub(super) struct MockSyncRuntimeProvider;
impl RuntimeProvider for MockSyncRuntimeProvider {
type Handle = MockSyncHandle;
type Timer = TokioTime;
type Udp = UdpSocket;
type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
fn create_handle(&self) -> Self::Handle {
MockSyncHandle
}
#[allow(clippy::unimplemented)]
fn connect_tcp(
&self,
_server_addr: std::net::SocketAddr,
_bind_addr: Option<std::net::SocketAddr>,
_timeout: Option<Duration>,
) -> Pin<Box<dyn Future<Output = Result<Self::Tcp, io::Error>> + Send>> {
unimplemented!();
}
#[allow(clippy::unimplemented)]
fn bind_udp(
&self,
_local_addr: std::net::SocketAddr,
_server_addr: std::net::SocketAddr,
) -> Pin<Box<dyn Future<Output = Result<Self::Udp, io::Error>> + Send>> {
unimplemented!();
}
}
#[derive(Clone)]
pub(super) struct MockSyncHandle;
impl Spawn for MockSyncHandle {
fn spawn_bg(&mut self, future: impl Future<Output = ()> + Send + 'static) {
let waker = futures_util::task::noop_waker();
let mut context = Context::from_waker(&waker);
let mut future = Box::pin(future);
loop {
match future.as_mut().poll(&mut context) {
Poll::Ready(_) => break,
Poll::Pending => continue,
}
}
}
}
}