apikeys_rs/axum_layer/
mod.rs

1use std::task::{Context, Poll};
2
3use axum::{
4    extract::Request,
5    response::{IntoResponse, Response},
6};
7use futures_util::future::BoxFuture;
8use http::HeaderMap;
9use tower::{Layer, Service};
10
11pub mod errors;
12use tracing::error;
13
14use self::errors::ApiKeyLayerError;
15use crate::{errors::ApiKeyManagerError, traits::ApiKeyManager};
16
17#[derive(Clone)]
18pub struct ApiKeyLayer<T>
19where
20    T: ApiKeyManager + Send + Sync + Clone,
21{
22    manager: T,
23}
24
25impl<S, T> Layer<S> for ApiKeyLayer<T>
26where
27    T: ApiKeyManager + Send + Sync + Clone,
28{
29    type Service = ApiKeyMiddleware<S, T>;
30
31    fn layer(&self, inner: S) -> Self::Service {
32        ApiKeyMiddleware { inner, manager: self.manager.clone() }
33    }
34}
35
36#[derive(Clone)]
37pub struct ApiKeyMiddleware<S, T>
38where
39    T: ApiKeyManager + Send + Sync + Clone,
40{
41    inner: S,
42    manager: T,
43}
44
45impl<S, T> Service<Request> for ApiKeyMiddleware<S, T>
46where
47    S: Service<Request, Response = Response> + Send + 'static,
48    S::Future: Send + 'static,
49    T: ApiKeyManager + Send + Sync + Clone + 'static,
50{
51    type Response = S::Response;
52    type Error = S::Error;
53    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
54    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
55
56    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57        self.inner.poll_ready(cx)
58    }
59
60    fn call(&mut self, request: Request) -> Self::Future {
61        let headers = request.headers().clone();
62        // let origin = extract_header(header::ORIGIN.as_str(), &headers);
63
64        let x_api_key = match extract_header("x-api-key", &headers) {
65            Some(key) => key,
66            None => {
67                return Box::pin(async move {
68                    let response = errors::ApiKeyLayerError::MissingApiKey.into_response();
69                    Ok(response)
70                });
71            }
72        };
73
74        let manager = self.manager.clone();
75        let future = self.inner.call(request);
76        let verification_future = verify_api_key(manager, x_api_key);
77        Box::pin(async move {
78            match verification_future.await {
79                Ok(true) => {
80                    let response: Response = future.await?;
81                    Ok(response)
82                }
83                Ok(false) => {
84                    let response = errors::ApiKeyLayerError::InvalidApiKey.into_response();
85                    Ok(response)
86                }
87                Err(e) => {
88                    let response = e.into_response();
89                    Ok(response)
90                }
91            }
92        })
93    }
94}
95
96impl<T> ApiKeyLayer<T>
97where
98    T: ApiKeyManager + Send + Sync + Clone,
99{
100    pub fn new(manager: T) -> Self
101    where
102        T: ApiKeyManager + Send + Sync + Clone,
103    {
104        Self { manager }
105    }
106}
107
108fn extract_header(key: &str, headers: &HeaderMap) -> Option<String> {
109    match headers.get(key) {
110        Some(key) => match key.to_str() {
111            Ok(key) => Some(key.to_string()),
112            Err(_) => None,
113        },
114        None => None,
115    }
116}
117
118async fn verify_api_key(
119    manager: impl ApiKeyManager + Send + Sync,
120    key: String,
121) -> Result<bool, errors::ApiKeyLayerError> {
122    match manager.use_key(key.as_str()).await {
123        Ok(key) => key,
124        Err(e) => {
125            return Err(e.into());
126        }
127    };
128
129    Ok(true)
130}
131
132impl From<ApiKeyManagerError> for ApiKeyLayerError {
133    fn from(error: ApiKeyManagerError) -> Self {
134        match error {
135            ApiKeyManagerError::LimiterError(e) => ApiKeyLayerError::LimiterError(e),
136            e => {
137                error!("{e:?}");
138                ApiKeyLayerError::InvalidApiKey
139            }
140        }
141    }
142}