use std::sync::{Arc, Mutex};
use crate::event::ChanMgrEventSender;
use async_trait::async_trait;
use tor_error::{HasKind, HasRetryTime, internal};
use tor_linkspec::{HasChanMethod, OwnedChanTarget, PtTransportName};
use tor_proto::channel::Channel;
use tor_proto::memquota::ChannelAccount;
use tracing::{debug, instrument};
#[cfg(feature = "relay")]
use safelog::Sensitive;
#[derive(Clone)]
pub struct BootstrapReporter(pub(crate) Arc<Mutex<ChanMgrEventSender>>);
impl BootstrapReporter {
#[cfg(test)]
pub(crate) fn fake() -> Self {
let (snd, _rcv) = crate::event::channel();
Self(Arc::new(Mutex::new(snd)))
}
}
#[async_trait]
pub trait ChannelFactory: Send + Sync {
async fn connect_via_transport(
&self,
target: &OwnedChanTarget,
reporter: BootstrapReporter,
memquota: ChannelAccount,
) -> crate::Result<Arc<Channel>>;
}
#[async_trait]
pub trait IncomingChannelFactory: Send + Sync {
type Stream: Send + Sync + 'static;
#[cfg(feature = "relay")]
async fn accept_from_transport(
&self,
peer: Sensitive<std::net::SocketAddr>,
stream: Self::Stream,
memquota: ChannelAccount,
) -> crate::Result<Arc<Channel>>;
}
#[async_trait]
impl<CF> crate::mgr::AbstractChannelFactory for CF
where
CF: ChannelFactory + IncomingChannelFactory + Sync,
{
type Channel = tor_proto::channel::Channel;
type BuildSpec = OwnedChanTarget;
type Stream = CF::Stream;
#[instrument(skip_all, level = "trace")]
async fn build_channel(
&self,
target: &Self::BuildSpec,
reporter: BootstrapReporter,
memquota: ChannelAccount,
) -> crate::Result<Arc<Self::Channel>> {
debug!("Attempting to open a new channel to {target}");
self.connect_via_transport(target, reporter, memquota).await
}
#[cfg(feature = "relay")]
#[instrument(skip_all, level = "trace")]
async fn build_channel_using_incoming(
&self,
peer: Sensitive<std::net::SocketAddr>,
stream: Self::Stream,
memquota: ChannelAccount,
) -> crate::Result<Arc<tor_proto::channel::Channel>> {
debug!("Attempting to open a new channel from {peer}");
self.accept_from_transport(peer, stream, memquota).await
}
}
pub trait AbstractPtError:
std::error::Error + HasKind + HasRetryTime + Send + Sync + std::fmt::Debug
{
}
#[async_trait]
pub trait AbstractPtMgr: Send + Sync {
async fn factory_for_transport(
&self,
transport: &PtTransportName,
) -> Result<Option<Arc<dyn ChannelFactory + Send + Sync>>, Arc<dyn AbstractPtError>>;
}
#[async_trait]
impl<P> AbstractPtMgr for Option<P>
where
P: AbstractPtMgr,
{
async fn factory_for_transport(
&self,
transport: &PtTransportName,
) -> Result<Option<Arc<dyn ChannelFactory + Send + Sync>>, Arc<dyn AbstractPtError>> {
match self {
Some(mgr) => mgr.factory_for_transport(transport).await,
None => Ok(None),
}
}
}
pub(crate) struct CompoundFactory<CF> {
#[cfg(feature = "pt-client")]
ptmgr: Option<Arc<dyn AbstractPtMgr + 'static>>,
default_factory: Arc<CF>,
}
impl<CF> Clone for CompoundFactory<CF> {
fn clone(&self) -> Self {
Self {
#[cfg(feature = "pt-client")]
ptmgr: self.ptmgr.as_ref().map(Arc::clone),
default_factory: Arc::clone(&self.default_factory),
}
}
}
#[async_trait]
impl<CF: ChannelFactory> ChannelFactory for CompoundFactory<CF> {
#[instrument(skip_all, level = "trace")]
async fn connect_via_transport(
&self,
target: &OwnedChanTarget,
reporter: BootstrapReporter,
memquota: ChannelAccount,
) -> crate::Result<Arc<Channel>> {
use tor_linkspec::ChannelMethod::*;
let factory = match target.chan_method() {
Direct(_) => self.default_factory.clone(),
#[cfg(feature = "pt-client")]
Pluggable(a) => match self.ptmgr.as_ref() {
Some(mgr) => mgr
.factory_for_transport(a.transport())
.await
.map_err(crate::Error::Pt)?
.ok_or_else(|| crate::Error::NoSuchTransport(a.transport().clone().into()))?,
None => return Err(crate::Error::NoSuchTransport(a.transport().clone().into())),
},
#[allow(unreachable_patterns)]
_ => {
return Err(crate::Error::Internal(internal!(
"No support for channel method"
)));
}
};
factory
.connect_via_transport(target, reporter, memquota)
.await
}
}
#[async_trait]
impl<CF: IncomingChannelFactory> IncomingChannelFactory for CompoundFactory<CF> {
type Stream = CF::Stream;
#[cfg(feature = "relay")]
async fn accept_from_transport(
&self,
peer: Sensitive<std::net::SocketAddr>,
stream: Self::Stream,
memquota: ChannelAccount,
) -> crate::Result<Arc<Channel>> {
self.default_factory
.accept_from_transport(peer, stream, memquota)
.await
}
}
impl<CF: ChannelFactory + 'static> CompoundFactory<CF> {
pub(crate) fn new(
default_factory: Arc<CF>,
#[cfg(feature = "pt-client")] ptmgr: Option<Arc<dyn AbstractPtMgr + 'static>>,
) -> Self {
Self {
default_factory,
#[cfg(feature = "pt-client")]
ptmgr,
}
}
#[cfg(feature = "pt-client")]
pub(crate) fn replace_ptmgr(&mut self, ptmgr: Arc<dyn AbstractPtMgr + 'static>) {
self.ptmgr = Some(ptmgr);
}
#[cfg(feature = "relay")]
pub(crate) fn default_factory(&self) -> &CF {
&self.default_factory
}
#[cfg(feature = "relay")]
pub(crate) fn replace_default_factory(&mut self, factory: Arc<CF>) {
self.default_factory = factory;
}
}