use std::collections::VecDeque;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BbrState {
Startup,
ProbeBW,
Drain,
ProbeRTT,
FastRecovery,
}
#[derive(Debug, Clone, Copy)]
pub struct DeliverySample {
pub delivered_bytes: u64,
pub sent_at: Instant,
pub acked_at: Instant,
pub packet_bytes: u64,
pub is_app_limited: bool,
pub ack_delay_us: u64,
}
#[derive(Debug)]
struct WindowFilter {
window: VecDeque<(Instant, u64)>,
window_size: Duration,
}
impl WindowFilter {
fn new(window_size: Duration) -> Self {
Self {
window: VecDeque::new(),
window_size,
}
}
fn update_max(&mut self, now: Instant, value: u64) -> u64 {
while let Some(&(ts, _)) = self.window.front() {
if now.duration_since(ts) > self.window_size {
self.window.pop_front();
} else {
break;
}
}
while let Some(&(_, v)) = self.window.back() {
if v <= value {
self.window.pop_back();
} else {
break;
}
}
self.window.push_back((now, value));
self.window.front().map(|&(_, v)| v).unwrap_or(value)
}
fn update_min(&mut self, now: Instant, value: u64) -> u64 {
while let Some(&(ts, _)) = self.window.front() {
if now.duration_since(ts) > self.window_size {
self.window.pop_front();
} else {
break;
}
}
while let Some(&(_, v)) = self.window.back() {
if v >= value {
self.window.pop_back();
} else {
break;
}
}
self.window.push_back((now, value));
self.window.front().map(|&(_, v)| v).unwrap_or(value)
}
}
const PROBE_BW_GAINS: [f64; 4] = [1.25, 0.75, 1.0, 1.0];
const STARTUP_GROWTH_THRESHOLD: f64 = 0.25;
const STARTUP_ROUNDS_LIMIT: u32 = 3;
const PROBE_RTT_INTERVAL: Duration = Duration::from_secs(10);
const PROBE_RTT_DURATION: Duration = Duration::from_millis(200);
const PROBE_RTT_CWND_PACKETS: u64 = 4;
const MIN_PACKET_SIZE: u64 = 1400;
const FAST_RECOVERY_PACING_GAIN: f64 = 0.5;
const FAST_RECOVERY_EXIT_FRACTION: f64 = 1.0;
pub struct BandwidthEstimator {
state: BbrState,
btl_bw: u64,
min_rtt: Duration,
bw_filter: WindowFilter,
rtt_filter: WindowFilter,
delivered_bytes: u64,
last_delivery: Instant,
pacing_gain: f64,
cwnd_gain: f64,
round_count: u32,
filled_pipe: bool,
prev_bw: u64,
rounds_without_growth: u32,
inflight_bytes: u64,
last_probe_rtt_time: Instant,
probe_rtt_entered: Option<Instant>,
prior_state: BbrState,
app_limited: bool,
app_limited_at_delivered: u64,
fast_recovery_entered: Option<Instant>,
recovery_lost_bytes: u64,
}
impl BandwidthEstimator {
pub fn new() -> Self {
let now = Instant::now();
Self {
state: BbrState::Startup,
btl_bw: 0,
min_rtt: Duration::from_millis(100), bw_filter: WindowFilter::new(Duration::from_secs(10)),
rtt_filter: WindowFilter::new(Duration::from_secs(10)),
delivered_bytes: 0,
last_delivery: now,
pacing_gain: 2.0, cwnd_gain: 2.0,
round_count: 0,
filled_pipe: false,
prev_bw: 0,
rounds_without_growth: 0,
inflight_bytes: 0,
last_probe_rtt_time: now,
probe_rtt_entered: None,
prior_state: BbrState::ProbeBW,
app_limited: false,
app_limited_at_delivered: 0,
fast_recovery_entered: None,
recovery_lost_bytes: 0,
}
}
pub fn on_send(&mut self, bytes: u64) {
self.inflight_bytes = self.inflight_bytes.saturating_add(bytes);
}
pub fn on_ack(&mut self, sample: DeliverySample) -> u64 {
let now = sample.acked_at;
self.inflight_bytes = self.inflight_bytes.saturating_sub(sample.packet_bytes);
self.delivered_bytes += sample.packet_bytes;
self.last_delivery = now;
let send_elapsed = sample.acked_at.duration_since(sample.sent_at);
let ack_delay = Duration::from_micros(sample.ack_delay_us);
let rtt_propagation = send_elapsed.saturating_sub(ack_delay);
let rtt_us = rtt_propagation.as_micros() as u64;
if rtt_us > 0 {
let min_rtt_us = self.rtt_filter.update_min(now, rtt_us);
self.min_rtt = Duration::from_micros(min_rtt_us);
}
let delivery_rate = if !send_elapsed.is_zero() {
(sample.packet_bytes as f64 / send_elapsed.as_secs_f64()) as u64
} else {
0
};
if delivery_rate > 0 && !sample.is_app_limited {
self.btl_bw = self.bw_filter.update_max(now, delivery_rate);
}
if self.app_limited && self.delivered_bytes > self.app_limited_at_delivered {
self.app_limited = false;
}
self.update_state(now);
self.pacing_rate()
}
pub fn on_loss(&mut self, bytes: u64) {
self.inflight_bytes = self.inflight_bytes.saturating_sub(bytes);
self.recovery_lost_bytes = self.recovery_lost_bytes.saturating_add(bytes);
if self.state != BbrState::FastRecovery && self.state != BbrState::ProbeRTT {
self.prior_state = self.state;
self.fast_recovery_entered = Some(Instant::now());
self.transition_to(BbrState::FastRecovery);
}
}
pub fn set_app_limited(&mut self) {
self.app_limited = true;
self.app_limited_at_delivered = self.delivered_bytes;
}
pub fn is_app_limited(&self) -> bool {
self.app_limited
}
pub fn pacing_rate(&self) -> u64 {
let base = self.btl_bw.max(1);
(base as f64 * self.pacing_gain) as u64
}
pub fn cwnd(&self) -> u64 {
if self.state == BbrState::ProbeRTT {
return PROBE_RTT_CWND_PACKETS * MIN_PACKET_SIZE;
}
let bdp = self.bdp();
(bdp as f64 * self.cwnd_gain).max((PROBE_RTT_CWND_PACKETS * MIN_PACKET_SIZE) as f64) as u64
}
pub fn bdp(&self) -> u64 {
(self.btl_bw as f64 * self.min_rtt.as_secs_f64()) as u64
}
pub fn inflight_bytes(&self) -> u64 {
self.inflight_bytes
}
pub fn bottleneck_bandwidth(&self) -> u64 {
self.btl_bw
}
pub fn min_rtt(&self) -> Duration {
self.min_rtt
}
pub fn state(&self) -> BbrState {
self.state
}
pub fn delivered_bytes(&self) -> u64 {
self.delivered_bytes
}
pub fn round_count(&self) -> u32 {
self.round_count
}
fn update_state(&mut self, now: Instant) {
if self.state != BbrState::ProbeRTT
&& self.state != BbrState::Startup
&& self.state != BbrState::FastRecovery
&& now.duration_since(self.last_probe_rtt_time) >= PROBE_RTT_INTERVAL
{
self.prior_state = self.state;
self.transition_to(BbrState::ProbeRTT);
self.probe_rtt_entered = Some(now);
return;
}
match self.state {
BbrState::Startup => {
self.round_count += 1;
if self.prev_bw > 0 {
let growth = (self.btl_bw as f64 - self.prev_bw as f64) / self.prev_bw as f64;
if growth < STARTUP_GROWTH_THRESHOLD {
self.rounds_without_growth += 1;
} else {
self.rounds_without_growth = 0;
}
if self.rounds_without_growth >= STARTUP_ROUNDS_LIMIT {
self.filled_pipe = true;
self.transition_to(BbrState::Drain);
}
}
self.prev_bw = self.btl_bw;
}
BbrState::Drain => {
let bdp = self.bdp();
if self.inflight_bytes <= bdp || bdp == 0 {
self.transition_to(BbrState::ProbeBW);
}
}
BbrState::ProbeBW => {
let cycle_idx = (self.round_count as usize) % PROBE_BW_GAINS.len();
self.pacing_gain = PROBE_BW_GAINS[cycle_idx];
self.cwnd_gain = 2.0;
self.round_count += 1;
}
BbrState::ProbeRTT => {
if let Some(entered) = self.probe_rtt_entered {
if now.duration_since(entered) >= PROBE_RTT_DURATION {
self.last_probe_rtt_time = now;
self.probe_rtt_entered = None;
self.transition_to(self.prior_state);
}
} else {
self.transition_to(BbrState::ProbeBW);
}
}
BbrState::FastRecovery => {
let bdp = self.bdp();
let should_exit = self.inflight_bytes
<= (bdp as f64 * FAST_RECOVERY_EXIT_FRACTION) as u64
|| bdp == 0;
if should_exit {
self.recovery_lost_bytes = 0;
self.fast_recovery_entered = None;
self.transition_to(self.prior_state);
}
}
}
}
fn transition_to(&mut self, new_state: BbrState) {
match new_state {
BbrState::Startup => {
self.pacing_gain = 2.0;
self.cwnd_gain = 2.0;
}
BbrState::Drain => {
self.pacing_gain = 0.75;
self.cwnd_gain = 2.0;
}
BbrState::ProbeBW => {
self.pacing_gain = 1.0;
self.cwnd_gain = 2.0;
}
BbrState::ProbeRTT => {
self.pacing_gain = 1.0;
self.cwnd_gain = 1.0;
}
BbrState::FastRecovery => {
self.pacing_gain = FAST_RECOVERY_PACING_GAIN;
self.cwnd_gain = 1.0;
}
}
self.state = new_state;
}
}
impl Default for BandwidthEstimator {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for BandwidthEstimator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BandwidthEstimator")
.field("state", &self.state)
.field("btl_bw_kbps", &(self.btl_bw / 1024))
.field("min_rtt_ms", &self.min_rtt.as_millis())
.field("pacing_gain", &self.pacing_gain)
.field("inflight_bytes", &self.inflight_bytes)
.field("delivered_bytes", &self.delivered_bytes)
.field("app_limited", &self.app_limited)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_sample(sent_at: Instant, rtt_ms: u64, packet_bytes: u64) -> DeliverySample {
DeliverySample {
delivered_bytes: 0,
sent_at,
acked_at: sent_at + Duration::from_millis(rtt_ms),
packet_bytes,
is_app_limited: false,
ack_delay_us: 0, }
}
fn make_app_limited_sample(sent_at: Instant, rtt_ms: u64, packet_bytes: u64) -> DeliverySample {
DeliverySample {
delivered_bytes: 0,
sent_at,
acked_at: sent_at + Duration::from_millis(rtt_ms),
packet_bytes,
is_app_limited: true,
ack_delay_us: 0,
}
}
#[test]
fn test_estimator_starts_in_startup() {
let est = BandwidthEstimator::new();
assert_eq!(est.state(), BbrState::Startup);
assert_eq!(est.delivered_bytes(), 0);
assert_eq!(est.inflight_bytes(), 0);
assert!(!est.is_app_limited());
}
#[test]
fn test_bandwidth_increases_with_acks() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
for i in 0..10 {
let sent = now + Duration::from_millis(i * 10);
est.on_send(1400);
let sample = make_sample(sent, 10, 1400);
est.on_ack(sample);
}
assert!(
est.bottleneck_bandwidth() > 0,
"btl_bw = {} should be > 0",
est.bottleneck_bandwidth()
);
assert_eq!(est.delivered_bytes(), 14_000);
}
#[test]
fn test_min_rtt_tracking() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
let s1 = make_sample(now, 100, 1400);
est.on_ack(s1);
assert!(est.min_rtt() <= Duration::from_millis(101));
let s2 = make_sample(now + Duration::from_millis(200), 5, 1400);
est.on_ack(s2);
assert!(
est.min_rtt() <= Duration::from_millis(6),
"min_rtt = {:?}",
est.min_rtt()
);
}
#[test]
fn test_pacing_rate_positive() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
let sample = make_sample(now, 20, 1400);
est.on_ack(sample);
assert!(est.pacing_rate() > 0);
}
#[test]
fn test_cwnd_at_least_minimum() {
let est = BandwidthEstimator::new();
let cwnd = est.cwnd();
assert!(
cwnd >= 4 * 1400,
"cwnd = {} should be >= {}",
cwnd,
4 * 1400
);
}
#[test]
fn test_startup_to_drain_transition() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
for i in 0..20 {
let sent = now + Duration::from_millis(i * 10);
est.on_send(1400);
let sample = make_sample(sent, 10, 1400);
est.on_ack(sample);
}
assert!(
est.state() != BbrState::Startup || est.round_count < 20,
"expected startup exit, state = {:?}, rounds = {}",
est.state(),
est.round_count
);
}
#[test]
fn test_inflight_tracking() {
let mut est = BandwidthEstimator::new();
est.on_send(1400);
est.on_send(1400);
est.on_send(1400);
assert_eq!(est.inflight_bytes(), 4200);
let now = Instant::now();
est.on_ack(make_sample(now, 10, 1400));
assert_eq!(est.inflight_bytes(), 2800);
est.on_loss(1400);
assert_eq!(est.inflight_bytes(), 1400);
est.on_ack(make_sample(now + Duration::from_millis(10), 10, 1400));
assert_eq!(est.inflight_bytes(), 0);
}
#[test]
fn test_inflight_cant_go_negative() {
let mut est = BandwidthEstimator::new();
est.on_loss(5000);
assert_eq!(est.inflight_bytes(), 0); }
#[test]
fn test_app_limited_filtering() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
for i in 0..5 {
let sent = now + Duration::from_millis(i * 10);
est.on_send(1400);
est.on_ack(make_sample(sent, 10, 1400));
}
let real_bw = est.bottleneck_bandwidth();
assert!(real_bw > 0);
est.set_app_limited();
assert!(est.is_app_limited());
for i in 5..10 {
let sent = now + Duration::from_millis(i * 1000);
est.on_ack(make_app_limited_sample(sent, 1000, 100)); }
assert!(
est.bottleneck_bandwidth() >= real_bw,
"BW should not decrease from app-limited samples: {} < {}",
est.bottleneck_bandwidth(),
real_bw
);
}
#[test]
fn test_drain_waits_for_bdp() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
for i in 0..20 {
let sent = now + Duration::from_millis(i * 10);
est.on_send(1400);
est.on_ack(make_sample(sent, 10, 1400));
}
if est.state() == BbrState::Drain {
est.inflight_bytes = est.bdp() * 3;
let sent = now + Duration::from_millis(300);
est.on_ack(make_sample(sent, 10, 1400));
if est.inflight_bytes > est.bdp() {
assert_eq!(
est.state(),
BbrState::Drain,
"should stay in Drain while inflight ({}) > BDP ({})",
est.inflight_bytes,
est.bdp()
);
}
}
}
#[test]
fn test_bdp_calculation() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
for i in 0..5 {
let sent = now + Duration::from_millis(i * 10);
est.on_send(1400);
est.on_ack(make_sample(sent, 10, 1400));
}
let bdp = est.bdp();
assert!(bdp > 0, "BDP should be positive, got {}", bdp);
}
#[test]
fn test_cwnd_minimum_in_probe_rtt() {
let mut est = BandwidthEstimator::new();
est.state = BbrState::ProbeRTT;
let cwnd = est.cwnd();
assert_eq!(
cwnd,
PROBE_RTT_CWND_PACKETS * MIN_PACKET_SIZE,
"ProbeRTT CWND should be {} (4 packets), got {}",
PROBE_RTT_CWND_PACKETS * MIN_PACKET_SIZE,
cwnd
);
}
}