use std::num::Saturating;
use std::sync::Arc;
use postage::watch;
use tor_cell::relaycell::flow_ctrl::{FlowCtrlVersion, Xoff, Xon, XonKbpsEwma};
use tor_cell::relaycell::msg::AnyRelayMsg;
use tor_cell::relaycell::{RelayMsg, UnparsedRelayMsg};
use tracing::trace;
use super::reader::DrainRateRequest;
use crate::stream::flow_ctrl::params::{CellCount, FlowCtrlParameters};
use crate::stream::flow_ctrl::state::{FlowCtrlHooks, StreamRateLimit};
use crate::util::notify::NotifySender;
use crate::{Error, Result};
#[cfg(doc)]
use {crate::client::stream::DataStream, crate::stream::flow_ctrl::state::StreamFlowCtrl};
#[derive(Debug)]
pub(crate) struct XonXoffFlowCtrl {
params: Arc<FlowCtrlParameters>,
rate_limit_updater: watch::Sender<StreamRateLimit>,
drain_rate_requester: NotifySender<DrainRateRequest>,
last_sent_xon_xoff: Option<XonXoffMsg>,
xoff_limit: CellCount<{ tor_cell::relaycell::PAYLOAD_MAX_SIZE_ALL as u32 }>,
sidechannel_mitigation: Option<SidechannelMitigation>,
}
impl XonXoffFlowCtrl {
pub(crate) fn new(
params: Arc<FlowCtrlParameters>,
use_sidechannel_mitigations: bool,
rate_limit_updater: watch::Sender<StreamRateLimit>,
drain_rate_requester: NotifySender<DrainRateRequest>,
) -> Self {
let sidechannel_mitigation =
use_sidechannel_mitigations.then_some(SidechannelMitigation::new());
let xoff_limit = std::cmp::max(params.cc_xoff_client, params.cc_xoff_exit);
Self {
params,
rate_limit_updater,
drain_rate_requester,
last_sent_xon_xoff: None,
xoff_limit,
sidechannel_mitigation,
}
}
}
impl FlowCtrlHooks for XonXoffFlowCtrl {
fn can_send<M: RelayMsg>(&self, _msg: &M) -> bool {
true
}
fn about_to_send(&mut self, msg: &AnyRelayMsg) -> Result<()> {
if let Some(ref mut sidechannel_mitigation) = self.sidechannel_mitigation {
if let AnyRelayMsg::Data(data_msg) = msg {
sidechannel_mitigation.sent_stream_data(data_msg.as_ref().len());
}
}
Ok(())
}
fn put_for_incoming_sendme(&mut self, _msg: UnparsedRelayMsg) -> Result<()> {
let msg = "Stream level SENDME not allowed due to congestion control";
Err(Error::CircProto(msg.into()))
}
fn handle_incoming_xon(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
let xon = msg
.decode::<Xon>()
.map_err(|e| Error::from_bytes_err(e, "failed to decode XON message"))?
.into_msg();
if *xon.version() != 0 {
return Err(Error::CircProto("Unrecognized XON version".into()));
}
if let Some(ref mut sidechannel_mitigation) = self.sidechannel_mitigation {
sidechannel_mitigation.received_xon(&self.params)?;
}
trace!("Received an XON with rate {}", xon.kbps_ewma());
let rate = match xon.kbps_ewma() {
XonKbpsEwma::Limited(rate_kbps) => {
let rate_kbps = u64::from(rate_kbps.get());
StreamRateLimit::new_bytes_per_sec(rate_kbps * 1000 / 8)
}
XonKbpsEwma::Unlimited => StreamRateLimit::MAX,
};
*self.rate_limit_updater.borrow_mut() = rate;
Ok(())
}
fn handle_incoming_xoff(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
let xoff = msg
.decode::<Xoff>()
.map_err(|e| Error::from_bytes_err(e, "failed to decode XOFF message"))?
.into_msg();
if *xoff.version() != 0 {
return Err(Error::CircProto("Unrecognized XOFF version".into()));
}
if let Some(ref mut sidechannel_mitigation) = self.sidechannel_mitigation {
sidechannel_mitigation.received_xoff(&self.params)?;
}
trace!("Received an XOFF");
*self.rate_limit_updater.borrow_mut() = StreamRateLimit::ZERO;
Ok(())
}
fn maybe_send_xon(&mut self, rate: XonKbpsEwma, buffer_len: usize) -> Result<Option<Xon>> {
if buffer_len as u64 > self.xoff_limit.as_bytes() {
debug_assert!(matches!(self.last_sent_xon_xoff, Some(XonXoffMsg::Xoff)));
self.drain_rate_requester.notify();
return Ok(None);
}
self.last_sent_xon_xoff = Some(XonXoffMsg::Xon(rate));
trace!("Want to send an XON with rate {rate}");
Ok(Some(Xon::new(FlowCtrlVersion::V0, rate)))
}
fn maybe_send_xoff(&mut self, buffer_len: usize) -> Result<Option<Xoff>> {
if matches!(self.last_sent_xon_xoff, Some(XonXoffMsg::Xoff)) {
return Ok(None);
}
if buffer_len as u64 <= self.xoff_limit.as_bytes() {
return Ok(None);
}
self.last_sent_xon_xoff = Some(XonXoffMsg::Xoff);
self.drain_rate_requester.notify();
trace!("Want to send an XOFF");
Ok(Some(Xoff::new(FlowCtrlVersion::V0)))
}
}
#[derive(Debug, PartialEq, Eq)]
enum XonXoff {
Xon,
Xoff,
}
#[derive(Debug)]
enum XonXoffMsg {
#[expect(dead_code)]
Xon(XonKbpsEwma),
Xoff,
}
#[derive(Debug)]
struct SidechannelMitigation {
last_recvd_xon_xoff: Option<XonXoff>,
bytes_sent_total: Saturating<u32>,
bytes_sent_since_recvd_last_advisory_xon: Saturating<u32>,
bytes_sent_since_recvd_last_xoff: Saturating<u32>,
}
impl SidechannelMitigation {
fn new() -> Self {
Self {
last_recvd_xon_xoff: None,
bytes_sent_total: Saturating(0),
bytes_sent_since_recvd_last_advisory_xon: Saturating(0),
bytes_sent_since_recvd_last_xoff: Saturating(0),
}
}
fn peer_xoff_limit_bytes(params: &FlowCtrlParameters) -> u64 {
let min = std::cmp::min(
params.cc_xoff_client.as_bytes(),
params.cc_xoff_exit.as_bytes(),
);
min / 2
}
fn peer_xon_limit_bytes(params: &FlowCtrlParameters) -> u64 {
params.cc_xon_rate.as_bytes() / 2
}
fn sent_stream_data(&mut self, stream_bytes: usize) {
let stream_bytes: u32 = stream_bytes.try_into().unwrap_or(u32::MAX);
self.bytes_sent_total += stream_bytes;
self.bytes_sent_since_recvd_last_advisory_xon += stream_bytes;
self.bytes_sent_since_recvd_last_xoff += stream_bytes;
}
fn received_xon(&mut self, params: &FlowCtrlParameters) -> Result<()> {
if self.bytes_sent_total.0 == 0 {
const MSG: &str = "Received XON before sending any data";
return Err(Error::CircProto(MSG.into()));
}
let is_advisory = match self.last_recvd_xon_xoff {
Some(XonXoff::Xon) => true,
Some(XonXoff::Xoff) => false,
None => true,
};
self.last_recvd_xon_xoff = Some(XonXoff::Xon);
if !is_advisory {
return Ok(());
}
let advisory_not_expected_before = std::cmp::min(
Self::peer_xoff_limit_bytes(params),
Self::peer_xon_limit_bytes(params),
);
if u64::from(self.bytes_sent_total.0) < advisory_not_expected_before {
const MSG: &str = "Received advisory XON too early";
return Err(Error::CircProto(MSG.into()));
}
if u64::from(self.bytes_sent_since_recvd_last_advisory_xon.0)
< Self::peer_xon_limit_bytes(params)
{
const MSG: &str = "Received advisory XON too frequently";
return Err(Error::CircProto(MSG.into()));
}
self.bytes_sent_since_recvd_last_advisory_xon = Saturating(0);
Ok(())
}
fn received_xoff(&mut self, params: &FlowCtrlParameters) -> Result<()> {
if self.bytes_sent_total.0 == 0 {
const MSG: &str = "Received XOFF before sending any data";
return Err(Error::CircProto(MSG.into()));
}
if self.last_recvd_xon_xoff == Some(XonXoff::Xoff) {
const MSG: &str = "Received consecutive XOFF messages";
return Err(Error::CircProto(MSG.into()));
}
if u64::from(self.bytes_sent_total.0) < Self::peer_xoff_limit_bytes(params) {
const MSG: &str = "Received XOFF too early";
return Err(Error::CircProto(MSG.into()));
}
if u64::from(self.bytes_sent_since_recvd_last_xoff.0) < Self::peer_xoff_limit_bytes(params)
{
return Err(Error::CircProto("Received XOFF too frequently".into()));
}
self.bytes_sent_since_recvd_last_xoff = Saturating(0);
self.last_recvd_xon_xoff = Some(XonXoff::Xoff);
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::stream::flow_ctrl::params::CellCount;
#[test]
fn sidechannel_mitigation() {
let params = [
FlowCtrlParameters {
cc_xoff_client: CellCount::new(2),
cc_xoff_exit: CellCount::new(4),
cc_xon_rate: CellCount::new(8),
cc_xon_change_pct: 1,
cc_xon_ewma_cnt: 1,
},
FlowCtrlParameters {
cc_xoff_client: CellCount::new(8),
cc_xoff_exit: CellCount::new(4),
cc_xon_rate: CellCount::new(2),
cc_xon_change_pct: 1,
cc_xon_ewma_cnt: 1,
},
];
for params in params {
let xon_limit = SidechannelMitigation::peer_xon_limit_bytes(¶ms);
let xoff_limit = SidechannelMitigation::peer_xoff_limit_bytes(¶ms);
let mut x = SidechannelMitigation::new();
assert!(x.received_xon(¶ms).is_err());
let mut x = SidechannelMitigation::new();
assert!(x.received_xoff(¶ms).is_err());
let mut x = SidechannelMitigation::new();
x.sent_stream_data(xoff_limit as usize - 1);
assert!(x.received_xoff(¶ms).is_err());
let mut x = SidechannelMitigation::new();
x.sent_stream_data(xoff_limit as usize);
assert!(x.received_xoff(¶ms).is_ok());
assert!(x.received_xoff(¶ms).is_err());
let mut x = SidechannelMitigation::new();
x.sent_stream_data(xoff_limit as usize);
assert!(x.received_xoff(¶ms).is_ok());
x.sent_stream_data(xoff_limit as usize);
assert!(x.received_xoff(¶ms).is_err());
let mut x = SidechannelMitigation::new();
x.sent_stream_data(xoff_limit as usize);
assert!(x.received_xoff(¶ms).is_ok());
assert!(x.received_xon(¶ms).is_ok());
x.sent_stream_data(xoff_limit as usize);
assert!(x.received_xoff(¶ms).is_ok());
let mut x = SidechannelMitigation::new();
x.sent_stream_data(xon_limit as usize - 1);
assert!(x.received_xon(¶ms).is_err());
}
}
}