use crate::circuit::circhop::CircHopOutbound;
use crate::circuit::reactor::macros::derive_deftly_template_CircuitReactor;
use crate::circuit::{CircHopSyncView, UniqId};
use crate::congestion::{CongestionControl, sendme};
use crate::memquota::{CircuitAccount, SpecificAccount as _, StreamAccount};
use crate::stream::CloseStreamBehavior;
use crate::stream::cmdcheck::StreamStatus;
use crate::stream::flow_ctrl::state::StreamRateLimit;
use crate::stream::queue::stream_queue;
use crate::streammap;
use crate::util::err::ReactorError;
use crate::util::notify::NotifySender;
use crate::{Error, HopNum};
#[cfg(any(feature = "hs-service", feature = "relay"))]
use crate::stream::incoming::{
InboundDataCmdChecker, IncomingStreamRequest, IncomingStreamRequestContext,
IncomingStreamRequestDisposition, IncomingStreamRequestHandler, StreamReqInfo,
};
use tor_async_utils::{SinkTrySend as _, SinkTrySendError as _};
use tor_cell::relaycell::msg::{AnyRelayMsg, Begin, End, EndReason};
use tor_cell::relaycell::{AnyRelayMsgOuter, RelayCellFormat, StreamId, UnparsedRelayMsg};
use tor_error::into_internal;
use tor_log_ratelim::log_ratelim;
use tor_memquota::mq_queue::{ChannelSpec as _, MpscSpec};
use tor_rtcompat::{DynTimeProvider, Runtime, SleepProvider as _};
use derive_deftly::Deftly;
use futures::SinkExt;
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _, future, select_biased};
use postage::watch;
use tracing::debug;
use std::pin::Pin;
use std::result::Result as StdResult;
use std::sync::{Arc, Mutex};
use std::task::Poll;
use std::time::Duration;
const CIRCUIT_BUFFER_SIZE: usize = 128;
pub(crate) trait StreamHandler: Send + Sync + 'static {
fn halfstream_expiry(&self, hop: &CircHopOutbound) -> Duration;
}
#[derive(Deftly)]
#[derive_deftly(CircuitReactor)]
#[deftly(reactor_name = "stream reactor")]
#[deftly(run_inner_fn = "Self::run_once")]
#[must_use = "If you don't call run() on a reactor, the circuit won't work."]
pub(crate) struct StreamReactor {
hopnum: Option<HopNum>,
hop: CircHopOutbound,
time_provider: DynTimeProvider,
unique_id: UniqId,
cell_rx: mpsc::Receiver<StreamMsg>,
bwd_tx: mpsc::Sender<ReadyStreamMsg>,
#[cfg(any(feature = "hs-service", feature = "relay"))]
incoming: Arc<Mutex<Option<IncomingStreamRequestHandler>>>,
inner: Arc<dyn StreamHandler>,
memquota: CircuitAccount,
}
#[allow(unused)] impl StreamReactor {
#[allow(clippy::too_many_arguments)] pub(crate) fn new<R: Runtime>(
runtime: R,
hopnum: Option<HopNum>,
hop: CircHopOutbound,
unique_id: UniqId,
cell_rx: mpsc::Receiver<StreamMsg>,
bwd_tx: mpsc::Sender<ReadyStreamMsg>,
inner: Arc<dyn StreamHandler>,
#[cfg(any(feature = "hs-service", feature = "relay"))] incoming: Arc<Mutex<Option<IncomingStreamRequestHandler>>>,
memquota: CircuitAccount,
) -> Self {
Self {
hopnum,
hop,
time_provider: DynTimeProvider::new(runtime),
unique_id,
#[cfg(any(feature = "hs-service", feature = "relay"))]
incoming,
cell_rx,
bwd_tx,
inner,
memquota,
}
}
async fn run_once(&mut self) -> StdResult<(), ReactorError> {
use postage::prelude::{Sink as _, Stream as _};
self.hop
.stream_map()
.lock()
.expect("poisoned lock")
.remove_expired_halfstreams(self.time_provider.now());
let mut streams = Arc::clone(self.hop.stream_map());
let can_send = self
.hop
.ccontrol()
.lock()
.expect("poisoned lock")
.can_send();
let mut ready_streams_fut = future::poll_fn(move |cx| {
if !can_send {
return Poll::Pending;
}
let mut streams = streams.lock().expect("lock poisoned");
let Some((sid, msg)) = streams.poll_ready_streams_iter(cx).next() else {
return Poll::Pending;
};
if msg.is_none() {
return Poll::Ready(StreamEvent::Closed {
sid,
behav: CloseStreamBehavior::default(),
reason: streammap::TerminateReason::StreamTargetClosed,
});
};
let msg = streams.take_ready_msg(sid).expect("msg disappeared");
Poll::Ready(StreamEvent::ReadyMsg { sid, msg })
});
select_biased! {
res = self.cell_rx.next().fuse() => {
let Some(cmd) = res else {
return Err(ReactorError::Shutdown);
};
self.handle_reactor_cmd(cmd).await?;
}
event = ready_streams_fut.fuse() => {
self.handle_stream_event(event).await?;
}
}
Ok(())
}
async fn handle_reactor_cmd(&mut self, msg: StreamMsg) -> StdResult<(), ReactorError> {
let StreamMsg {
sid,
msg,
cell_counts_toward_windows,
} = msg;
let msg = self.handle_msg(sid, msg, cell_counts_toward_windows)?;
if let Some(msg) = msg {
self.hop.decrement_cell_limit()?;
let c_t_w = sendme::cmd_counts_towards_windows(msg.cmd());
if c_t_w {
if let Some(stream_id) = msg.stream_id() {
self.hop
.about_to_send(self.unique_id, stream_id, msg.msg())?;
}
}
self.send_msg_to_bwd(msg).await?;
}
Ok(())
}
fn handle_msg(
&mut self,
streamid: StreamId,
msg: UnparsedRelayMsg,
cell_counts_toward_windows: bool,
) -> StdResult<Option<AnyRelayMsgOuter>, ReactorError> {
let cmd = msg.cmd();
let possible_proto_violation_err = move |streamid: StreamId| {
Error::StreamProto(format!(
"Unexpected {cmd:?} message on unknown stream {streamid}"
))
};
let now = self.time_provider.now();
let res = self.hop.handle_msg(
possible_proto_violation_err,
cell_counts_toward_windows,
streamid,
msg,
now,
)?;
if let Some(msg) = res {
cfg_if::cfg_if! {
if #[cfg(any(feature = "hs-service", feature = "relay"))] {
return self.handle_incoming_stream_request(streamid, msg);
} else {
return Err(
tor_error::internal!(
"incoming stream not rejected, but relay and hs-service features are disabled?!"
).into()
);
}
}
}
if let Some(cell) = self.hop.maybe_send_xoff(streamid)? {
let cell = AnyRelayMsgOuter::new(Some(streamid), cell.into());
return Ok(Some(cell));
}
Ok(None)
}
#[cfg(any(feature = "hs-service", feature = "relay"))]
fn handle_incoming_stream_request(
&mut self,
sid: StreamId,
msg: UnparsedRelayMsg,
) -> StdResult<Option<AnyRelayMsgOuter>, ReactorError> {
let mut lock = self.incoming.lock().expect("poisoned lock");
let Some(handler) = lock.as_mut() else {
return Err(
Error::CircProto("Cannot handle BEGIN cells on this circuit".into()).into(),
);
};
if self.hopnum != handler.hop_num {
let expected_hopnum = match handler.hop_num {
Some(hopnum) => hopnum.display().to_string(),
None => "client".to_string(),
};
let actual_hopnum = match self.hopnum {
Some(hopnum) => hopnum.display().to_string(),
None => "None".to_string(),
};
return Err(Error::CircProto(format!(
"Expecting incoming streams from {}, but received {} cell from unexpected hop {}",
expected_hopnum,
msg.cmd(),
actual_hopnum,
))
.into());
}
let message_closes_stream = handler.cmd_checker.check_msg(&msg)? == StreamStatus::Closed;
if message_closes_stream {
self.hop
.stream_map()
.lock()
.expect("poisoned lock")
.ending_msg_received(sid)?;
return Ok(None);
}
let req = parse_incoming_stream_req(msg)?;
let view = CircHopSyncView::new(&self.hop);
if let Some(reject) = Self::should_reject_incoming(handler, sid, &req, &view)? {
return Ok(Some(reject));
};
let memquota =
StreamAccount::new(&self.memquota).map_err(|e| ReactorError::Err(e.into()))?;
let (sender, receiver) = stream_queue(
#[cfg(not(feature = "flowctl-cc"))]
crate::stream::STREAM_READER_BUFFER,
&memquota,
&self.time_provider,
)
.map_err(|e| ReactorError::Err(e.into()))?;
let (msg_tx, msg_rx) = MpscSpec::new(CIRCUIT_BUFFER_SIZE)
.new_mq(self.time_provider.clone(), memquota.as_raw_account())
.map_err(|e| ReactorError::Err(e.into()))?;
let (rate_limit_tx, rate_limit_rx) = watch::channel_with(StreamRateLimit::MAX);
let mut drain_rate_request_tx = NotifySender::new_typed();
let drain_rate_request_rx = drain_rate_request_tx.subscribe();
let cmd_checker = InboundDataCmdChecker::new_connected();
self.hop.add_ent_with_id(
sender,
msg_rx,
rate_limit_tx,
drain_rate_request_tx,
sid,
cmd_checker,
)?;
let outcome = Pin::new(&mut handler.incoming_sender).try_send(StreamReqInfo {
req,
stream_id: sid,
hop: None,
msg_tx,
receiver,
rate_limit_stream: rate_limit_rx,
drain_rate_request_stream: drain_rate_request_rx,
memquota,
relay_cell_format: self.hop.relay_cell_format(),
});
log_ratelim!("Delivering message to incoming stream handler"; outcome);
if let Err(e) = outcome {
if e.is_full() {
let end_msg = AnyRelayMsgOuter::new(
Some(sid),
End::new_with_reason(EndReason::RESOURCELIMIT).into(),
);
return Ok(Some(end_msg));
} else if e.is_disconnected() {
debug!(
circ_id = %self.unique_id,
"Incoming stream request receiver dropped",
);
return Err(ReactorError::Err(Error::CircuitClosed));
} else {
return Err(
Error::from((into_internal!("try_send failed unexpectedly"))(e)).into(),
);
}
}
Ok(None)
}
#[cfg(any(feature = "hs-service", feature = "relay"))]
fn should_reject_incoming<'a>(
handler: &mut IncomingStreamRequestHandler,
sid: StreamId,
request: &IncomingStreamRequest,
view: &CircHopSyncView<'a>,
) -> StdResult<Option<AnyRelayMsgOuter>, ReactorError> {
use IncomingStreamRequestDisposition::*;
let ctx = IncomingStreamRequestContext { request };
match handler.filter.as_mut().disposition(&ctx, view)? {
Accept => {
Ok(None)
}
CloseCircuit => Err(ReactorError::Shutdown),
RejectRequest(end) => {
let end_msg = AnyRelayMsgOuter::new(Some(sid), end.into());
Ok(Some(end_msg))
}
}
}
async fn handle_stream_event(&mut self, event: StreamEvent) -> StdResult<(), ReactorError> {
match event {
StreamEvent::Closed { sid, behav, reason } => {
let timeout = self.inner.halfstream_expiry(&self.hop);
let expire_at = self.time_provider.now() + timeout;
let res =
self.hop
.close_stream(self.unique_id, sid, None, behav, reason, expire_at)?;
let Some(msg) = res else {
return Ok(());
};
self.send_msg_to_bwd(msg.cell).await
}
StreamEvent::ReadyMsg { sid, msg } => {
self.send_msg_to_bwd(AnyRelayMsgOuter::new(Some(sid), msg))
.await
}
}
}
async fn send_msg_to_bwd(&mut self, msg: AnyRelayMsgOuter) -> StdResult<(), ReactorError> {
let msg = ReadyStreamMsg {
hop: self.hopnum,
relay_cell_format: self.hop.relay_cell_format(),
ccontrol: Arc::clone(self.hop.ccontrol()),
msg,
};
self.bwd_tx
.send(msg)
.await
.map_err(|_| ReactorError::Shutdown)?;
Ok(())
}
}
enum StreamEvent {
Closed {
sid: StreamId,
behav: CloseStreamBehavior,
reason: streammap::TerminateReason,
},
ReadyMsg {
sid: StreamId,
msg: AnyRelayMsg,
},
}
#[cfg(any(feature = "hs-service", feature = "relay"))]
fn parse_incoming_stream_req(msg: UnparsedRelayMsg) -> crate::Result<IncomingStreamRequest> {
let begin = msg
.decode::<Begin>()
.map_err(|e| Error::from_bytes_err(e, "Invalid Begin message"))?
.into_msg();
Ok(IncomingStreamRequest::Begin(begin))
}
pub(crate) struct ReadyStreamMsg {
pub(crate) hop: Option<HopNum>,
pub(crate) msg: AnyRelayMsgOuter,
pub(crate) relay_cell_format: RelayCellFormat,
pub(crate) ccontrol: Arc<Mutex<CongestionControl>>,
}
pub(crate) struct StreamMsg {
pub(crate) sid: StreamId,
pub(crate) msg: UnparsedRelayMsg,
pub(crate) cell_counts_toward_windows: bool,
}