froodi 1.0.0-beta.18

An ergonomic Rust IoC container
Documentation
use alloc::string::ToString as _;
use core::future::Future;
use telers::{
    errors::{EventErrorKind, ExtractionError},
    event::{telegram::HandlerResponse, EventReturn},
    middlewares::{outer::MiddlewareResponse, InnerMiddleware, Next, OuterMiddleware},
    Extractor, Request, Router,
};

#[cfg(feature = "async")]
use crate::async_impl::Container as AsyncContainer;
use crate::{Container, Context, DefaultScope::Request as RequestScope, Inject, InjectTransient, ResolveErrorKind, Scope};

#[derive(Debug, thiserror::Error)]
pub enum InjectErrorKind {
    #[error("Container not found in extensions")]
    ContainerNotFound,
    #[error(transparent)]
    Resolve(ResolveErrorKind),
}

impl From<InjectErrorKind> for ExtractionError {
    fn from(value: InjectErrorKind) -> Self {
        Self::new(value.to_string())
    }
}

macro_rules! impl_setup {
    (
        $OuterStructName:ident,
        $InnerStructName:ident,
        $ContainerType:ty
    ) => {
        #[derive(Clone)]
        pub struct $OuterStructName<WithScope> {
            container: $ContainerType,
            scope: WithScope,
        }

        #[derive(Clone)]
        pub struct $InnerStructName;

        impl<Client, WithScope> OuterMiddleware<Client> for $OuterStructName<WithScope>
        where
            Client: Send,
            WithScope: Scope + Clone + Send + Sync + 'static,
        {
            fn call(
                &mut self,
                mut request: Request<Client>,
            ) -> impl Future<Output = Result<MiddlewareResponse<Client>, EventErrorKind>> + Send {
                let mut context = Context::new();
                context.insert(request.update.clone());

                let container = self
                    .container
                    .clone()
                    .enter()
                    .with_scope(self.scope.clone())
                    .with_context(context)
                    .build()
                    .unwrap();
                request.extensions.insert(container);

                async move { Ok((request, EventReturn::Finish)) }
            }
        }
    };
}

impl_setup!(ContainerOuterMiddleware, ContainerInnerMiddleware, Container);

impl<Client> InnerMiddleware<Client> for ContainerInnerMiddleware
where
    Client: Send,
{
    fn call(
        &mut self,
        request: Request<Client>,
        next: Next<Client>,
    ) -> impl Future<Output = Result<HandlerResponse<Client>, EventErrorKind>> + Send {
        let container_option = request.extensions.get::<Container>().cloned();
        async move {
            let resp = next(request).await;
            if let Some(container) = container_option {
                container.close();
            }
            resp
        }
    }
}

#[cfg(feature = "async")]
impl_setup!(AsyncContainerOuterMiddleware, AsyncContainerInnerMiddleware, AsyncContainer);

#[cfg(feature = "async")]
impl<Client> InnerMiddleware<Client> for AsyncContainerInnerMiddleware
where
    Client: Send,
{
    fn call(
        &mut self,
        request: Request<Client>,
        next: Next<Client>,
    ) -> impl Future<Output = Result<HandlerResponse<Client>, EventErrorKind>> + Send {
        let container_option = request.extensions.get::<AsyncContainer>().cloned();
        async move {
            let resp = next(request).await;
            if let Some(container) = container_option {
                container.close().await;
            }
            resp
        }
    }
}

impl<Client, Dep, const PREFER_SYNC_OVER_ASYNC: bool> Extractor<Client> for Inject<Dep, PREFER_SYNC_OVER_ASYNC>
where
    Dep: Send + Sync + 'static,
{
    type Error = InjectErrorKind;

    #[cfg(not(feature = "async"))]
    fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
        let res = match request.extensions.get::<Container>() {
            Some(container) => match container.get() {
                Ok(dep) => Ok(Self(dep)),
                Err(err) => Err(Self::Error::Resolve(err)),
            },
            None => Err(Self::Error::ContainerNotFound),
        };
        async move { res }
    }

    #[cfg(feature = "async")]
    fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
        let sync_container = request.extensions.get::<Container>();
        let async_container = request.extensions.get::<AsyncContainer>();
        async move {
            if PREFER_SYNC_OVER_ASYNC {
                return match sync_container {
                    Some(container) => match container.get() {
                        Ok(dep) => Ok(Self(dep)),
                        Err(err) => Err(Self::Error::Resolve(err)),
                    },
                    None => match async_container {
                        Some(container) => match container.get().await {
                            Ok(dep) => Ok(Self(dep)),
                            Err(err) => Err(Self::Error::Resolve(err)),
                        },
                        None => Err(Self::Error::ContainerNotFound),
                    },
                };
            }

            match async_container {
                Some(container) => match container.get().await {
                    Ok(dep) => Ok(Self(dep)),
                    Err(err) => Err(Self::Error::Resolve(err)),
                },
                None => match sync_container {
                    Some(container) => match container.get() {
                        Ok(dep) => Ok(Self(dep)),
                        Err(err) => Err(Self::Error::Resolve(err)),
                    },
                    None => Err(Self::Error::ContainerNotFound),
                },
            }
        }
    }
}

