Skip to main content

modo/auth/jwt/
middleware.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::response::IntoResponse;
9use http::Request;
10use serde::de::DeserializeOwned;
11use tower::{Layer, Service};
12
13use crate::Error;
14
15use super::claims::Claims;
16use super::decoder::JwtDecoder;
17use super::error::JwtError;
18use super::revocation::Revocation;
19use super::source::{BearerSource, TokenSource};
20
21/// Tower [`Layer`] that installs JWT authentication on a route.
22///
23/// For each request the middleware:
24/// 1. Tries each `TokenSource` in order; returns `401` if none yields a token.
25/// 2. Decodes and validates the token with `JwtDecoder`; returns `401` on failure.
26/// 3. If a `Revocation` backend is registered and the token has a `jti`, calls
27///    `is_revoked()`; returns `401` on revocation or backend error (fail-closed).
28/// 4. Inserts `Claims<T>` into request extensions for handler extraction.
29///
30/// The default token source is [`BearerSource`] (`Authorization: Bearer <token>`).
31pub struct JwtLayer<T> {
32    decoder: JwtDecoder,
33    sources: Arc<[Arc<dyn TokenSource>]>,
34    revocation: Option<Arc<dyn Revocation>>,
35    _marker: PhantomData<T>,
36}
37
38impl<T> JwtLayer<T>
39where
40    T: DeserializeOwned + Clone + Send + Sync + 'static,
41{
42    /// Creates a `JwtLayer` with `BearerSource` as the sole token source
43    /// and no revocation backend.
44    pub fn new(decoder: JwtDecoder) -> Self {
45        Self {
46            decoder,
47            sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
48            revocation: None,
49            _marker: PhantomData,
50        }
51    }
52
53    /// Replaces the token sources with the provided list.
54    ///
55    /// Sources are tried in order; the first to return `Some` is used.
56    pub fn with_sources(mut self, sources: Vec<Arc<dyn TokenSource>>) -> Self {
57        self.sources = Arc::from(sources);
58        self
59    }
60
61    /// Attaches a revocation backend. Tokens with a `jti` claim are checked
62    /// against it on every request.
63    pub fn with_revocation(mut self, revocation: Arc<dyn Revocation>) -> Self {
64        self.revocation = Some(revocation);
65        self
66    }
67}
68
69impl<T> Clone for JwtLayer<T> {
70    fn clone(&self) -> Self {
71        Self {
72            decoder: self.decoder.clone(),
73            sources: self.sources.clone(),
74            revocation: self.revocation.clone(),
75            _marker: PhantomData,
76        }
77    }
78}
79
80impl<Svc, T> Layer<Svc> for JwtLayer<T>
81where
82    T: DeserializeOwned + Clone + Send + Sync + 'static,
83{
84    type Service = JwtMiddleware<Svc, T>;
85
86    fn layer(&self, inner: Svc) -> Self::Service {
87        JwtMiddleware {
88            inner,
89            decoder: self.decoder.clone(),
90            sources: self.sources.clone(),
91            revocation: self.revocation.clone(),
92            _marker: PhantomData,
93        }
94    }
95}
96
97/// Tower [`Service`] produced by [`JwtLayer`]. See that type for behavior details.
98pub struct JwtMiddleware<Svc, T> {
99    inner: Svc,
100    decoder: JwtDecoder,
101    sources: Arc<[Arc<dyn TokenSource>]>,
102    revocation: Option<Arc<dyn Revocation>>,
103    _marker: PhantomData<T>,
104}
105
106impl<Svc: Clone, T> Clone for JwtMiddleware<Svc, T> {
107    fn clone(&self) -> Self {
108        Self {
109            inner: self.inner.clone(),
110            decoder: self.decoder.clone(),
111            sources: self.sources.clone(),
112            revocation: self.revocation.clone(),
113            _marker: PhantomData,
114        }
115    }
116}
117
118impl<Svc, T> Service<Request<Body>> for JwtMiddleware<Svc, T>
119where
120    Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
121    Svc::Future: Send + 'static,
122    Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
123    T: DeserializeOwned + Clone + Send + Sync + 'static,
124{
125    type Response = http::Response<Body>;
126    type Error = Svc::Error;
127    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
128
129    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130        self.inner.poll_ready(cx)
131    }
132
133    fn call(&mut self, request: Request<Body>) -> Self::Future {
134        let decoder = self.decoder.clone();
135        let sources = self.sources.clone();
136        let revocation = self.revocation.clone();
137        let mut inner = self.inner.clone();
138        std::mem::swap(&mut self.inner, &mut inner);
139
140        Box::pin(async move {
141            let (mut parts, body) = request.into_parts();
142
143            // Try each token source in order
144            let token = sources.iter().find_map(|s| s.extract(&parts));
145            let token = match token {
146                Some(t) => t,
147                None => {
148                    let err = Error::unauthorized("unauthorized")
149                        .chain(JwtError::MissingToken)
150                        .with_code(JwtError::MissingToken.code());
151                    return Ok(err.into_response());
152                }
153            };
154
155            // Decode and validate (sync)
156            let claims: Claims<T> = match decoder.decode(&token) {
157                Ok(c) => c,
158                Err(e) => return Ok(e.into_response()),
159            };
160
161            // Check revocation (async) if backend registered and jti present
162            if let (Some(rev), Some(jti)) = (&revocation, claims.token_id()) {
163                match rev.is_revoked(jti).await {
164                    Ok(true) => {
165                        let err = Error::unauthorized("unauthorized")
166                            .chain(JwtError::Revoked)
167                            .with_code(JwtError::Revoked.code());
168                        return Ok(err.into_response());
169                    }
170                    Err(e) => {
171                        tracing::warn!(error = %e, jti = jti, "JWT revocation check failed");
172                        let err = Error::unauthorized("unauthorized")
173                            .chain(JwtError::RevocationCheckFailed)
174                            .with_code(JwtError::RevocationCheckFailed.code());
175                        return Ok(err.into_response());
176                    }
177                    Ok(false) => {} // not revoked, proceed
178                }
179            }
180
181            // Insert claims into extensions
182            parts.extensions.insert(claims);
183
184            let request = Request::from_parts(parts, body);
185            inner.call(request).await
186        })
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use http::{Response, StatusCode};
194    use std::convert::Infallible;
195    use tower::ServiceExt;
196
197    use crate::auth::jwt::{Claims, JwtConfig, JwtEncoder};
198
199    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
200    struct TestClaims {
201        role: String,
202    }
203
204    fn test_config() -> JwtConfig {
205        JwtConfig {
206            secret: "test-secret-key-at-least-32-bytes-long!".into(),
207            default_expiry: None,
208            leeway: 0,
209            issuer: None,
210            audience: None,
211        }
212    }
213
214    fn now_secs() -> u64 {
215        std::time::SystemTime::now()
216            .duration_since(std::time::UNIX_EPOCH)
217            .unwrap()
218            .as_secs()
219    }
220
221    fn make_token(config: &JwtConfig) -> String {
222        let encoder = JwtEncoder::from_config(config);
223        let claims = Claims::new(TestClaims {
224            role: "admin".into(),
225        })
226        .with_sub("user_1")
227        .with_exp(now_secs() + 3600);
228        encoder.encode(&claims).unwrap()
229    }
230
231    async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
232        let has_claims = req.extensions().get::<Claims<TestClaims>>().is_some();
233        let body = if has_claims { "ok" } else { "no-claims" };
234        Ok(Response::new(Body::from(body)))
235    }
236
237    #[tokio::test]
238    async fn valid_token_passes_through() {
239        let config = test_config();
240        let decoder = JwtDecoder::from_config(&config);
241        let token = make_token(&config);
242        let layer = JwtLayer::<TestClaims>::new(decoder);
243        let svc = layer.layer(tower::service_fn(echo_handler));
244
245        let req = Request::builder()
246            .header("Authorization", format!("Bearer {token}"))
247            .body(Body::empty())
248            .unwrap();
249        let resp = svc.oneshot(req).await.unwrap();
250        assert_eq!(resp.status(), StatusCode::OK);
251    }
252
253    #[tokio::test]
254    async fn missing_header_returns_401() {
255        let config = test_config();
256        let decoder = JwtDecoder::from_config(&config);
257        let layer = JwtLayer::<TestClaims>::new(decoder);
258        let svc = layer.layer(tower::service_fn(echo_handler));
259
260        let req = Request::builder().body(Body::empty()).unwrap();
261        let resp = svc.oneshot(req).await.unwrap();
262        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
263    }
264
265    #[tokio::test]
266    async fn expired_token_returns_401() {
267        let config = test_config();
268        let encoder = JwtEncoder::from_config(&config);
269        let decoder = JwtDecoder::from_config(&config);
270        let claims = Claims::new(TestClaims {
271            role: "admin".into(),
272        })
273        .with_exp(now_secs() - 10);
274        let token = encoder.encode(&claims).unwrap();
275        let layer = JwtLayer::<TestClaims>::new(decoder);
276        let svc = layer.layer(tower::service_fn(echo_handler));
277
278        let req = Request::builder()
279            .header("Authorization", format!("Bearer {token}"))
280            .body(Body::empty())
281            .unwrap();
282        let resp = svc.oneshot(req).await.unwrap();
283        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
284    }
285
286    #[tokio::test]
287    async fn tampered_token_returns_401() {
288        let config = test_config();
289        let decoder = JwtDecoder::from_config(&config);
290        let token = make_token(&config);
291        // Flip a character in the middle of the signature where all 6 bits are significant.
292        // The last character of a base64url string may have insignificant low bits,
293        // so flipping it can decode to identical bytes (making the test flaky).
294        let dot = token.rfind('.').unwrap();
295        let mid = dot + (token.len() - dot) / 2;
296        let mut bytes = token.into_bytes();
297        bytes[mid] = if bytes[mid] == b'A' { b'Z' } else { b'A' };
298        let token = String::from_utf8(bytes).unwrap();
299        let layer = JwtLayer::<TestClaims>::new(decoder);
300        let svc = layer.layer(tower::service_fn(echo_handler));
301
302        let req = Request::builder()
303            .header("Authorization", format!("Bearer {token}"))
304            .body(Body::empty())
305            .unwrap();
306        let resp = svc.oneshot(req).await.unwrap();
307        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
308    }
309
310    #[tokio::test]
311    async fn claims_inserted_into_extensions() {
312        let config = test_config();
313        let decoder = JwtDecoder::from_config(&config);
314        let token = make_token(&config);
315        let layer = JwtLayer::<TestClaims>::new(decoder);
316
317        let inner = tower::service_fn(|req: Request<Body>| async move {
318            let claims = req.extensions().get::<Claims<TestClaims>>().unwrap();
319            assert_eq!(claims.custom.role, "admin");
320            assert_eq!(claims.subject(), Some("user_1"));
321            Ok::<_, Infallible>(Response::new(Body::empty()))
322        });
323
324        let svc = layer.layer(inner);
325        let req = Request::builder()
326            .header("Authorization", format!("Bearer {token}"))
327            .body(Body::empty())
328            .unwrap();
329        let resp = svc.oneshot(req).await.unwrap();
330        assert_eq!(resp.status(), StatusCode::OK);
331    }
332
333    #[tokio::test]
334    async fn custom_token_source_works() {
335        let config = test_config();
336        let decoder = JwtDecoder::from_config(&config);
337        let token = make_token(&config);
338        let layer = JwtLayer::<TestClaims>::new(decoder)
339            .with_sources(vec![
340                Arc::new(super::super::source::QuerySource("token")) as Arc<dyn TokenSource>
341            ]);
342        let svc = layer.layer(tower::service_fn(echo_handler));
343
344        let req = Request::builder()
345            .uri(format!("/path?token={token}"))
346            .body(Body::empty())
347            .unwrap();
348        let resp = svc.oneshot(req).await.unwrap();
349        assert_eq!(resp.status(), StatusCode::OK);
350    }
351}