use crate::Result;
use crate::congestion::sendme::{StreamRecvWindow, cmd_counts_towards_windows};
use crate::stream::cmdcheck::{AnyCmdChecker, StreamStatus};
use crate::stream::flow_ctrl::state::{FlowCtrlHooks, StreamFlowCtrl};
use tor_cell::relaycell::{RelayCmd, UnparsedRelayMsg};
#[derive(Debug)]
pub(crate) struct HalfStream {
flow_control: StreamFlowCtrl,
recvw: StreamRecvWindow,
cmd_checker: AnyCmdChecker,
}
impl HalfStream {
pub(crate) fn new(
flow_control: StreamFlowCtrl,
recvw: StreamRecvWindow,
cmd_checker: AnyCmdChecker,
) -> Self {
HalfStream {
flow_control,
recvw,
cmd_checker,
}
}
pub(crate) fn handle_msg(&mut self, msg: UnparsedRelayMsg) -> Result<StreamStatus> {
use StreamStatus::*;
match msg.cmd() {
RelayCmd::SENDME => {
self.flow_control.put_for_incoming_sendme(msg)?;
return Ok(Open);
}
RelayCmd::XON => {
self.flow_control.handle_incoming_xon(msg)?;
return Ok(Open);
}
RelayCmd::XOFF => {
self.flow_control.handle_incoming_xoff(msg)?;
return Ok(Open);
}
_ => {}
}
if cmd_counts_towards_windows(msg.cmd()) {
self.recvw.take()?;
}
let status = self.cmd_checker.check_msg(&msg)?;
self.cmd_checker.consume_checked_msg(msg)?;
Ok(status)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use crate::{
client::stream::OutboundDataCmdChecker,
congestion::sendme::{StreamRecvWindow, StreamSendWindow},
};
use rand::{CryptoRng, Rng};
use tor_basic_utils::test_rng::testing_rng;
use tor_cell::relaycell::{
AnyRelayMsgOuter, RelayCellFormat, StreamId,
msg::{self, AnyRelayMsg},
};
fn to_unparsed<R: Rng + CryptoRng>(rng: &mut R, val: AnyRelayMsg) -> UnparsedRelayMsg {
UnparsedRelayMsg::from_singleton_body(
RelayCellFormat::V0,
AnyRelayMsgOuter::new(StreamId::new(77), val)
.encode(RelayCellFormat::V0, rng)
.expect("encoding failed"),
)
.unwrap()
}
#[test]
fn halfstream_sendme() {
let mut rng = testing_rng();
let sendw = StreamSendWindow::new(450);
let mut hs = HalfStream::new(
StreamFlowCtrl::new_window(sendw),
StreamRecvWindow::new(20),
OutboundDataCmdChecker::new_any(),
);
let m = msg::Sendme::new_empty();
assert!(
hs.handle_msg(to_unparsed(&mut rng, m.clone().into()))
.is_ok()
);
let e = hs
.handle_msg(to_unparsed(&mut rng, m.into()))
.err()
.unwrap();
assert_eq!(
format!("{}", e),
"Circuit protocol violation: Unexpected stream SENDME"
);
}
fn hs_new() -> HalfStream {
HalfStream::new(
StreamFlowCtrl::new_window(StreamSendWindow::new(20)),
StreamRecvWindow::new(20),
OutboundDataCmdChecker::new_any(),
)
}
#[test]
fn halfstream_data() {
let mut hs = hs_new();
let mut rng = testing_rng();
hs.handle_msg(to_unparsed(&mut rng, msg::Connected::new_empty().into()))
.unwrap();
let m = msg::Data::new(&b"this offer is unrepeatable"[..]).unwrap();
for _ in 0_u8..20 {
assert!(
hs.handle_msg(to_unparsed(&mut rng, m.clone().into()))
.is_ok()
);
}
let e = hs
.handle_msg(to_unparsed(&mut rng, m.into()))
.err()
.unwrap();
assert_eq!(
format!("{}", e),
"Circuit protocol violation: Received a data cell in violation of a window"
);
}
#[test]
fn halfstream_connected() {
let mut hs = hs_new();
let mut rng = testing_rng();
let m = msg::Connected::new_empty();
assert!(
hs.handle_msg(to_unparsed(&mut rng, m.clone().into()))
.is_ok()
);
assert!(
hs.handle_msg(to_unparsed(&mut rng, m.clone().into()))
.is_err()
);
let mut cmd_checker = OutboundDataCmdChecker::new_any();
{
cmd_checker
.check_msg(&to_unparsed(&mut rng, msg::Connected::new_empty().into()))
.unwrap();
}
let mut hs = HalfStream::new(
StreamFlowCtrl::new_window(StreamSendWindow::new(20)),
StreamRecvWindow::new(20),
cmd_checker,
);
let e = hs
.handle_msg(to_unparsed(&mut rng, m.into()))
.err()
.unwrap();
assert_eq!(
format!("{}", e),
"Stream protocol violation: Received CONNECTED twice on a stream."
);
}
#[test]
fn halfstream_other() {
let mut hs = hs_new();
let mut rng = testing_rng();
let m = msg::Extended2::new(Vec::new());
let e = hs
.handle_msg(to_unparsed(&mut rng, m.into()))
.err()
.unwrap();
assert_eq!(
format!("{}", e),
"Stream protocol violation: Unexpected EXTENDED2 on a data stream!"
);
}
}