use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use crate::simulation::TimeSource;
use super::config::{BASE_HISTORY_SIZE, DELAY_FILTER_SIZE};
pub(crate) const EMPTY_DELAY_NANOS: u64 = u64::MAX;
pub(crate) struct AtomicDelayFilter {
samples: [AtomicU64; DELAY_FILTER_SIZE],
write_index: AtomicUsize,
sample_count: AtomicUsize,
}
impl AtomicDelayFilter {
pub(crate) fn new() -> Self {
Self {
samples: std::array::from_fn(|_| AtomicU64::new(EMPTY_DELAY_NANOS)),
write_index: AtomicUsize::new(0),
sample_count: AtomicUsize::new(0),
}
}
pub(crate) fn add_sample(&self, rtt: Duration) {
let nanos = rtt.as_nanos() as u64;
let idx = self.write_index.fetch_add(1, Ordering::Relaxed) % DELAY_FILTER_SIZE;
self.samples[idx].store(nanos, Ordering::Release);
self.sample_count
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |count| {
if count < DELAY_FILTER_SIZE {
Some(count + 1)
} else {
None }
})
.ok();
}
pub(crate) fn filtered_delay(&self) -> Option<Duration> {
let mut min_nanos = u64::MAX;
for slot in &self.samples {
let nanos = slot.load(Ordering::Acquire);
if nanos != EMPTY_DELAY_NANOS && nanos < min_nanos {
min_nanos = nanos;
}
}
if min_nanos == u64::MAX {
None
} else {
Some(Duration::from_nanos(min_nanos))
}
}
pub(crate) fn is_ready(&self) -> bool {
self.sample_count.load(Ordering::Acquire) >= 2
}
}
pub(crate) struct AtomicBaseDelayHistory<T: TimeSource> {
buckets: [AtomicU64; BASE_HISTORY_SIZE],
bucket_count: AtomicUsize,
bucket_write_index: AtomicUsize,
current_minute_min: AtomicU64,
current_minute_start_nanos: AtomicU64,
time_source: T,
epoch_nanos: u64,
}
impl<T: TimeSource> AtomicBaseDelayHistory<T> {
pub(crate) fn new(time_source: T) -> Self {
let epoch_nanos = time_source.now_nanos();
Self {
buckets: std::array::from_fn(|_| AtomicU64::new(EMPTY_DELAY_NANOS)),
bucket_count: AtomicUsize::new(0),
bucket_write_index: AtomicUsize::new(0),
current_minute_min: AtomicU64::new(EMPTY_DELAY_NANOS),
current_minute_start_nanos: AtomicU64::new(0),
time_source,
epoch_nanos,
}
}
pub(crate) fn update(&self, rtt_sample: Duration) {
let rtt_nanos = rtt_sample.as_nanos() as u64;
let now_nanos = self.time_source.now_nanos() - self.epoch_nanos;
let minute_nanos = 60_000_000_000u64;
let minute_start = self.current_minute_start_nanos.load(Ordering::Acquire);
if now_nanos.saturating_sub(minute_start) >= minute_nanos {
let new_minute_start = now_nanos;
if self
.current_minute_start_nanos
.compare_exchange(
minute_start,
new_minute_start,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
let old_min = self
.current_minute_min
.swap(EMPTY_DELAY_NANOS, Ordering::AcqRel);
if old_min != EMPTY_DELAY_NANOS {
let idx =
self.bucket_write_index.fetch_add(1, Ordering::Relaxed) % BASE_HISTORY_SIZE;
self.buckets[idx].store(old_min, Ordering::Release);
self.bucket_count
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |count| {
if count < BASE_HISTORY_SIZE {
Some(count + 1)
} else {
None
}
})
.ok();
}
self.update_current_min(rtt_nanos);
} else {
self.update_current_min(rtt_nanos);
}
} else {
self.update_current_min(rtt_nanos);
}
}
fn update_current_min(&self, rtt_nanos: u64) {
self.current_minute_min
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if rtt_nanos < current {
Some(rtt_nanos)
} else {
None }
})
.ok();
}
pub(crate) fn base_delay(&self) -> Duration {
let mut historical_min = u64::MAX;
for bucket in &self.buckets {
let nanos = bucket.load(Ordering::Acquire);
if nanos != EMPTY_DELAY_NANOS && nanos < historical_min {
historical_min = nanos;
}
}
let current_min = self.current_minute_min.load(Ordering::Acquire);
let result_nanos = match (historical_min != u64::MAX, current_min != EMPTY_DELAY_NANOS) {
(true, true) => historical_min.min(current_min),
(true, false) => historical_min,
(false, true) => current_min,
(false, false) => 10_000_000, };
Duration::from_nanos(result_nanos)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulation::RealTime;
use std::sync::atomic::Ordering;
#[test]
fn test_atomic_delay_filter_concurrent() {
use std::sync::Arc;
use std::thread;
let filter = Arc::new(AtomicDelayFilter::new());
let num_threads = 4;
let iterations = 100;
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let filter = Arc::clone(&filter);
thread::spawn(move || {
for j in 0..iterations {
let rtt_ms = 10 + i * 5 + (j % 10);
filter.add_sample(Duration::from_millis(rtt_ms as u64));
}
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
assert!(filter.is_ready());
let min_delay = filter.filtered_delay().expect("Should have samples");
assert!(min_delay >= Duration::from_millis(10));
assert!(min_delay <= Duration::from_millis(50));
}
#[test]
fn test_atomic_base_delay_history_concurrent() {
use std::sync::Arc;
use std::thread;
let history = Arc::new(AtomicBaseDelayHistory::new(RealTime::new()));
let num_threads = 4;
let iterations = 100;
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let history = Arc::clone(&history);
thread::spawn(move || {
for j in 0..iterations {
let rtt_ms = 20 + i * 10 + (j % 15);
history.update(Duration::from_millis(rtt_ms as u64));
}
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let base = history.base_delay();
assert!(base >= Duration::from_millis(20));
assert!(base <= Duration::from_millis(100));
}
#[test]
fn test_minute_rollover_race_condition() {
use std::sync::Arc;
use std::thread;
let history = Arc::new(AtomicBaseDelayHistory::new(RealTime::new()));
history.update(Duration::from_millis(100));
assert_eq!(history.base_delay(), Duration::from_millis(100));
let expected_min = Arc::new(AtomicU64::new(u64::MAX));
let num_threads = 16;
let iterations = 100;
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let history = Arc::clone(&history);
let expected_min = Arc::clone(&expected_min);
thread::spawn(move || {
for j in 0..iterations {
let rtt_ms = 10 + i + (j % 10);
let rtt_nanos = rtt_ms as u64 * 1_000_000;
expected_min
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if rtt_nanos < current {
Some(rtt_nanos)
} else {
None
}
})
.ok();
history.update(Duration::from_millis(rtt_ms as u64));
if j % 10 == 0 {
thread::yield_now();
}
}
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let actual_min = history.base_delay();
let _expected = Duration::from_nanos(expected_min.load(Ordering::Relaxed));
assert!(
actual_min <= Duration::from_millis(20),
"Base delay {} should be <= 20ms (smallest thread's range)",
actual_min.as_millis()
);
assert!(
actual_min >= Duration::from_millis(10),
"Base delay {} should be >= 10ms (smallest value sent)",
actual_min.as_millis()
);
}
#[test]
fn test_no_minimum_value_lost_during_rollover() {
use std::sync::Arc;
use std::thread;
let history = Arc::new(AtomicBaseDelayHistory::new(RealTime::new()));
let num_threads = 8;
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let history = Arc::clone(&history);
thread::spawn(move || {
let rtt_ms = 5 + i;
for _ in 0..50 {
history.update(Duration::from_millis(rtt_ms as u64));
thread::yield_now();
}
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let base = history.base_delay();
assert_eq!(
base,
Duration::from_millis(5),
"Base delay should be 5ms (the smallest value sent), got {}ms",
base.as_millis()
);
}
}