use super::{CircEvent, ExtendResult, Outbound};
use crate::Error;
use crate::circuit::UniqId;
use crate::circuit::create::{Create2Wrap, CreateHandshakeWrap};
use crate::peer::PeerInfo;
use crate::relay::channel_provider::{ChannelProvider, ChannelResult, OutboundChanSender};
use crate::relay::reactor::CircuitAccount;
use crate::util::err::ReactorError;
use tor_cell::chancell::AnyChanCell;
use tor_cell::relaycell::UnparsedRelayMsg;
use tor_cell::relaycell::msg::{Extend2, Extended2};
use tor_error::{internal, into_internal, warn_report};
use tor_linkspec::decode::Strictness;
use tor_linkspec::{HasRelayIds, OwnedChanTarget, OwnedChanTargetBuilder};
use tor_rtcompat::{Runtime, SpawnExt as _};
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
use tracing::{debug, trace};
use std::result::Result as StdResult;
use std::sync::Arc;
pub(super) struct ExtendRequestHandler {
unique_id: UniqId,
have_seen_extend2: bool,
chan_provider: Arc<dyn ChannelProvider<BuildSpec = OwnedChanTarget> + Send + Sync>,
inbound_peer: Arc<PeerInfo>,
event_tx: mpsc::Sender<CircEvent>,
memquota: CircuitAccount,
}
impl ExtendRequestHandler {
pub(super) fn new(
unique_id: UniqId,
chan_provider: Arc<dyn ChannelProvider<BuildSpec = OwnedChanTarget> + Send + Sync>,
inbound_peer: Arc<PeerInfo>,
event_tx: mpsc::Sender<CircEvent>,
memquota: CircuitAccount,
) -> Self {
Self {
unique_id,
have_seen_extend2: false,
chan_provider,
inbound_peer,
event_tx,
memquota,
}
}
pub(super) fn handle_extend2<R: Runtime>(
&mut self,
runtime: &R,
early: bool,
msg: UnparsedRelayMsg,
) -> StdResult<(), ReactorError> {
if !early {
return Err(Error::CircProto("got EXTEND2 in a RELAY cell?!".into()).into());
}
if self.have_seen_extend2 {
return Err(Error::CircProto("got 2 EXTEND2 on the same circuit?!".into()).into());
}
self.have_seen_extend2 = true;
let to_bytes_err = |e| Error::from_bytes_err(e, "EXTEND2 message");
let extend2 = msg.decode::<Extend2>().map_err(to_bytes_err)?.into_msg();
let chan_target = OwnedChanTargetBuilder::from_encoded_linkspecs(
Strictness::Standard,
extend2.linkspecs(),
)
.map_err(|err| Error::LinkspecDecodeErr {
object: "EXTEND2",
err,
})?
.build()
.map_err(|_| {
Error::CircProto("Invalid channel target".into())
})?;
if chan_target.has_any_relay_id_from(&*self.inbound_peer) {
return Err(Error::CircProto("Cannot extend circuit to previous hop".into()).into());
}
let (chan_tx, chan_rx) = mpsc::unbounded();
let chan_tx = OutboundChanSender(chan_tx);
Arc::clone(&self.chan_provider).get_or_launch(self.unique_id, chan_target, chan_tx)?;
let mut result_tx = self.event_tx.clone();
let rt = runtime.clone();
let unique_id = self.unique_id;
let memquota = self.memquota.clone();
runtime
.spawn(async move {
let res = Self::extend_circuit(rt, unique_id, extend2, chan_rx, memquota).await;
let _ = result_tx.send(CircEvent::ExtendResult(res)).await;
})
.map_err(into_internal!("failed to spawn extend task?!"))?;
Ok(())
}
async fn extend_circuit<R: Runtime>(
_runtime: R,
unique_id: UniqId,
extend2: Extend2,
mut chan_rx: mpsc::UnboundedReceiver<ChannelResult>,
memquota: CircuitAccount,
) -> StdResult<ExtendResult, ReactorError> {
let chan_res = chan_rx
.next()
.await
.ok_or_else(|| internal!("channel provider task exited"))?;
let channel = match chan_res {
Ok(c) => c,
Err(e) => {
warn_report!(e, "Failed to launch outgoing channel");
return Err(ReactorError::Shutdown);
}
};
debug!(
circ_id = %unique_id,
"Launched channel to the next hop"
);
let (circ_id, outbound_chan_rx, createdreceiver) =
channel.new_outbound_circ(memquota).await?;
let create2_wrap = Create2Wrap {
handshake_type: extend2.handshake_type(),
};
let create2 = create2_wrap.to_chanmsg(extend2.handshake().into());
let mut outbound_chan_tx = channel.sender();
let cell = AnyChanCell::new(Some(circ_id), create2);
trace!(
circ_id = %unique_id,
"Sending CREATE2 to the next hop"
);
outbound_chan_tx.send((cell, None)).await?;
let response = createdreceiver
.await
.map_err(|_| internal!("channel disappeared?"))?;
trace!(
circ_id = %unique_id,
"Got CREATED2 response from next hop"
);
let outbound = Outbound {
circ_id,
channel: Arc::clone(&channel),
outbound_chan_tx,
};
let created2_body = create2_wrap.decode_chanmsg(response)?;
let extended2 = Extended2::new(created2_body);
Ok(ExtendResult {
extended2,
outbound,
outbound_chan_rx,
})
}
}