fregate 0.17.1

Framework for services creation
Documentation
use crate::middleware::ProxyError;
use axum::response::IntoResponse;
use core::any::type_name;
use hyper::body::{Bytes, HttpBody};
use hyper::http::uri::PathAndQuery;
use hyper::service::Service;
use hyper::{Request, Response, Uri};
use std::error::Error;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::str::FromStr;

pub(crate) struct Shared<
    TBody,
    TRespBody,
    ShouldProxyCallback,
    OnProxyErrorCallback,
    OnProxyRequestCallback,
    OnProxyResponseCallback,
    TExtension = (),
> {
    pub(crate) destination: Uri,
    pub(crate) should_proxy: ShouldProxyCallback,
    pub(crate) on_proxy_error: OnProxyErrorCallback,
    pub(crate) on_proxy_request: OnProxyRequestCallback,
    pub(crate) on_proxy_response: OnProxyResponseCallback,
    pub(crate) phantom: PhantomData<(TExtension, TBody, TRespBody)>,
}

impl<
        TBody,
        TRespBody,
        ShouldProxyCallback,
        OnProxyErrorCallback,
        OnProxyRequestCallback,
        OnProxyResponseCallback,
        TExtension,
    > Debug
    for Shared<
        TBody,
        TRespBody,
        ShouldProxyCallback,
        OnProxyErrorCallback,
        OnProxyRequestCallback,
        OnProxyResponseCallback,
        TExtension,
    >
{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Shared")
            .field("destination", &self.destination)
            .field(
                "on_proxy_error",
                &format_args!("{}", type_name::<OnProxyErrorCallback>()),
            )
            .field(
                "on_proxy_request",
                &format_args!("{}", type_name::<OnProxyRequestCallback>()),
            )
            .field(
                "on_proxy_response",
                &format_args!("{}", type_name::<OnProxyResponseCallback>()),
            )
            .field("extension", &format_args!("{}", type_name::<TExtension>()))
            .finish()
    }
}

#[allow(clippy::type_complexity)]
impl<
        TBody,
        TRespBody,
        ShouldProxyCallback,
        OnProxyErrorCallback,
        OnProxyRequestCallback,
        OnProxyResponseCallback,
    >
    Shared<
        TBody,
        TRespBody,
        ShouldProxyCallback,
        OnProxyErrorCallback,
        OnProxyRequestCallback,
        OnProxyResponseCallback,
    >
{
    pub(crate) fn new_with_ext<TExtension>(
        destination: impl Into<String>,
        should_proxy: ShouldProxyCallback,
        on_proxy_error: OnProxyErrorCallback,
        on_proxy_request: OnProxyRequestCallback,
        on_proxy_response: OnProxyResponseCallback,
    ) -> Result<
        Shared<
            TBody,
            TRespBody,
            ShouldProxyCallback,
            OnProxyErrorCallback,
            OnProxyRequestCallback,
            OnProxyResponseCallback,
            TExtension,
        >,
        String,
    > {
        let destination = Uri::from_str(&destination.into()).map_err(|err| err.to_string())?;

        let _ = destination
            .scheme()
            .ok_or("destination Uri has no scheme!".to_string())?;
        let _ = destination
            .authority()
            .ok_or("destination Uri has no authority!".to_string())?;

        let shared = Shared::<
            TBody,
            TRespBody,
            ShouldProxyCallback,
            OnProxyErrorCallback,
            OnProxyRequestCallback,
            OnProxyResponseCallback,
            TExtension,
        > {
            destination,
            should_proxy,
            on_proxy_error,
            on_proxy_request,
            on_proxy_response,
            phantom: PhantomData,
        };

        Ok(shared)
    }
}

impl<
        TBody,
        TRespBody,
        ShouldProxyCallback,
        OnProxyErrorCallback,
        OnProxyRequestCallback,
        OnProxyResponseCallback,
        TExtension,
    >
    Shared<
        TBody,
        TRespBody,
        ShouldProxyCallback,
        OnProxyErrorCallback,
        OnProxyRequestCallback,
        OnProxyResponseCallback,
        TExtension,
    >
