use crate::FlowCtrlParameters;
use crate::ccparams::{
Algorithm, AlgorithmDiscriminants, CongestionControlParams, CongestionWindowParams,
FixedWindowParams, RoundTripEstimatorParams, VegasParams,
};
use crate::channel::Channel;
use crate::circuit::CircuitRxSender;
use crate::circuit::UniqId;
use crate::circuit::celltypes::{CreateRequest, CreateResponse};
use crate::circuit::circhop::{HopNegotiationType, HopSettings};
use crate::client::circuit::CircParameters;
use crate::client::circuit::padding::PaddingController;
use crate::crypto::binding::CircuitBinding;
use crate::crypto::cell::CryptInit as _;
use crate::crypto::cell::{InboundRelayLayer, OutboundRelayLayer, RelayLayer, tor1};
use crate::crypto::handshake::RelayHandshakeError;
use crate::crypto::handshake::ServerHandshake as _;
use crate::crypto::handshake::fast::CreateFastServer;
use crate::crypto::handshake::ntor::{NtorSecretKey, NtorServer};
use crate::memquota::SpecificAccount as _;
use crate::memquota::{ChannelAccount, CircuitAccount};
use crate::relay::channel_provider::ChannelProvider;
use crate::relay::reactor::Reactor;
use crate::relay::{IncomingStreamRequestFilter, RelayCirc};
use smallvec::SmallVec;
use std::sync::{Arc, RwLock, Weak};
use tor_cell::chancell::ChanMsg as _;
use tor_cell::chancell::CircId;
use tor_cell::chancell::msg::{
CreateFast, Created2, CreatedFast, Destroy, DestroyReason, HandshakeType,
};
use tor_error::{Bug, ErrorKind, HasKind, debug_report, internal, into_internal};
use tor_linkspec::OwnedChanTarget;
use tor_llcrypto::cipher::aes::Aes128Ctr;
use tor_llcrypto::d::Sha1;
use tor_llcrypto::pk::ed25519::Ed25519Identity;
use tor_llcrypto::pk::rsa::RsaIdentity;
use tor_memquota::mq_queue::ChannelSpec as _;
use tor_memquota::mq_queue::MpscSpec;
use tor_relay_crypto::pk::{RelayNtorKeypair, RelayNtorKeys};
use tor_rtcompat::SpawnExt as _;
use tor_rtcompat::{DynTimeProvider, Runtime};
use tracing::warn;
#[derive(derive_more::Debug)]
pub struct CreateRequestHandler {
chan_provider: Weak<dyn ChannelProvider<BuildSpec = OwnedChanTarget> + Send + Sync>,
circ_net_params: RwLock<CircNetParameters>,
#[debug(skip)]
ntor_keys: RwLock<RelayNtorKeys>,
#[debug(skip)]
incoming_filter_factory: Box<dyn IncomingStreamRequestFilterFactory + Send + Sync>,
}
impl CreateRequestHandler {
pub fn new(
chan_provider: Weak<dyn ChannelProvider<BuildSpec = OwnedChanTarget> + Send + Sync>,
circ_net_params: CircNetParameters,
ntor_keys: RelayNtorKeys,
incoming_filter_factory: Box<dyn IncomingStreamRequestFilterFactory + Send + Sync>,
) -> Self {
Self {
chan_provider,
circ_net_params: RwLock::new(circ_net_params),
ntor_keys: RwLock::new(ntor_keys),
incoming_filter_factory,
}
}
pub fn update_params(&self, circ_net_params: CircNetParameters) {
*self.circ_net_params.write().expect("rwlock poisoned") = circ_net_params;
}
pub fn update_ntor_keys(&self, ntor_keys: RelayNtorKeys) {
*self.ntor_keys.write().expect("rwlock poisoned") = ntor_keys;
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn handle_create<R: Runtime>(
&self,
runtime: &R,
channel: &Arc<Channel>,
our_ed25519_id: &Ed25519Identity,
our_rsa_id: &RsaIdentity,
circ_id: CircId,
msg: &CreateRequest,
memquota: &ChannelAccount,
circ_unique_id: UniqId,
) -> Result<(CreateResponse, RelayCircComponents), Destroy> {
let result = self.handle_create_inner(
runtime,
channel,
our_ed25519_id,
our_rsa_id,
circ_id,
msg,
memquota,
circ_unique_id,
);
match result {
Ok(x) => Ok(x),
Err(e) => {
let cmd = msg.cmd();
debug_report!(&e, %cmd, "Failed to handle circuit create request");
Err(Destroy::new(DestroyReason::NONE))
}
}
}
#[allow(clippy::too_many_arguments)]
fn handle_create_inner<R: Runtime>(
&self,
runtime: &R,
channel: &Arc<Channel>,
our_ed25519_id: &Ed25519Identity,
our_rsa_id: &RsaIdentity,
circ_id: CircId,
msg: &CreateRequest,
memquota: &ChannelAccount,
circ_unique_id: UniqId,
) -> Result<(CreateResponse, RelayCircComponents), HandleCreateError> {
let handshake_components = match msg {
CreateRequest::CreateFast(msg) => self.handle_create_fast(msg)?,
CreateRequest::Create2(msg) => match msg.handshake_type() {
HandshakeType::NTOR_V3 => self.handle_create2_ntorv3(msg.body(), our_ed25519_id)?,
HandshakeType::NTOR => self.handle_create2_ntor(msg.body(), our_rsa_id)?,
x @ HandshakeType::TAP | x => {
return Err(HandleCreateError::Create2HandshakeType(x));
}
},
};
let memquota = CircuitAccount::new(memquota)?;
let time_provider = DynTimeProvider::new(runtime.clone());
let account = memquota.as_raw_account();
let (sender, receiver) =
MpscSpec::new(10_000_000).new_mq(time_provider.clone(), account)?;
let (sender, receiver) = crate::circuit::circ_sender::channel(sender, receiver);
let (padding_ctrl, padding_stream) =
crate::client::circuit::padding::new_padding(DynTimeProvider::new(runtime.clone()));
let Some(chan_provider) = self.chan_provider.upgrade() else {
return Err(internal!("Unable to upgrade weak `ChannelProvider`").into());
};
let incoming_filter = self.incoming_filter_factory.current_filter();
let (reactor, circ, _incoming_streams) = Reactor::new(
runtime.clone(),
channel,
circ_id,
circ_unique_id,
receiver,
handshake_components.crypto_in,
handshake_components.crypto_out,
&handshake_components.hop_settings,
chan_provider,
padding_ctrl.clone(),
padding_stream,
incoming_filter,
&memquota,
)
.map_err(into_internal!("Failed to start circuit reactor"))?;
let () = runtime.spawn(async {
match reactor.run().await {
Ok(()) => {}
Err(e) => {
debug_report!(e, "Relay circuit reactor exited with an error");
}
}
})?;
Ok((
handshake_components.response,
RelayCircComponents {
circ,
sender,
padding_ctrl,
},
))
}
fn handle_create_fast(
&self,
msg: &CreateFast,
) -> Result<CompletedHandshakeComponents, HandleCreateError> {
let (keygen, handshake_msg) = CreateFastServer::server(
&mut rand::rng(),
&mut |_: &()| Some(()),
&[()],
msg.handshake(),
)?;
let crypt = tor1::CryptStatePair::<Aes128Ctr, Sha1>::construct(keygen)
.map_err(into_internal!("Circuit crypt state construction failed"))?;
let circ_params = self
.circ_net_params
.read()
.expect("rwlock poisoned")
.as_circ_parameters(AlgorithmDiscriminants::FixedWindow)?;
let protos = tor_protover::Protocols::default();
let hop_settings =
HopSettings::from_params_and_caps(HopNegotiationType::None, &circ_params, &protos)
.map_err(into_internal!("Unable to build `HopSettings`"))?;
let response = CreatedFast::new(handshake_msg);
let response = CreateResponse::CreatedFast(response);
let (crypto_out, crypto_in, _binding) = split_relay_layer(crypt);
Ok(CompletedHandshakeComponents {
response,
hop_settings,
crypto_out,
crypto_in,
})
}
fn handle_create2_ntor(
&self,
msg_body: &[u8],
our_rsa_id: &RsaIdentity,
) -> Result<CompletedHandshakeComponents, HandleCreateError> {
let ntor_keys = self.ntor_keys(|k| {
NtorSecretKey::new(k.secret().clone(), *k.public().inner(), *our_rsa_id)
});
let (keygen, handshake_msg) = NtorServer::server(
&mut rand::rng(),
&mut |_: &()| Some(()),
ntor_keys.as_ref(),
msg_body,
)?;
let crypt = tor1::CryptStatePair::<Aes128Ctr, Sha1>::construct(keygen)
.map_err(into_internal!("Circuit crypt state construction failed"))?;
let (crypto_out, crypto_in, _binding) = split_relay_layer(crypt);
let circ_params = self
.circ_net_params
.read()
.expect("rwlock poisoned")
.as_circ_parameters(AlgorithmDiscriminants::FixedWindow)?;
let protos = tor_protover::Protocols::default();
let hop_settings =
HopSettings::from_params_and_caps(HopNegotiationType::None, &circ_params, &protos)
.map_err(into_internal!("Unable to build `HopSettings`"))?;
let response = Created2::new(handshake_msg);
let response = CreateResponse::Created2(response);
Ok(CompletedHandshakeComponents {
response,
hop_settings,
crypto_out,
crypto_in,
})
}
fn handle_create2_ntorv3(
&self,
_msg_body: &[u8],
_our_ed25519_id: &Ed25519Identity,
) -> Result<CompletedHandshakeComponents, HandleCreateError> {
Err(HandleCreateError::Create2HandshakeType(
HandshakeType::NTOR_V3,
))
}
fn ntor_keys<T>(&self, map: impl FnMut(&RelayNtorKeypair) -> T) -> impl AsRef<[T]> {
let ntor_keys = self.ntor_keys.read().expect("rwlock poisoned");
let ntor_keys = [Some(ntor_keys.latest()), ntor_keys.previous()];
ntor_keys
.into_iter()
.flatten()
.map(map)
.collect::<SmallVec<[T; 2]>>()
}
}
fn split_relay_layer<F, B>(
crypt: impl RelayLayer<F, B>,
) -> (
Box<dyn OutboundRelayLayer + Send>,
Box<dyn InboundRelayLayer + Send>,
CircuitBinding,
)
where
F: OutboundRelayLayer + Send + 'static,
B: InboundRelayLayer + Send + 'static,
{
let (crypto_out, crypto_in, binding) = crypt.split_relay_layer();
let (crypto_out, crypto_in) = (Box::new(crypto_out), Box::new(crypto_in));
(crypto_out, crypto_in, binding)
}
#[derive(Debug, thiserror::Error)]
enum HandleCreateError {
#[error("Circuit relay handshake failed")]
Handshake(#[from] RelayHandshakeError),
#[error("Unsupported handshake type {0}")]
Create2HandshakeType(HandshakeType),
#[error("Memquota error")]
Memquota(#[from] tor_memquota::Error),
#[error("Runtime task spawn error")]
Spawn(#[from] futures::task::SpawnError),
#[error("Internal error")]
Internal(#[from] tor_error::Bug),
}
impl HasKind for HandleCreateError {
fn kind(&self) -> ErrorKind {
match self {
Self::Handshake(e) => e.kind(),
Self::Create2HandshakeType(_) => ErrorKind::NotImplemented,
Self::Memquota(e) => e.kind(),
Self::Spawn(e) => e.kind(),
Self::Internal(_) => ErrorKind::Internal,
}
}
}
struct CompletedHandshakeComponents {
response: CreateResponse,
hop_settings: HopSettings,
crypto_out: Box<dyn OutboundRelayLayer + Send>,
crypto_in: Box<dyn InboundRelayLayer + Send>,
}
pub(crate) struct RelayCircComponents {
pub(crate) circ: Arc<RelayCirc>,
pub(crate) sender: CircuitRxSender,
pub(crate) padding_ctrl: PaddingController,
}
#[derive(Debug, Clone)]
#[allow(clippy::exhaustive_structs)]
pub struct CongestionControlNetParams {
pub fixed_window: FixedWindowParams,
pub vegas_exit: VegasParams,
pub cwnd: CongestionWindowParams,
pub rtt: RoundTripEstimatorParams,
pub flow_ctrl: FlowCtrlParameters,
}
impl CongestionControlNetParams {
#[cfg(test)]
pub(crate) fn defaults_for_tests() -> Self {
Self {
fixed_window: FixedWindowParams::defaults_for_tests(),
vegas_exit: VegasParams::defaults_for_tests(),
cwnd: CongestionWindowParams::defaults_for_tests(),
rtt: RoundTripEstimatorParams::defaults_for_tests(),
flow_ctrl: FlowCtrlParameters::defaults_for_tests(),
}
}
}
#[derive(Debug, Clone)]
#[allow(clippy::exhaustive_structs)]
pub struct CircNetParameters {
pub extend_by_ed25519_id: bool,
pub cc: CongestionControlNetParams,
}
impl CircNetParameters {
#[warn(unused)]
fn as_circ_parameters(&self, algorithm: AlgorithmDiscriminants) -> Result<CircParameters, Bug> {
let Self {
extend_by_ed25519_id,
cc:
CongestionControlNetParams {
fixed_window,
vegas_exit,
cwnd,
rtt,
flow_ctrl,
},
} = self;
let algorithm = match algorithm {
AlgorithmDiscriminants::FixedWindow => Algorithm::FixedWindow(*fixed_window),
AlgorithmDiscriminants::Vegas => Algorithm::Vegas(*vegas_exit),
};
let cc = CongestionControlParams::builder()
.alg(algorithm)
.fixed_window_params(*fixed_window)
.cwnd_params(*cwnd)
.rtt_params(rtt.clone())
.build()
.map_err(into_internal!("Could not build `CongestionControlParams`"))?;
Ok(CircParameters::new(
*extend_by_ed25519_id,
cc,
flow_ctrl.clone(),
))
}
}
pub trait IncomingStreamRequestFilterFactory {
fn current_filter(&self) -> Box<dyn IncomingStreamRequestFilter>;
}
impl<F> IncomingStreamRequestFilterFactory for F
where
F: Fn() -> Box<dyn IncomingStreamRequestFilter>,
{
fn current_filter(&self) -> Box<dyn IncomingStreamRequestFilter> {
(self)()
}
}