use std::cmp::Ordering;
use std::sync::{
atomic::{self, AtomicU32},
Arc,
};
use parking_lot::Mutex;
use rand::Rng as _;
#[cfg(not(test))]
use std::time::{Duration, Instant};
#[cfg(test)]
use tokio::time::{Duration, Instant};
pub(crate) struct NameServerStats {
srtt_microseconds: AtomicU32,
last_update: Arc<Mutex<Option<Instant>>>,
}
impl Default for NameServerStats {
fn default() -> Self {
Self::new(Duration::from_micros(rand::thread_rng().gen_range(1..32)))
}
}
fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
exponent.exp()
}
impl NameServerStats {
const CONNECTION_FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
pub(crate) fn new(initial_srtt: Duration) -> Self {
Self {
srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
last_update: Arc::new(Mutex::new(None)),
}
}
pub(crate) fn record_rtt(&self, rtt: Duration) {
self.update_srtt(
rtt.as_micros() as u32,
|cur_srtt_microseconds, last_update| {
let factor = compute_srtt_factor(last_update, 3);
let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
+ factor * f64::from(cur_srtt_microseconds);
new_srtt.round() as u32
},
);
}
pub(crate) fn record_connection_failure(&self) {
self.update_srtt(
Self::CONNECTION_FAILURE_PENALTY,
|cur_srtt_microseconds, _last_update| {
cur_srtt_microseconds.saturating_add(Self::CONNECTION_FAILURE_PENALTY)
},
);
}
fn srtt(&self) -> Duration {
Duration::from_micros(u64::from(
self.srtt_microseconds.load(atomic::Ordering::Acquire),
))
}
fn decayed_srtt(&self) -> f64 {
let srtt = f64::from(self.srtt_microseconds.load(atomic::Ordering::Acquire));
self.last_update.lock().map_or(srtt, |last_update| {
srtt * compute_srtt_factor(last_update, 180)
})
}
fn update_srtt(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
let last_update = self.last_update.lock().replace(Instant::now());
let _ = self.srtt_microseconds.fetch_update(
atomic::Ordering::SeqCst,
atomic::Ordering::SeqCst,
move |cur_srtt_microseconds| {
Some(
last_update
.map_or(default, |last_update| {
update_fn(cur_srtt_microseconds, last_update)
})
.min(Self::MAX_SRTT_MICROS),
)
},
);
}
}
impl PartialEq for NameServerStats {
fn eq(&self, other: &Self) -> bool {
self.srtt() == other.srtt()
}
}
impl Eq for NameServerStats {}
fn total_cmp(x: f64, y: f64) -> Ordering {
let mut left = x.to_bits() as i64;
let mut right = y.to_bits() as i64;
left ^= (((left >> 63) as u64) >> 1) as i64;
right ^= (((right >> 63) as u64) >> 1) as i64;
left.cmp(&right)
}
impl Ord for NameServerStats {
fn cmp(&self, other: &Self) -> Ordering {
total_cmp(self.decayed_srtt(), other.decayed_srtt())
}
}
impl PartialOrd for NameServerStats {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(test)]
#[allow(clippy::extra_unused_type_parameters)]
mod tests {
use super::*;
fn is_send_sync<S: Sync + Send>() -> bool {
true
}
#[test]
fn stats_are_sync() {
assert!(is_send_sync::<NameServerStats>());
}
#[tokio::test(start_paused = true)]
async fn test_stats_cmp() {
let server_a = NameServerStats::new(Duration::from_micros(10));
let server_b = NameServerStats::new(Duration::from_micros(20));
assert_eq!(server_a.cmp(&server_b), Ordering::Less);
server_a.record_rtt(Duration::from_millis(30));
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(server_a.cmp(&server_b), Ordering::Greater);
server_b.record_rtt(Duration::from_millis(50));
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(server_a.cmp(&server_b), Ordering::Less);
server_a.record_connection_failure();
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(server_a.cmp(&server_b), Ordering::Greater);
while server_a.cmp(&server_b) != Ordering::Less {
server_b.record_rtt(Duration::from_millis(50));
tokio::time::advance(Duration::from_secs(5)).await;
}
server_a.record_rtt(Duration::from_millis(30));
tokio::time::advance(Duration::from_secs(3)).await;
assert_eq!(server_a.cmp(&server_b), Ordering::Less);
}
#[tokio::test(start_paused = true)]
async fn test_record_rtt() {
let server = NameServerStats::new(Duration::from_micros(10));
let first_rtt = Duration::from_millis(50);
server.record_rtt(first_rtt);
assert_eq!(server.srtt(), first_rtt);
tokio::time::advance(Duration::from_secs(3)).await;
server.record_rtt(Duration::from_millis(100));
assert_eq!(server.srtt(), Duration::from_micros(81606));
}
#[test]
fn test_record_rtt_maximum_value() {
let server = NameServerStats::new(Duration::from_micros(10));
server.record_rtt(Duration::MAX);
assert_eq!(
server.srtt(),
Duration::from_micros(NameServerStats::MAX_SRTT_MICROS.into())
);
}
#[tokio::test(start_paused = true)]
async fn test_record_connection_failure() {
let server = NameServerStats::new(Duration::from_micros(10));
for failure_count in 1..4 {
server.record_connection_failure();
assert_eq!(
server.srtt(),
Duration::from_micros(
NameServerStats::CONNECTION_FAILURE_PENALTY
.checked_mul(failure_count)
.expect("checked_mul overflow")
.into()
)
);
tokio::time::advance(Duration::from_secs(3)).await;
}
server.record_rtt(Duration::from_millis(50));
assert_eq!(server.srtt(), Duration::from_micros(197152));
}
#[test]
fn test_record_connection_failure_maximum_value() {
let server = NameServerStats::new(Duration::from_micros(10));
let num_failures =
(NameServerStats::MAX_SRTT_MICROS / NameServerStats::CONNECTION_FAILURE_PENALTY) + 1;
for _ in 0..num_failures {
server.record_connection_failure();
}
assert_eq!(
server.srtt(),
Duration::from_micros(NameServerStats::MAX_SRTT_MICROS.into())
);
}
#[tokio::test(start_paused = true)]
async fn test_decayed_srtt() {
let initial_srtt = 10;
let server = NameServerStats::new(Duration::from_micros(initial_srtt));
assert_eq!(server.decayed_srtt() as u32, initial_srtt as u32);
tokio::time::advance(Duration::from_secs(5)).await;
server.record_rtt(Duration::from_millis(100));
tokio::time::advance(Duration::from_millis(500)).await;
assert_eq!(server.decayed_srtt() as u32, 99445);
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(server.decayed_srtt() as u32, 96990);
}
}