Skip to main content

modo/auth/session/jwt/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use tower::{Layer, Service};
10
11use crate::auth::session::Session;
12
13use super::claims::Claims;
14use super::decoder::{JwtDecoder, auth_err, jwt_err};
15use super::error::JwtError;
16use super::service::JwtSessionService;
17use super::source::{BearerSource, TokenSource};
18use crate::auth::session::token::SessionToken;
19
20/// Tower [`Layer`] that installs JWT authentication on a route.
21///
22/// For each request the middleware:
23/// 1. Tries each `TokenSource` in order; returns `401` if none yields a token.
24/// 2. Decodes and validates the token with `JwtDecoder`; returns `401` on failure.
25/// 3. Inserts [`Claims`] into request extensions for handler extraction.
26/// 4. When constructed via [`JwtLayer::from_service`], also performs a stateful
27///    database row lookup: hashes the `jti` claim and reads the session row,
28///    inserting the transport-agnostic [`Session`](crate::auth::session::Session)
29///    into extensions. Returns `401` if the row is missing (logged-out / revoked).
30///
31/// The default token source is [`BearerSource`] (`Authorization: Bearer <token>`).
32#[derive(Clone)]
33pub struct JwtLayer {
34    decoder: JwtDecoder,
35    sources: Arc<[Arc<dyn TokenSource>]>,
36    /// Present only when stateful validation is enabled (constructed via
37    /// [`JwtLayer::from_service`]). When `None` the layer behaves as a
38    /// purely stateless JWT validator.
39    service: Option<JwtSessionService>,
40}
41
42impl JwtLayer {
43    /// Creates a `JwtLayer` with `BearerSource` as the sole token source.
44    ///
45    /// This constructor performs **stateless** JWT validation only (signature +
46    /// claims). No database row lookup is performed. Use [`JwtLayer::from_service`]
47    /// for stateful validation that also inserts [`Session`](crate::auth::session::Session)
48    /// into request extensions.
49    pub fn new(decoder: JwtDecoder) -> Self {
50        Self {
51            decoder,
52            sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
53            service: None,
54        }
55    }
56
57    /// Creates a `JwtLayer` backed by a [`JwtSessionService`].
58    ///
59    /// After JWT signature/claims validation the middleware hashes the `jti`
60    /// claim, looks up the session row in the database, and inserts the
61    /// transport-agnostic [`Session`](crate::auth::session::Session) into
62    /// request extensions. Returns `401` with `auth:session_not_found` when
63    /// the session row is absent (logged-out or revoked).
64    ///
65    /// Use [`JwtSessionService::layer`] as the primary entry-point; this
66    /// constructor is the lower-level building block.
67    pub fn from_service(service: JwtSessionService) -> Self {
68        let decoder = service.decoder().clone();
69        Self {
70            decoder,
71            sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
72            service: Some(service),
73        }
74    }
75
76    /// Replaces the token sources with the provided list.
77    ///
78    /// Sources are tried in order; the first to return `Some` is used.
79    pub fn with_sources(mut self, sources: Vec<Arc<dyn TokenSource>>) -> Self {
80        self.sources = Arc::from(sources);
81        self
82    }
83}
84
85impl<Svc> Layer<Svc> for JwtLayer {
86    type Service = JwtMiddleware<Svc>;
87
88    fn layer(&self, inner: Svc) -> Self::Service {
89        JwtMiddleware {
90            inner,
91            decoder: self.decoder.clone(),
92            sources: self.sources.clone(),
93            service: self.service.clone(),
94        }
95    }
96}
97
98/// Tower [`Service`] produced by [`JwtLayer`]. See that type for behavior details.
99pub struct JwtMiddleware<Svc> {
100    inner: Svc,
101    decoder: JwtDecoder,
102    sources: Arc<[Arc<dyn TokenSource>]>,
103    service: Option<JwtSessionService>,
104}
105
106impl<Svc: Clone> Clone for JwtMiddleware<Svc> {
107    fn clone(&self) -> Self {
108        Self {
109            inner: self.inner.clone(),
110            decoder: self.decoder.clone(),
111            sources: self.sources.clone(),
112            service: self.service.clone(),
113        }
114    }
115}
116
117impl<Svc> Service<Request<Body>> for JwtMiddleware<Svc>
118where
119    Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
120    Svc::Future: Send + 'static,
121    Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
122{
123    type Response = http::Response<Body>;
124    type Error = Svc::Error;
125    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
126
127    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128        self.inner.poll_ready(cx)
129    }
130
131    fn call(&mut self, request: Request<Body>) -> Self::Future {
132        let decoder = self.decoder.clone();
133        let sources = self.sources.clone();
134        let service = self.service.clone();
135        let mut inner = self.inner.clone();
136        std::mem::swap(&mut self.inner, &mut inner);
137
138        Box::pin(async move {
139            let (mut parts, body) = request.into_parts();
140
141            let token = match sources.iter().find_map(|s| s.extract(&parts)) {
142                Some(t) => t,
143                None => return Ok(jwt_err(JwtError::MissingToken).into_response()),
144            };
145
146            let claims: Claims = match decoder.decode(&token) {
147                Ok(c) => c,
148                Err(e) => return Ok(e.into_response()),
149            };
150
151            if let Some(svc) = service {
152                if claims.aud.as_deref() != Some("access") {
153                    return Ok(auth_err("auth:aud_mismatch").into_response());
154                }
155
156                if svc.config().stateful_validation {
157                    let session_token = match claims.jti.as_deref().and_then(SessionToken::from_raw)
158                    {
159                        Some(t) => t,
160                        None => {
161                            return Ok(auth_err("auth:session_not_found").into_response());
162                        }
163                    };
164
165                    let raw = match svc.store().read_by_token_hash(&session_token.hash()).await {
166                        Err(e) => return Ok(e.into_response()),
167                        Ok(None) => {
168                            return Ok(auth_err("auth:session_not_found").into_response());
169                        }
170                        Ok(Some(row)) => row,
171                    };
172
173                    parts.extensions.insert(Session::from(raw));
174                }
175            }
176
177            parts.extensions.insert(claims);
178
179            let request = Request::from_parts(parts, body);
180            inner.call(request).await
181        })
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use http::{Response, StatusCode};
189    use std::convert::Infallible;
190    use tower::ServiceExt;
191
192    use crate::auth::session::jwt::{Claims, JwtEncoder, JwtSessionsConfig};
193
194    fn test_config() -> JwtSessionsConfig {
195        JwtSessionsConfig {
196            signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
197            ..JwtSessionsConfig::default()
198        }
199    }
200
201    fn now_secs() -> u64 {
202        std::time::SystemTime::now()
203            .duration_since(std::time::UNIX_EPOCH)
204            .unwrap()
205            .as_secs()
206    }
207
208    fn make_token(config: &JwtSessionsConfig) -> String {
209        let encoder = JwtEncoder::from_config(config);
210        let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
211        encoder.encode(&claims).unwrap()
212    }
213
214    async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
215        let has_claims = req.extensions().get::<Claims>().is_some();
216        let body = if has_claims { "ok" } else { "no-claims" };
217        Ok(Response::new(Body::from(body)))
218    }
219
220    #[tokio::test]
221    async fn valid_token_passes_through() {
222        let config = test_config();
223        let decoder = JwtDecoder::from_config(&config);
224        let token = make_token(&config);
225        let layer = JwtLayer::new(decoder);
226        let svc = layer.layer(tower::service_fn(echo_handler));
227
228        let req = Request::builder()
229            .header("Authorization", format!("Bearer {token}"))
230            .body(Body::empty())
231            .unwrap();
232        let resp = svc.oneshot(req).await.unwrap();
233        assert_eq!(resp.status(), StatusCode::OK);
234    }
235
236    #[tokio::test]
237    async fn missing_header_returns_401() {
238        let config = test_config();
239        let decoder = JwtDecoder::from_config(&config);
240        let layer = JwtLayer::new(decoder);
241        let svc = layer.layer(tower::service_fn(echo_handler));
242
243        let req = Request::builder().body(Body::empty()).unwrap();
244        let resp = svc.oneshot(req).await.unwrap();
245        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
246    }
247
248    #[tokio::test]
249    async fn expired_token_returns_401() {
250        let config = test_config();
251        let encoder = JwtEncoder::from_config(&config);
252        let decoder = JwtDecoder::from_config(&config);
253        let claims = Claims::new().with_exp(now_secs() - 10);
254        let token = encoder.encode(&claims).unwrap();
255        let layer = JwtLayer::new(decoder);
256        let svc = layer.layer(tower::service_fn(echo_handler));
257
258        let req = Request::builder()
259            .header("Authorization", format!("Bearer {token}"))
260            .body(Body::empty())
261            .unwrap();
262        let resp = svc.oneshot(req).await.unwrap();
263        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
264    }
265
266    #[tokio::test]
267    async fn tampered_token_returns_401() {
268        let config = test_config();
269        let decoder = JwtDecoder::from_config(&config);
270        let token = make_token(&config);
271        // Flip a character in the middle of the signature where all 6 bits are significant.
272        // The last character of a base64url string may have insignificant low bits,
273        // so flipping it can decode to identical bytes (making the test flaky).
274        let dot = token.rfind('.').unwrap();
275        let mid = dot + (token.len() - dot) / 2;
276        let mut bytes = token.into_bytes();
277        bytes[mid] = if bytes[mid] == b'A' { b'Z' } else { b'A' };
278        let token = String::from_utf8(bytes).unwrap();
279        let layer = JwtLayer::new(decoder);
280        let svc = layer.layer(tower::service_fn(echo_handler));
281
282        let req = Request::builder()
283            .header("Authorization", format!("Bearer {token}"))
284            .body(Body::empty())
285            .unwrap();
286        let resp = svc.oneshot(req).await.unwrap();
287        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
288    }
289
290    #[tokio::test]
291    async fn claims_inserted_into_extensions() {
292        let config = test_config();
293        let decoder = JwtDecoder::from_config(&config);
294        let token = make_token(&config);
295        let layer = JwtLayer::new(decoder);
296
297        let inner = tower::service_fn(|req: Request<Body>| async move {
298            let claims = req.extensions().get::<Claims>().unwrap();
299            assert_eq!(claims.subject(), Some("user_1"));
300            Ok::<_, Infallible>(Response::new(Body::empty()))
301        });
302
303        let svc = layer.layer(inner);
304        let req = Request::builder()
305            .header("Authorization", format!("Bearer {token}"))
306            .body(Body::empty())
307            .unwrap();
308        let resp = svc.oneshot(req).await.unwrap();
309        assert_eq!(resp.status(), StatusCode::OK);
310    }
311
312    #[tokio::test]
313    async fn custom_token_source_works() {
314        let config = test_config();
315        let decoder = JwtDecoder::from_config(&config);
316        let token = make_token(&config);
317        let layer = JwtLayer::new(decoder).with_sources(vec![Arc::new(
318            super::super::source::QuerySource("token"),
319        ) as Arc<dyn TokenSource>]);
320        let svc = layer.layer(tower::service_fn(echo_handler));
321
322        let req = Request::builder()
323            .uri(format!("/path?token={token}"))
324            .body(Body::empty())
325            .unwrap();
326        let resp = svc.oneshot(req).await.unwrap();
327        assert_eq!(resp.status(), StatusCode::OK);
328    }
329}