use std::sync::atomic::{AtomicU64, Ordering};
use super::config::BW_FILTER_SIZE;
pub(crate) struct BandwidthFilter {
samples: [AtomicU64; BW_FILTER_SIZE],
timestamps: [AtomicU64; BW_FILTER_SIZE],
write_idx: AtomicU64,
window_nanos: u64,
}
impl BandwidthFilter {
#[allow(clippy::declare_interior_mutable_const)]
pub(crate) fn new(window: std::time::Duration) -> Self {
const ZERO: AtomicU64 = AtomicU64::new(0);
Self {
samples: [ZERO; BW_FILTER_SIZE],
timestamps: [ZERO; BW_FILTER_SIZE],
write_idx: AtomicU64::new(0),
window_nanos: window.as_nanos() as u64,
}
}
pub(crate) fn update(&self, bw: u64, now_nanos: u64) {
let idx = self.write_idx.fetch_add(1, Ordering::AcqRel) as usize % BW_FILTER_SIZE;
self.samples[idx].store(bw, Ordering::Release);
self.timestamps[idx].store(now_nanos, Ordering::Release);
}
pub(crate) fn max_bw(&self, now_nanos: u64) -> u64 {
let cutoff = now_nanos.saturating_sub(self.window_nanos);
let mut max = 0u64;
for i in 0..BW_FILTER_SIZE {
let timestamp = self.timestamps[i].load(Ordering::Acquire);
if timestamp >= cutoff {
let bw = self.samples[i].load(Ordering::Acquire);
max = max.max(bw);
}
}
max
}
pub(crate) fn reset(&self) {
for i in 0..BW_FILTER_SIZE {
self.samples[i].store(0, Ordering::Release);
self.timestamps[i].store(0, Ordering::Release);
}
self.write_idx.store(0, Ordering::Release);
}
pub(crate) fn has_samples(&self, now_nanos: u64) -> bool {
let cutoff = now_nanos.saturating_sub(self.window_nanos);
for i in 0..BW_FILTER_SIZE {
let timestamp = self.timestamps[i].load(Ordering::Acquire);
if timestamp >= cutoff && timestamp > 0 {
return true;
}
}
false
}
}
impl std::fmt::Debug for BandwidthFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BandwidthFilter")
.field("window_nanos", &self.window_nanos)
.field("write_idx", &self.write_idx.load(Ordering::Relaxed))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_bandwidth_filter_basic() {
let filter = BandwidthFilter::new(Duration::from_secs(10));
filter.update(1_000_000, 1_000_000_000); filter.update(2_000_000, 2_000_000_000); filter.update(1_500_000, 3_000_000_000);
let max = filter.max_bw(5_000_000_000); assert_eq!(max, 2_000_000);
}
#[test]
fn test_bandwidth_filter_expiration() {
let filter = BandwidthFilter::new(Duration::from_secs(5));
filter.update(1_000_000, 0);
assert_eq!(filter.max_bw(4_000_000_000), 1_000_000);
assert_eq!(filter.max_bw(6_000_000_000), 0);
}
#[test]
fn test_bandwidth_filter_windowed_max() {
let filter = BandwidthFilter::new(Duration::from_secs(5));
filter.update(100, 1_000_000_000); filter.update(500, 2_000_000_000); filter.update(300, 3_000_000_000);
assert_eq!(filter.max_bw(4_000_000_000), 500);
assert_eq!(filter.max_bw(7_000_000_000), 500);
assert_eq!(filter.max_bw(8_000_000_000), 300);
}
#[test]
fn test_bandwidth_filter_reset() {
let filter = BandwidthFilter::new(Duration::from_secs(10));
filter.update(1_000_000, 1_000_000_000);
assert_eq!(filter.max_bw(2_000_000_000), 1_000_000);
filter.reset();
assert_eq!(filter.max_bw(2_000_000_000), 0);
}
#[test]
fn test_bandwidth_filter_has_samples() {
let filter = BandwidthFilter::new(Duration::from_secs(5));
assert!(!filter.has_samples(1_000_000_000));
filter.update(1000, 1_000_000_000);
assert!(filter.has_samples(2_000_000_000));
assert!(!filter.has_samples(7_000_000_000));
}
}