use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::{adapter::Adapter, socket::Socket};
use socketioxide_core::Value;
use super::MakeErasedHandler;
pub(crate) type BoxedConnectHandler<A> = Box<dyn ErasedConnectHandler<A>>;
type MiddlewareRes = Result<(), Box<dyn std::fmt::Display + Send>>;
type MiddlewareResFut<'a> = Pin<Box<dyn Future<Output = MiddlewareRes> + Send + 'a>>;
pub(crate) trait ErasedConnectHandler<A: Adapter>: Send + Sync + 'static {
fn call(&self, s: Arc<Socket<A>>, auth: Option<Value>);
fn call_middleware<'a>(
&'a self,
s: Arc<Socket<A>>,
auth: &'a Option<Value>,
) -> MiddlewareResFut<'a>;
fn boxed_clone(&self) -> BoxedConnectHandler<A>;
}
#[diagnostic::on_unimplemented(
note = "Function argument is not a valid socketio extractor.
See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details",
label = "Invalid extractor"
)]
pub trait FromConnectParts<A: Adapter>: Sized {
type Error: std::error::Error + Send + 'static;
fn from_connect_parts(s: &Arc<Socket<A>>, auth: &Option<Value>) -> Result<Self, Self::Error>;
}
#[diagnostic::on_unimplemented(
note = "This function is not a ConnectMiddleware. Check that:
* It is a clonable async `FnOnce` that returns `Result<(), E> where E: Display`.
* All its arguments are valid connect extractors.
* If you use a custom adapter, it must be generic over the adapter type.
See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details.\n",
label = "Invalid ConnectMiddleware"
)]
pub trait ConnectMiddleware<A: Adapter, T>: Sized + Clone + Send + Sync + 'static {
fn call<'a>(
&'a self,
s: Arc<Socket<A>>,
auth: &'a Option<Value>,
) -> impl Future<Output = MiddlewareRes> + Send;
#[doc(hidden)]
fn phantom(&self) -> std::marker::PhantomData<(A, T)> {
std::marker::PhantomData
}
}
#[diagnostic::on_unimplemented(
note = "This function is not a ConnectHandler. Check that:
* It is a clonable async `FnOnce` that returns nothing.
* All its arguments are valid connect extractors.
* If you use a custom adapter, it must be generic over the adapter type.
See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details.\n",
label = "Invalid ConnectHandler"
)]
pub trait ConnectHandler<A: Adapter, T>: Sized + Clone + Send + Sync + 'static {
fn call(&self, s: Arc<Socket<A>>, auth: Option<Value>);
fn call_middleware<'a>(
&'a self,
_: Arc<Socket<A>>,
_: &'a Option<Value>,
) -> MiddlewareResFut<'a> {
Box::pin(async move { Ok(()) })
}
fn with<M, T1>(self, middleware: M) -> impl ConnectHandler<A, T>
where
M: ConnectMiddleware<A, T1> + Send + Sync + 'static,
T: Send + Sync + 'static,
T1: Send + Sync + 'static,
{
LayeredConnectHandler {
handler: self,
middleware,
phantom: std::marker::PhantomData,
}
}
#[doc(hidden)]
fn phantom(&self) -> std::marker::PhantomData<T> {
std::marker::PhantomData
}
}
struct LayeredConnectHandler<A, H, M, T, T1> {
handler: H,
middleware: M,
phantom: std::marker::PhantomData<(A, T, T1)>,
}
struct ConnectMiddlewareLayer<M, N, T, T1> {
middleware: M,
next: N,
phantom: std::marker::PhantomData<(T, T1)>,
}
impl<A: Adapter, T, H> MakeErasedHandler<H, A, T>
where
H: ConnectHandler<A, T> + Send + Sync + 'static,
T: Send + Sync + 'static,
{
pub fn new_ns_boxed(inner: H) -> Box<dyn ErasedConnectHandler<A>> {
Box::new(MakeErasedHandler::new(inner))
}
}
impl<A: Adapter, T, H> ErasedConnectHandler<A> for MakeErasedHandler<H, A, T>
where
H: ConnectHandler<A, T> + Send + Sync + 'static,
T: Send + Sync + 'static,
{
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self, s), fields(id = ?s.id)))]
fn call(&self, s: Arc<Socket<A>>, auth: Option<Value>) {
self.handler.call(s, auth);
}
fn call_middleware<'a>(
&'a self,
s: Arc<Socket<A>>,
auth: &'a Option<Value>,
) -> MiddlewareResFut<'a> {
self.handler.call_middleware(s, auth)
}
fn boxed_clone(&self) -> BoxedConnectHandler<A> {
Box::new(self.clone())
}
}
#[diagnostic::do_not_recommend]
impl<A, H, M, T, T1> ConnectHandler<A, T> for LayeredConnectHandler<A, H, M, T, T1>
where
A: Adapter,
H: ConnectHandler<A, T> + Send + Sync + 'static,
M: ConnectMiddleware<A, T1> + Send + Sync + 'static,
T: Send + Sync + 'static,
T1: Send + Sync + 'static,
{
fn call(&self, s: Arc<Socket<A>>, auth: Option<Value>) {
self.handler.call(s, auth);
}
fn call_middleware<'a>(
&'a self,
s: Arc<Socket<A>>,
auth: &'a Option<Value>,
) -> MiddlewareResFut<'a> {
Box::pin(async move { self.middleware.call(s, auth).await })
}
fn with<M2, T2>(self, next: M2) -> impl ConnectHandler<A, T>
where
M2: ConnectMiddleware<A, T2> + Send + Sync + 'static,
T2: Send + Sync + 'static,
{
LayeredConnectHandler {
handler: self.handler,
middleware: ConnectMiddlewareLayer {
middleware: next,
next: self.middleware,
phantom: std::marker::PhantomData,
},
phantom: std::marker::PhantomData,
}
}
}
#[diagnostic::do_not_recommend]
impl<A, H, N, T, T1> ConnectMiddleware<A, T1> for LayeredConnectHandler<A, H, N, T, T1>
where
A: Adapter,
H: ConnectHandler<A, T> + Send + Sync + 'static,
N: ConnectMiddleware<A, T1> + Send + Sync + 'static,
T: Send + Sync + 'static,
T1: Send + Sync + 'static,
{
async fn call<'a>(&'a self, s: Arc<Socket<A>>, auth: &'a Option<Value>) -> MiddlewareRes {
self.middleware.call(s, auth).await
}
}
impl<A, H, N, T, T1> Clone for LayeredConnectHandler<A, H, N, T, T1>
where
H: Clone,
N: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
middleware: self.middleware.clone(),
phantom: self.phantom,
}
}
}
impl<M, N, T, T1> Clone for ConnectMiddlewareLayer<M, N, T, T1>
where
M: Clone,
N: Clone,
{
fn clone(&self) -> Self {
Self {
middleware: self.middleware.clone(),
next: self.next.clone(),
phantom: self.phantom,
}
}
}
#[diagnostic::do_not_recommend]
impl<A, M, N, T, T1> ConnectMiddleware<A, T> for ConnectMiddlewareLayer<M, N, T, T1>
where
A: Adapter,
M: ConnectMiddleware<A, T> + Send + Sync + 'static,
N: ConnectMiddleware<A, T1> + Send + Sync + 'static,
T: Send + Sync + 'static,
T1: Send + Sync + 'static,
{
async fn call<'a>(&'a self, s: Arc<Socket<A>>, auth: &'a Option<Value>) -> MiddlewareRes {
self.middleware.call(s.clone(), auth).await?;
self.next.call(s, auth).await
}
}
macro_rules! impl_handler_async {
(
[$($ty:ident),*]
) => {
#[allow(non_snake_case, unused)]
#[diagnostic::do_not_recommend]
impl<A, F, Fut, $($ty,)*> ConnectHandler<A, ($($ty,)*)> for F
where
F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = ()> + Send + 'static,
A: Adapter,
$( $ty: FromConnectParts<A> + Send, )*
{
fn call(&self, s: Arc<Socket<A>>, auth: Option<Value>) {
$(
let $ty = match $ty::from_connect_parts(&s, &auth) {
Ok(v) => v,
Err(_e) => {
#[cfg(feature = "tracing")]
tracing::error!("Error while extracting data: {}", _e);
return;
},
};
)*
let fut = (self.clone())($($ty,)*);
tokio::spawn(fut);
}
}
};
}
macro_rules! impl_middleware_async {
(
[$($ty:ident),*]
) => {
#[allow(non_snake_case, unused)]
#[diagnostic::do_not_recommend]
impl<A, F, Fut, E, $($ty,)*> ConnectMiddleware<A, ($($ty,)*)> for F
where
F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
A: Adapter,
E: std::fmt::Display + Send + 'static,
$( $ty: FromConnectParts<A> + Send, )*
{
async fn call<'a>(
&'a self,
s: Arc<Socket<A>>,
auth: &'a Option<Value>,
) -> MiddlewareRes {
$(
let $ty = match $ty::from_connect_parts(&s, auth) {
Ok(v) => v,
Err(e) => {
#[cfg(feature = "tracing")]
tracing::error!("Error while extracting data: {}", e);
return Err(Box::new(e) as _);
},
};
)*
let res = (self.clone())($($ty,)*).await;
if let Err(e) = res {
#[cfg(feature = "tracing")]
tracing::trace!("middleware returned error: {}", e);
Err(Box::new(e) as _)
} else {
Ok(())
}
}
}
};
}
#[rustfmt::skip]
macro_rules! all_the_tuples {
($name:ident) => {
$name!([]);
$name!([T1]);
$name!([T1, T2]);
$name!([T1, T2, T3]);
$name!([T1, T2, T3, T4]);
$name!([T1, T2, T3, T4, T5]);
$name!([T1, T2, T3, T4, T5, T6]);
$name!([T1, T2, T3, T4, T5, T6, T7]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]);
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16]);
};
}
all_the_tuples!(impl_handler_async);
all_the_tuples!(impl_middleware_async);