use std::sync::atomic::{AtomicU8, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub(crate) enum CongestionState {
SlowStart = 0,
CongestionAvoidance = 1,
WaitingForSlowdown = 2,
InSlowdown = 3,
Frozen = 4,
RampingUp = 5,
}
impl CongestionState {
pub(crate) fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(Self::SlowStart),
1 => Some(Self::CongestionAvoidance),
2 => Some(Self::WaitingForSlowdown),
3 => Some(Self::InSlowdown),
4 => Some(Self::Frozen),
5 => Some(Self::RampingUp),
_ => None,
}
}
}
pub(crate) struct AtomicCongestionState(AtomicU8);
impl AtomicCongestionState {
pub(crate) fn new(state: CongestionState) -> Self {
Self(AtomicU8::new(state as u8))
}
pub(crate) fn load(&self) -> CongestionState {
let value = self.0.load(Ordering::Acquire);
match CongestionState::from_u8(value) {
Some(state) => state,
None => {
tracing::error!(
value,
"CRITICAL: Invalid congestion state value - possible memory corruption"
);
debug_assert!(false, "Invalid congestion state value: {}", value);
CongestionState::CongestionAvoidance
}
}
}
pub(crate) fn store(&self, state: CongestionState) {
self.0.store(state as u8, Ordering::Release);
}
pub(crate) fn is_slow_start(&self) -> bool {
self.load() == CongestionState::SlowStart
}
pub(crate) fn enter_slow_start(&self) {
self.store(CongestionState::SlowStart);
}
pub(crate) fn enter_congestion_avoidance(&self) {
self.store(CongestionState::CongestionAvoidance);
}
pub(crate) fn enter_waiting_for_slowdown(&self) {
self.store(CongestionState::WaitingForSlowdown);
}
pub(crate) fn enter_in_slowdown(&self) {
self.store(CongestionState::InSlowdown);
}
pub(crate) fn enter_frozen(&self) {
self.store(CongestionState::Frozen);
}
pub(crate) fn enter_ramping_up(&self) {
self.store(CongestionState::RampingUp);
}
#[allow(dead_code)] pub(crate) fn compare_exchange(
&self,
expected: CongestionState,
new: CongestionState,
) -> Result<CongestionState, CongestionState> {
self.0
.compare_exchange(
expected as u8,
new as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.map(|v| CongestionState::from_u8(v).unwrap_or(CongestionState::CongestionAvoidance))
.map_err(|v| {
CongestionState::from_u8(v).unwrap_or(CongestionState::CongestionAvoidance)
})
}
}
impl std::fmt::Debug for AtomicCongestionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "AtomicCongestionState({:?})", self.load())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_congestion_state_values() {
assert_eq!(CongestionState::SlowStart as u8, 0);
assert_eq!(CongestionState::CongestionAvoidance as u8, 1);
assert_eq!(CongestionState::WaitingForSlowdown as u8, 2);
assert_eq!(CongestionState::InSlowdown as u8, 3);
assert_eq!(CongestionState::Frozen as u8, 4);
assert_eq!(CongestionState::RampingUp as u8, 5);
}
}