use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use super::Uplink;
use crate::types::Bandwidth;
pub const DEFAULT_MTU: u32 = 1500;
pub const MIN_MTU: u32 = 1280;
pub const MAX_MTU: u32 = 9000;
pub const DEFAULT_PACING_INTERVAL: Duration = Duration::from_micros(100);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThroughputConfig {
#[serde(default = "default_probing_enabled")]
pub probing_enabled: bool,
#[serde(default = "default_probe_interval", with = "humantime_serde")]
pub probe_interval: Duration,
#[serde(default = "default_probe_size")]
pub probe_size: usize,
#[serde(default = "default_probe_burst")]
pub probe_burst_count: usize,
#[serde(default = "default_pmtud_enabled")]
pub pmtud_enabled: bool,
#[serde(default = "default_pacing_enabled")]
pub pacing_enabled: bool,
#[serde(default = "default_buffer_target")]
pub buffer_target_fraction: f64,
#[serde(default = "default_max_latency_ms")]
pub max_acceptable_latency_ms: u32,
#[serde(default = "default_latency_weight")]
pub latency_weight: f64,
#[serde(default = "default_bbr_enabled")]
pub bbr_estimation: bool,
#[serde(default = "default_min_rtt_window")]
pub min_rtt_window_samples: usize,
#[serde(default = "default_probe_gain")]
pub probe_gain: f64,
#[serde(default = "default_drain_gain")]
pub drain_gain: f64,
#[serde(default = "default_batching_enabled")]
pub frame_batching: bool,
#[serde(default = "default_batch_delay", with = "humantime_serde")]
pub max_batch_delay: Duration,
#[serde(default = "default_batch_size")]
pub max_batch_size: usize,
}
fn default_probing_enabled() -> bool {
true
}
fn default_probe_interval() -> Duration {
Duration::from_secs(1)
}
fn default_probe_size() -> usize {
1400
}
fn default_probe_burst() -> usize {
10
}
fn default_pmtud_enabled() -> bool {
true
}
fn default_pacing_enabled() -> bool {
true
}
fn default_buffer_target() -> f64 {
0.5
}
fn default_max_latency_ms() -> u32 {
500
}
fn default_latency_weight() -> f64 {
0.4
}
fn default_bbr_enabled() -> bool {
true
}
fn default_min_rtt_window() -> usize {
10
}
fn default_probe_gain() -> f64 {
1.25
}
fn default_drain_gain() -> f64 {
0.75
}
fn default_batching_enabled() -> bool {
false
}
fn default_batch_delay() -> Duration {
Duration::from_millis(1)
}
fn default_batch_size() -> usize {
16384
}
impl Default for ThroughputConfig {
fn default() -> Self {
Self {
probing_enabled: default_probing_enabled(),
probe_interval: default_probe_interval(),
probe_size: default_probe_size(),
probe_burst_count: default_probe_burst(),
pmtud_enabled: default_pmtud_enabled(),
pacing_enabled: default_pacing_enabled(),
buffer_target_fraction: default_buffer_target(),
max_acceptable_latency_ms: default_max_latency_ms(),
latency_weight: default_latency_weight(),
bbr_estimation: default_bbr_enabled(),
min_rtt_window_samples: default_min_rtt_window(),
probe_gain: default_probe_gain(),
drain_gain: default_drain_gain(),
frame_batching: default_batching_enabled(),
max_batch_delay: default_batch_delay(),
max_batch_size: default_batch_size(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BbrState {
Startup,
Drain,
ProbeBandwidth,
ProbeRtt,
}
impl Default for BbrState {
fn default() -> Self {
Self::Startup
}
}
#[derive(Debug)]
pub struct BdpEstimator {
max_bandwidth: RwLock<f64>,
min_rtt: RwLock<Duration>,
rtt_samples: RwLock<VecDeque<(Instant, Duration)>>,
bw_samples: RwLock<VecDeque<(Instant, f64)>>,
window_duration: Duration,
bbr_state: RwLock<BbrState>,
state_cycles: AtomicU32,
last_transition: RwLock<Instant>,
pacing_rate: AtomicU64,
cwnd_bytes: AtomicU64,
inflight_bytes: AtomicU64,
delivery_rate: RwLock<f64>,
round_count: AtomicU64,
probe_rtt_round: RwLock<Option<u64>>,
}
impl BdpEstimator {
pub fn new(window_duration: Duration) -> Self {
Self {
max_bandwidth: RwLock::new(0.0),
min_rtt: RwLock::new(Duration::MAX),
rtt_samples: RwLock::new(VecDeque::with_capacity(100)),
bw_samples: RwLock::new(VecDeque::with_capacity(100)),
window_duration,
bbr_state: RwLock::new(BbrState::default()),
state_cycles: AtomicU32::new(0),
last_transition: RwLock::new(Instant::now()),
pacing_rate: AtomicU64::new(0),
cwnd_bytes: AtomicU64::new(10 * 1500), inflight_bytes: AtomicU64::new(0),
delivery_rate: RwLock::new(0.0),
round_count: AtomicU64::new(0),
probe_rtt_round: RwLock::new(None),
}
}
pub fn record_rtt(&self, rtt: Duration) {
let now = Instant::now();
{
let mut min = self.min_rtt.write();
if rtt < *min {
*min = rtt;
}
}
{
let mut samples = self.rtt_samples.write();
samples.push_back((now, rtt));
while let Some((ts, _)) = samples.front() {
if now.duration_since(*ts) > self.window_duration {
samples.pop_front();
} else {
break;
}
}
}
self.maybe_transition_state();
}
pub fn record_bandwidth(&self, bytes_per_sec: f64) {
let now = Instant::now();
{
let mut max = self.max_bandwidth.write();
if bytes_per_sec > *max {
*max = bytes_per_sec;
}
}
*self.delivery_rate.write() = bytes_per_sec;
{
let mut samples = self.bw_samples.write();
samples.push_back((now, bytes_per_sec));
while let Some((ts, _)) = samples.front() {
if now.duration_since(*ts) > self.window_duration {
samples.pop_front();
} else {
break;
}
}
}
self.update_pacing_rate();
}
pub fn record_send(&self, bytes: u64) {
self.inflight_bytes.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_ack(&self, bytes: u64) {
let prev = self.inflight_bytes.fetch_sub(bytes, Ordering::Relaxed);
if bytes > prev {
self.inflight_bytes.store(0, Ordering::Relaxed);
}
self.round_count.fetch_add(1, Ordering::Relaxed);
}
pub fn bdp(&self) -> u64 {
let bandwidth = *self.max_bandwidth.read();
let rtt = *self.min_rtt.read();
if rtt == Duration::MAX || bandwidth == 0.0 {
return 10 * 1500; }
(bandwidth * rtt.as_secs_f64()) as u64
}
pub fn optimal_cwnd(&self, target_fraction: f64) -> u64 {
let bdp = self.bdp();
((bdp as f64) * target_fraction).max(2.0 * 1500.0) as u64
}
pub fn pacing_rate(&self) -> u64 {
self.pacing_rate.load(Ordering::Relaxed)
}
pub fn state(&self) -> BbrState {
*self.bbr_state.read()
}
pub fn min_rtt(&self) -> Duration {
let rtt = *self.min_rtt.read();
if rtt == Duration::MAX {
Duration::ZERO
} else {
rtt
}
}
pub fn max_bandwidth(&self) -> Bandwidth {
Bandwidth::from_bps(*self.max_bandwidth.read())
}
pub fn inflight(&self) -> u64 {
self.inflight_bytes.load(Ordering::Relaxed)
}
pub fn can_send(&self) -> bool {
self.inflight_bytes.load(Ordering::Relaxed) < self.cwnd_bytes.load(Ordering::Relaxed)
}
pub fn pacing_interval(&self, packet_size: usize) -> Duration {
let rate = self.pacing_rate.load(Ordering::Relaxed);
if rate == 0 {
return Duration::ZERO;
}
let seconds = packet_size as f64 / rate as f64;
Duration::from_secs_f64(seconds)
}
fn update_pacing_rate(&self) {
let bw = *self.max_bandwidth.read();
let state = *self.bbr_state.read();
let gain = match state {
BbrState::Startup => 2.0, BbrState::Drain => 0.5, BbrState::ProbeBandwidth => 1.0, BbrState::ProbeRtt => 1.0, };
let rate = (bw * gain) as u64;
self.pacing_rate.store(rate, Ordering::Relaxed);
let bdp = self.bdp();
let cwnd = match state {
BbrState::Startup => bdp * 2, BbrState::Drain => bdp, BbrState::ProbeBandwidth => bdp, BbrState::ProbeRtt => 4 * 1500, };
self.cwnd_bytes.store(cwnd.max(4 * 1500), Ordering::Relaxed);
}
fn maybe_transition_state(&self) {
let mut state = self.bbr_state.write();
let cycles = self.state_cycles.load(Ordering::Relaxed);
let transition_elapsed = self.last_transition.read().elapsed();
let new_state = match *state {
BbrState::Startup => {
let bw_samples = self.bw_samples.read();
if bw_samples.len() >= 3 {
let recent: Vec<f64> =
bw_samples.iter().rev().take(3).map(|(_, b)| *b).collect();
let growth = if recent.len() >= 2 && recent[1] > 0.0 {
recent[0] / recent[1]
} else {
2.0
};
if growth < 1.25 {
Some(BbrState::Drain)
} else {
None
}
} else {
None
}
}
BbrState::Drain => {
let inflight = self.inflight_bytes.load(Ordering::Relaxed);
let bdp = self.bdp();
if inflight <= bdp {
Some(BbrState::ProbeBandwidth)
} else {
None
}
}
BbrState::ProbeBandwidth => {
if transition_elapsed > Duration::from_secs(10) {
Some(BbrState::ProbeRtt)
} else {
None
}
}
BbrState::ProbeRtt => {
let inflight = self.inflight_bytes.load(Ordering::Relaxed);
let probe_round = self.probe_rtt_round.read();
let round = self.round_count.load(Ordering::Relaxed);
if inflight <= 4 * 1500 {
if let Some(start_round) = *probe_round {
if round > start_round {
drop(probe_round);
*self.probe_rtt_round.write() = None;
Some(BbrState::ProbeBandwidth)
} else {
None
}
} else {
drop(probe_round);
*self.probe_rtt_round.write() = Some(round);
None
}
} else {
None
}
}
};
if let Some(new) = new_state {
*state = new;
self.state_cycles.store(0, Ordering::Relaxed);
*self.last_transition.write() = Instant::now();
} else {
self.state_cycles.fetch_add(1, Ordering::Relaxed);
}
}
pub fn reset(&self) {
*self.max_bandwidth.write() = 0.0;
*self.min_rtt.write() = Duration::MAX;
self.rtt_samples.write().clear();
self.bw_samples.write().clear();
*self.bbr_state.write() = BbrState::Startup;
self.state_cycles.store(0, Ordering::Relaxed);
*self.last_transition.write() = Instant::now();
self.pacing_rate.store(0, Ordering::Relaxed);
self.cwnd_bytes.store(10 * 1500, Ordering::Relaxed);
self.inflight_bytes.store(0, Ordering::Relaxed);
}
}
#[derive(Debug)]
pub struct PmtudState {
current_mtu: AtomicU32,
last_good_mtu: AtomicU32,
probe_mtu: AtomicU32,
state: RwLock<PmtudPhase>,
last_probe: RwLock<Instant>,
failures: AtomicU32,
history: RwLock<VecDeque<(u32, bool)>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PmtudPhase {
Search,
Stable,
Verify,
}
impl PmtudState {
pub fn new(initial_mtu: u32) -> Self {
Self {
current_mtu: AtomicU32::new(initial_mtu),
last_good_mtu: AtomicU32::new(MIN_MTU),
probe_mtu: AtomicU32::new(initial_mtu),
state: RwLock::new(PmtudPhase::Search),
last_probe: RwLock::new(Instant::now()),
failures: AtomicU32::new(0),
history: RwLock::new(VecDeque::with_capacity(20)),
}
}
pub fn mtu(&self) -> u32 {
self.current_mtu.load(Ordering::Relaxed)
}
pub fn optimal_payload_size(&self, header_overhead: u32) -> u32 {
self.current_mtu
.load(Ordering::Relaxed)
.saturating_sub(header_overhead)
}
pub fn record_success(&self, size: u32) {
let current = self.current_mtu.load(Ordering::Relaxed);
if size >= current {
self.last_good_mtu.store(size, Ordering::Relaxed);
self.failures.store(0, Ordering::Relaxed);
let mut history = self.history.write();
history.push_back((size, true));
if history.len() > 20 {
history.pop_front();
}
}
let mut state = self.state.write();
if *state == PmtudPhase::Search {
let probe = self.probe_mtu.load(Ordering::Relaxed);
if size >= probe && probe < MAX_MTU {
let new_probe = ((probe + MAX_MTU) / 2).min(MAX_MTU);
self.probe_mtu.store(new_probe, Ordering::Relaxed);
} else if size >= probe {
self.current_mtu.store(probe, Ordering::Relaxed);
*state = PmtudPhase::Stable;
}
}
}
pub fn record_failure(&self, size: u32) {
let failures = self.failures.fetch_add(1, Ordering::Relaxed) + 1;
{
let mut history = self.history.write();
history.push_back((size, false));
if history.len() > 20 {
history.pop_front();
}
}
let mut state = self.state.write();
match *state {
PmtudPhase::Search => {
let probe = self.probe_mtu.load(Ordering::Relaxed);
let good = self.last_good_mtu.load(Ordering::Relaxed);
if probe - good < 32 {
self.current_mtu.store(good, Ordering::Relaxed);
self.probe_mtu.store(good, Ordering::Relaxed);
*state = PmtudPhase::Stable;
} else {
let new_probe = (probe + good) / 2;
self.probe_mtu.store(new_probe, Ordering::Relaxed);
}
}
PmtudPhase::Stable => {
if failures >= 3 {
let current = self.current_mtu.load(Ordering::Relaxed);
let reduced = (current * 3 / 4).max(MIN_MTU);
self.current_mtu.store(reduced, Ordering::Relaxed);
self.probe_mtu.store(reduced, Ordering::Relaxed);
self.failures.store(0, Ordering::Relaxed);
*state = PmtudPhase::Search;
}
}
PmtudPhase::Verify => {
let current = self.current_mtu.load(Ordering::Relaxed);
let reduced = (current * 3 / 4).max(MIN_MTU);
self.current_mtu.store(reduced, Ordering::Relaxed);
self.probe_mtu.store(reduced, Ordering::Relaxed);
self.failures.store(0, Ordering::Relaxed);
*state = PmtudPhase::Search;
}
}
}
pub fn next_probe_mtu(&self) -> u32 {
self.probe_mtu.load(Ordering::Relaxed)
}
pub fn should_probe(&self) -> bool {
let state = *self.state.read();
match state {
PmtudPhase::Search => true,
PmtudPhase::Stable => {
self.last_probe.read().elapsed() > Duration::from_secs(60)
}
PmtudPhase::Verify => true,
}
}
pub fn probe_sent(&self) {
*self.last_probe.write() = Instant::now();
}
}
impl Default for PmtudState {
fn default() -> Self {
Self::new(DEFAULT_MTU)
}
}
#[derive(Debug, Clone, Copy)]
pub struct EffectiveThroughput {
pub bandwidth_bps: f64,
pub rtt: Duration,
pub loss_ratio: f64,
pub jitter: Duration,
pub score: f64,
pub time_for_1mb: Duration,
pub bdp: u64,
}
impl EffectiveThroughput {
pub fn calculate(
bandwidth_bps: f64,
rtt: Duration,
loss_ratio: f64,
jitter: Duration,
config: &ThroughputConfig,
) -> Self {
let rtt_secs = rtt.as_secs_f64().max(0.001); let loss_adjusted = (1.0 - loss_ratio).max(0.01);
let jitter_penalty = 1.0 / (1.0 + jitter.as_secs_f64() * 10.0);
let loss_throughput_factor = if loss_ratio > 0.0 {
1.0 / (1.0 + loss_ratio * 2.0)
} else {
1.0
};
let effective_bandwidth = bandwidth_bps * loss_throughput_factor;
let max_latency_secs = config.max_acceptable_latency_ms as f64 / 1000.0;
let latency_penalty = if rtt_secs > max_latency_secs {
(max_latency_secs / rtt_secs).powi(2)
} else {
1.0 - (rtt_secs / max_latency_secs) * config.latency_weight
};
let bdp = (bandwidth_bps * rtt_secs) as u64;
let bytes_1mb = 1024.0 * 1024.0;
let transfer_time = if effective_bandwidth > 0.0 {
bytes_1mb / effective_bandwidth
} else {
3600.0
};
let total_transfer_time = (transfer_time + rtt_secs).min(3600.0);
let normalized_bw = (effective_bandwidth / 125_000_000.0).min(1.0);
let score = normalized_bw * latency_penalty * loss_adjusted * jitter_penalty;
Self {
bandwidth_bps,
rtt,
loss_ratio,
jitter,
score,
time_for_1mb: Duration::from_secs_f64(total_transfer_time),
bdp,
}
}
pub fn is_better_than(&self, other: &Self) -> bool {
self.score > other.score
}
pub fn faster_for_size(&self, other: &Self, size_bytes: u64) -> bool {
let self_time = self.transfer_time(size_bytes);
let other_time = other.transfer_time(size_bytes);
self_time < other_time
}
pub fn transfer_time(&self, bytes: u64) -> Duration {
let loss_factor = 1.0 / (1.0 - self.loss_ratio).max(0.01);
let bandwidth = self.bandwidth_bps.max(1.0);
let transfer_secs = (bytes as f64 * loss_factor) / bandwidth;
let total_secs = (transfer_secs + self.rtt.as_secs_f64()).min(3600.0);
Duration::from_secs_f64(total_secs)
}
}
#[derive(Debug)]
pub struct FrameBatcher {
pending: RwLock<Vec<Vec<u8>>>,
pending_size: AtomicU64,
first_pending: RwLock<Option<Instant>>,
max_delay: Duration,
max_size: usize,
}
impl FrameBatcher {
pub fn new(max_delay: Duration, max_size: usize) -> Self {
Self {
pending: RwLock::new(Vec::new()),
pending_size: AtomicU64::new(0),
first_pending: RwLock::new(None),
max_delay,
max_size,
}
}
pub fn add(&self, frame: Vec<u8>) -> Option<Vec<Vec<u8>>> {
let frame_size = frame.len();
let mut pending = self.pending.write();
let prev_size = self
.pending_size
.fetch_add(frame_size as u64, Ordering::Relaxed);
if pending.is_empty() {
*self.first_pending.write() = Some(Instant::now());
}
pending.push(frame);
let total_size = prev_size as usize + frame_size;
let first_time = *self.first_pending.read();
let should_flush = total_size >= self.max_size
|| first_time.is_some_and(|t| t.elapsed() >= self.max_delay);
if should_flush {
self.pending_size.store(0, Ordering::Relaxed);
*self.first_pending.write() = None;
Some(std::mem::take(&mut *pending))
} else {
None
}
}
pub fn flush(&self) -> Vec<Vec<u8>> {
self.pending_size.store(0, Ordering::Relaxed);
*self.first_pending.write() = None;
std::mem::take(&mut *self.pending.write())
}
pub fn is_ready(&self) -> bool {
let first = *self.first_pending.read();
first.is_some_and(|t| t.elapsed() >= self.max_delay)
}
pub fn pending_size(&self) -> usize {
self.pending_size.load(Ordering::Relaxed) as usize
}
}
#[derive(Debug)]
pub struct UplinkThroughputState {
pub bdp: BdpEstimator,
pub pmtud: PmtudState,
pub batcher: Option<FrameBatcher>,
pub last_probe: RwLock<Instant>,
pub probe_inflight: AtomicU32,
pub probe_timestamps: RwLock<HashMap<u64, Instant>>,
pub probing_active: bool,
}
impl UplinkThroughputState {
pub fn new(config: &ThroughputConfig) -> Self {
Self {
bdp: BdpEstimator::new(Duration::from_secs(10)),
pmtud: PmtudState::new(DEFAULT_MTU),
batcher: if config.frame_batching {
Some(FrameBatcher::new(
config.max_batch_delay,
config.max_batch_size,
))
} else {
None
},
last_probe: RwLock::new(Instant::now()),
probe_inflight: AtomicU32::new(0),
probe_timestamps: RwLock::new(HashMap::new()),
probing_active: config.probing_enabled,
}
}
pub fn needs_probe(&self, interval: Duration) -> bool {
self.probing_active
&& self.last_probe.read().elapsed() >= interval
&& self.probe_inflight.load(Ordering::Relaxed) == 0
}
pub fn record_probe_sent(&self, probe_id: u64) {
self.probe_timestamps
.write()
.insert(probe_id, Instant::now());
self.probe_inflight.fetch_add(1, Ordering::Relaxed);
*self.last_probe.write() = Instant::now();
}
pub fn record_probe_response(&self, probe_id: u64, bytes: u64) {
self.probe_inflight.fetch_sub(1, Ordering::Relaxed);
if let Some(sent_time) = self.probe_timestamps.write().remove(&probe_id) {
let rtt = sent_time.elapsed();
self.bdp.record_rtt(rtt);
if rtt.as_secs_f64() > 0.0 {
let bw = bytes as f64 / rtt.as_secs_f64();
self.bdp.record_bandwidth(bw);
}
}
}
}
pub struct ThroughputOptimizer {
config: ThroughputConfig,
uplink_state: RwLock<HashMap<u16, UplinkThroughputState>>,
last_optimization: RwLock<Instant>,
}
impl ThroughputOptimizer {
pub fn new(config: ThroughputConfig) -> Self {
Self {
config,
uplink_state: RwLock::new(HashMap::new()),
last_optimization: RwLock::new(Instant::now()),
}
}
pub fn register_uplink(&self, uplink_id: u16) {
self.uplink_state
.write()
.insert(uplink_id, UplinkThroughputState::new(&self.config));
}
pub fn unregister_uplink(&self, uplink_id: u16) {
self.uplink_state.write().remove(&uplink_id);
}
pub fn config(&self) -> &ThroughputConfig {
&self.config
}
pub fn effective_throughput(&self, uplink: &Uplink) -> EffectiveThroughput {
let bandwidth = uplink.bandwidth();
let rtt = uplink.rtt();
let loss = uplink.loss_ratio();
let metrics = uplink.quality_metrics();
let jitter = metrics.jitter;
EffectiveThroughput::calculate(bandwidth.bytes_per_sec, rtt, loss, jitter, &self.config)
}
pub fn rank_uplinks(&self, uplinks: &[Arc<Uplink>]) -> Vec<(u16, EffectiveThroughput)> {
let mut ranked: Vec<_> = uplinks
.iter()
.filter(|u| u.is_usable())
.map(|u| {
let throughput = self.effective_throughput(u);
(u.numeric_id(), throughput)
})
.collect();
ranked.sort_by(|a, b| {
b.1.score
.partial_cmp(&a.1.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
ranked
}
pub fn best_for_size(&self, uplinks: &[Arc<Uplink>], size_bytes: u64) -> Option<u16> {
uplinks
.iter()
.filter(|u| u.is_usable())
.map(|u| {
let throughput = self.effective_throughput(u);
let time = throughput.transfer_time(size_bytes);
(u.numeric_id(), time)
})
.min_by(|a, b| a.1.cmp(&b.1))
.map(|(id, _)| id)
}
pub fn optimal_packet_size(&self, uplink_id: u16) -> u32 {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.pmtud.optimal_payload_size(60)
} else {
DEFAULT_MTU - 60
}
}
pub fn pacing_interval(&self, uplink_id: u16, packet_size: usize) -> Duration {
if !self.config.pacing_enabled {
return Duration::ZERO;
}
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.bdp.pacing_interval(packet_size)
} else {
Duration::ZERO
}
}
pub fn can_send(&self, uplink_id: u16) -> bool {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.bdp.can_send()
} else {
true }
}
pub fn record_send(&self, uplink_id: u16, bytes: u64) {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.bdp.record_send(bytes);
}
}
pub fn record_ack(&self, uplink_id: u16, bytes: u64, rtt: Duration) {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.bdp.record_ack(bytes);
uplink_state.bdp.record_rtt(rtt);
}
}
pub fn record_bandwidth(&self, uplink_id: u16, bytes_per_sec: f64) {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.bdp.record_bandwidth(bytes_per_sec);
}
}
pub fn record_mtu_result(&self, uplink_id: u16, size: u32, success: bool) {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
if success {
uplink_state.pmtud.record_success(size);
} else {
uplink_state.pmtud.record_failure(size);
}
}
}
pub fn uplinks_needing_probe(&self) -> Vec<u16> {
self.uplink_state
.read()
.iter()
.filter(|(_, state)| state.needs_probe(self.config.probe_interval))
.map(|(id, _)| *id)
.collect()
}
pub fn record_probe_sent(&self, uplink_id: u16, probe_id: u64) {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.record_probe_sent(probe_id);
}
}
pub fn record_probe_response(&self, uplink_id: u16, probe_id: u64, bytes: u64) {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
uplink_state.record_probe_response(probe_id, bytes);
}
}
pub fn bbr_state(&self, uplink_id: u16) -> Option<BbrState> {
self.uplink_state
.read()
.get(&uplink_id)
.map(|s| s.bdp.state())
}
pub fn bdp(&self, uplink_id: u16) -> Option<u64> {
self.uplink_state
.read()
.get(&uplink_id)
.map(|s| s.bdp.bdp())
}
pub fn batch_frame(&self, uplink_id: u16, frame: Vec<u8>) -> Option<Vec<Vec<u8>>> {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
if let Some(ref batcher) = uplink_state.batcher {
return batcher.add(frame);
}
}
Some(vec![frame])
}
pub fn flush_batch(&self, uplink_id: u16) -> Vec<Vec<u8>> {
let state = self.uplink_state.read();
if let Some(uplink_state) = state.get(&uplink_id) {
if let Some(ref batcher) = uplink_state.batcher {
return batcher.flush();
}
}
vec![]
}
pub fn ready_batches(&self) -> Vec<(u16, Vec<Vec<u8>>)> {
let state = self.uplink_state.read();
let mut result = Vec::new();
for (&uplink_id, uplink_state) in state.iter() {
if let Some(ref batcher) = uplink_state.batcher {
if batcher.is_ready() {
result.push((uplink_id, batcher.flush()));
}
}
}
result
}
pub fn summary(&self, uplinks: &[Arc<Uplink>]) -> ThroughputSummary {
let ranked = self.rank_uplinks(uplinks);
let total_bandwidth: f64 = ranked.iter().map(|(_, t)| t.bandwidth_bps).sum();
let best_score = ranked.first().map(|(_, t)| t.score).unwrap_or(0.0);
let worst_score = ranked.last().map(|(_, t)| t.score).unwrap_or(0.0);
let avg_rtt = if ranked.is_empty() {
Duration::ZERO
} else {
let total: Duration = ranked.iter().map(|(_, t)| t.rtt).sum();
total / ranked.len() as u32
};
ThroughputSummary {
uplink_count: ranked.len(),
total_bandwidth: Bandwidth::from_bps(total_bandwidth),
best_score,
worst_score,
avg_rtt,
ranked_uplinks: ranked,
}
}
}
impl std::fmt::Debug for ThroughputOptimizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThroughputOptimizer")
.field("config", &self.config)
.field("uplink_count", &self.uplink_state.read().len())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ThroughputSummary {
pub uplink_count: usize,
pub total_bandwidth: Bandwidth,
pub best_score: f64,
pub worst_score: f64,
pub avg_rtt: Duration,
pub ranked_uplinks: Vec<(u16, EffectiveThroughput)>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_effective_throughput_calculation() {
let config = ThroughputConfig::default();
let good = EffectiveThroughput::calculate(
125_000_000.0, Duration::from_millis(10),
0.0,
Duration::from_millis(1),
&config,
);
let bad = EffectiveThroughput::calculate(
12_500_000.0, Duration::from_millis(500),
0.05,
Duration::from_millis(50),
&config,
);
assert!(good.score > bad.score);
assert!(good.is_better_than(&bad));
}
#[test]
fn test_transfer_time_comparison() {
let config = ThroughputConfig::default();
let high_bw_high_lat = EffectiveThroughput::calculate(
125_000_000.0, Duration::from_secs(3), 0.0,
Duration::ZERO,
&config,
);
let low_bw_low_lat = EffectiveThroughput::calculate(
12_500_000.0, Duration::from_millis(10),
0.0,
Duration::ZERO,
&config,
);
let small_size = 10 * 1024; assert!(low_bw_low_lat.faster_for_size(&high_bw_high_lat, small_size));
let huge_size = 1024 * 1024 * 1024; assert!(high_bw_high_lat.faster_for_size(&low_bw_low_lat, huge_size));
}
#[test]
fn test_bdp_calculation() {
let estimator = BdpEstimator::new(Duration::from_secs(10));
estimator.record_rtt(Duration::from_millis(100));
estimator.record_bandwidth(125_000_000.0);
let bdp = estimator.bdp();
assert!(bdp > 10_000_000);
assert!(bdp < 20_000_000);
}
#[test]
fn test_pmtud_binary_search() {
let pmtud = PmtudState::new(DEFAULT_MTU);
pmtud.record_success(1400);
assert!(pmtud.mtu() <= DEFAULT_MTU);
pmtud.record_failure(1500);
assert!(pmtud.next_probe_mtu() < 1500);
}
#[test]
fn test_frame_batcher() {
let batcher = FrameBatcher::new(Duration::from_millis(10), 1000);
assert!(batcher.add(vec![0; 100]).is_none());
assert!(batcher.add(vec![0; 100]).is_none());
let frames = batcher.flush();
assert_eq!(frames.len(), 2);
for _ in 0..10 {
let result = batcher.add(vec![0; 200]);
if result.is_some() {
break;
}
}
}
}