where
    TExtension: Default + Clone + Send + Sync + 'static,
    ShouldProxyCallback: for<'any> Fn(
            &'any Request<TBody>,
            &'any TExtension,
        ) -> Pin<
            Box<dyn Future<Output = Result<bool, axum::response::Response>> + Send + 'any>,
        > + Send
        + Sync
        + 'static,
    OnProxyErrorCallback:
        Fn(ProxyError, &TExtension) -> axum::response::Response + Send + Sync + 'static,
    OnProxyRequestCallback: Fn(&mut Request<TBody>, &TExtension) + Send + Sync + 'static,
    OnProxyResponseCallback: Fn(&mut Response<TRespBody>, &TExtension) + Send + Sync + 'static,
    TBody: Sync + Send + 'static,
    TRespBody: HttpBody<Data = Bytes> + Sync + Send + 'static,
    TRespBody::Error: Into<Box<(dyn Error + Send + Sync + 'static)>>,
{
    pub(crate) async fn proxy<TClient>(
        &self,
        mut req: Request<TBody>,
        client: TClient,
        extension: TExtension,
        poll_error: Option<Box<(dyn Error + Send + Sync + 'static)>>,
    ) -> axum::response::Response
    where
        TClient: Service<Request<TBody>, Response = Response<TRespBody>>,
        TClient: Clone + Send + Sync + 'static,
        <TClient as Service<Request<TBody>>>::Future: Send + 'static,
        <TClient as Service<Request<TBody>>>::Error:
            Into<Box<(dyn Error + Send + Sync + 'static)>> + Send,
    {
        if let Some(err) = poll_error {
            return (self.on_proxy_error)(ProxyError::SendRequest(err), &extension);
        }

        let build_uri = |req: &Request<TBody>| {
            let p_and_q = req
                .uri()
                .path_and_query()
                .map_or_else(|| req.uri().path(), PathAndQuery::as_str);

            let destination_parts = self.destination.clone().into_parts();

            #[allow(clippy::expect_used)]
            // SAFETY: Checked on [`Shared::new`]
            let authority = destination_parts
                .authority
                .expect("Destination uri must have [Authority]");

            #[allow(clippy::expect_used)]
            // SAFETY: Checked on [`Shared::new`]
            let scheme = destination_parts
                .scheme
                .expect("Destination uri must have [Scheme]");

            Uri::builder()
                .authority(authority)
                .scheme(scheme)
                .path_and_query(p_and_q)
                .build()
                .map_err(ProxyError::UriBuilder)
        };

        match build_uri(&req) {
            Ok(new_uri) => {
                *req.uri_mut() = new_uri;

                (self.on_proxy_request)(&mut req, &extension);
                let result = send_request(client, req).await;

                match result {
                    Ok(mut response) => {
                        (self.on_proxy_response)(&mut response, &extension);
                        response.into_response()
                    }
                    Err(err) => (self.on_proxy_error)(err, &extension),
                }
            }
            Err(err) => (self.on_proxy_error)(err, &extension),
        }
    }
}

pub(crate) fn get_extension<TBody, TExtension>(request: &Request<TBody>) -> TExtension
where
    TExtension: Default + Clone + Send + Sync + 'static,
{
    request
        .extensions()
        .get::<TExtension>()
        .cloned()
        .unwrap_or_default()
}

#[allow(clippy::needless_question_mark)]
async fn send_request<TClient, TBody, TRespBody>(
    mut service: TClient,
    request: Request<TBody>,
) -> Result<Response<TRespBody>, ProxyError>
where
    TClient: Service<Request<TBody>, Response = Response<TRespBody>>,
    TClient: Clone + Send + Sync + 'static,
    <TClient as Service<Request<TBody>>>::Future: Send + 'static,
    <TClient as Service<Request<TBody>>>::Error:
        Into<Box<(dyn Error + Send + Sync + 'static)>> + Send,
{
    Ok(service
        .call(request)
        .await
        .map_err(|err| ProxyError::SendRequest(err.into()))?)
}