use std::collections::VecDeque;
use tor_cell::relaycell::msg::RelayMsg;
use tor_cell::relaycell::RelayCell;
use tor_error::internal;
use crate::{Error, Result};
#[derive(Clone, Debug)]
pub(crate) struct CircTag([u8; 20]);
impl From<[u8; 20]> for CircTag {
fn from(v: [u8; 20]) -> CircTag {
Self(v)
}
}
impl PartialEq for CircTag {
fn eq(&self, other: &Self) -> bool {
crate::util::ct::bytes_eq(&self.0, &other.0)
}
}
impl Eq for CircTag {}
impl PartialEq<[u8; 20]> for CircTag {
fn eq(&self, other: &[u8; 20]) -> bool {
crate::util::ct::bytes_eq(&self.0, &other[..])
}
}
pub(crate) type NoTag = ();
pub(crate) type CircSendWindow = SendWindow<CircParams, CircTag>;
pub(crate) type StreamSendWindow = SendWindow<StreamParams, NoTag>;
pub(crate) type CircRecvWindow = RecvWindow<CircParams>;
pub(crate) type StreamRecvWindow = RecvWindow<StreamParams>;
pub(crate) struct SendWindow<P, T>
where
P: WindowParams,
T: PartialEq + Eq + Clone,
{
window: u16,
tags: VecDeque<T>,
_dummy: std::marker::PhantomData<P>,
}
pub(crate) trait WindowParams {
fn maximum() -> u16;
fn increment() -> u16;
}
pub(crate) struct CircParams;
impl WindowParams for CircParams {
fn maximum() -> u16 {
1000
}
fn increment() -> u16 {
100
}
}
#[derive(Clone, Debug)]
pub(crate) struct StreamParams;
impl WindowParams for StreamParams {
fn maximum() -> u16 {
500
}
fn increment() -> u16 {
50
}
}
impl<P, T> SendWindow<P, T>
where
P: WindowParams,
T: PartialEq + Eq + Clone,
{
pub(crate) fn new(window: u16) -> SendWindow<P, T> {
let increment = P::increment();
let capacity = (window + increment - 1) / increment;
SendWindow {
window,
tags: VecDeque::with_capacity(capacity as usize),
_dummy: std::marker::PhantomData,
}
}
pub(crate) fn take<U>(&mut self, tag: &U) -> Result<u16>
where
U: Clone + Into<T>,
{
if let Some(val) = self.window.checked_sub(1) {
self.window = val;
if self.window % P::increment() == 0 {
self.tags.push_back(tag.clone().into());
}
Ok(val)
} else {
Err(Error::CircProto(
"Called SendWindow::take() on empty SendWindow".into(),
))
}
}
#[must_use = "didn't check whether SENDME was expected and tag was right."]
pub(crate) fn put<U>(&mut self, tag: Option<U>) -> Result<u16>
where
T: PartialEq<U>,
{
match (self.tags.front(), tag) {
(Some(t), Some(tag)) if t == &tag => {} (Some(_), None) => {} (Some(_), Some(_)) => {
return Err(Error::CircProto("Mismatched tag on circuit SENDME".into()));
}
(None, _) => {
return Err(Error::CircProto(
"Received a SENDME when none was expected".into(),
));
}
}
self.tags.pop_front();
let v = self
.window
.checked_add(P::increment())
.ok_or_else(|| Error::from(internal!("Overflow on SENDME window")))?;
self.window = v;
Ok(v)
}
pub(crate) fn window(&self) -> u16 {
self.window
}
#[cfg(test)]
pub(crate) fn window_and_expected_tags(&self) -> (u16, Vec<T>) {
let tags = self.tags.iter().map(Clone::clone).collect();
(self.window, tags)
}
}
#[derive(Clone, Debug)]
pub(crate) struct RecvWindow<P: WindowParams> {
window: u16,
_dummy: std::marker::PhantomData<P>,
}
impl<P: WindowParams> RecvWindow<P> {
pub(crate) fn new(window: u16) -> RecvWindow<P> {
RecvWindow {
window,
_dummy: std::marker::PhantomData,
}
}
pub(crate) fn take(&mut self) -> Result<bool> {
let v = self.window.checked_sub(1);
if let Some(x) = v {
self.window = x;
Ok(x % P::increment() == 0)
} else {
Err(Error::CircProto(
"Received a data cell in violation of a window".into(),
))
}
}
pub(crate) fn decrement_n(&mut self, n: u16) -> crate::Result<()> {
let v = self.window.checked_sub(n);
if let Some(x) = v {
self.window = x;
Ok(())
} else {
Err(crate::Error::CircProto(
"Received too many cells on a stream".into(),
))
}
}
pub(crate) fn put(&mut self) {
self.window = self
.window
.checked_add(P::increment())
.expect("Overflow detected while attempting to increment window");
}
}
pub(crate) fn msg_counts_towards_windows(msg: &RelayMsg) -> bool {
matches!(msg, RelayMsg::Data(_))
}
pub(crate) fn cell_counts_towards_windows(cell: &RelayCell) -> bool {
msg_counts_towards_windows(cell.msg())
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use tor_cell::relaycell::{msg, RelayCell};
#[test]
fn what_counts() {
let m = msg::Begin::new("www.torproject.org", 443, 0)
.unwrap()
.into();
assert!(!msg_counts_towards_windows(&m));
assert!(!cell_counts_towards_windows(&RelayCell::new(77.into(), m)));
let m = msg::Data::new(&b"Education is not a prerequisite to political control-political control is the cause of popular education."[..]).unwrap().into(); assert!(msg_counts_towards_windows(&m));
assert!(cell_counts_towards_windows(&RelayCell::new(128.into(), m)));
}
#[test]
fn recvwindow() {
let mut w: RecvWindow<StreamParams> = RecvWindow::new(500);
for _ in 0..49 {
assert!(!w.take().unwrap());
}
assert!(w.take().unwrap());
assert_eq!(w.window, 450);
assert!(w.decrement_n(123).is_ok());
assert_eq!(w.window, 327);
w.put();
assert_eq!(w.window, 377);
assert!(w.decrement_n(400).is_err());
assert!(w.decrement_n(377).is_ok());
assert!(w.take().is_err());
}
fn new_sendwindow() -> SendWindow<CircParams, &'static str> {
SendWindow::new(1000)
}
#[test]
fn sendwindow_basic() -> Result<()> {
let mut w = new_sendwindow();
let n = w.take(&"Hello")?;
assert_eq!(n, 999);
for _ in 0_usize..98 {
w.take(&"world")?;
}
assert_eq!(w.window, 901);
assert_eq!(w.tags.len(), 0);
let n = w.take(&"and")?;
assert_eq!(n, 900);
assert_eq!(w.tags.len(), 1);
assert_eq!(w.tags[0], "and");
let n = w.take(&"goodbye")?;
assert_eq!(n, 899);
assert_eq!(w.tags.len(), 1);
let n = w.put(Some("and"));
assert_eq!(n?, 999);
assert_eq!(w.tags.len(), 0);
for _ in 0_usize..300 {
w.take(&"dreamland")?;
}
assert_eq!(w.tags.len(), 3);
let x: Option<&str> = None;
let n = w.put(x);
assert_eq!(n?, 799);
assert_eq!(w.tags.len(), 2);
Ok(())
}
#[test]
fn sendwindow_bad_put() -> Result<()> {
let mut w = new_sendwindow();
for _ in 0_usize..250 {
w.take(&"correct")?;
}
assert_eq!(w.window, 750);
let n = w.put(Some("incorrect"));
assert!(n.is_err());
let n = w.put(Some("correct"));
assert_eq!(n?, 850);
let n = w.put(Some("correct"));
assert_eq!(n?, 950);
let n = w.put(Some("correct"));
assert!(n.is_err());
assert_eq!(w.window, 950);
let x: Option<&str> = None;
let n = w.put(x);
assert!(n.is_err());
assert_eq!(w.window, 950);
Ok(())
}
#[test]
fn sendwindow_erroring() -> Result<()> {
let mut w = new_sendwindow();
for _ in 0_usize..1000 {
w.take(&"here a string")?;
}
assert_eq!(w.window, 0);
let ready = w.take(&"there a string");
assert!(ready.is_err());
Ok(())
}
}