use crate::channel::{Channel, ChannelSender};
use crate::circuit::CircuitRxReceiver;
use crate::circuit::UniqId;
use crate::circuit::create::{Create2Wrap, CreateHandshakeWrap};
use crate::circuit::reactor::ControlHandler;
use crate::circuit::reactor::backward::BackwardReactorCmd;
use crate::circuit::reactor::forward::{ForwardCellDisposition, ForwardHandler};
use crate::circuit::reactor::hop_mgr::HopMgr;
use crate::crypto::cell::OutboundRelayLayer;
use crate::crypto::cell::RelayCellBody;
use crate::relay::RelayCircChanMsg;
use crate::util::err::ReactorError;
use crate::{Error, HopNum, Result};
use crate::client::circuit::padding::QueuedCellPaddingInfo;
use crate::relay::channel_provider::{ChannelProvider, ChannelResult, OutboundChanSender};
use crate::relay::reactor::CircuitAccount;
use tor_cell::chancell::msg::{AnyChanMsg, Destroy, PaddingNegotiate, Relay};
use tor_cell::chancell::{AnyChanCell, BoxedCellBody, ChanMsg, CircId};
use tor_cell::relaycell::msg::{Extend2, Extended2, SendmeTag};
use tor_cell::relaycell::{RelayCellDecoderResult, RelayCellFormat, RelayCmd, UnparsedRelayMsg};
use tor_error::{internal, into_internal, warn_report};
use tor_linkspec::decode::Strictness;
use tor_linkspec::{OwnedChanTarget, OwnedChanTargetBuilder};
use tor_rtcompat::{Runtime, SpawnExt as _};
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _, future};
use tracing::{debug, trace};
use std::result::Result as StdResult;
use std::sync::Arc;
use std::task::Poll;
type CtrlMsg = ();
type CtrlCmd = ();
const MAX_RELAY_EARLY_CELLS_PER_CIRCUIT: usize = 8;
pub(crate) struct Forward {
unique_id: UniqId,
outbound: Option<Outbound>,
crypto_out: Box<dyn OutboundRelayLayer + Send>,
chan_provider: Arc<dyn ChannelProvider<BuildSpec = OwnedChanTarget> + Send + Sync>,
have_seen_extend2: bool,
relay_early_count: usize,
event_tx: mpsc::Sender<CircEvent>,
memquota: CircuitAccount,
}
pub(crate) enum CircEvent {
ExtendResult(StdResult<ExtendResult, ReactorError>),
}
pub(crate) struct ExtendResult {
extended2: Extended2,
outbound: Outbound,
outbound_chan_rx: CircuitRxReceiver,
}
struct Outbound {
circ_id: CircId,
channel: Arc<Channel>,
outbound_chan_tx: ChannelSender,
}
enum CellDecodeResult {
Recognized(SendmeTag, RelayCellDecoderResult),
Unrecognizd(RelayCellBody),
}
impl Forward {
pub(crate) fn new(
unique_id: UniqId,
crypto_out: Box<dyn OutboundRelayLayer + Send>,
chan_provider: Arc<dyn ChannelProvider<BuildSpec = OwnedChanTarget> + Send + Sync>,
event_tx: mpsc::Sender<CircEvent>,
memquota: CircuitAccount,
) -> Self {
Self {
unique_id,
outbound: None,
crypto_out,
chan_provider,
have_seen_extend2: false,
relay_early_count: 0,
event_tx,
memquota,
}
}
fn decode_relay_cell<R: Runtime>(
&mut self,
hop_mgr: &mut HopMgr<R>,
cell: Relay,
) -> Result<(Option<HopNum>, CellDecodeResult)> {
let hopnum = None;
let cmd = cell.cmd();
let mut body = cell.into_relay_body().into();
let Some(tag) = self.crypto_out.decrypt_outbound(cmd, &mut body) else {
return Ok((hopnum, CellDecodeResult::Unrecognizd(body)));
};
let mut hops = hop_mgr.hops().write().expect("poisoned lock");
let decode_res = hops
.get_mut(hopnum)
.ok_or_else(|| internal!("msg from non-existent hop???"))?
.inbound
.decode(body.into())?;
Ok((hopnum, CellDecodeResult::Recognized(tag, decode_res)))
}
#[allow(clippy::unnecessary_wraps)] fn handle_drop(&mut self) -> StdResult<(), ReactorError> {
cfg_if::cfg_if! {
if #[cfg(feature = "circ-padding")] {
Err(internal!("relay circuit padding not yet supported").into())
} else {
Ok(())
}
}
}
fn handle_extend2<R: Runtime>(
&mut self,
runtime: &R,
early: bool,
msg: UnparsedRelayMsg,
) -> StdResult<(), ReactorError> {
if !early {
return Err(Error::CircProto("got EXTEND2 in a RELAY cell?!".into()).into());
}
if self.have_seen_extend2 {
return Err(Error::CircProto("got 2 EXTEND2 on the same circuit?!".into()).into());
}
self.have_seen_extend2 = true;
let to_bytes_err = |e| Error::from_bytes_err(e, "EXTEND2 message");
let extend2 = msg.decode::<Extend2>().map_err(to_bytes_err)?.into_msg();
let chan_target = OwnedChanTargetBuilder::from_encoded_linkspecs(
Strictness::Standard,
extend2.linkspecs(),
)
.map_err(|err| Error::LinkspecDecodeErr {
object: "EXTEND2",
err,
})?
.build()
.map_err(|_| {
Error::CircProto("Invalid channel target".into())
})?;
let (chan_tx, chan_rx) = mpsc::unbounded();
let chan_tx = OutboundChanSender(chan_tx);
Arc::clone(&self.chan_provider).get_or_launch(self.unique_id, chan_target, chan_tx)?;
let mut result_tx = self.event_tx.clone();
let rt = runtime.clone();
let unique_id = self.unique_id;
let memquota = self.memquota.clone();
runtime
.spawn(async move {
let res = Self::extend_circuit(rt, unique_id, extend2, chan_rx, memquota).await;
let _ = result_tx.send(CircEvent::ExtendResult(res)).await;
})
.map_err(into_internal!("failed to spawn extend task?!"))?;
Ok(())
}
fn handle_extend_result(
&mut self,
res: StdResult<ExtendResult, ReactorError>,
) -> StdResult<Option<BackwardReactorCmd>, ReactorError> {
let ExtendResult {
extended2,
outbound,
outbound_chan_rx,
} = res?;
self.outbound = Some(outbound);
Ok(Some(BackwardReactorCmd::HandleCircuitExtended {
hop: None,
extended2,
outbound_chan_rx,
}))
}
#[allow(unused_variables)] async fn extend_circuit<R: Runtime>(
_runtime: R,
unique_id: UniqId,
extend2: Extend2,
mut chan_rx: mpsc::UnboundedReceiver<ChannelResult>,
memquota: CircuitAccount,
) -> StdResult<ExtendResult, ReactorError> {
let chan_res = chan_rx
.next()
.await
.ok_or_else(|| internal!("channel provider task exited"))?;
let channel = match chan_res {
Ok(c) => c,
Err(e) => {
warn_report!(e, "Failed to launch outgoing channel");
return Err(ReactorError::Shutdown);
}
};
debug!(
circ_id = %unique_id,
"Launched channel to the next hop"
);
let (circ_id, outbound_chan_rx, createdreceiver) =
channel.new_outbound_circ(memquota).await?;
let create2_wrap = Create2Wrap {
handshake_type: extend2.handshake_type(),
};
let create2 = create2_wrap.to_chanmsg(extend2.handshake().into());
let mut outbound_chan_tx = channel.sender();
let cell = AnyChanCell::new(Some(circ_id), create2);
trace!(
circ_id = %unique_id,
"Sending CREATE2 to the next hop"
);
outbound_chan_tx.send((cell, None)).await?;
let response = createdreceiver
.await
.map_err(|_| internal!("channel disappeared?"))?;
trace!(
circ_id = %unique_id,
"Got CREATED2 response from next hop"
);
let outbound = Outbound {
circ_id,
channel: Arc::clone(&channel),
outbound_chan_tx,
};
let created2_body = create2_wrap.decode_chanmsg(response)?;
let extended2 = Extended2::new(created2_body);
Ok(ExtendResult {
extended2,
outbound,
outbound_chan_rx,
})
}
fn handle_relay_cell<R: Runtime>(
&mut self,
hop_mgr: &mut HopMgr<R>,
cell: Relay,
early: bool,
) -> StdResult<Option<ForwardCellDisposition>, ReactorError> {
if early {
self.relay_early_count += 1;
if self.relay_early_count > MAX_RELAY_EARLY_CELLS_PER_CIRCUIT {
return Err(
Error::CircProto("Circuit received too many RELAY_EARLY cells".into()).into(),
);
}
}
let (hopnum, res) = self.decode_relay_cell(hop_mgr, cell)?;
let (tag, decode_res) = match res {
CellDecodeResult::Unrecognizd(body) => {
self.handle_unrecognized_cell(body, None, early)?;
return Ok(None);
}
CellDecodeResult::Recognized(tag, res) => (tag, res),
};
Ok(Some(ForwardCellDisposition::HandleRecognizedRelay {
cell: decode_res,
early,
hopnum,
tag,
}))
}
fn handle_unrecognized_cell(
&mut self,
body: RelayCellBody,
info: Option<QueuedCellPaddingInfo>,
early: bool,
) -> StdResult<(), ReactorError> {
trace!(
circ_id = %self.unique_id,
"Forwarding unrecognized cell"
);
let Some(chan) = self.outbound.as_mut() else {
return Err(Error::CircProto(
"Asked to forward cell before the circuit was extended?!".into(),
)
.into());
};
let msg = Relay::from(BoxedCellBody::from(body));
let relay = if early {
AnyChanMsg::RelayEarly(msg.into())
} else {
AnyChanMsg::Relay(msg)
};
let cell = AnyChanCell::new(Some(chan.circ_id), relay);
chan.outbound_chan_tx.start_send_unpin((cell, info))?;
Ok(())
}
#[allow(clippy::unused_async)] async fn handle_truncate(&mut self) -> StdResult<(), ReactorError> {
Err(internal!("TRUNCATE is not implemented").into())
}
#[allow(clippy::needless_pass_by_value)] fn handle_destroy_cell(&mut self, _cell: Destroy) -> StdResult<(), ReactorError> {
Err(internal!("DESTROY is not implemented").into())
}
#[allow(clippy::needless_pass_by_value)] fn handle_padding_negotiate(&mut self, _cell: PaddingNegotiate) -> StdResult<(), ReactorError> {
Err(internal!("PADDING_NEGOTIATE is not implemented").into())
}
}
impl ForwardHandler for Forward {
type BuildSpec = OwnedChanTarget;
type CircChanMsg = RelayCircChanMsg;
type CircEvent = CircEvent;
async fn handle_meta_msg<R: Runtime>(
&mut self,
runtime: &R,
early: bool,
_hopnum: Option<HopNum>,
msg: UnparsedRelayMsg,
_relay_cell_format: RelayCellFormat,
) -> StdResult<(), ReactorError> {
match msg.cmd() {
RelayCmd::DROP => self.handle_drop(),
RelayCmd::EXTEND2 => self.handle_extend2(runtime, early, msg),
RelayCmd::TRUNCATE => self.handle_truncate().await,
cmd => Err(internal!("relay cmd {cmd} not supported").into()),
}
}
async fn handle_forward_cell<R: Runtime>(
&mut self,
hop_mgr: &mut HopMgr<R>,
cell: RelayCircChanMsg,
) -> StdResult<Option<ForwardCellDisposition>, ReactorError> {
use RelayCircChanMsg::*;
match cell {
Relay(r) => self.handle_relay_cell(hop_mgr, r, false),
RelayEarly(r) => self.handle_relay_cell(hop_mgr, r.into(), true),
Destroy(d) => {
self.handle_destroy_cell(d)?;
Ok(None)
}
PaddingNegotiate(p) => {
self.handle_padding_negotiate(p)?;
Ok(None)
}
}
}
fn handle_event(
&mut self,
event: Self::CircEvent,
) -> StdResult<Option<BackwardReactorCmd>, ReactorError> {
match event {
CircEvent::ExtendResult(res) => self.handle_extend_result(res),
}
}
async fn outbound_chan_ready(&mut self) -> Result<()> {
future::poll_fn(|cx| match &mut self.outbound {
Some(chan) => {
let _ = chan.outbound_chan_tx.poll_flush_unpin(cx);
chan.outbound_chan_tx.poll_ready_unpin(cx)
}
None => {
Poll::Ready(Ok(()))
}
})
.await
}
}
impl ControlHandler for Forward {
type CtrlMsg = CtrlMsg;
type CtrlCmd = CtrlCmd;
fn handle_cmd(&mut self, cmd: Self::CtrlCmd) -> StdResult<(), ReactorError> {
let () = cmd;
Ok(())
}
fn handle_msg(&mut self, msg: Self::CtrlMsg) -> StdResult<(), ReactorError> {
let () = msg;
Ok(())
}
}
impl Drop for Forward {
fn drop(&mut self) {
if let Some(outbound) = self.outbound.as_mut() {
let _ = outbound.channel.close_circuit(outbound.circ_id);
}
}
}