Skip to main content

modo/auth/apikey/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::body::Body;
6use axum::response::IntoResponse;
7use http::Request;
8use tower::{Layer, Service};
9
10use crate::error::Error;
11
12use super::store::ApiKeyStore;
13
14/// Tower [`Layer`] that verifies API keys on incoming requests.
15///
16/// Reads the raw token from the `Authorization: Bearer <token>` header
17/// (or a custom header), calls [`ApiKeyStore::verify`], and inserts
18/// [`super::ApiKeyMeta`] into request extensions on success.
19///
20/// Errors are returned as [`crate::Error`] -- the app's error handler
21/// decides rendering.
22pub struct ApiKeyLayer {
23    store: ApiKeyStore,
24    header: HeaderSource,
25}
26
27#[derive(Clone)]
28enum HeaderSource {
29    Authorization,
30    Custom(http::HeaderName),
31}
32
33impl Clone for ApiKeyLayer {
34    fn clone(&self) -> Self {
35        Self {
36            store: self.store.clone(),
37            header: self.header.clone(),
38        }
39    }
40}
41
42impl ApiKeyLayer {
43    /// Create a layer that reads from `Authorization: Bearer <token>`.
44    pub fn new(store: ApiKeyStore) -> Self {
45        Self {
46            store,
47            header: HeaderSource::Authorization,
48        }
49    }
50
51    /// Create a layer that reads from a custom header.
52    ///
53    /// # Errors
54    ///
55    /// Returns `Error::bad_request` if the header name is invalid.
56    pub fn from_header(store: ApiKeyStore, header: &str) -> crate::Result<Self> {
57        let name = http::HeaderName::from_bytes(header.as_bytes())
58            .map_err(|_| Error::bad_request(format!("invalid header name: {header}")))?;
59        Ok(Self {
60            store,
61            header: HeaderSource::Custom(name),
62        })
63    }
64}
65
66impl<S> Layer<S> for ApiKeyLayer {
67    type Service = ApiKeyMiddleware<S>;
68
69    fn layer(&self, inner: S) -> Self::Service {
70        ApiKeyMiddleware {
71            inner,
72            store: self.store.clone(),
73            header: self.header.clone(),
74        }
75    }
76}
77
78/// Tower [`Service`] that verifies API keys on every request.
79pub struct ApiKeyMiddleware<S> {
80    inner: S,
81    store: ApiKeyStore,
82    header: HeaderSource,
83}
84
85impl<S: Clone> Clone for ApiKeyMiddleware<S> {
86    fn clone(&self) -> Self {
87        Self {
88            inner: self.inner.clone(),
89            store: self.store.clone(),
90            header: self.header.clone(),
91        }
92    }
93}
94
95impl<S> Service<Request<Body>> for ApiKeyMiddleware<S>
96where
97    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
98    S::Future: Send + 'static,
99    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
100{
101    type Response = http::Response<Body>;
102    type Error = S::Error;
103    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
104
105    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106        self.inner.poll_ready(cx)
107    }
108
109    fn call(&mut self, request: Request<Body>) -> Self::Future {
110        let store = self.store.clone();
111        let header = self.header.clone();
112        let mut inner = self.inner.clone();
113        std::mem::swap(&mut self.inner, &mut inner);
114
115        Box::pin(async move {
116            let (mut parts, body) = request.into_parts();
117
118            // Extract raw token from header
119            let raw_token = match extract_token(&parts, &header) {
120                Ok(token) => token,
121                Err(e) => return Ok(e.into_response()),
122            };
123
124            // Verify
125            let meta = match store.verify(raw_token).await {
126                Ok(m) => m,
127                Err(e) => return Ok(e.into_response()),
128            };
129
130            // Insert into extensions
131            parts.extensions.insert(meta);
132
133            let request = Request::from_parts(parts, body);
134            inner.call(request).await
135        })
136    }
137}
138
139fn extract_token<'a>(
140    parts: &'a http::request::Parts,
141    header: &HeaderSource,
142) -> Result<&'a str, Error> {
143    match header {
144        HeaderSource::Authorization => {
145            let value = parts
146                .headers
147                .get(http::header::AUTHORIZATION)
148                .ok_or_else(|| Error::unauthorized("missing API key"))?
149                .to_str()
150                .map_err(|_| Error::unauthorized("invalid API key"))?;
151            value
152                .strip_prefix("Bearer ")
153                .ok_or_else(|| Error::unauthorized("invalid API key"))
154        }
155        HeaderSource::Custom(name) => {
156            let value = parts
157                .headers
158                .get(name)
159                .ok_or_else(|| Error::unauthorized("missing API key"))?
160                .to_str()
161                .map_err(|_| Error::unauthorized("invalid API key"))?;
162            Ok(value)
163        }
164    }
165}