use std::convert::Infallible;
use std::sync::Arc;
use crate::adapter::Adapter;
use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts};
use crate::socket::{DisconnectReason, Socket};
use socketioxide_core::Value;
#[cfg(feature = "extensions")]
#[cfg_attr(docsrs, doc(cfg(feature = "extensions")))]
pub use extensions_extract::*;
pub struct ExtensionNotFound<T>(std::marker::PhantomData<T>);
impl<T> std::fmt::Display for ExtensionNotFound<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Extension of type {} not found, maybe you forgot to insert it in the extensions map?",
std::any::type_name::<T>()
)
}
}
impl<T> std::fmt::Debug for ExtensionNotFound<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ExtensionNotFound {}", std::any::type_name::<T>())
}
}
impl<T> std::error::Error for ExtensionNotFound<T> {}
fn extract_http_extension<T: Clone + Send + Sync + 'static>(
s: &Arc<Socket<impl Adapter>>,
) -> Result<T, ExtensionNotFound<T>> {
s.req_parts()
.extensions
.get::<T>()
.cloned()
.ok_or(ExtensionNotFound(std::marker::PhantomData))
}
pub struct HttpExtension<T>(pub T);
pub struct MaybeHttpExtension<T>(pub Option<T>);
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for HttpExtension<T> {
type Error = ExtensionNotFound<T>;
fn from_connect_parts(
s: &Arc<Socket<A>>,
_: &Option<Value>,
) -> Result<Self, ExtensionNotFound<T>> {
extract_http_extension(s).map(HttpExtension)
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for MaybeHttpExtension<T> {
type Error = Infallible;
fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<Value>) -> Result<Self, Infallible> {
Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for HttpExtension<T> {
type Error = ExtensionNotFound<T>;
fn from_disconnect_parts(
s: &Arc<Socket<A>>,
_: DisconnectReason,
) -> Result<Self, ExtensionNotFound<T>> {
extract_http_extension(s).map(HttpExtension)
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A>
for MaybeHttpExtension<T>
{
type Error = Infallible;
fn from_disconnect_parts(s: &Arc<Socket<A>>, _: DisconnectReason) -> Result<Self, Infallible> {
Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for HttpExtension<T> {
type Error = ExtensionNotFound<T>;
fn from_message_parts(
s: &Arc<Socket<A>>,
_: &mut Value,
_: &Option<i64>,
) -> Result<Self, ExtensionNotFound<T>> {
extract_http_extension(s).map(HttpExtension)
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for MaybeHttpExtension<T> {
type Error = Infallible;
fn from_message_parts(
s: &Arc<Socket<A>>,
_: &mut Value,
_: &Option<i64>,
) -> Result<Self, Infallible> {
Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
}
}
super::__impl_deref!(HttpExtension);
super::__impl_deref!(MaybeHttpExtension<T>: Option<T>);
#[cfg(feature = "extensions")]
mod extensions_extract {
use super::*;
fn extract_extension<T: Clone + Send + Sync + 'static>(
s: &Arc<Socket<impl Adapter>>,
) -> Result<T, ExtensionNotFound<T>> {
s.extensions
.get::<T>()
.ok_or(ExtensionNotFound(std::marker::PhantomData))
}
pub struct Extension<T>(pub T);
pub struct MaybeExtension<T>(pub Option<T>);
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for Extension<T> {
type Error = ExtensionNotFound<T>;
fn from_connect_parts(
s: &Arc<Socket<A>>,
_: &Option<Value>,
) -> Result<Self, ExtensionNotFound<T>> {
extract_extension(s).map(Extension)
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for MaybeExtension<T> {
type Error = Infallible;
fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<Value>) -> Result<Self, Infallible> {
Ok(MaybeExtension(extract_extension(s).ok()))
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for Extension<T> {
type Error = ExtensionNotFound<T>;
fn from_disconnect_parts(
s: &Arc<Socket<A>>,
_: DisconnectReason,
) -> Result<Self, ExtensionNotFound<T>> {
extract_extension(s).map(Extension)
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for MaybeExtension<T> {
type Error = Infallible;
fn from_disconnect_parts(
s: &Arc<Socket<A>>,
_: DisconnectReason,
) -> Result<Self, Infallible> {
Ok(MaybeExtension(extract_extension(s).ok()))
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for Extension<T> {
type Error = ExtensionNotFound<T>;
fn from_message_parts(
s: &Arc<Socket<A>>,
_: &mut Value,
_: &Option<i64>,
) -> Result<Self, ExtensionNotFound<T>> {
extract_extension(s).map(Extension)
}
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for MaybeExtension<T> {
type Error = Infallible;
fn from_message_parts(
s: &Arc<Socket<A>>,
_: &mut Value,
_: &Option<i64>,
) -> Result<Self, Infallible> {
Ok(MaybeExtension(extract_extension(s).ok()))
}
}
super::super::__impl_deref!(Extension);
super::super::__impl_deref!(MaybeExtension<T>: Option<T>);
}