use bitvec::prelude::*;
use derive_deftly::Deftly;
use oneshot_fused_workaround as oneshot;
use postage::watch;
use tor_cell::relaycell::msg::AnyRelayMsg;
use tor_cell::relaycell::{RelayCellFormat, RelayCmd, StreamId, UnparsedRelayMsg, msg};
use tor_cell::restricted_msg;
use tor_error::internal;
use tor_memquota::derive_deftly_template_HasMemoryCost;
use tor_memquota::mq_queue::{self, MpscSpec};
use tor_rtcompat::DynTimeProvider;
use crate::circuit::CircHopSyncView;
use crate::stream::cmdcheck::{AnyCmdChecker, CmdChecker, StreamStatus};
use crate::stream::{CloseStreamBehavior, StreamComponents};
use crate::{Error, Result};
use crate::client::stream::DataStream;
use crate::memquota::StreamAccount;
use crate::stream::StreamMpscSender;
use crate::stream::flow_ctrl::state::StreamRateLimit;
use crate::stream::flow_ctrl::xon_xoff::reader::DrainRateRequest;
use crate::stream::queue::StreamQueueReceiver;
use crate::util::notify::NotifyReceiver;
use crate::{HopLocation, HopNum};
use std::mem::size_of;
#[derive(Debug, Default)]
pub(crate) struct InboundDataCmdChecker;
restricted_msg! {
enum IncomingDataStreamMsg:RelayMsg {
Data, End,
}
}
impl CmdChecker for InboundDataCmdChecker {
fn check_msg(&mut self, msg: &tor_cell::relaycell::UnparsedRelayMsg) -> Result<StreamStatus> {
use StreamStatus::*;
match msg.cmd() {
RelayCmd::DATA => Ok(Open),
RelayCmd::END => Ok(Closed),
_ => Err(Error::StreamProto(format!(
"Unexpected {} on an incoming data stream!",
msg.cmd()
))),
}
}
fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
let _ = msg
.decode::<IncomingDataStreamMsg>()
.map_err(|err| Error::from_bytes_err(err, "cell on half-closed stream"))?;
Ok(())
}
}
impl InboundDataCmdChecker {
pub(crate) fn new_connected() -> AnyCmdChecker {
Box::new(Self)
}
}
#[derive(Debug)]
pub struct IncomingStream {
time_provider: DynTimeProvider,
request: IncomingStreamRequest,
components: StreamComponents,
}
impl IncomingStream {
pub(crate) fn new(
time_provider: DynTimeProvider,
request: IncomingStreamRequest,
components: StreamComponents,
) -> Self {
Self {
time_provider,
request,
components,
}
}
pub fn request(&self) -> &IncomingStreamRequest {
&self.request
}
pub async fn accept_data(self, message: msg::Connected) -> Result<DataStream> {
let Self {
time_provider,
request,
components:
StreamComponents {
mut target,
stream_receiver,
xon_xoff_reader_ctrl,
memquota,
},
} = self;
match request {
IncomingStreamRequest::Begin(_) | IncomingStreamRequest::BeginDir(_) => {
target.send(message.into()).await?;
Ok(DataStream::new_connected(
time_provider,
stream_receiver,
xon_xoff_reader_ctrl,
target,
memquota,
))
}
IncomingStreamRequest::Resolve(_) => {
Err(internal!("Cannot accept data on a RESOLVE stream").into())
}
}
}
pub async fn reject(mut self, message: msg::End) -> Result<()> {
let rx = self.reject_inner(CloseStreamBehavior::SendEnd(message))?;
rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
}
fn reject_inner(
&mut self,
message: CloseStreamBehavior,
) -> Result<oneshot::Receiver<Result<()>>> {
self.components.target.close_pending(message)
}
pub async fn discard(mut self) -> Result<()> {
let rx = self.reject_inner(CloseStreamBehavior::SendNothing)?;
rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
}
}
restricted_msg! {
#[derive(Clone, Debug, Deftly)]
#[derive_deftly(HasMemoryCost)]
#[non_exhaustive]
pub enum IncomingStreamRequest: RelayMsg {
Begin,
BeginDir,
Resolve,
}
}
type RelayCmdSet = bitvec::BitArr!(for 256);
#[derive(Debug)]
pub(crate) struct IncomingCmdChecker {
allow_commands: RelayCmdSet,
}
impl IncomingCmdChecker {
pub(crate) fn new_any(allow_commands: &[RelayCmd]) -> AnyCmdChecker {
let mut array = BitArray::ZERO;
for c in allow_commands {
array.set(u8::from(*c) as usize, true);
}
Box::new(Self {
allow_commands: array,
})
}
}
impl CmdChecker for IncomingCmdChecker {
fn check_msg(&mut self, msg: &UnparsedRelayMsg) -> Result<StreamStatus> {
if self.allow_commands[u8::from(msg.cmd()) as usize] {
Ok(StreamStatus::Open)
} else {
Err(Error::StreamProto(format!(
"Unexpected {} on incoming stream",
msg.cmd()
)))
}
}
fn consume_checked_msg(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
let _ = msg
.decode::<IncomingStreamRequest>()
.map_err(|err| Error::from_bytes_err(err, "invalid message on incoming stream"))?;
Ok(())
}
}
pub trait IncomingStreamRequestFilter: Send + 'static {
fn disposition(
&mut self,
ctx: &IncomingStreamRequestContext<'_>,
circ: &CircHopSyncView<'_>,
) -> Result<IncomingStreamRequestDisposition>;
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum IncomingStreamRequestDisposition {
Accept,
CloseCircuit,
RejectRequest(msg::End),
}
pub struct IncomingStreamRequestContext<'a> {
pub(crate) request: &'a IncomingStreamRequest,
}
impl<'a> IncomingStreamRequestContext<'a> {
pub fn request(&self) -> &'a IncomingStreamRequest {
self.request
}
}
#[derive(Debug, Deftly)]
#[derive_deftly(HasMemoryCost)]
pub(crate) struct StreamReqInfo {
pub(crate) req: IncomingStreamRequest,
pub(crate) stream_id: StreamId,
pub(crate) hop: Option<HopLocation>,
#[deftly(has_memory_cost(indirect_size = "0"))]
pub(crate) relay_cell_format: RelayCellFormat,
#[deftly(has_memory_cost(indirect_size = "0"))] pub(crate) receiver: StreamQueueReceiver,
#[deftly(has_memory_cost(indirect_size = "size_of::<AnyRelayMsg>()"))] pub(crate) msg_tx: StreamMpscSender<AnyRelayMsg>,
#[deftly(has_memory_cost(indirect_size = "0"))]
pub(crate) rate_limit_stream: watch::Receiver<StreamRateLimit>,
#[deftly(has_memory_cost(indirect_size = "0"))]
pub(crate) drain_rate_request_stream: NotifyReceiver<DrainRateRequest>,
#[deftly(has_memory_cost(indirect_size = "0"))] pub(crate) memquota: StreamAccount,
}
#[cfg(any(feature = "hs-service", feature = "relay"))]
pub(crate) type StreamReqSender = mq_queue::Sender<StreamReqInfo, MpscSpec>;
#[derive(educe::Educe)]
#[educe(Debug)]
#[cfg(any(feature = "hs-service", feature = "relay"))]
pub(crate) struct IncomingStreamRequestHandler {
pub(crate) incoming_sender: StreamReqSender,
pub(crate) hop_num: Option<HopNum>,
pub(crate) cmd_checker: AnyCmdChecker,
#[educe(Debug(ignore))]
pub(crate) filter: Box<dyn IncomingStreamRequestFilter>,
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use tor_cell::relaycell::{
AnyRelayMsgOuter, RelayCellFormat,
msg::{Begin, BeginDir, Data, Resolve},
};
use super::*;
#[test]
fn incoming_cmd_checker() {
let u = |msg| {
let body = AnyRelayMsgOuter::new(None, msg)
.encode(RelayCellFormat::V0, &mut rand::rng())
.unwrap();
UnparsedRelayMsg::from_singleton_body(RelayCellFormat::V0, body).unwrap()
};
let begin = u(Begin::new("allium.example.com", 443, 0).unwrap().into());
let begin_dir = u(BeginDir::default().into());
let resolve = u(Resolve::new("allium.example.com").into());
let data = u(Data::new(&[1, 2, 3]).unwrap().into());
{
let mut cc_none = IncomingCmdChecker::new_any(&[]);
for m in [&begin, &begin_dir, &resolve, &data] {
assert!(cc_none.check_msg(m).is_err());
}
}
{
let mut cc_begin = IncomingCmdChecker::new_any(&[RelayCmd::BEGIN]);
assert_eq!(cc_begin.check_msg(&begin).unwrap(), StreamStatus::Open);
for m in [&begin_dir, &resolve, &data] {
assert!(cc_begin.check_msg(m).is_err());
}
}
{
let mut cc_any = IncomingCmdChecker::new_any(&[
RelayCmd::BEGIN,
RelayCmd::BEGIN_DIR,
RelayCmd::RESOLVE,
]);
for m in [&begin, &begin_dir, &resolve] {
assert_eq!(cc_any.check_msg(m).unwrap(), StreamStatus::Open);
}
assert!(cc_any.check_msg(&data).is_err());
}
}
}