ts-webapi 0.4.11

Library for my web API projects
Documentation
//! Token validation middleware.

use core::{
    mem,
    task::{Context, Poll},
};
use std::sync::LazyLock;

use http::{Request, Response, StatusCode, header::AUTHORIZATION};
use http_body::Body;
use parking_lot::RwLock;
use tower_layer::Layer;
use tower_service::Service;
use ts_error::{IntoReport, LogError};
use ts_token::{
    JsonWebToken,
    jwks::{JsonWebKeyCache, JsonWebKeySetProvider},
    jwt::TokenType,
};

use crate::middleware::futures::UndefinedFuture;

/// The shared JSON web key cache.
static CACHE: LazyLock<RwLock<JsonWebKeyCache>> =
    LazyLock::new(|| RwLock::new(JsonWebKeyCache::new()));

/// Token validation layer.
///
/// Additionally, ensures that a consent token is for the route it is being used on.
#[derive(Debug, Clone)]
pub struct TokenAuth<P>
where
    P: JsonWebKeySetProvider,
{
    /// Should this layer require that a token is present.
    pub is_token_required: bool,
    /// The provider for the JSON web key set.
    pub provider: P,
}

impl<P> TokenAuth<P>
where
    P: JsonWebKeySetProvider,
{
    /// Create token auth where the token is required.
    pub fn required(provider: P) -> Self {
        Self {
            is_token_required: true,
            provider,
        }
    }

    /// Create token auth where the token is not required.
    pub fn optional(provider: P) -> Self {
        Self {
            is_token_required: false,
            provider,
        }
    }

    /// Authenticate a request against this auth
    pub async fn authenticate<T>(self, mut request: Request<T>) -> Result<Request<T>, StatusCode> {
        let Some(authorization_header) = request.headers().get(AUTHORIZATION) else {
            if self.is_token_required {
                return Err(StatusCode::UNAUTHORIZED);
            } else {
                return Ok(request);
            }
        };

        let Ok(authorization_header) = authorization_header.to_str() else {
            return Err(StatusCode::UNAUTHORIZED);
        };

        if authorization_header
            .get(0..7)
            .is_none_or(|bearer| !bearer.to_lowercase().eq("bearer "))
        {
            return Err(StatusCode::UNAUTHORIZED);
        }

        let Some(encoded_token) = authorization_header.get(7..) else {
            return Err(StatusCode::UNAUTHORIZED);
        };

        let Some(token) = JsonWebToken::deserialize(encoded_token) else {
            return Err(StatusCode::UNAUTHORIZED);
        };

        if !token.claims.is_valid() {
            return Err(StatusCode::UNAUTHORIZED);
        }

        // Clear stale keys
        {
            let mut cache = CACHE.write();
            cache.remove_stale_keys();
        }

        // Check if the cache contains the key.
        let cache_contains_key = {
            let cache = CACHE.read();
            cache.get(&token.header.kid).is_some()
        };

        // If the cache does not contain the key, fetch the current keys from the endpoint.
        if !cache_contains_key {
            let jwks = self
                .provider
                .fetch()
                .await
                .into_report()
                .log_err()
                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
            let mut cache = CACHE.write();
            cache.insert(jwks);
        }

        let cache = CACHE.read();
        let Some(key) = cache.get(&token.header.kid) else {
            return Err(StatusCode::UNAUTHORIZED);
        };

        if !key.verifies_signature(&token) {
            return Err(StatusCode::UNAUTHORIZED);
        }

        // Ensure a consent token is for this route
        if let TokenType::Consent { act } = &token.claims.typ {
            let expected_action = format!("{} {}", request.method(), request.uri().path());
            if expected_action.ne(act) {
                return Err(StatusCode::FORBIDDEN);
            }
        }

        request.extensions_mut().insert(token);

        Ok(request)
    }
}

impl<S, P> Layer<S> for TokenAuth<P>
where
    P: JsonWebKeySetProvider,
{
    type Service = TokenAuthService<S, P>;

    fn layer(&self, inner: S) -> Self::Service {
        TokenAuthService {
            inner,
            auth: self.clone(),
        }
    }
}

/// Tower service behind the token auth middleware.
#[derive(Debug, Clone)]
pub struct TokenAuthService<S, P>
where
    P: JsonWebKeySetProvider,
{
    /// Inner service.
    inner: S,
    /// The logic layer.
    auth: TokenAuth<P>,
}

impl<S, P> TokenAuthService<S, P>
where
    P: JsonWebKeySetProvider,
{
    /// Create a new service.
    pub fn new(inner: S, auth: TokenAuth<P>) -> Self {
        Self { inner, auth }
    }
}

impl<Svc, Prov, ReqBody, ResBody> Service<Request<ReqBody>> for TokenAuthService<Svc, Prov>
where
    Svc: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send,
    ResBody: Body + Send + Default,
    ReqBody: Send + 'static,
    Prov: JsonWebKeySetProvider,
{
    type Response = Svc::Response;
    type Error = Svc::Error;
    type Future = UndefinedFuture<Svc, ReqBody>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
        let auth_future = self.auth.clone().authenticate(request);
        let mut inner = self.inner.clone();
        // mem::swap due to https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
        mem::swap(&mut self.inner, &mut inner);

        UndefinedFuture::define(Box::pin(auth_future), inner)
    }
}

