use super::{CircuitCmd, CloseStreamBehavior};
use crate::circuit::circhop::{CircHopInbound, CircHopOutbound, HopSettings, SendRelayCell};
use crate::client::reactor::circuit::path::PathEntry;
use crate::congestion::CongestionControl;
use crate::crypto::cell::HopNum;
use crate::stream::StreamMpscReceiver;
use crate::stream::cmdcheck::AnyCmdChecker;
use crate::stream::flow_ctrl::state::StreamRateLimit;
use crate::stream::flow_ctrl::xon_xoff::reader::DrainRateRequest;
use crate::stream::queue::StreamQueueSender;
use crate::streammap::{self, StreamEntMut, StreamMap};
use crate::tunnel::TunnelScopedCircId;
use crate::util::notify::NotifySender;
use crate::util::tunnel_activity::TunnelActivity;
use crate::{Error, Result};
use futures::Stream;
use futures::stream::FuturesUnordered;
use postage::watch;
use smallvec::SmallVec;
use tor_cell::chancell::BoxedCellBody;
use tor_cell::relaycell::flow_ctrl::{Xoff, Xon, XonKbpsEwma};
use tor_cell::relaycell::msg::AnyRelayMsg;
use tor_cell::relaycell::{
AnyRelayMsgOuter, RelayCellDecoder, RelayCellDecoderResult, RelayCellFormat, StreamId,
UnparsedRelayMsg,
};
use web_time_compat::Instant;
use safelog::sensitive as sv;
use tor_error::Bug;
use tracing::instrument;
use std::result::Result as StdResult;
use std::sync::{Arc, Mutex, MutexGuard};
use std::task::Poll;
#[cfg(test)]
use tor_cell::relaycell::msg::SendmeTag;
const NUM_HOPS: usize = 3;
#[derive(Default)]
pub(crate) struct CircHopList {
hops: SmallVec<[CircHop; NUM_HOPS]>,
}
impl CircHopList {
pub(super) fn hop(&self, hopnum: HopNum) -> Option<&CircHop> {
self.hops.get(Into::<usize>::into(hopnum))
}
pub(super) fn get_mut(&mut self, hopnum: HopNum) -> Option<&mut CircHop> {
self.hops.get_mut(Into::<usize>::into(hopnum))
}
pub(crate) fn push(&mut self, hop: CircHop) {
self.hops.push(hop);
}
pub(crate) fn is_empty(&self) -> bool {
self.hops.is_empty()
}
pub(crate) fn len(&self) -> usize {
self.hops.len()
}
pub(in crate::client::reactor) fn ready_streams_iterator(
&self,
exclude: Option<HopNum>,
) -> impl Stream<Item = CircuitCmd> + use<> {
self.hops
.iter()
.enumerate()
.filter_map(|(i, hop)| {
let hop_num = HopNum::from(i as u8);
if exclude == Some(hop_num) {
return None;
}
if !hop.ccontrol().can_send() {
return None;
}
let hop_map = Arc::clone(self.hops[i].stream_map());
Some(futures::future::poll_fn(move |cx| {
let mut hop_map = hop_map.lock().expect("lock poisoned");
let Some((sid, msg)) = hop_map.poll_ready_streams_iter(cx).next() else {
return Poll::Pending;
};
if msg.is_none() {
return Poll::Ready(CircuitCmd::CloseStream {
hop: hop_num,
sid,
behav: CloseStreamBehavior::default(),
reason: streammap::TerminateReason::StreamTargetClosed,
});
};
let msg = hop_map.take_ready_msg(sid).expect("msg disappeared");
#[allow(unused)] let Some(StreamEntMut::Open(s)) = hop_map.get_mut(sid) else {
panic!("Stream {sid} disappeared");
};
debug_assert!(
s.can_send(&msg),
"Stream {sid} produced a message it can't send: {msg:?}"
);
let cell = SendRelayCell {
hop: Some(hop_num),
early: false,
cell: AnyRelayMsgOuter::new(Some(sid), msg),
};
Poll::Ready(CircuitCmd::Send(cell))
}))
})
.collect::<FuturesUnordered<_>>()
}
pub(super) fn remove_expired_halfstreams(&mut self, now: Instant) {
for hop in self.hops.iter_mut() {
hop.stream_map()
.lock()
.expect("lock poisoned")
.remove_expired_halfstreams(now);
}
}
pub(super) fn has_streams(&self) -> bool {
self.hops.iter().any(|hop| {
hop.stream_map()
.lock()
.expect("lock poisoned")
.n_open_streams()
> 0
})
}
pub(crate) fn tunnel_activity(&self) -> TunnelActivity {
self.hops
.iter()
.map(|hop| {
hop.stream_map()
.lock()
.expect("Poisoned lock")
.tunnel_activity()
})
.max()
.unwrap_or_else(TunnelActivity::never_used)
}
}
pub(crate) struct CircHop {
unique_id: TunnelScopedCircId,
hop_num: HopNum,
inbound: CircHopInbound,
outbound: CircHopOutbound,
}
impl CircHop {
pub(crate) fn new(
unique_id: TunnelScopedCircId,
hop_num: HopNum,
settings: &HopSettings,
) -> Self {
let relay_format = settings.relay_crypt_protocol().relay_cell_format();
let ccontrol = Arc::new(Mutex::new(CongestionControl::new(&settings.ccontrol)));
let inbound = CircHopInbound::new(RelayCellDecoder::new(relay_format), settings);
let outbound = CircHopOutbound::new(
ccontrol,
relay_format,
Arc::new(settings.flow_ctrl_params.clone()),
settings,
);
CircHop {
unique_id,
hop_num,
inbound,
outbound,
}
}
pub(crate) fn begin_stream(
&mut self,
message: AnyRelayMsg,
sender: StreamQueueSender,
rx: StreamMpscReceiver<AnyRelayMsg>,
rate_limit_updater: watch::Sender<StreamRateLimit>,
drain_rate_requester: NotifySender<DrainRateRequest>,
cmd_checker: AnyCmdChecker,
) -> Result<(SendRelayCell, StreamId)> {
self.outbound.begin_stream(
Some(self.hop_num),
message,
sender,
rx,
rate_limit_updater,
drain_rate_requester,
cmd_checker,
)
}
pub(crate) fn close_stream(
&mut self,
id: StreamId,
message: CloseStreamBehavior,
why: streammap::TerminateReason,
expiry: Instant,
) -> Result<Option<SendRelayCell>> {
self.outbound
.close_stream(self.unique_id, id, Some(self.hop_num), message, why, expiry)
}
#[instrument(level = "trace", skip_all)]
pub(crate) fn maybe_send_xon(
&mut self,
rate: XonKbpsEwma,
id: StreamId,
) -> Result<Option<Xon>> {
self.outbound.maybe_send_xon(rate, id)
}
pub(crate) fn maybe_send_xoff(&mut self, id: StreamId) -> Result<Option<Xoff>> {
self.outbound.maybe_send_xoff(id)
}
pub(crate) fn relay_cell_format(&self) -> RelayCellFormat {
self.outbound.relay_cell_format()
}
#[cfg(test)]
pub(crate) fn send_window_and_expected_tags(&self) -> (u32, Vec<SendmeTag>) {
self.outbound.send_window_and_expected_tags()
}
pub(crate) fn ccontrol(&self) -> MutexGuard<'_, CongestionControl> {
self.outbound.ccontrol().lock().expect("poisoned lock")
}
pub(crate) fn outbound(&self) -> &CircHopOutbound {
&self.outbound
}
pub(crate) fn about_to_send(&mut self, stream_id: StreamId, msg: &AnyRelayMsg) -> Result<()> {
self.outbound.about_to_send(self.unique_id, stream_id, msg)
}
#[cfg(feature = "hs-service")]
pub(crate) fn add_ent_with_id(
&self,
sink: StreamQueueSender,
rx: StreamMpscReceiver<AnyRelayMsg>,
rate_limit_updater: watch::Sender<StreamRateLimit>,
drain_rate_requester: NotifySender<DrainRateRequest>,
stream_id: StreamId,
cmd_checker: AnyCmdChecker,
) -> Result<()> {
self.outbound.add_ent_with_id(
sink,
rx,
rate_limit_updater,
drain_rate_requester,
stream_id,
cmd_checker,
)
}
#[cfg(feature = "hs-service")]
pub(crate) fn ending_msg_received(&self, stream_id: StreamId) -> Result<()> {
self.outbound.ending_msg_received(stream_id)
}
pub(crate) fn decode(&mut self, cell: BoxedCellBody) -> Result<RelayCellDecoderResult> {
self.inbound.decode(cell)
}
pub(super) fn handle_msg(
&self,
hop_detail: &PathEntry,
cell_counts_toward_windows: bool,
streamid: StreamId,
msg: UnparsedRelayMsg,
now: Instant,
) -> Result<Option<UnparsedRelayMsg>> {
let possible_proto_violation_err = |streamid: StreamId| Error::UnknownStream {
src: sv(hop_detail.clone()),
streamid,
};
self.outbound.handle_msg(
possible_proto_violation_err,
cell_counts_toward_windows,
streamid,
msg,
now,
)
}
pub(crate) fn stream_map(&self) -> &Arc<Mutex<StreamMap>> {
self.outbound.stream_map()
}
pub(crate) fn set_stream_map(&mut self, map: Arc<Mutex<StreamMap>>) -> StdResult<(), Bug> {
self.outbound.set_stream_map(map)
}
pub(crate) fn decrement_outbound_cell_limit(&mut self) -> Result<()> {
self.outbound.decrement_cell_limit()
}
pub(crate) fn decrement_inbound_cell_limit(&mut self) -> Result<()> {
self.inbound.decrement_cell_limit()
}
}