impl<Client, Dep, const PREFER_SYNC_OVER_ASYNC: bool> Extractor<Client> for InjectTransient<Dep, PREFER_SYNC_OVER_ASYNC>
where
    Dep: Send + Sync + 'static,
{
    type Error = InjectErrorKind;

    #[cfg(not(feature = "async"))]
    fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
        let res = match request.extensions.get::<Container>() {
            Some(container) => match container.get_transient() {
                Ok(dep) => Ok(Self(dep)),
                Err(err) => Err(Self::Error::Resolve(err)),
            },
            None => Err(Self::Error::ContainerNotFound),
        };
        async move { res }
    }

    #[cfg(feature = "async")]
    fn extract(request: &Request<Client>) -> impl Future<Output = Result<Self, Self::Error>> + Send {
        let sync_container = request.extensions.get::<Container>();
        let async_container = request.extensions.get::<AsyncContainer>();
        async move {
            if PREFER_SYNC_OVER_ASYNC {
                return match sync_container {
                    Some(container) => match container.get_transient() {
                        Ok(dep) => Ok(Self(dep)),
                        Err(err) => Err(Self::Error::Resolve(err)),
                    },
                    None => match async_container {
                        Some(container) => match container.get_transient().await {
                            Ok(dep) => Ok(Self(dep)),
                            Err(err) => Err(Self::Error::Resolve(err)),
                        },
                        None => Err(Self::Error::ContainerNotFound),
                    },
                };
            }

            match async_container {
                Some(container) => match container.get_transient().await {
                    Ok(dep) => Ok(Self(dep)),
                    Err(err) => Err(Self::Error::Resolve(err)),
                },
                None => match sync_container {
                    Some(container) => match container.get_transient() {
                        Ok(dep) => Ok(Self(dep)),
                        Err(err) => Err(Self::Error::Resolve(err)),
                    },
                    None => Err(Self::Error::ContainerNotFound),
                },
            }
        }
    }
}

#[inline]
#[must_use]
pub fn setup<Client, WithScope>(router: Router<Client>, container: Container, scope: WithScope) -> Router<Client>
where
    WithScope: Scope + Clone + Send + Sync + 'static,
    Client: Send + Sync + 'static,
{
    router
        .on_all(|observer| observer.register_inner_middleware(ContainerInnerMiddleware))
        .on_update(|observer| observer.register_outer_middleware(ContainerOuterMiddleware { container, scope }))
}

#[inline]
#[must_use]
pub fn setup_default<Client>(router: Router<Client>, container: Container) -> Router<Client>
where
    Client: Send + Sync + 'static,
{
    setup(router, container, RequestScope)
}

#[inline]
#[must_use]
#[cfg(feature = "async")]
pub fn setup_async<Client, WithScope>(router: Router<Client>, container: AsyncContainer, scope: WithScope) -> Router<Client>
where
    WithScope: Scope + Clone + Send + Sync + 'static,
    Client: Send + Sync + 'static,
{
    router
        .on_all(|observer| observer.register_inner_middleware(AsyncContainerInnerMiddleware))
        .on_update(|observer| observer.register_outer_middleware(AsyncContainerOuterMiddleware { container, scope }))
}

#[inline]
#[must_use]
#[cfg(feature = "async")]
pub fn setup_async_default<Client>(router: Router<Client>, container: AsyncContainer) -> Router<Client>
where
    Client: Send + Sync + 'static,
{
    setup_async(router, container, RequestScope)
}