#[cfg(test)]
mod test {
    use axum::{Extension, Router, routing::get};
    use bytes::Bytes;
    use http::{Request, Response, StatusCode, header::AUTHORIZATION};
    use http_body_util::Full;
    use tower::{ServiceBuilder, ServiceExt};
    use tower_http::BoxError;
    use ts_token::{
        JsonWebKey, JsonWebKeySet, JsonWebToken, jwks::StaticJsonWebKeySet, jwt::TokenType,
    };

    use crate::middleware::test::get_request;
    use crate::middleware::{test::JWK, token::TokenAuth};
    use crate::test::ResponseTestExt;

    async fn test_service(
        token_type: Option<TokenType>,
        token_required: bool,
    ) -> Response<Full<Bytes>> {
        let jwk: JsonWebKey = serde_json::from_str(JWK).unwrap();
        let provider = StaticJsonWebKeySet::new(JsonWebKeySet { keys: vec![jwk] });
        let auth = if token_required {
            TokenAuth::required(provider)
        } else {
            TokenAuth::optional(provider)
        };

        ServiceBuilder::new()
            .layer(auth)
            .service_fn(async |req: Request<Full<Bytes>>| {
                Ok::<_, BoxError>(Response::new(req.into_body()))
            })
            .ready()
            .await
            .expect("service should be ok")
            .oneshot(get_request(token_type))
            .await
            .unwrap()
    }

    #[tokio::test]
    async fn axum() {
        let jwk: JsonWebKey = serde_json::from_str(JWK).unwrap();
        let provider = StaticJsonWebKeySet::new(JsonWebKeySet { keys: vec![jwk] });
        let auth = TokenAuth::required(provider);

        Router::new()
            .route(
                "/resource/id",
                get(|Extension(token): Extension<JsonWebToken>| async move {
                    assert_eq!("subject", token.claims.sub);
                    StatusCode::OK
                }),
            )
            .layer(auth)
            .oneshot(get_request(Some(TokenType::Common)))
            .await
            .unwrap()
            .expect_status(StatusCode::OK);
    }

    #[tokio::test]
    async fn consent_token() {
        test_service(
            Some(TokenType::Consent {
                act: "DELETE /resource/id".to_string(),
            }),
            true,
        )
        .await
        .expect_status(StatusCode::FORBIDDEN);

        test_service(
            Some(TokenType::Consent {
                act: "GET /resource/id".to_string(),
            }),
            true,
        )
        .await
        .expect_status(StatusCode::OK);

        test_service(
            Some(TokenType::Consent {
                act: "GET /resource/id2".to_string(),
            }),
            true,
        )
        .await
        .expect_status(StatusCode::FORBIDDEN);
    }

    #[tokio::test]
    async fn requirement() {
        test_service(None, true)
            .await
            .expect_status(StatusCode::UNAUTHORIZED);

        test_service(None, false)
            .await
            .expect_status(StatusCode::OK);

        test_service(Some(TokenType::Common), true)
            .await
            .expect_status(StatusCode::OK);

        test_service(Some(TokenType::Common), false)
            .await
            .expect_status(StatusCode::OK);
    }

    #[tokio::test]
    async fn token_validity() {
        test_service(Some(TokenType::Provisioning), true)
            .await
            .expect_status(StatusCode::OK);

        test_service(Some(TokenType::Common), true)
            .await
            .expect_status(StatusCode::OK);

        {
            const INVALID_JWT: &str = r#"bearer eyJhbGciOiJFZDI1NTE5IiwidHlwIjoiSldUIiwia2lkIjoiVU1JaTBoZGxCQmNJRzhvQ09tQmlfMGJ2UWZsaXZneHA5REtlMkw2UGpiRSJ9.eyJ0aWQiOiJ0b2tlbi1pZCIsImV4cCI6MiwiaWF0IjoxLCJzdWIiOiJzdWJqZWN0LWlkIiwidHlwIjoiY29tbW9uIn0.f7PHRouKc9DYxbRNZdUdrdmM6gC-HdmlorxZHPv5s21oqmbJMsOXXFpnh_52fXPbgY-rNPCvwHFyVKsovk51CA"#;

            let jwk: JsonWebKey = serde_json::from_str(JWK).unwrap();
            let provider = StaticJsonWebKeySet::new(JsonWebKeySet { keys: vec![jwk] });
            let auth = TokenAuth::required(provider);

            let request = Request::builder()
                .uri("/resource/id")
                .header(AUTHORIZATION, INVALID_JWT)
                .body(Full::<Bytes>::default())
                .expect("request should be valid");

            ServiceBuilder::new()
                .layer(auth)
                .service_fn(async |req: Request<Full<Bytes>>| {
                    Ok::<_, BoxError>(Response::new(req.into_body()))
                })
                .ready()
                .await
                .expect("service should be ok")
                .oneshot(request)
                .await
                .unwrap()
                .expect_status(StatusCode::UNAUTHORIZED);
        }
    }
}