modo/auth/apikey/
middleware.rs1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::body::Body;
6use axum::response::IntoResponse;
7use http::Request;
8use tower::{Layer, Service};
9
10use crate::error::Error;
11
12use super::store::ApiKeyStore;
13
14pub struct ApiKeyLayer {
23 store: ApiKeyStore,
24 header: HeaderSource,
25}
26
27#[derive(Clone)]
28enum HeaderSource {
29 Authorization,
30 Custom(http::HeaderName),
31}
32
33impl Clone for ApiKeyLayer {
34 fn clone(&self) -> Self {
35 Self {
36 store: self.store.clone(),
37 header: self.header.clone(),
38 }
39 }
40}
41
42impl ApiKeyLayer {
43 pub fn new(store: ApiKeyStore) -> Self {
45 Self {
46 store,
47 header: HeaderSource::Authorization,
48 }
49 }
50
51 pub fn from_header(store: ApiKeyStore, header: &str) -> crate::Result<Self> {
57 let name = http::HeaderName::from_bytes(header.as_bytes())
58 .map_err(|_| Error::bad_request(format!("invalid header name: {header}")))?;
59 Ok(Self {
60 store,
61 header: HeaderSource::Custom(name),
62 })
63 }
64}
65
66impl<S> Layer<S> for ApiKeyLayer {
67 type Service = ApiKeyMiddleware<S>;
68
69 fn layer(&self, inner: S) -> Self::Service {
70 ApiKeyMiddleware {
71 inner,
72 store: self.store.clone(),
73 header: self.header.clone(),
74 }
75 }
76}
77
78pub struct ApiKeyMiddleware<S> {
80 inner: S,
81 store: ApiKeyStore,
82 header: HeaderSource,
83}
84
85impl<S: Clone> Clone for ApiKeyMiddleware<S> {
86 fn clone(&self) -> Self {
87 Self {
88 inner: self.inner.clone(),
89 store: self.store.clone(),
90 header: self.header.clone(),
91 }
92 }
93}
94
95impl<S> Service<Request<Body>> for ApiKeyMiddleware<S>
96where
97 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
98 S::Future: Send + 'static,
99 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
100{
101 type Response = http::Response<Body>;
102 type Error = S::Error;
103 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
104
105 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106 self.inner.poll_ready(cx)
107 }
108
109 fn call(&mut self, request: Request<Body>) -> Self::Future {
110 let store = self.store.clone();
111 let header = self.header.clone();
112 let mut inner = self.inner.clone();
113 std::mem::swap(&mut self.inner, &mut inner);
114
115 Box::pin(async move {
116 let (mut parts, body) = request.into_parts();
117
118 let raw_token = match extract_token(&parts, &header) {
120 Ok(token) => token,
121 Err(e) => return Ok(e.into_response()),
122 };
123
124 let meta = match store.verify(raw_token).await {
126 Ok(m) => m,
127 Err(e) => return Ok(e.into_response()),
128 };
129
130 parts.extensions.insert(meta);
132
133 let request = Request::from_parts(parts, body);
134 inner.call(request).await
135 })
136 }
137}
138
139fn extract_token<'a>(
140 parts: &'a http::request::Parts,
141 header: &HeaderSource,
142) -> Result<&'a str, Error> {
143 match header {
144 HeaderSource::Authorization => {
145 let value = parts
146 .headers
147 .get(http::header::AUTHORIZATION)
148 .ok_or_else(|| Error::unauthorized("missing API key"))?
149 .to_str()
150 .map_err(|_| Error::unauthorized("invalid API key"))?;
151 value
152 .strip_prefix("Bearer ")
153 .ok_or_else(|| Error::unauthorized("invalid API key"))
154 }
155 HeaderSource::Custom(name) => {
156 let value = parts
157 .headers
158 .get(name)
159 .ok_or_else(|| Error::unauthorized("missing API key"))?
160 .to_str()
161 .map_err(|_| Error::unauthorized("invalid API key"))?;
162 Ok(value)
163 }
164 }
165}