use crate::ext::future::TryFutureExtExt;
#[warn(unused_imports)]
use crate::{
common::protocol::tunnel::{
Tunnel, TunnelAddressInfo, TunnelError, TunnelIncomingType, TunnelName, TunnelSide,
},
util::{cancellation::CancellationListener, tunnel_stream::TunnelStream},
};
use futures::{future::BoxFuture, FutureExt, TryFutureExt, TryStreamExt};
use std::{
collections::HashMap,
fmt::Debug,
marker::{PhantomData, Unpin},
sync::Arc,
};
#[derive(Debug, Clone)]
pub struct TunnelInfo {
pub side: TunnelSide,
pub addr: TunnelAddressInfo,
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum AuthenticationHandlingError<TInner> {
#[error("Authentication dependency failure: {0} - {1}")]
DependencyFailure(String, TInner),
#[error(transparent)]
ApplicationError(TInner),
#[error("Authentication thread join failed")]
JoinError(
#[from]
#[cfg_attr(feature = "backtrace", backtrace)]
tokio::task::JoinError,
),
}
impl<TInner> AuthenticationHandlingError<TInner> {
pub fn map_err<TNew, F>(self, f: F) -> AuthenticationHandlingError<TNew>
where
F: FnOnce(TInner) -> TNew,
{
match self {
Self::DependencyFailure(message, inner) => {
AuthenticationHandlingError::DependencyFailure(message, f(inner))
}
Self::ApplicationError(inner) => AuthenticationHandlingError::ApplicationError(f(inner)),
Self::JoinError(e) => AuthenticationHandlingError::JoinError(e),
}
}
pub fn err_into<TNew>(self) -> AuthenticationHandlingError<TNew>
where
TInner: Into<TNew>,
{
self.map_err(Into::into)
}
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum RemoteAuthenticationError {
#[error("Remote authentication refused")]
Refused,
#[error("Remote authentication link closed locally")]
LinkClosedLocally,
#[error("Remote authentication link closed by the remote")]
LinkClosedRemotely,
#[error("Connection closed by remote")]
IncomingStreamsClosed,
#[error("Remote connection timed out")]
TimedOut,
#[error("Transport error encountered authenticating remote")]
TransportError,
#[error("Remote authentication protocol violation: {0}")]
ProtocolViolation(String),
}
#[derive(thiserror::Error, Debug)]
pub enum AuthenticationError<TInner> {
#[error(transparent)]
Handling(#[from] AuthenticationHandlingError<TInner>),
#[error(transparent)]
Remote(#[from] RemoteAuthenticationError),
}
impl<TInner> AuthenticationError<TInner> {
pub fn to_nested_result<T>(
res: Result<T, Self>,
) -> Result<Result<T, RemoteAuthenticationError>, AuthenticationHandlingError<TInner>> {
match res {
Ok(res) => Ok(Ok(res)),
Err(AuthenticationError::Handling(e)) => Err(e),
Err(AuthenticationError::Remote(e)) => Ok(Err(e)),
}
}
pub fn from_nested_result<T>(
res: Result<Result<T, RemoteAuthenticationError>, AuthenticationHandlingError<TInner>>,
) -> Result<T, Self> {
match res {
Ok(Ok(res)) => Ok(res),
Err(e) => Err(AuthenticationError::Handling(e)),
Ok(Err(e)) => Err(AuthenticationError::Remote(e)),
}
}
pub fn map_err<TNew, F>(self, f: F) -> AuthenticationError<TNew>
where
F: FnOnce(TInner) -> TNew,
{
match self {
Self::Handling(handling) => AuthenticationError::Handling(handling.map_err(f)),
Self::Remote(remote) => AuthenticationError::Remote(remote),
}
}
pub fn err_into<TNew>(self) -> AuthenticationError<TNew>
where
TInner: Into<TNew>,
{
self.map_err(Into::into)
}
}
impl<TInner> From<std::convert::Infallible> for AuthenticationError<TInner> {
fn from(_: std::convert::Infallible) -> Self {
unreachable!()
}
}
pub type AuthenticationAttributes = HashMap<String, Vec<u8>>;
pub type AuthenticationChannel<'a> = dyn TunnelStream + Send + Unpin + 'a;
pub trait AuthenticationHandler: std::fmt::Debug + Send + Sync {
type Error: Send;
fn authenticate<'a>(
&'a self,
channel: &'a mut AuthenticationChannel<'a>,
tunnel_info: TunnelInfo,
shutdown_notifier: &'a CancellationListener,
) -> BoxFuture<'a, Result<(TunnelName, AuthenticationAttributes), AuthenticationError<Self::Error>>>;
}
#[derive(Copy, Clone)]
pub struct MappedAuthenticationHandler<F, TInner> {
inner: TInner,
f: F,
}
impl<F, TInner> Debug for MappedAuthenticationHandler<F, TInner>
where
TInner: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MappedAuthenticationHandler")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
impl<F, TInner, TOutput> AuthenticationHandler for MappedAuthenticationHandler<F, TInner>
where
F: (Fn(<TInner as AuthenticationHandler>::Error) -> TOutput) + Send + Sync,
TInner: AuthenticationHandler,
TOutput: Send,
{
type Error = TOutput;
fn authenticate<'a>(
&'a self,
channel: &'a mut AuthenticationChannel<'a>,
tunnel_info: TunnelInfo,
shutdown_notifier: &'a CancellationListener,
) -> BoxFuture<'a, Result<(TunnelName, AuthenticationAttributes), AuthenticationError<Self::Error>>>
{
self
.inner
.authenticate(channel, tunnel_info, shutdown_notifier)
.map(move |r| r.map_err(|e| e.map_err(|ei| (&self.f)(ei))))
.boxed()
}
}
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct MappedErrIntoAuthenticationHandler<TInner, TOutput> {
inner: TInner,
phantom_output: PhantomData<std::sync::Arc<std::sync::Mutex<TOutput>>>,
}
impl<TInner, TOutput> Debug for MappedErrIntoAuthenticationHandler<TInner, TOutput>
where
TInner: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MappedErrIntoAuthenticationHandler")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
impl<TInner, TOutput> AuthenticationHandler for MappedErrIntoAuthenticationHandler<TInner, TOutput>
where
TInner: AuthenticationHandler,
TOutput: Send,
TInner::Error: Into<TOutput>,
{
type Error = TOutput;
fn authenticate<'a>(
&'a self,
channel: &'a mut AuthenticationChannel<'a>,
tunnel_info: TunnelInfo,
shutdown_notifier: &'a CancellationListener,
) -> BoxFuture<'a, Result<(TunnelName, AuthenticationAttributes), AuthenticationError<Self::Error>>>
{
self
.inner
.authenticate(channel, tunnel_info, shutdown_notifier)
.map(|r| r.map_err(AuthenticationError::err_into))
.boxed()
}
}
pub trait AuthenticationHandlerExt: AuthenticationHandler {
fn map_err<TNew, F>(self, f: F) -> MappedAuthenticationHandler<F, Self>
where
F: (Fn(<Self as AuthenticationHandler>::Error) -> TNew) + Send + Sync,
Self: Sized,
{
MappedAuthenticationHandler { f, inner: self }
}
fn err_into<TNew>(self) -> MappedErrIntoAuthenticationHandler<Self, TNew>
where
Self::Error: Into<TNew>,
Self: Sized,
{
MappedErrIntoAuthenticationHandler {
inner: self,
phantom_output: PhantomData,
}
}
}
impl<T: AuthenticationHandler> AuthenticationHandlerExt for T {}
impl<T: AuthenticationHandler + ?Sized> AuthenticationHandler for Box<T> {
type Error = T::Error;
fn authenticate<'a>(
&'a self,
channel: &'a mut AuthenticationChannel<'a>,
tunnel_info: TunnelInfo,
shutdown_notifier: &'a CancellationListener,
) -> BoxFuture<'a, Result<(TunnelName, AuthenticationAttributes), AuthenticationError<Self::Error>>>
{
self
.as_ref()
.authenticate(channel, tunnel_info, shutdown_notifier)
}
}
impl<T: AuthenticationHandler + ?Sized> AuthenticationHandler for Arc<T> {
type Error = T::Error;
fn authenticate<'a>(
&'a self,
channel: &'a mut AuthenticationChannel<'a>,
tunnel_info: TunnelInfo,
shutdown_notifier: &'a CancellationListener,
) -> BoxFuture<'a, Result<(TunnelName, AuthenticationAttributes), AuthenticationError<Self::Error>>>
{
self
.as_ref()
.authenticate(channel, tunnel_info, shutdown_notifier)
}
}
fn tunnel_error_to_remote_auth_error(e: TunnelError) -> RemoteAuthenticationError {
match e {
TunnelError::ApplicationClosed => RemoteAuthenticationError::LinkClosedLocally,
TunnelError::LocallyClosed => RemoteAuthenticationError::LinkClosedLocally,
TunnelError::ConnectionClosed => RemoteAuthenticationError::LinkClosedRemotely,
TunnelError::TimedOut => RemoteAuthenticationError::TimedOut,
TunnelError::TransportError => RemoteAuthenticationError::TransportError,
}
}
pub fn perform_authentication<'a, T: AuthenticationHandler + ?Sized>(
handler: &'a T,
tunnel: &'a (dyn Tunnel + Send + Sync + 'a),
shutdown_notifier: &'a CancellationListener,
) -> BoxFuture<'a, Result<(TunnelName, AuthenticationAttributes), AuthenticationError<T::Error>>>
where
T::Error: std::fmt::Debug + Send,
{
use tracing::{debug, debug_span, warn, Instrument};
let tunnel_info = TunnelInfo {
side: tunnel.side(),
addr: tunnel.addr(),
};
let tracing_span_authentication =
debug_span!("authentication", side=?tunnel_info.side, addr=?tunnel_info.addr);
let establishment = {
let side = tunnel_info.side;
async move {
let shutdown_notifier = async move { shutdown_notifier.clone().cancelled().await };
let auth_channel: Result<_, AuthenticationError<_>> = match side {
TunnelSide::Listen => {
tunnel.open_link()
.instrument(debug_span!("open_link"))
.try_poll_until_or_else(shutdown_notifier, || Err(TunnelError::TimedOut))
.map_err(tunnel_error_to_remote_auth_error).map_err(AuthenticationError::from)
.await
},
TunnelSide::Connect => {
tunnel
.downlink()
.await
.ok_or(RemoteAuthenticationError::IncomingStreamsClosed)?
.as_stream()
.map_err(tunnel_error_to_remote_auth_error)
.try_next()
.try_poll_until_or_else(shutdown_notifier, || Err(RemoteAuthenticationError::TimedOut))
.instrument(debug_span!("accept_link"))
.and_then(|s: Option<TunnelIncomingType>| futures::future::ready(match s {
Some(TunnelIncomingType::BiStream(stream)) => Ok(stream),
_ => Err(RemoteAuthenticationError::IncomingStreamsClosed),
}))
.map_err(AuthenticationError::from)
.await
}
};
auth_channel.map_err(|e| match e {
AuthenticationError::Handling(local_err) => {
warn!(error=?local_err, "AuthenticationError reported during tunnel establishment phase");
local_err.into()
},
AuthenticationError::Remote(remote_err) => {
debug!(error=?remote_err, "Remote authentication failure reported in tunnel establishment phase");
remote_err.into()
}
})
}.instrument(debug_span!("establishment"))
};
async move {
let establishment: Result<_, AuthenticationError<_>> = establishment.await;
let auth_channel = establishment?;
handler
.authenticate(&mut Box::new(auth_channel), tunnel_info, shutdown_notifier)
.instrument(debug_span!("authenticator"))
.await
}
.instrument(tracing_span_authentication)
.boxed()
}