use std::collections::VecDeque;
use std::time::Instant;
use crate::mmp::{EWMA_LONG_ALPHA, EWMA_SHORT_ALPHA};
pub struct JitterEstimator {
jitter_q4: i64,
}
impl JitterEstimator {
pub fn new() -> Self {
Self { jitter_q4: 0 }
}
pub fn update(&mut self, transit_delta: i32) {
let abs_d = (transit_delta as i64).unsigned_abs() as i64;
self.jitter_q4 += abs_d - (self.jitter_q4 >> 4);
}
pub fn jitter_us(&self) -> u32 {
(self.jitter_q4 >> 4) as u32
}
}
impl Default for JitterEstimator {
fn default() -> Self {
Self::new()
}
}
pub struct SrttEstimator {
srtt_us: i64,
rttvar_us: i64,
initialized: bool,
}
impl SrttEstimator {
pub fn new() -> Self {
Self {
srtt_us: 0,
rttvar_us: 0,
initialized: false,
}
}
pub fn update(&mut self, rtt_us: i64) {
if !self.initialized {
self.srtt_us = rtt_us;
self.rttvar_us = rtt_us / 2;
self.initialized = true;
} else {
let err = (self.srtt_us - rtt_us).abs();
self.rttvar_us = self.rttvar_us - (self.rttvar_us >> 2) + (err >> 2);
self.srtt_us = self.srtt_us - (self.srtt_us >> 3) + (rtt_us >> 3);
}
}
pub fn srtt_us(&self) -> i64 {
self.srtt_us
}
pub fn rttvar_us(&self) -> i64 {
self.rttvar_us
}
pub fn initialized(&self) -> bool {
self.initialized
}
pub fn rto_us(&self) -> i64 {
let rto = self.srtt_us + (self.rttvar_us << 2).max(1_000_000);
rto.max(1_000_000)
}
}
impl Default for SrttEstimator {
fn default() -> Self {
Self::new()
}
}
pub struct DualEwma {
short: f64,
long: f64,
initialized: bool,
}
impl DualEwma {
pub fn new() -> Self {
Self {
short: 0.0,
long: 0.0,
initialized: false,
}
}
pub fn update(&mut self, sample: f64) {
if !self.initialized {
self.short = sample;
self.long = sample;
self.initialized = true;
} else {
self.short += EWMA_SHORT_ALPHA * (sample - self.short);
self.long += EWMA_LONG_ALPHA * (sample - self.long);
}
}
pub fn short(&self) -> f64 {
self.short
}
pub fn long(&self) -> f64 {
self.long
}
pub fn initialized(&self) -> bool {
self.initialized
}
}
impl Default for DualEwma {
fn default() -> Self {
Self::new()
}
}
pub struct OwdTrendDetector {
samples: VecDeque<(u32, i64)>,
capacity: usize,
}
impl OwdTrendDetector {
pub fn new(capacity: usize) -> Self {
Self {
samples: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn clear(&mut self) {
self.samples.clear();
}
pub fn push(&mut self, seq: u32, owd_us: i64) {
if self.samples.len() == self.capacity {
self.samples.pop_front();
}
self.samples.push_back((seq, owd_us));
}
pub fn trend_us_per_sec(&self) -> i32 {
let n = self.samples.len();
if n < 2 {
return 0;
}
let n_f = n as f64;
let sum_x: f64 = self.samples.iter().map(|(s, _)| *s as f64).sum();
let sum_y: f64 = self.samples.iter().map(|(_, y)| *y as f64).sum();
let mean_x = sum_x / n_f;
let mean_y = sum_y / n_f;
let mut num = 0.0;
let mut den = 0.0;
for &(x, y) in &self.samples {
let dx = x as f64 - mean_x;
let dy = y as f64 - mean_y;
num += dx * dy;
den += dx * dx;
}
if den.abs() < f64::EPSILON {
return 0;
}
let slope_per_packet = num / den;
(slope_per_packet * 1000.0) as i32
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
}
pub fn compute_etx(d_forward: f64, d_reverse: f64) -> f64 {
let product = d_forward * d_reverse;
if product <= 0.0 {
return 100.0;
}
(1.0 / product).clamp(1.0, 100.0)
}
pub struct SpinBitState {
is_initiator: bool,
current_value: bool,
highest_counter_for_spin: u64,
last_edge_time: Option<Instant>,
}
impl SpinBitState {
pub fn new(is_initiator: bool) -> Self {
Self {
is_initiator,
current_value: false,
highest_counter_for_spin: 0,
last_edge_time: None,
}
}
pub fn is_initiator(&self) -> bool {
self.is_initiator
}
pub fn tx_bit(&self) -> bool {
self.current_value
}
pub fn rx_observe(
&mut self,
received_bit: bool,
counter: u64,
now: Instant,
) -> Option<std::time::Duration> {
if self.is_initiator {
if received_bit == self.current_value {
let rtt = self.last_edge_time.map(|t| now.duration_since(t));
self.last_edge_time = Some(now);
self.current_value = !self.current_value;
rtt
} else {
None
}
} else {
if counter > self.highest_counter_for_spin {
self.highest_counter_for_spin = counter;
self.current_value = received_bit;
}
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jitter_zero_input() {
let mut j = JitterEstimator::new();
j.update(0);
assert_eq!(j.jitter_us(), 0);
}
#[test]
fn test_jitter_convergence() {
let mut j = JitterEstimator::new();
for _ in 0..200 {
j.update(1000);
}
let jitter = j.jitter_us();
assert!(
jitter > 900 && jitter < 1100,
"jitter={jitter}, expected ~1000"
);
}
#[test]
fn test_srtt_first_sample() {
let mut s = SrttEstimator::new();
s.update(10_000); assert_eq!(s.srtt_us(), 10_000);
assert_eq!(s.rttvar_us(), 5_000);
assert!(s.initialized());
}
#[test]
fn test_srtt_convergence() {
let mut s = SrttEstimator::new();
for _ in 0..100 {
s.update(50_000);
}
let srtt = s.srtt_us();
assert!((srtt - 50_000).abs() < 1000, "srtt={srtt}, expected ~50000");
}
#[test]
fn test_dual_ewma_initialization() {
let mut e = DualEwma::new();
assert!(!e.initialized());
e.update(100.0);
assert!(e.initialized());
assert_eq!(e.short(), 100.0);
assert_eq!(e.long(), 100.0);
}
#[test]
fn test_dual_ewma_short_tracks_faster() {
let mut e = DualEwma::new();
e.update(0.0);
for _ in 0..20 {
e.update(100.0);
}
assert!(
e.short() > e.long(),
"short={} long={}",
e.short(),
e.long()
);
}
#[test]
fn test_owd_trend_flat() {
let mut d = OwdTrendDetector::new(32);
for i in 0..20 {
d.push(i, 5000); }
let trend = d.trend_us_per_sec();
assert_eq!(trend, 0, "flat OWD should have zero trend");
}
#[test]
fn test_owd_trend_increasing() {
let mut d = OwdTrendDetector::new(32);
for i in 0..20 {
d.push(i, 5000 + (i as i64) * 100); }
let trend = d.trend_us_per_sec();
assert!(
trend > 0,
"increasing OWD should have positive trend, got {trend}"
);
}
#[test]
fn test_owd_trend_insufficient_samples() {
let mut d = OwdTrendDetector::new(32);
d.push(0, 5000);
assert_eq!(d.trend_us_per_sec(), 0);
}
#[test]
fn test_etx_perfect_link() {
assert!((compute_etx(1.0, 1.0) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_etx_lossy_link() {
let etx = compute_etx(0.9, 0.95);
assert!(etx > 1.0 && etx < 2.0, "etx={etx}");
}
#[test]
fn test_etx_zero_delivery() {
assert_eq!(compute_etx(0.0, 1.0), 100.0);
assert_eq!(compute_etx(1.0, 0.0), 100.0);
}
#[test]
fn test_spin_bit_initiator_rtt() {
let mut initiator = SpinBitState::new(true);
let mut responder = SpinBitState::new(false);
let t0 = Instant::now();
let t1 = t0 + std::time::Duration::from_millis(10);
let t2 = t0 + std::time::Duration::from_millis(20);
let bit_to_send = initiator.tx_bit();
assert!(!bit_to_send);
responder.rx_observe(bit_to_send, 1, t0);
assert!(!responder.tx_bit());
let resp_bit = responder.tx_bit();
let rtt1 = initiator.rx_observe(resp_bit, 2, t1);
assert!(rtt1.is_none());
let bit2 = initiator.tx_bit();
assert!(bit2);
responder.rx_observe(bit2, 3, t1);
assert!(responder.tx_bit());
let resp_bit2 = responder.tx_bit();
let rtt2 = initiator.rx_observe(resp_bit2, 4, t2);
assert!(rtt2.is_some());
}
#[test]
fn test_spin_bit_responder_counter_guard() {
let mut responder = SpinBitState::new(false);
responder.rx_observe(true, 5, Instant::now());
assert!(responder.tx_bit());
responder.rx_observe(false, 3, Instant::now());
assert!(responder.tx_bit()); }
}