apikeys_rs/axum_layer/
mod.rs1use std::task::{Context, Poll};
2
3use axum::{
4 extract::Request,
5 response::{IntoResponse, Response},
6};
7use futures_util::future::BoxFuture;
8use http::HeaderMap;
9use tower::{Layer, Service};
10
11pub mod errors;
12use tracing::error;
13
14use self::errors::ApiKeyLayerError;
15use crate::{errors::ApiKeyManagerError, traits::ApiKeyManager};
16
17#[derive(Clone)]
18pub struct ApiKeyLayer<T>
19where
20 T: ApiKeyManager + Send + Sync + Clone,
21{
22 manager: T,
23}
24
25impl<S, T> Layer<S> for ApiKeyLayer<T>
26where
27 T: ApiKeyManager + Send + Sync + Clone,
28{
29 type Service = ApiKeyMiddleware<S, T>;
30
31 fn layer(&self, inner: S) -> Self::Service {
32 ApiKeyMiddleware { inner, manager: self.manager.clone() }
33 }
34}
35
36#[derive(Clone)]
37pub struct ApiKeyMiddleware<S, T>
38where
39 T: ApiKeyManager + Send + Sync + Clone,
40{
41 inner: S,
42 manager: T,
43}
44
45impl<S, T> Service<Request> for ApiKeyMiddleware<S, T>
46where
47 S: Service<Request, Response = Response> + Send + 'static,
48 S::Future: Send + 'static,
49 T: ApiKeyManager + Send + Sync + Clone + 'static,
50{
51 type Response = S::Response;
52 type Error = S::Error;
53 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
55
56 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57 self.inner.poll_ready(cx)
58 }
59
60 fn call(&mut self, request: Request) -> Self::Future {
61 let headers = request.headers().clone();
62 let x_api_key = match extract_header("x-api-key", &headers) {
65 Some(key) => key,
66 None => {
67 return Box::pin(async move {
68 let response = errors::ApiKeyLayerError::MissingApiKey.into_response();
69 Ok(response)
70 });
71 }
72 };
73
74 let manager = self.manager.clone();
75 let future = self.inner.call(request);
76 let verification_future = verify_api_key(manager, x_api_key);
77 Box::pin(async move {
78 match verification_future.await {
79 Ok(true) => {
80 let response: Response = future.await?;
81 Ok(response)
82 }
83 Ok(false) => {
84 let response = errors::ApiKeyLayerError::InvalidApiKey.into_response();
85 Ok(response)
86 }
87 Err(e) => {
88 let response = e.into_response();
89 Ok(response)
90 }
91 }
92 })
93 }
94}
95
96impl<T> ApiKeyLayer<T>
97where
98 T: ApiKeyManager + Send + Sync + Clone,
99{
100 pub fn new(manager: T) -> Self
101 where
102 T: ApiKeyManager + Send + Sync + Clone,
103 {
104 Self { manager }
105 }
106}
107
108fn extract_header(key: &str, headers: &HeaderMap) -> Option<String> {
109 match headers.get(key) {
110 Some(key) => match key.to_str() {
111 Ok(key) => Some(key.to_string()),
112 Err(_) => None,
113 },
114 None => None,
115 }
116}
117
118async fn verify_api_key(
119 manager: impl ApiKeyManager + Send + Sync,
120 key: String,
121) -> Result<bool, errors::ApiKeyLayerError> {
122 match manager.use_key(key.as_str()).await {
123 Ok(key) => key,
124 Err(e) => {
125 return Err(e.into());
126 }
127 };
128
129 Ok(true)
130}
131
132impl From<ApiKeyManagerError> for ApiKeyLayerError {
133 fn from(error: ApiKeyManagerError) -> Self {
134 match error {
135 ApiKeyManagerError::LimiterError(e) => ApiKeyLayerError::LimiterError(e),
136 e => {
137 error!("{e:?}");
138 ApiKeyLayerError::InvalidApiKey
139 }
140 }
141 }
142}