use alloc::string::ToString as _;
use core::future::Future;
use telers::{
errors::{EventErrorKind, ExtractionError},
event::{telegram::HandlerResponse, EventReturn},
middlewares::{outer::MiddlewareResponse, InnerMiddleware, Next, OuterMiddleware},
Extractor, Request, Router,
};
#[cfg(feature = "async")]
use crate::async_impl::Container as AsyncContainer;
use crate::{Container, Context, DefaultScope::Request as RequestScope, Inject, InjectTransient, ResolveErrorKind, Scope};
#[derive(Debug, thiserror::Error)]
pub enum InjectErrorKind {
#[error("Container not found in extensions")]
ContainerNotFound,
#[error(transparent)]
Resolve(ResolveErrorKind),
}
impl From<InjectErrorKind> for ExtractionError {
fn from(value: InjectErrorKind) -> Self {
Self::new(value.to_string())
}
}
macro_rules! impl_setup {
(
$OuterStructName:ident,
$InnerStructName:ident,
$ContainerType:ty
) => {
#[derive(Clone)]
pub struct $OuterStructName<WithScope> {
container: $ContainerType,
scope: WithScope,
}
#[derive(Clone)]
pub struct $InnerStructName;
impl<Client, WithScope> OuterMiddleware<Client> for $OuterStructName<WithScope>
where
Client: Send,
WithScope: Scope + Clone + Send + Sync + 'static,
{
fn call(
&mut self,
mut request: Request<Client>,
) -> impl Future<Output = Result<MiddlewareResponse<Client>, EventErrorKind>> + Send {
let mut context = Context::new();
context.insert(request.update.clone());
let container = self
.container
.clone()
.enter()
.with_scope(self.scope.clone())
.with_context(context)
.build()
.unwrap();
request.extensions.insert(container);
async move { Ok((request, EventReturn::Finish)) }
}
}
};
}
impl_setup!(ContainerOuterMiddleware, ContainerInnerMiddleware, Container);
impl<Client> InnerMiddleware<Client> for ContainerInnerMiddleware
where
Client: Send,
{
fn call(
&mut self,
request: Request<Client>,
next: Next<Client>,
) -> impl Future<Output = Result<HandlerResponse<Client>, EventErrorKind>> + Send {
let container_option = request.extensions.get::<Container>().cloned();
async move {
let resp = next(request).await;
if let Some(container) = container_option {
container.close();
}
resp
}
}
}
#[cfg(feature = "async")]
impl_setup!(AsyncContainerOuterMiddleware, AsyncContainerInnerMiddleware, AsyncContainer);
#[cfg(feature = "async")]
impl<Client> InnerMiddleware<Client> for AsyncContainerInnerMiddleware
where
Client: Send,
{
fn call(
&mut self,
request: Request<Client>,
next: Next<Client>,
) -> impl Future<Output = Result<HandlerResponse<Client>, EventErrorKind>> + Send {
let container_option = request.extensions.get::<AsyncContainer>().cloned();
async move {
let resp = next(request).await;
if let Some(container) = container_option {
container.close().await;
}
resp
}
}
}
impl<Client, Dep, const PREFER_SYNC_OVER_ASYNC: bool> Extractor<Client> for Inject<Dep, PREFER_SYNC_OVER_ASYNC>
where
Dep: Send + Sync + 'static,
{
type Error = InjectErrorKind;
#[cfg(not(feature = "async"))]
fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
let res = match request.extensions.get::<Container>() {
Some(container) => match container.get() {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => Err(Self::Error::ContainerNotFound),
};
async move { res }
}
#[cfg(feature = "async")]
fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
let sync_container = request.extensions.get::<Container>();
let async_container = request.extensions.get::<AsyncContainer>();
async move {
if PREFER_SYNC_OVER_ASYNC {
return match sync_container {
Some(container) => match container.get() {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => match async_container {
Some(container) => match container.get().await {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => Err(Self::Error::ContainerNotFound),
},
};
}
match async_container {
Some(container) => match container.get().await {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => match sync_container {
Some(container) => match container.get() {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => Err(Self::Error::ContainerNotFound),
},
}
}
}
}
impl<Client, Dep, const PREFER_SYNC_OVER_ASYNC: bool> Extractor<Client> for InjectTransient<Dep, PREFER_SYNC_OVER_ASYNC>
where
Dep: Send + Sync + 'static,
{
type Error = InjectErrorKind;
#[cfg(not(feature = "async"))]
fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
let res = match request.extensions.get::<Container>() {
Some(container) => match container.get_transient() {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => Err(Self::Error::ContainerNotFound),
};
async move { res }
}
#[cfg(feature = "async")]
fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
let sync_container = request.extensions.get::<Container>();
let async_container = request.extensions.get::<AsyncContainer>();
async move {
if PREFER_SYNC_OVER_ASYNC {
return match sync_container {
Some(container) => match container.get_transient() {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => match async_container {
Some(container) => match container.get_transient().await {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => Err(Self::Error::ContainerNotFound),
},
};
}
match async_container {
Some(container) => match container.get_transient().await {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => match sync_container {
Some(container) => match container.get_transient() {
Ok(dep) => Ok(Self(dep)),
Err(err) => Err(Self::Error::Resolve(err)),
},
None => Err(Self::Error::ContainerNotFound),
},
}
}
}
}
#[inline]
#[must_use]
pub fn setup<Client, WithScope>(router: Router<Client>, container: Container, scope: WithScope) -> Router<Client>
where
WithScope: Scope + Clone + Send + Sync + 'static,
Client: Send + Sync + 'static,
{
router
.on_all(|observer| observer.register_inner_middleware(ContainerInnerMiddleware))
.on_update(|observer| observer.register_outer_middleware(ContainerOuterMiddleware { container, scope }))
}
#[inline]
#[must_use]
pub fn setup_default<Client>(router: Router<Client>, container: Container) -> Router<Client>
where
Client: Send + Sync + 'static,
{
setup(router, container, RequestScope)
}
#[inline]
#[must_use]
#[cfg(feature = "async")]
pub fn setup_async<Client, WithScope>(router: Router<Client>, container: AsyncContainer, scope: WithScope) -> Router<Client>
where
WithScope: Scope + Clone + Send + Sync + 'static,
Client: Send + Sync + 'static,
{
router
.on_all(|observer| observer.register_inner_middleware(AsyncContainerInnerMiddleware))
.on_update(|observer| observer.register_outer_middleware(AsyncContainerOuterMiddleware { container, scope }))
}
#[inline]
#[must_use]
#[cfg(feature = "async")]
pub fn setup_async_default<Client>(router: Router<Client>, container: AsyncContainer) -> Router<Client>
where
Client: Send + Sync + 'static,
{
setup_async(router, container, RequestScope)
}