ts-webapi 0.4.11

Library for my web API projects
Documentation
//! API key middleware, validates an API key is present and in the allow list.

use alloc::sync::Arc;
use core::task::{Context, Poll};

use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
use http_body::Body;
use tower_layer::Layer;
use tower_service::Service;

use crate::middleware::futures::DefinedFuture;

/// API key validation layer.
#[derive(Debug, Clone)]
pub struct ApiKeyAuth {
    /// The header to find the API key in.
    pub header: Arc<HeaderName>,
    /// The list of allowed keys.
    pub allowed_keys: Arc<Vec<String>>,
}
impl ApiKeyAuth {
    /// Create new API key auth.
    pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
        Self {
            allowed_keys: Arc::new(allowed_keys),
            header: Arc::new(header),
        }
    }

    /// Try get the API key header from a request.
    pub fn get_header<'a, T>(&self, request: &'a Request<T>) -> Option<&'a str> {
        request
            .headers()
            .get(self.header.as_ref())
            .map(HeaderValue::to_str)
            .transpose()
            .ok()
            .flatten()
    }

    /// Check if a given API key is allowed
    pub fn is_allowed_key(&self, key: &str) -> bool {
        self.allowed_keys
            .iter()
            .map(String::as_str)
            .any(|allowed| allowed.eq(key))
    }
}

impl<S> Layer<S> for ApiKeyAuth {
    type Service = ApiKeyAuthService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        ApiKeyAuthService {
            inner,
            auth: self.clone(),
        }
    }
}

/// Tower service behind the API key auth middleware.
#[derive(Debug, Clone)]
pub struct ApiKeyAuthService<S> {
    /// Inner service.
    inner: S,
    /// The logic layer.
    auth: ApiKeyAuth,
}

impl<S> ApiKeyAuthService<S> {
    /// Create a new service.
    pub fn new(inner: S, auth: ApiKeyAuth) -> Self {
        Self { inner, auth }
    }
}

impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ApiKeyAuthService<S>
where
    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
    ResBody: Body + Default,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = DefinedFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
        let Some(key) = self.auth.get_header(&request) else {
            return DefinedFuture::return_status(StatusCode::UNAUTHORIZED);
        };

        if self.auth.is_allowed_key(key) {
            DefinedFuture::proceed(self.inner.call(request))
        } else {
            DefinedFuture::return_status(StatusCode::FORBIDDEN)
        }
    }
}

#[cfg(test)]
mod test {
    use axum::{Router, routing::get};
    use bytes::Bytes;
    use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
    use http_body_util::Full;
    use tower::{ServiceBuilder, ServiceExt};
    use tower_http::BoxError;

    use crate::{middleware::api_key::ApiKeyAuth, test::ResponseTestExt};

    async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
        Ok(Response::new(req.into_body()))
    }

    #[tokio::test]
    async fn axum_compat() {
        let api_key_auth = ApiKeyAuth::new(
            HeaderName::from_static("x-api-key"),
            vec!["api-key-1".to_string()],
        );

        let router = Router::new()
            .route("/", get(|| async { StatusCode::OK }))
            .layer(api_key_auth);

        router
            .oneshot(
                Request::builder()
                    .uri("/")
                    .body(axum::body::Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap()
            .expect_status(StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn blocks_no_header() {
        let api_key_auth = ApiKeyAuth::new(
            HeaderName::from_static("x-api-key"),
            vec!["api-key-1".to_string()],
        );

        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);

        let body: Bytes = Bytes::new();
        let request = Request::new(Full::new(body));

        service
            .ready()
            .await
            .unwrap()
            .oneshot(request)
            .await
            .unwrap()
            .expect_status(StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn blocks_not_allowed() {
        let api_key_auth = ApiKeyAuth::new(
            HeaderName::from_static("x-api-key"),
            vec!["api-key-1".to_string()],
        );

        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);

        let body: Bytes = Bytes::new();
        let mut request = Request::new(Full::new(body));
        request
            .headers_mut()
            .insert("x-api-key", HeaderValue::from_static("not-allowed-key"));

        service
            .ready()
            .await
            .unwrap()
            .oneshot(request)
            .await
            .unwrap()
            .expect_status(StatusCode::FORBIDDEN);
    }

    #[tokio::test]
    async fn allows_allowed() {
        let api_key_auth = ApiKeyAuth::new(
            HeaderName::from_static("x-api-key"),
            vec!["api-key-1".to_string()],
        );

        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);

        let body: Bytes = Bytes::new();
        let mut request = Request::new(Full::new(body));
        request
            .headers_mut()
            .insert("x-api-key", HeaderValue::from_static("api-key-1"));

        service
            .ready()
            .await
            .unwrap()
            .oneshot(request)
            .await
            .unwrap()
            .expect_status(StatusCode::OK);
    }
}