use std::collections::HashMap;
use std::time::Instant;
use crate::propagate::ClientId;
use super::subscriber::{ClientHint, PerSubscriber};
#[derive(Debug, Default)]
pub struct BandwidthEstimator {
subscribers: HashMap<ClientId, PerSubscriber>,
}
impl BandwidthEstimator {
pub fn new() -> Self {
Self::default()
}
pub(crate) fn get_or_insert(&mut self, id: ClientId) -> &mut PerSubscriber {
self.subscribers.entry(id).or_insert_with(PerSubscriber::new)
}
pub fn record_native_estimate(&mut self, subscriber: ClientId, bps: f64) {
self.get_or_insert(subscriber).native_estimate_bps = Some(bps);
}
pub fn record_client_hint(&mut self, subscriber: ClientId, bps: u64, now: Instant) {
self.get_or_insert(subscriber).client_hint = Some(ClientHint { bps, received_at: now });
}
#[must_use]
pub fn estimate_bps(&self, subscriber: ClientId, now: Instant) -> Option<u64> {
self.subscribers
.get(&subscriber)
.map(|s| s.combined_bps(now) as u64)
}
pub fn reap_dead(&mut self, subscriber: ClientId) {
self.subscribers.remove(&subscriber);
}
#[cfg(any(test, feature = "test-utils"))]
#[doc(hidden)]
pub fn force_high_estimate_for_tests(&mut self, subscriber: ClientId, bps: f64) {
let sub = self.get_or_insert(subscriber);
sub.delay = super::kalman::DelayEstimator::new(bps);
sub.loss = super::loss::LossEstimator::new(bps);
sub.native_estimate_bps = None; }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::propagate::ClientId;
use std::time::Instant;
fn id(n: u64) -> ClientId {
ClientId(n)
}
#[test]
fn estimate_returns_none_for_unknown_subscriber() {
let est = BandwidthEstimator::new();
assert!(est.estimate_bps(id(99), Instant::now()).is_none());
}
#[test]
fn native_estimate_acts_as_ceiling_via_estimator() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
est.record_native_estimate(id(1), 600_000.0);
let bps = est.estimate_bps(id(1), now).unwrap();
assert!(bps > 0, "expected non-zero estimate");
}
#[test]
fn client_hint_caps_estimate() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
{
let sub = est.get_or_insert(id(2));
sub.delay = super::super::kalman::DelayEstimator::new(5_000_000.0);
sub.loss = super::super::loss::LossEstimator::new(5_000_000.0);
}
est.record_client_hint(id(2), 400_000, now);
let bps = est.estimate_bps(id(2), now).unwrap();
assert!(bps <= 400_100, "hint ceiling not applied: {bps}");
}
#[test]
fn reap_dead_removes_subscriber() {
let mut est = BandwidthEstimator::new();
let now = Instant::now();
est.record_native_estimate(id(3), 1_000_000.0);
assert!(est.estimate_bps(id(3), now).is_some());
est.reap_dead(id(3));
assert!(est.estimate_bps(id(3), now).is_none());
}
}
use super::feedback::{TwccFeedback, ingest_twcc};
impl BandwidthEstimator {
pub fn on_twcc_feedback(&mut self, subscriber: ClientId, feedback: &TwccFeedback, now: Instant) {
let sub = self.get_or_insert(subscriber);
ingest_twcc(sub, feedback, now);
}
pub fn record_send_time(&mut self, subscriber: ClientId, seq: u64, sent_at: Instant) {
let sub = self.get_or_insert(subscriber);
const MAX_SEND_TIMES: usize = 512;
if sub.send_times.len() >= MAX_SEND_TIMES {
if let Some(&oldest_seq) = sub.send_times.keys().min() {
sub.send_times.remove(&oldest_seq);
}
}
sub.send_times.insert(seq, sent_at);
}
}