use std::collections::VecDeque;
use tor_cell::relaycell::RelayCmd;
use tor_cell::relaycell::UnparsedRelayMsg;
use tor_error::internal;
use crate::{Error, Result};
pub(crate) type CircSendWindow = SendWindow<CircParams>;
pub(crate) type StreamSendWindow = SendWindow<StreamParams>;
pub(crate) type CircRecvWindow = RecvWindow<CircParams>;
pub(crate) type StreamRecvWindow = RecvWindow<StreamParams>;
#[derive(Clone, Debug)]
pub(crate) struct SendWindow<P>
where
P: WindowParams,
{
window: u16,
_dummy: std::marker::PhantomData<P>,
}
pub(crate) trait WindowParams {
#[allow(dead_code)] fn maximum() -> u16;
fn increment() -> u16;
fn start() -> u16;
}
#[derive(Clone, Debug)]
pub(crate) struct CircParams;
impl WindowParams for CircParams {
fn maximum() -> u16 {
1000
}
fn increment() -> u16 {
100
}
fn start() -> u16 {
1000
}
}
#[derive(Clone, Debug)]
pub(crate) struct StreamParams;
impl WindowParams for StreamParams {
fn maximum() -> u16 {
500
}
fn increment() -> u16 {
50
}
fn start() -> u16 {
500
}
}
#[derive(Clone, Debug)]
pub(crate) struct SendmeValidator<T>
where
T: PartialEq + Eq + Clone,
{
tags: VecDeque<T>,
}
impl<T> SendmeValidator<T>
where
T: PartialEq + Eq + Clone,
{
pub(crate) fn new() -> Self {
Self {
tags: VecDeque::new(),
}
}
pub(crate) fn record<U>(&mut self, tag: &U)
where
U: Clone + Into<T>,
{
self.tags.push_back(tag.clone().into());
}
pub(crate) fn validate<U>(&mut self, tag: Option<U>) -> Result<()>
where
T: PartialEq<U>,
{
match (self.tags.front(), tag) {
(Some(t), Some(tag)) if t == &tag => {} (Some(_), None) => {} (Some(_), Some(_)) => {
return Err(Error::CircProto("Mismatched tag on circuit SENDME".into()));
}
(None, _) => {
return Err(Error::CircProto(
"Received a SENDME when none was expected".into(),
));
}
}
self.tags.pop_front();
Ok(())
}
#[cfg(test)]
pub(crate) fn expected_tags(&self) -> Vec<T> {
self.tags.iter().map(Clone::clone).collect()
}
}
impl<P> SendWindow<P>
where
P: WindowParams,
{
pub(crate) fn new(window: u16) -> SendWindow<P> {
SendWindow {
window,
_dummy: std::marker::PhantomData,
}
}
pub(crate) fn should_record_tag(&self) -> bool {
self.window.is_multiple_of(P::increment())
}
pub(crate) fn take(&mut self) -> Result<()> {
self.window = self.window.checked_sub(1).ok_or(Error::CircProto(
"Called SendWindow::take() on empty SendWindow".into(),
))?;
Ok(())
}
#[must_use = "didn't check whether SENDME was expected."]
pub(crate) fn put(&mut self) -> Result<()> {
let new_window = self
.window
.checked_add(P::increment())
.ok_or(Error::from(internal!("Overflow on SENDME window")))?;
if new_window > P::maximum() {
return Err(Error::CircProto("Unexpected stream SENDME".into()));
}
self.window = new_window;
Ok(())
}
pub(crate) fn window(&self) -> u16 {
self.window
}
}
#[derive(Clone, Debug)]
pub(crate) struct RecvWindow<P: WindowParams> {
window: u16,
_dummy: std::marker::PhantomData<P>,
}
impl<P: WindowParams> RecvWindow<P> {
pub(crate) fn new(window: u16) -> RecvWindow<P> {
RecvWindow {
window,
_dummy: std::marker::PhantomData,
}
}
pub(crate) fn take(&mut self) -> Result<bool> {
let v = self.window.checked_sub(1);
if let Some(x) = v {
self.window = x;
Ok(x % P::increment() == 0)
} else {
Err(Error::CircProto(
"Received a data cell in violation of a window".into(),
))
}
}
pub(crate) fn decrement_n(&mut self, n: u16) -> crate::Result<()> {
self.window = self.window.checked_sub(n).ok_or(Error::CircProto(
"Received too many cells on a stream".into(),
))?;
Ok(())
}
pub(crate) fn put(&mut self) {
self.window = self
.window
.checked_add(P::increment())
.expect("Overflow detected while attempting to increment window");
}
}
pub(crate) fn cmd_counts_towards_windows(cmd: RelayCmd) -> bool {
cmd == RelayCmd::DATA
}
#[cfg(test)]
pub(crate) fn msg_counts_towards_windows(msg: &tor_cell::relaycell::msg::AnyRelayMsg) -> bool {
use tor_cell::relaycell::RelayMsg;
cmd_counts_towards_windows(msg.cmd())
}
pub(crate) fn cell_counts_towards_windows(cell: &UnparsedRelayMsg) -> bool {
cmd_counts_towards_windows(cell.cmd())
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use tor_basic_utils::test_rng::testing_rng;
use tor_cell::relaycell::{AnyRelayMsgOuter, RelayCellFormat, StreamId, msg};
#[test]
fn what_counts() {
let mut rng = testing_rng();
let fmt = RelayCellFormat::V0;
let m = msg::Begin::new("www.torproject.org", 443, 0)
.unwrap()
.into();
assert!(!msg_counts_towards_windows(&m));
assert!(!cell_counts_towards_windows(
&UnparsedRelayMsg::from_singleton_body(
RelayCellFormat::V0,
AnyRelayMsgOuter::new(StreamId::new(77), m)
.encode(fmt, &mut rng)
.unwrap()
)
.unwrap()
));
let m = msg::Data::new(&b"Education is not a prerequisite to political control-political control is the cause of popular education."[..]).unwrap().into(); assert!(msg_counts_towards_windows(&m));
assert!(cell_counts_towards_windows(
&UnparsedRelayMsg::from_singleton_body(
RelayCellFormat::V0,
AnyRelayMsgOuter::new(StreamId::new(128), m)
.encode(fmt, &mut rng)
.unwrap()
)
.unwrap()
));
}
#[test]
fn recvwindow() {
let mut w: RecvWindow<StreamParams> = RecvWindow::new(500);
for _ in 0..49 {
assert!(!w.take().unwrap());
}
assert!(w.take().unwrap());
assert_eq!(w.window, 450);
assert!(w.decrement_n(123).is_ok());
assert_eq!(w.window, 327);
w.put();
assert_eq!(w.window, 377);
assert!(w.decrement_n(400).is_err());
assert!(w.decrement_n(377).is_ok());
assert!(w.take().is_err());
}
fn new_sendwindow() -> SendWindow<CircParams> {
SendWindow::new(1000)
}
#[test]
fn sendwindow_basic() -> Result<()> {
let mut w = new_sendwindow();
w.take()?;
assert_eq!(w.window(), 999);
for _ in 0_usize..98 {
w.take()?;
}
assert_eq!(w.window(), 901);
w.take()?;
assert_eq!(w.window(), 900);
w.take()?;
assert_eq!(w.window(), 899);
w.put()?;
assert_eq!(w.window(), 999);
for _ in 0_usize..300 {
w.take()?;
}
w.put()?;
assert_eq!(w.window(), 799);
Ok(())
}
#[test]
fn sendwindow_erroring() -> Result<()> {
let mut w = new_sendwindow();
for _ in 0_usize..1000 {
w.take()?;
}
assert_eq!(w.window(), 0);
let ready = w.take();
assert!(ready.is_err());
Ok(())
}
}