use std::time::Duration;
use statrs::function::erf::erf;
use super::exp_weighted_olr::ExpWeightedOnlineLinearRegression;
#[derive(Clone)]
pub struct RTTPredictor {
model: ExpWeightedOnlineLinearRegression,
}
const BASE_SIZE_UNIT: f64 = 1024. * 1024.;
impl RTTPredictor {
pub fn new(decay_half_life_count: f64) -> Self {
Self {
model: ExpWeightedOnlineLinearRegression::new(decay_half_life_count),
}
}
pub fn update(&mut self, size_bytes: u64, duration: Duration, avg_concurrent: f64, weight: f64) {
let w = weight.max(0.0);
let x = (size_bytes as f64) / BASE_SIZE_UNIT; let concurrency_factor = avg_concurrent.max(1.0);
let x_eff = x * concurrency_factor;
let y = duration.as_secs_f64().max(1e-9);
self.model.update(w, x_eff, y);
}
pub fn predict(&self, size_bytes: u64, avg_concurrent: f64) -> (Option<f64>, Option<f64>) {
let x = (size_bytes as f64) / BASE_SIZE_UNIT;
let concurrency_factor = avg_concurrent.max(1.0);
let x_eff = x * concurrency_factor;
let (mean, std_dev) = self.model.predict(x_eff);
(mean.map(|m| m.max(0.0)), std_dev)
}
pub fn predicted_rtt(&self, size_bytes: u64, avg_concurrent: f64) -> Option<f64> {
self.predict(size_bytes, avg_concurrent).0
}
#[cfg(test)]
pub fn prediction_standard_error(&self, size_bytes: u64, avg_concurrent: f64) -> Option<f64> {
self.predict(size_bytes, avg_concurrent).1
}
pub fn predicted_bandwidth(&self) -> Option<f64> {
let query_bytes = 10 * 1024 * 1024;
let min_rtt = self.predicted_rtt(query_bytes, 1.)?;
Some(query_bytes as f64 / min_rtt.max(1e-6))
}
pub fn rtt_quantile(&self, observed_rtt_secs: f64, size_bytes: u64, avg_concurrent: f64) -> f64 {
let (mean_opt, std_dev_opt) = self.predict(size_bytes, avg_concurrent);
let Some(predicted) = mean_opt else {
return 0.5;
};
let Some(se) = std_dev_opt else {
return 0.5;
};
if !se.is_finite() || se <= 0.0 {
return 0.5;
}
let z = (observed_rtt_secs - predicted) / se;
let quantile = 0.5 * (1.0 + erf(z / std::f64::consts::SQRT_2));
quantile.clamp(0.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use approx::assert_abs_diff_eq;
use super::*;
#[test]
fn test_estimator_update() {
let mut estimator = RTTPredictor::new(10.0);
estimator.update(1_000_000, Duration::from_millis(500), 1., 1.0);
estimator.update(2_000_000, Duration::from_millis(1000), 1., 1.0); let expected = estimator.predicted_rtt(1_000_000, 1.).unwrap();
assert!(expected > 0.0);
}
#[test]
fn test_converges_to_constant_observation() {
for concurrency in [1., 5., 100.] {
let mut predictor = RTTPredictor::new(10.0);
predictor.update(1000, Duration::from_secs_f64(1.0), concurrency, 1.0);
predictor.update(2000, Duration::from_secs_f64(2.0), concurrency, 1.0);
for _ in 0..8 {
predictor.update(1000, Duration::from_secs_f64(1.0), concurrency, 1.0);
}
let prediction = predictor.predicted_rtt(1000, concurrency).unwrap();
assert_abs_diff_eq!(prediction, 1.0, epsilon = 0.01);
}
}
#[test]
fn test_scaling_with_size() {
let mut predictor = RTTPredictor::new(10.0);
let sizes = vec![
(1024 * 1024, 0.2), (5 * 1024 * 1024, 0.6), (10 * 1024 * 1024, 1.1), ];
for (size, duration) in sizes {
for _ in 0..5 {
predictor.update(size, Duration::from_secs_f64(duration), 1., 1.0);
}
}
let predicted_1mb = predictor.predicted_rtt(1024 * 1024, 1.).unwrap();
let predicted_5mb = predictor.predicted_rtt(5 * 1024 * 1024, 1.).unwrap();
let predicted_10mb = predictor.predicted_rtt(10 * 1024 * 1024, 1.).unwrap();
assert!(predicted_5mb > predicted_1mb);
assert!(predicted_10mb > predicted_5mb);
assert!(predicted_1mb > 0.0);
assert!(predicted_5mb > 0.0);
assert!(predicted_10mb > 0.0);
}
#[test]
fn test_scaling_with_concurrency() {
let mut predictor = RTTPredictor::new(10.0);
let sizes = vec![
(1024 * 1024, 0.2), (5 * 1024 * 1024, 0.6), (10 * 1024 * 1024, 1.1), ];
for (size, duration) in sizes {
for _ in 0..5 {
predictor.update(size, Duration::from_secs_f64(duration), 1., 1.0);
}
}
let test_size = 5 * 1024 * 1024;
let predicted_1 = predictor.predicted_rtt(test_size, 1.).unwrap();
let predicted_2 = predictor.predicted_rtt(test_size, 2.).unwrap();
let predicted_4 = predictor.predicted_rtt(test_size, 4.).unwrap();
assert!(predicted_2 > predicted_1, "predicted_2 ({}) should be > predicted_1 ({})", predicted_2, predicted_1);
assert!(predicted_4 > predicted_2, "predicted_4 ({}) should be > predicted_2 ({})", predicted_4, predicted_2);
let ratio_2_1 = predicted_2 / predicted_1;
let ratio_4_2 = predicted_4 / predicted_2;
assert!(ratio_2_1 > 1.0 && ratio_2_1 <= 2.5, "ratio_2_1 should be between 1.0 and 2.5, got {}", ratio_2_1);
assert!(ratio_4_2 > 1.0 && ratio_4_2 <= 2.5, "ratio_4_2 should be between 1.0 and 2.5, got {}", ratio_4_2);
}
#[test]
fn test_variance_tracking() {
let mut predictor = RTTPredictor::new(10.0);
for size_mb in [1, 2, 5, 10] {
let size_bytes = size_mb * 1024 * 1024;
let duration = 0.1 + 0.1 * size_mb as f64; predictor.update(size_bytes, Duration::from_secs_f64(duration), 1., 1.0);
}
assert!(predictor.predicted_rtt(5 * 1024 * 1024, 1.0).is_some());
let se = predictor.prediction_standard_error(5 * 1024 * 1024, 1.0).unwrap();
assert!(se >= 0.0);
}
#[test]
fn test_rtt_quantile() {
let mut predictor = RTTPredictor::new(10.0);
for size_mb in [1, 2, 5, 10] {
let size_bytes = size_mb * 1024 * 1024;
let duration = 0.1 + 0.1 * size_mb as f64;
predictor.update(size_bytes, Duration::from_secs_f64(duration), 1., 1.0);
}
let test_size = 5 * 1024 * 1024;
let predicted = predictor.predicted_rtt(test_size, 1.0).unwrap();
let quantile_at_predicted = predictor.rtt_quantile(predicted, test_size, 1.0);
assert!(
(quantile_at_predicted - 0.5).abs() < 0.1,
"Quantile at predicted should be ~0.5, got {}",
quantile_at_predicted
);
let quantile_low = predictor.rtt_quantile(predicted * 0.5, test_size, 1.0);
assert!(quantile_low < 0.5, "Quantile for low RTT should be < 0.5, got {}", quantile_low);
let quantile_high = predictor.rtt_quantile(predicted * 2.0, test_size, 1.0);
assert!(quantile_high > 0.5, "Quantile for high RTT should be > 0.5, got {}", quantile_high);
assert!(quantile_low < quantile_at_predicted);
assert!(quantile_at_predicted < quantile_high);
}
#[test]
fn test_rtt_quantile_with_insufficient_samples() {
let mut predictor = RTTPredictor::new(10.0);
predictor.update(1024 * 1024, Duration::from_secs_f64(0.2), 1., 1.0);
let quantile = predictor.rtt_quantile(0.5, 1024 * 1024, 1.0);
assert_eq!(quantile, 0.5, "Should return 0.5 with insufficient samples");
}
fn create_sample_pool(num_samples: usize, scale_std_dev: f64, seed: u64) -> Vec<(u64, Duration)> {
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use rand_distr::{Distribution, Normal};
let mut rng = StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, scale_std_dev).unwrap();
let mut samples = Vec::new();
for _ in 0..num_samples {
let size_mb = rng.random_range(32.0..64.0);
let size_bytes = (size_mb * 1024.0 * 1024.0) as u64;
let noise = normal.sample(&mut rng);
let duration_secs = (size_mb * 1.0 + noise).max(0.1);
samples.push((size_bytes, Duration::from_secs_f64(duration_secs)));
}
samples
}
fn create_predictor(half_life: f64, sample_pool: &[(u64, Duration)], weights: &[f64]) -> RTTPredictor {
let mut predictor = RTTPredictor::new(half_life);
for ((size_bytes, duration), &w) in sample_pool.iter().zip(weights.iter()) {
predictor.update(*size_bytes, *duration, 1.0, w);
}
predictor
}
fn assert_models_similar(p1: &RTTPredictor, p2: &RTTPredictor, seed: u64) {
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
let mut rng = StdRng::seed_from_u64(seed);
assert!(p1.model.approx_equals(&p2.model, 1e-5), "Regression models should be approximately equal");
for _ in 0..10 {
let test_size_mb = rng.random_range(32.0..64.0);
let test_size = (test_size_mb * 1024.0 * 1024.0) as u64;
let concurrency = rng.random_range(1.0..10.0);
let pred1 = p1.predicted_rtt(test_size, concurrency);
let pred2 = p2.predicted_rtt(test_size, concurrency);
if let (Some(p1_val), Some(p2_val)) = (pred1, pred2) {
assert_abs_diff_eq!(p1_val, p2_val, epsilon = 1e-6);
}
let se1 = p1.prediction_standard_error(test_size, concurrency);
let se2 = p2.prediction_standard_error(test_size, concurrency);
if let (Some(s1_val), Some(s2_val)) = (se1, se2) {
assert_abs_diff_eq!(s1_val, s2_val, epsilon = 1e-6);
}
}
}
#[test]
fn test_weighted_mean_and_variance_fundamentals() {
let half_life = 1e10;
let mut predictor = RTTPredictor::new(half_life);
let test_size = 10 * 1024 * 1024; let test_duration = Duration::from_secs_f64(1.1);
predictor.update(test_size, test_duration, 1.0, 1.0);
predictor.update(test_size * 2, Duration::from_secs_f64(test_duration.as_secs_f64() * 2.0), 1.0, 1.0);
for _ in 0..8 {
predictor.update(test_size, test_duration, 1.0, 1.0);
}
let predicted = predictor.predicted_rtt(test_size, 1.0).unwrap();
assert_abs_diff_eq!(predicted, test_duration.as_secs_f64(), epsilon = 1e-6);
let se = predictor.prediction_standard_error(test_size, 1.0).unwrap();
assert!(se < 0.1, "Standard error should be low for many identical samples");
let mut predictor2 = RTTPredictor::new(half_life);
predictor2.update(test_size, test_duration, 1.0, 0.1);
predictor2.update(test_size * 2, Duration::from_secs_f64(test_duration.as_secs_f64() * 2.0), 1.0, 0.1); predictor2.update(test_size, test_duration, 1.0, 0.9);
let predicted2 = predictor2.predicted_rtt(test_size, 1.0).unwrap();
assert_abs_diff_eq!(predicted2, test_duration.as_secs_f64(), epsilon = 1e-6);
}
#[test]
fn test_decay_and_weight_scaling() {
let half_life1 = 100.0;
let half_life2 = 1000.0;
let weights = [0.1; 20];
let sample_pool = create_sample_pool(20, 0.1, 1);
let predictor1 = create_predictor(half_life1, &sample_pool, &weights);
let scaled_weights: Vec<f64> = weights.iter().map(|&w| w * 10.0).collect();
let predictor2 = create_predictor(half_life2, &sample_pool, &scaled_weights);
let pred1 = predictor1.predicted_rtt(48 * 1024 * 1024, 1.0).unwrap();
let pred2 = predictor2.predicted_rtt(48 * 1024 * 1024, 1.0).unwrap();
assert_abs_diff_eq!(pred1, pred2, epsilon = pred1 * 0.01);
}
#[test]
fn test_order_independence_with_slow_decay() {
let half_life = 1e10;
let weights = [1.0; 20];
let sample_pool = create_sample_pool(20, 0.1, 8);
let predictor1 = create_predictor(half_life, &sample_pool, &weights);
let reversed_pool: Vec<_> = sample_pool.into_iter().rev().collect();
let predictor2 = create_predictor(half_life, &reversed_pool, &weights);
assert_models_similar(&predictor1, &predictor2, 100);
}
#[test]
fn test_variance_order_independence() {
let half_life = 1e10;
let weights = [1.0; 20];
let sample_pool = create_sample_pool(20, 0.1, 4);
let predictor1 = create_predictor(half_life, &sample_pool, &weights);
let reversed_pool: Vec<_> = sample_pool.into_iter().rev().collect();
let predictor2 = create_predictor(half_life, &reversed_pool, &weights);
assert_models_similar(&predictor1, &predictor2, 100);
}
#[test]
fn test_variance_same_samples_different_weight_order() {
let half_life = 1e10;
let weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0];
let sample_pool = create_sample_pool(11, 0.1, 5);
let predictor1 = create_predictor(half_life, &sample_pool, &weights);
let reversed_pool: Vec<_> = sample_pool.iter().rev().copied().collect();
let reversed_weights: Vec<f64> = weights.iter().rev().copied().collect();
let predictor2 = create_predictor(half_life, &reversed_pool, &reversed_weights);
assert_models_similar(&predictor1, &predictor2, 100);
}
#[test]
fn test_variance_same_samples_different_weight_scale() {
let half_life = 1e10;
let weights1 = [0.5; 11];
let weights2 = [1.0; 11];
let sample_pool = create_sample_pool(50, 0.1, 6);
let predictor1 = create_predictor(half_life, &sample_pool, &weights1);
let predictor2 = create_predictor(half_life, &sample_pool, &weights2);
let test_size = 48 * 1024 * 1024;
let pred1 = predictor1.predicted_rtt(test_size, 1.0).unwrap();
let pred2 = predictor2.predicted_rtt(test_size, 1.0).unwrap();
assert_abs_diff_eq!(pred1, pred2, epsilon = 1e-6);
let se1 = predictor1.prediction_standard_error(test_size, 1.0).unwrap();
let se2 = predictor2.prediction_standard_error(test_size, 1.0).unwrap();
assert!(se1 > 0.0 && se1.is_finite());
assert!(se2 > 0.0 && se2.is_finite());
}
#[test]
fn test_variance_with_partial_weights() {
let half_life = 1e10;
let weights = [0.5, 0.25, 0.5, 1.0, 0.5, 0.4, 0.5, 0.6, 0.7, 0.5, 1.0];
let sample_pool = create_sample_pool(11, 0.1, 7);
let predictor1 = create_predictor(half_life, &sample_pool, &weights);
let reversed_pool: Vec<_> = sample_pool.iter().rev().copied().collect();
let reversed_weights: Vec<f64> = weights.iter().rev().copied().collect();
let predictor2 = create_predictor(half_life, &reversed_pool, &reversed_weights);
let weights3 = [1.0; 11];
let predictor3 = create_predictor(half_life, &sample_pool, &weights3);
let predictor4 = create_predictor(half_life, &reversed_pool, &weights3);
let test_size = 48 * 1024 * 1024;
let pred1 = predictor1.predicted_rtt(test_size, 1.0).unwrap();
let pred2 = predictor2.predicted_rtt(test_size, 1.0).unwrap();
let pred3 = predictor3.predicted_rtt(test_size, 1.0).unwrap();
let pred4 = predictor4.predicted_rtt(test_size, 1.0).unwrap();
assert_abs_diff_eq!(pred1, pred2, epsilon = 1e-6);
assert_abs_diff_eq!(pred3, pred4, epsilon = 1e-6);
let se1 = predictor1.prediction_standard_error(test_size, 1.0).unwrap();
let se2 = predictor2.prediction_standard_error(test_size, 1.0).unwrap();
let se3 = predictor3.prediction_standard_error(test_size, 1.0).unwrap();
let se4 = predictor4.prediction_standard_error(test_size, 1.0).unwrap();
assert_abs_diff_eq!(se1, se2, epsilon = se1.min(se2) * 0.1);
assert_abs_diff_eq!(se3, se4, epsilon = se3.min(se4) * 0.1);
let pred1 = predictor1.predicted_rtt(test_size, 1.0).unwrap();
let pred2 = predictor2.predicted_rtt(test_size, 1.0).unwrap();
let pred3 = predictor3.predicted_rtt(test_size, 1.0).unwrap();
let pred4 = predictor4.predicted_rtt(test_size, 1.0).unwrap();
assert_abs_diff_eq!(pred1, pred2, epsilon = pred1.min(pred2) * 0.05);
assert_abs_diff_eq!(pred3, pred4, epsilon = pred3.min(pred4) * 0.05);
}
#[test]
fn test_linear_regression_against_fixed_calculation() {
let half_life = 1e10;
let sample_pool = create_sample_pool(10, 0.1, 9);
let weights = [1.0; 10];
let predictor = create_predictor(half_life, &sample_pool, &weights);
let mut sw = 0.0;
let mut sx = 0.0;
let mut sy = 0.0;
let mut sxx = 0.0;
let mut sxy = 0.0;
for (size_bytes, duration) in sample_pool.iter() {
let x = (*size_bytes as f64) / (1024.0 * 1024.0);
let y = duration.as_secs_f64();
sw += 1.0;
sx += x;
sy += y;
sxx += x * x;
sxy += x * y;
}
let xbar = sx / sw;
let ybar = sy / sw;
let sxx_c = sxx - sx * sx / sw;
let sxy_c = sxy - sx * sy / sw;
let b_expected = sxy_c / sxx_c;
let a_expected = ybar - b_expected * xbar;
let test_size = 48 * 1024 * 1024;
let x_test = (test_size as f64) / (1024.0 * 1024.0);
let y_expected = a_expected + b_expected * x_test;
let y_predicted = predictor.predicted_rtt(test_size, 1.0).unwrap();
assert_abs_diff_eq!(y_predicted, y_expected, epsilon = 1e-6);
}
#[test]
fn test_weights_scale_like_multiple_samples() {
let half_life = 1e10;
let sample_pool = create_sample_pool(5, 0.1, 10);
let mut samples1 = Vec::new();
let mut weights1 = Vec::new();
for (size_bytes, duration) in sample_pool.iter() {
samples1.push((*size_bytes, *duration));
weights1.push(0.25);
samples1.push((*size_bytes, *duration));
weights1.push(0.75);
}
let mut samples2 = Vec::new();
let weights2 = [0.5; 10];
for (size_bytes, duration) in sample_pool.iter() {
samples2.push((*size_bytes, *duration));
samples2.push((*size_bytes, *duration));
}
let predictor1 = create_predictor(half_life, &samples1, &weights1);
let predictor2 = create_predictor(half_life, &samples2, &weights2);
let test_size = 48 * 1024 * 1024;
let pred1 = predictor1.predicted_rtt(test_size, 1.0).unwrap();
let pred2 = predictor2.predicted_rtt(test_size, 1.0).unwrap();
assert_abs_diff_eq!(pred1, pred2, epsilon = 1e-6);
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
let mut rng = StdRng::seed_from_u64(100);
for _ in 0..10 {
let test_size_mb = rng.random_range(32.0..64.0);
let test_size = (test_size_mb * 1024.0 * 1024.0) as u64;
let concurrency = rng.random_range(1.0..10.0);
assert_abs_diff_eq!(
predictor1.predicted_rtt(test_size, concurrency).unwrap(),
predictor2.predicted_rtt(test_size, concurrency).unwrap(),
epsilon = 1e-6
);
}
let test_size = 48 * 1024 * 1024;
let se1 = predictor1.prediction_standard_error(test_size, 1.0).unwrap();
let se2 = predictor2.prediction_standard_error(test_size, 1.0).unwrap();
assert!(se1 > 0.0 && se1.is_finite());
assert!(se2 > 0.0 && se2.is_finite());
}
}