ts-webapi 0.4.11

Library for my web API projects
Documentation
//! Authorization middleware
//!

use core::{
    mem,
    task::{Context, Poll},
};

use http::{Request, Response};
use http_body::Body;
use tower_layer::Layer;
use tower_service::Service;

use crate::middleware::futures::{UndefinedFuture, undefined::DefiningFuture};

/// Authorization middleware layer.
#[derive(Debug)]
pub struct Authorization<B, Fut>
where
    Fut: DefiningFuture<B>,
{
    /// The function to authorize the request.
    authorize: fn(Request<B>) -> Fut,
}

impl<B, Fut> Clone for Authorization<B, Fut>
where
    Fut: DefiningFuture<B>,
{
    fn clone(&self) -> Self {
        Self {
            authorize: self.authorize,
        }
    }
}

impl<B, Fut> Authorization<B, Fut>
where
    Fut: DefiningFuture<B>,
{
    /// Create a new authorization layer.
    pub fn new(authorize: fn(Request<B>) -> Fut) -> Self {
        Self { authorize }
    }
}

impl<Svc, B, Fut> Layer<Svc> for Authorization<B, Fut>
where
    Svc: Clone,
    Fut: DefiningFuture<B>,
{
    type Service = AuthorizationService<Svc, B, Fut>;

    fn layer(&self, inner: Svc) -> Self::Service {
        AuthorizationService {
            inner,
            auth: self.clone(),
        }
    }
}

/// Tower service for the middleware.
#[derive(Debug)]
pub struct AuthorizationService<Svc, B, Fut>
where
    Svc: Clone,
    Fut: DefiningFuture<B>,
{
    /// Inner service.
    inner: Svc,
    /// The logic layer.
    auth: Authorization<B, Fut>,
}

impl<Svc, B, Fut> Clone for AuthorizationService<Svc, B, Fut>
where
    Svc: Clone,
    Fut: DefiningFuture<B>,
{
    fn clone(&self) -> Self {
        Self {
            inner: self.inner.clone(),
            auth: self.auth.clone(),
        }
    }
}

impl<Svc, B, Fut> AuthorizationService<Svc, B, Fut>
where
    Svc: Clone,
    Fut: DefiningFuture<B>,
{
    /// Create a new service.
    pub fn new(inner: Svc, auth: Authorization<B, Fut>) -> Self {
        Self { inner, auth }
    }
}

impl<Svc, ReqBody, ResBody, Fut> Service<Request<ReqBody>>
    for AuthorizationService<Svc, ReqBody, Fut>
where
    Svc: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
    ResBody: Body + Send + Default,
    ReqBody: Send + 'static,
    Fut: DefiningFuture<ReqBody>,
{
    type Response = Svc::Response;
    type Error = Svc::Error;
    type Future = UndefinedFuture<Svc, ReqBody>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
        let auth_future = (self.auth.clone().authorize)(request);
        let mut inner = self.inner.clone();
        // mem::swap due to https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
        mem::swap(&mut self.inner, &mut inner);

        UndefinedFuture::define(Box::pin(auth_future), inner)
    }
}

#[cfg(test)]
mod test {
    use axum::{Router, routing::get};
    use http::{Request, StatusCode};
    use tower::ServiceExt;
    use ts_token::jwt::TokenType;

    use crate::{
        middleware::{authorization::Authorization, test::get_request},
        test::ResponseTestExt,
    };

    #[tokio::test]
    async fn axum() {
        async fn authorize<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
            Ok(request)
        }

        let auth = Authorization::new(authorize);

        let request = get_request(Some(TokenType::Common));

        Router::new()
            .route("/resource/id", get(|| async move { StatusCode::OK }))
            .layer(auth)
            .oneshot(request)
            .await
            .unwrap()
            .expect_status(StatusCode::OK);
    }
}