use dashmap::DashMap;
use std::collections::HashSet;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct LatencyEntry {
pub smoothed_rtt: Duration,
pub jitter: Duration,
pub last_updated: Instant,
pub samples: u64,
}
pub struct LatencyMatrix {
entries: DashMap<(String, String), LatencyEntry>,
}
impl Default for LatencyMatrix {
fn default() -> Self {
Self {
entries: DashMap::new(),
}
}
}
impl LatencyMatrix {
pub fn new() -> Self {
Self::default()
}
pub fn update(&self, from: &str, to: &str, rtt: Duration) {
let key = (from.to_string(), to.to_string());
self.entries
.entry(key)
.and_modify(|e| {
let alpha = 0.125;
let rtt_us = rtt.as_micros() as f64;
let old_us = e.smoothed_rtt.as_micros() as f64;
let new_us = old_us * (1.0 - alpha) + rtt_us * alpha;
let diff = (rtt_us - old_us).abs();
let old_jitter = e.jitter.as_micros() as f64;
let new_jitter = old_jitter * (1.0 - alpha) + diff * alpha;
e.smoothed_rtt = Duration::from_micros(new_us as u64);
e.jitter = Duration::from_micros(new_jitter as u64);
e.last_updated = Instant::now();
e.samples += 1;
})
.or_insert_with(|| LatencyEntry {
smoothed_rtt: rtt,
jitter: Duration::ZERO,
last_updated: Instant::now(),
samples: 1,
});
}
pub fn get_rtt(&self, from: &str, to: &str) -> Option<Duration> {
let key = (from.to_string(), to.to_string());
self.entries.get(&key).map(|e| e.smoothed_rtt)
}
pub fn get_entry(&self, from: &str, to: &str) -> Option<LatencyEntry> {
let key = (from.to_string(), to.to_string());
self.entries.get(&key).map(|e| e.clone())
}
pub fn all_edges(&self) -> Vec<(String, String, Duration)> {
self.entries
.iter()
.map(|e| {
let (from, to) = e.key();
(from.clone(), to.clone(), e.value().smoothed_rtt)
})
.collect()
}
pub fn nodes(&self) -> Vec<String> {
let mut set = HashSet::new();
for entry in self.entries.iter() {
let (from, to) = entry.key();
set.insert(from.clone());
set.insert(to.clone());
}
set.into_iter().collect()
}
pub fn path_count(&self) -> usize {
self.entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn update_and_query() {
let m = LatencyMatrix::new();
m.update("syd", "sgp", Duration::from_millis(50));
assert_eq!(m.get_rtt("syd", "sgp").unwrap(), Duration::from_millis(50));
assert!(m.get_rtt("sgp", "syd").is_none());
}
#[test]
fn ewma_smoothing() {
let m = LatencyMatrix::new();
m.update("a", "b", Duration::from_millis(100));
m.update("a", "b", Duration::from_millis(200));
let rtt = m.get_rtt("a", "b").unwrap();
assert!(rtt.as_millis() >= 112 && rtt.as_millis() <= 113);
}
#[test]
fn all_edges_and_nodes() {
let m = LatencyMatrix::new();
m.update("syd", "sgp", Duration::from_millis(50));
m.update("sgp", "lon", Duration::from_millis(80));
assert_eq!(m.path_count(), 2);
let nodes = m.nodes();
assert!(nodes.contains(&"syd".to_string()));
assert!(nodes.contains(&"sgp".to_string()));
assert!(nodes.contains(&"lon".to_string()));
}
#[test]
fn jitter_tracking() {
let m = LatencyMatrix::new();
m.update("a", "b", Duration::from_millis(100));
let entry = m.entries.get(&("a".into(), "b".into())).unwrap();
assert_eq!(entry.jitter, Duration::ZERO);
drop(entry);
m.update("a", "b", Duration::from_millis(200));
let entry = m.entries.get(&("a".into(), "b".into())).unwrap();
assert!(entry.jitter > Duration::ZERO);
}
#[test]
fn multiple_ewma_convergence() {
let m = LatencyMatrix::new();
for _ in 0..50 {
m.update("x", "y", Duration::from_millis(100));
}
let rtt = m.get_rtt("x", "y").unwrap();
assert_eq!(rtt.as_millis(), 100);
}
#[test]
fn unknown_path_returns_none() {
let m = LatencyMatrix::new();
assert!(m.get_rtt("no", "path").is_none());
}
}