#[cfg(test)]
pub(crate) mod test_utils;
mod fixed;
pub mod params;
mod rtt;
pub(crate) mod sendme;
mod vegas;
use crate::{Error, Result};
use self::{
params::{Algorithm, CongestionControlParams, CongestionWindowParams},
rtt::RoundtripTimeEstimator,
sendme::SendmeValidator,
};
use tor_cell::relaycell::msg::SendmeTag;
use tor_rtcompat::{DynTimeProvider, SleepProvider};
pub(crate) trait CongestionControlAlgorithm: Send + std::fmt::Debug {
fn uses_stream_sendme(&self) -> bool;
fn uses_xon_xoff(&self) -> bool;
fn is_next_cell_sendme(&self) -> bool;
fn can_send(&self) -> bool;
fn cwnd(&self) -> Option<CongestionWindow>;
fn data_received(&mut self) -> Result<bool>;
fn data_sent(&mut self) -> Result<()>;
fn sendme_received(
&mut self,
state: &mut State,
rtt: &mut RoundtripTimeEstimator,
signals: CongestionSignals,
) -> Result<()>;
fn sendme_sent(&mut self) -> Result<()>;
#[cfg(feature = "conflux")]
fn inflight(&self) -> Option<u32>;
#[cfg(test)]
fn send_window(&self) -> u32;
fn algorithm(&self) -> Algorithm;
}
#[derive(Copy, Clone)]
pub(crate) struct CongestionSignals {
pub(crate) channel_blocked: bool,
pub(crate) channel_outbound_size: u32,
}
impl CongestionSignals {
pub(crate) fn new(channel_blocked: bool, channel_outbound_size: usize) -> Self {
Self {
channel_blocked,
channel_outbound_size: channel_outbound_size.saturating_add(0) as u32,
}
}
}
#[derive(Copy, Clone, Default)]
pub(crate) enum State {
#[default]
SlowStart,
Steady,
}
impl State {
pub(crate) fn in_slow_start(&self) -> bool {
matches!(self, State::SlowStart)
}
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct CongestionWindow {
params: CongestionWindowParams,
value: u32,
is_full: bool,
}
impl CongestionWindow {
fn new(params: CongestionWindowParams) -> Self {
Self {
value: params.cwnd_init(),
params,
is_full: false,
}
}
pub(crate) fn dec(&mut self) {
self.value = self
.value
.saturating_sub(self.increment())
.max(self.params.cwnd_min());
}
pub(crate) fn inc(&mut self) {
self.value = self
.value
.saturating_add(self.increment())
.min(self.params.cwnd_max());
}
pub(crate) fn get(&self) -> u32 {
self.value
}
pub(crate) fn update_rate(&self, state: &State) -> u32 {
if state.in_slow_start() {
1
} else {
(self.get() + self.increment_rate() * self.sendme_inc() / 2)
/ (self.increment_rate() * self.sendme_inc())
}
}
pub(crate) fn min(&self) -> u32 {
self.params.cwnd_min()
}
pub(crate) fn set(&mut self, value: u32) {
self.value = value;
}
pub(crate) fn increment(&self) -> u32 {
self.params.cwnd_inc()
}
pub(crate) fn increment_rate(&self) -> u32 {
self.params.cwnd_inc_rate()
}
pub(crate) fn is_full(&self) -> bool {
self.is_full
}
pub(crate) fn reset_full(&mut self) {
self.is_full = false;
}
pub(crate) fn sendme_per_cwnd(&self) -> u32 {
(self.get() + (self.sendme_inc() / 2)) / self.sendme_inc()
}
pub(crate) fn rfc3742_ss_inc(&mut self, ss_cap: u32) -> u32 {
let inc = if self.get() <= ss_cap {
((self.params.cwnd_inc_pct_ss().as_percent() * self.sendme_inc()) + 50) / 100
} else {
(((self.sendme_inc() * ss_cap) + self.get()) / (self.get() * 2)).max(1)
};
self.value += inc;
inc
}
pub(crate) fn eval_fullness(&mut self, inflight: u32, full_gap: u32, full_minpct: u32) {
if (inflight + (self.sendme_inc() * full_gap)) >= self.get() {
self.is_full = true;
} else if (100 * inflight) < (full_minpct * self.get()) {
self.is_full = false;
}
}
pub(crate) fn sendme_inc(&self) -> u32 {
self.params.sendme_inc()
}
#[cfg(any(test, feature = "conflux"))]
pub(crate) fn params(&self) -> &CongestionWindowParams {
&self.params
}
}
pub(crate) struct CongestionControl {
state: State,
sendme_validator: SendmeValidator<SendmeTag>,
rtt: RoundtripTimeEstimator,
algorithm: Box<dyn CongestionControlAlgorithm>,
}
impl CongestionControl {
pub(crate) fn new(params: &CongestionControlParams) -> Self {
let state = State::default();
let algorithm: Box<dyn CongestionControlAlgorithm> = match params.alg() {
Algorithm::FixedWindow(p) => Box::new(fixed::FixedWindow::new(*p)),
Algorithm::Vegas(p) => {
let cwnd = CongestionWindow::new(params.cwnd_params());
Box::new(vegas::Vegas::new(*p, &state, cwnd))
}
};
Self {
algorithm,
rtt: RoundtripTimeEstimator::new(params.rtt_params()),
sendme_validator: SendmeValidator::new(),
state,
}
}
pub(crate) fn uses_stream_sendme(&self) -> bool {
self.algorithm.uses_stream_sendme()
}
pub(crate) fn uses_xon_xoff(&self) -> bool {
self.algorithm.uses_xon_xoff()
}
pub(crate) fn can_send(&self) -> bool {
self.algorithm.can_send()
}
pub(crate) fn note_sendme_received(
&mut self,
runtime: &DynTimeProvider,
tag: SendmeTag,
signals: CongestionSignals,
) -> Result<()> {
self.sendme_validator.validate(Some(tag))?;
let now = runtime.now();
if let Some(cwnd) = self.algorithm.cwnd() {
self.rtt
.update(now, &self.state, &cwnd)
.map_err(|e| Error::CircProto(e.to_string()))?;
}
self.algorithm
.sendme_received(&mut self.state, &mut self.rtt, signals)
}
pub(crate) fn note_sendme_sent(&mut self) -> Result<()> {
self.algorithm.sendme_sent()
}
pub(crate) fn note_data_received(&mut self) -> Result<bool> {
self.algorithm.data_received()
}
pub(crate) fn note_data_sent<U>(&mut self, runtime: &DynTimeProvider, tag: &U) -> Result<()>
where
U: Clone + Into<SendmeTag>,
{
self.algorithm.data_sent()?;
if self.algorithm.is_next_cell_sendme() {
self.sendme_validator.record(tag);
if self.algorithm.cwnd().is_some() {
self.rtt.expect_sendme(runtime.now());
}
}
Ok(())
}
#[cfg(feature = "conflux")]
pub(crate) fn inflight(&self) -> Option<u32> {
self.algorithm.inflight()
}
#[cfg(feature = "conflux")]
pub(crate) fn cwnd(&self) -> Option<CongestionWindow> {
self.algorithm.cwnd()
}
pub(crate) fn rtt(&self) -> &RoundtripTimeEstimator {
&self.rtt
}
#[cfg(feature = "conflux")]
pub(crate) fn algorithm(&self) -> Algorithm {
self.algorithm.algorithm()
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use crate::congestion::test_utils::new_cwnd;
use super::CongestionControl;
use tor_cell::relaycell::msg::SendmeTag;
impl CongestionControl {
pub(crate) fn send_window_and_expected_tags(&self) -> (u32, Vec<SendmeTag>) {
(
self.algorithm.send_window(),
self.sendme_validator.expected_tags(),
)
}
}
#[test]
fn test_cwnd() {
let mut cwnd = new_cwnd();
assert_eq!(cwnd.get(), cwnd.params().cwnd_init());
assert_eq!(cwnd.min(), cwnd.params().cwnd_min());
assert_eq!(cwnd.increment(), cwnd.params().cwnd_inc());
assert_eq!(cwnd.increment_rate(), cwnd.params().cwnd_inc_rate());
assert_eq!(cwnd.sendme_inc(), cwnd.params().sendme_inc());
assert!(!cwnd.is_full());
cwnd.inc();
assert_eq!(
cwnd.get(),
cwnd.params().cwnd_init() + cwnd.params().cwnd_inc()
);
cwnd.dec();
assert_eq!(cwnd.get(), cwnd.params().cwnd_init());
}
}