use std::time::Duration;
use crate::adapter::net::subnet::SubnetId;
pub struct PropagationModel {
pub base_hop_latency_nanos: u64,
pub level_multipliers: [f32; 4],
sample_count: u64,
}
impl PropagationModel {
pub const DEFAULT_BASE_HOP_NANOS: u64 = 100_000;
pub const DEFAULT_MULTIPLIERS: [f32; 4] = [1.0, 5.0, 50.0, 500.0];
pub fn new() -> Self {
Self {
base_hop_latency_nanos: Self::DEFAULT_BASE_HOP_NANOS,
level_multipliers: Self::DEFAULT_MULTIPLIERS,
sample_count: 0,
}
}
pub fn with_base_latency(base_nanos: u64) -> Self {
Self {
base_hop_latency_nanos: base_nanos,
level_multipliers: Self::DEFAULT_MULTIPLIERS,
sample_count: 0,
}
}
pub fn estimate_latency(&self, source: SubnetId, dest: SubnetId, hop_count: u8) -> Duration {
let depth = crossing_depth(source, dest);
let multiplier = if (depth as usize) < self.level_multipliers.len() {
self.level_multipliers[depth as usize]
} else {
*self.level_multipliers.last().unwrap_or(&1.0)
};
let hops = if hop_count == 0 { 1 } else { hop_count as u64 };
let nanos = (self.base_hop_latency_nanos as f64 * hops as f64 * multiplier as f64) as u64;
Duration::from_nanos(nanos)
}
pub fn crossing_depth(source: SubnetId, dest: SubnetId) -> u8 {
crossing_depth(source, dest)
}
pub fn calibrate(
&mut self,
_source: SubnetId,
_dest: SubnetId,
hop_count: u8,
measured_rtt_nanos: u64,
) {
if hop_count == 0 {
return;
}
let depth = crossing_depth(_source, _dest);
let multiplier = if (depth as usize) < self.level_multipliers.len() {
self.level_multipliers[depth as usize]
} else {
*self.level_multipliers.last().unwrap_or(&1.0)
};
if !multiplier.is_finite() || multiplier <= 0.0 {
return;
}
let per_hop_f = measured_rtt_nanos as f64 / (2.0 * hop_count as f64 * multiplier as f64);
if !per_hop_f.is_finite() || per_hop_f < 0.0 {
return;
}
const MAX_REASONABLE_PER_HOP_NANOS: f64 = 1_000_000_000.0;
let per_hop = per_hop_f.min(MAX_REASONABLE_PER_HOP_NANOS) as u64;
let alpha = if self.sample_count < 10 { 0.5 } else { 0.1 };
self.base_hop_latency_nanos =
(self.base_hop_latency_nanos as f64 * (1.0 - alpha) + per_hop as f64 * alpha) as u64;
self.sample_count += 1;
}
pub fn max_depth_within(&self, max_latency: Duration, hop_count: u8) -> Option<u8> {
let budget_nanos = max_latency.as_nanos() as u64;
let hops = if hop_count == 0 { 1 } else { hop_count as u64 };
for depth in (0..4u8).rev() {
let multiplier = self.level_multipliers[depth as usize];
let estimated =
(self.base_hop_latency_nanos as f64 * hops as f64 * multiplier as f64) as u64;
if estimated <= budget_nanos {
return Some(depth);
}
}
None
}
}
impl Default for PropagationModel {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for PropagationModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PropagationModel")
.field("base_hop_nanos", &self.base_hop_latency_nanos)
.field("multipliers", &self.level_multipliers)
.field("samples", &self.sample_count)
.finish()
}
}
pub fn crossing_depth(a: SubnetId, b: SubnetId) -> u8 {
if a.is_same_subnet(b) {
return 0;
}
if a.is_global() || b.is_global() {
return a.depth().max(b.depth());
}
for level in 0..4u8 {
if a.level(level) != b.level(level) {
let max_depth = a.depth().max(b.depth());
return max_depth.saturating_sub(level);
}
}
0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_crossing_depth_same() {
let a = SubnetId::new(&[1, 2, 3]);
assert_eq!(crossing_depth(a, a), 0);
}
#[test]
fn test_crossing_depth_sibling() {
let a = SubnetId::new(&[1, 2]);
let b = SubnetId::new(&[1, 3]);
assert_eq!(crossing_depth(a, b), 1); }
#[test]
fn test_crossing_depth_different_region() {
let a = SubnetId::new(&[1, 2]);
let b = SubnetId::new(&[2, 3]);
assert_eq!(crossing_depth(a, b), 2); }
#[test]
fn test_crossing_depth_global() {
let a = SubnetId::GLOBAL;
let b = SubnetId::new(&[1, 2, 3]);
assert_eq!(crossing_depth(a, b), 3);
}
#[test]
fn test_estimate_latency_same_subnet() {
let model = PropagationModel::new();
let subnet = SubnetId::new(&[1, 2]);
let latency = model.estimate_latency(subnet, subnet, 1);
assert_eq!(latency, Duration::from_nanos(100_000)); }
#[test]
fn test_estimate_latency_cross_region() {
let model = PropagationModel::new();
let a = SubnetId::new(&[1, 1]);
let b = SubnetId::new(&[2, 1]);
let latency = model.estimate_latency(a, b, 5);
assert!(latency > Duration::from_millis(20));
}
#[test]
fn test_calibrate() {
let mut model = PropagationModel::new();
let a = SubnetId::new(&[1]);
let b = SubnetId::new(&[1, 2]);
model.calibrate(a, b, 2, 20_000);
assert!(model.base_hop_latency_nanos < PropagationModel::DEFAULT_BASE_HOP_NANOS);
}
#[test]
fn calibrate_rejects_pathological_samples() {
let mut model = PropagationModel::new();
let a = SubnetId::new(&[1]);
let b = SubnetId::new(&[1, 2]);
let baseline = model.base_hop_latency_nanos;
model.calibrate(a, b, 1, u64::MAX);
assert!(
model.base_hop_latency_nanos < 1_000_000_000,
"EWMA must not be poisoned by pathological RTT (got {} ns)",
model.base_hop_latency_nanos,
);
let after_pathological = model.base_hop_latency_nanos;
for _ in 0..50 {
model.calibrate(a, b, 2, 20_000); }
assert!(
model.base_hop_latency_nanos < after_pathological,
"EWMA stuck after pathological sample (still at {} ns, started this phase at {})",
model.base_hop_latency_nanos,
after_pathological,
);
let _ = baseline;
}
#[test]
fn calibrate_rejects_nan_multiplier() {
let mut model = PropagationModel::new();
model.level_multipliers[0] = f32::NAN;
let a = SubnetId::new(&[1, 2]);
let baseline = model.base_hop_latency_nanos;
model.calibrate(a, a, 2, 50_000);
assert_eq!(
model.base_hop_latency_nanos, baseline,
"NaN multiplier must skip calibration, not corrupt the EWMA",
);
}
#[test]
fn test_max_depth_within() {
let model = PropagationModel::new();
let depth = model.max_depth_within(Duration::from_millis(1), 1);
assert!(depth.unwrap() >= 1); }
}