tower-oauth2-resource-server 0.12.0

Tower middleware that provides JWT authorization against an OpenID Connect (OIDC) Provider
Documentation
use futures_util::{Future, future::BoxFuture};
use http::{Request, Response};
use pin_project::pin_project;
use serde::de::DeserializeOwned;

use std::{
    pin::Pin,
    sync::Arc,
    task::{Context, Poll, ready},
};
use tower::{Layer, Service};

use crate::{error_handler::ErrorHandler, server::OAuth2ResourceServer};

trait Authorize<ReqBody, ResBody> {
    type Future: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>;

    fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future;
}

impl<S, ReqBody, ResBody, Claims> Authorize<ReqBody, ResBody>
    for OAuth2ResourceServerService<S, ResBody, Claims>
where
    Claims: DeserializeOwned + Clone + Send + Sync + 'static,
    ReqBody: Send + 'static,
    ResBody: Send + 'static,
{
    type Future = BoxFuture<'static, Result<Request<ReqBody>, Response<ResBody>>>;

    fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future {
        let auth = self.auth_manager.clone();
        let error_handler = self.error_handler.clone();
        Box::pin(async move {
            match auth.authorize_request(request).await {
                Ok(request) => Ok(request),
                Err(error) => Err(error_handler.map_error(error)),
            }
        })
    }
}

pub struct OAuth2ResourceServerLayer<ResBody, Claims> {
    auth_manager: OAuth2ResourceServer<Claims>,
    error_handler: Arc<dyn ErrorHandler<ResBody>>,
}

impl<ResBody, Claims> Clone for OAuth2ResourceServerLayer<ResBody, Claims>
where
    Claims: Clone,
{
    fn clone(&self) -> Self {
        Self {
            auth_manager: self.auth_manager.clone(),
            error_handler: self.error_handler.clone(),
        }
    }
}

impl<S, ResBody, Claims> Layer<S> for OAuth2ResourceServerLayer<ResBody, Claims>
where
    Claims: Clone + DeserializeOwned + Send + 'static,
{
    type Service = OAuth2ResourceServerService<S, ResBody, Claims>;

    fn layer(&self, inner: S) -> Self::Service {
        OAuth2ResourceServerService::new(
            inner,
            self.auth_manager.clone(),
            self.error_handler.clone(),
        )
    }
}

impl<ResBody, Claims> OAuth2ResourceServerLayer<ResBody, Claims> {
    pub(crate) fn new(
        auth_manager: OAuth2ResourceServer<Claims>,
        error_handler: Arc<dyn ErrorHandler<ResBody>>,
    ) -> Self {
        OAuth2ResourceServerLayer {
            auth_manager,
            error_handler,
        }
    }
}

pub struct OAuth2ResourceServerService<S, ResBody, Claims> {
    inner: S,
    auth_manager: OAuth2ResourceServer<Claims>,
    error_handler: Arc<dyn ErrorHandler<ResBody>>,
}

impl<S, ResBody, Claims> Clone for OAuth2ResourceServerService<S, ResBody, Claims>
where
    S: Clone,
    Claims: Clone,
{
    fn clone(&self) -> Self {
        Self {
            inner: self.inner.clone(),
            auth_manager: self.auth_manager.clone(),
            error_handler: self.error_handler.clone(),
        }
    }
}

impl<S, ResBody, Claims> OAuth2ResourceServerService<S, ResBody, Claims> {
    fn new(
        inner: S,
        auth_manager: OAuth2ResourceServer<Claims>,
        error_handler: Arc<dyn ErrorHandler<ResBody>>,
    ) -> Self {
        Self {
            inner,
            auth_manager,
            error_handler,
        }
    }
}

impl<S, ReqBody, ResBody, Claims> Service<Request<ReqBody>>
    for OAuth2ResourceServerService<S, ResBody, Claims>
where
    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
    ResBody: Default + Send + 'static,
    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
    ReqBody: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = ResponseFuture<S, ReqBody, ResBody, Claims>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
        let inner = self.inner.clone();
        let authorize = self.authorize(request);

        ResponseFuture {
            state: State::Authorize { authorize },
            service: inner,
        }
    }
}

type AuthorizeFuture<S, ReqBody, ResBody, Claims> =
    <OAuth2ResourceServerService<S, ResBody, Claims> as Authorize<ReqBody, ResBody>>::Future;

#[pin_project]
pub struct ResponseFuture<S, ReqBody, ResBody, Claims>
where
    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
    ReqBody: Send + 'static,
    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
    ResBody: Send + 'static,
{
    #[pin]
    state: State<AuthorizeFuture<S, ReqBody, ResBody, Claims>, S::Future>,
    service: S,
}

#[pin_project(project = StateProj)]
enum State<A, SFut> {
    Authorize {
        #[pin]
        authorize: A,
    },
    Authorized {
        #[pin]
        fut: SFut,
    },
}

impl<S, ReqBody, ResBody, Claims> Future for ResponseFuture<S, ReqBody, ResBody, Claims>
where
    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
    ResBody: Default + Send + 'static,
    ReqBody: Send + 'static,
    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
{
    type Output = Result<Response<ResBody>, S::Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut this = self.project();

        loop {
            match this.state.as_mut().project() {
                StateProj::Authorize { authorize } => {
                    let auth = ready!(authorize.poll(cx));
                    match auth {
                        Ok(req) => {
                            let fut = this.service.call(req);
                            this.state.set(State::Authorized { fut })
                        }
                        Err(res) => {
                            return Poll::Ready(Ok(res));
                        }
                    };
                }
                StateProj::Authorized { fut } => {
                    return fut.poll(cx);
                }
            }
        }
    }
}