Skip to main content

docbox_http/middleware/
api_key.rs

1use axum::{
2    extract::Request,
3    http::StatusCode,
4    response::{IntoResponse, Response},
5};
6use std::{
7    pin::Pin,
8    task::{Context, Poll},
9};
10use tower::{Layer, Service};
11
12#[derive(Clone)]
13pub struct ApiKeyLayer {
14    key: String,
15}
16
17impl ApiKeyLayer {
18    pub fn new(key: String) -> Self {
19        Self { key }
20    }
21}
22
23impl<S> Layer<S> for ApiKeyLayer {
24    type Service = ApiKeyMiddleware<S>;
25
26    fn layer(&self, inner: S) -> Self::Service {
27        ApiKeyMiddleware {
28            inner,
29            key: self.key.clone(),
30        }
31    }
32}
33
34#[derive(Clone)]
35pub struct ApiKeyMiddleware<S> {
36    inner: S,
37    key: String,
38}
39
40impl<S> Service<Request> for ApiKeyMiddleware<S>
41where
42    S: Service<Request, Response = Response> + Send + 'static,
43    S::Future: Send + 'static,
44{
45    type Response = S::Response;
46    type Error = S::Error;
47    type Future =
48        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
49
50    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51        self.inner.poll_ready(cx)
52    }
53
54    fn call(&mut self, request: Request) -> Self::Future {
55        let header = match request.headers().get("x-docbox-api-key") {
56            Some(value) => value,
57            None => {
58                return Box::pin(async move {
59                    Ok((StatusCode::UNAUTHORIZED, "Missing x-docbox-api-key").into_response())
60                });
61            }
62        };
63
64        if header.to_str().is_ok_and(|value| value.ne(&self.key)) {
65            return Box::pin(async move {
66                Ok((
67                    StatusCode::UNAUTHORIZED,
68                    "Missing or invalid x-docbox-api-key",
69                )
70                    .into_response())
71            });
72        }
73
74        Box::pin(self.inner.call(request))
75    }
76}