Skip to main content

modo/auth/session/jwt/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use tower::{Layer, Service};
10
11use crate::Error;
12use crate::auth::session::Session;
13
14use super::claims::Claims;
15use super::decoder::JwtDecoder;
16use super::error::JwtError;
17use super::service::JwtSessionService;
18use super::source::{BearerSource, TokenSource};
19use crate::auth::session::token::SessionToken;
20
21/// Tower [`Layer`] that installs JWT authentication on a route.
22///
23/// For each request the middleware:
24/// 1. Tries each `TokenSource` in order; returns `401` if none yields a token.
25/// 2. Decodes and validates the token with `JwtDecoder`; returns `401` on failure.
26/// 3. Inserts [`Claims`] into request extensions for handler extraction.
27/// 4. When constructed via [`JwtLayer::from_service`], also performs a stateful
28///    database row lookup: hashes the `jti` claim and reads the session row,
29///    inserting the transport-agnostic [`Session`](crate::auth::session::Session)
30///    into extensions. Returns `401` if the row is missing (logged-out / revoked).
31///
32/// The default token source is [`BearerSource`] (`Authorization: Bearer <token>`).
33#[derive(Clone)]
34pub struct JwtLayer {
35    decoder: JwtDecoder,
36    sources: Arc<[Arc<dyn TokenSource>]>,
37    /// Present only when stateful validation is enabled (constructed via
38    /// [`JwtLayer::from_service`]). When `None` the layer behaves as a
39    /// purely stateless JWT validator.
40    service: Option<JwtSessionService>,
41}
42
43impl JwtLayer {
44    /// Creates a `JwtLayer` with `BearerSource` as the sole token source.
45    ///
46    /// This constructor performs **stateless** JWT validation only (signature +
47    /// claims). No database row lookup is performed. Use [`JwtLayer::from_service`]
48    /// for stateful validation that also inserts [`Session`](crate::auth::session::Session)
49    /// into request extensions.
50    pub fn new(decoder: JwtDecoder) -> Self {
51        Self {
52            decoder,
53            sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
54            service: None,
55        }
56    }
57
58    /// Creates a `JwtLayer` backed by a [`JwtSessionService`].
59    ///
60    /// After JWT signature/claims validation the middleware hashes the `jti`
61    /// claim, looks up the session row in the database, and inserts the
62    /// transport-agnostic [`Session`](crate::auth::session::Session) into
63    /// request extensions. Returns `401` with `auth:session_not_found` when
64    /// the session row is absent (logged-out or revoked).
65    ///
66    /// Use [`JwtSessionService::layer`] as the primary entry-point; this
67    /// constructor is the lower-level building block.
68    pub fn from_service(service: JwtSessionService) -> Self {
69        let decoder = service.decoder().clone();
70        Self {
71            decoder,
72            sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
73            service: Some(service),
74        }
75    }
76
77    /// Replaces the token sources with the provided list.
78    ///
79    /// Sources are tried in order; the first to return `Some` is used.
80    pub fn with_sources(mut self, sources: Vec<Arc<dyn TokenSource>>) -> Self {
81        self.sources = Arc::from(sources);
82        self
83    }
84}
85
86impl<Svc> Layer<Svc> for JwtLayer {
87    type Service = JwtMiddleware<Svc>;
88
89    fn layer(&self, inner: Svc) -> Self::Service {
90        JwtMiddleware {
91            inner,
92            decoder: self.decoder.clone(),
93            sources: self.sources.clone(),
94            service: self.service.clone(),
95        }
96    }
97}
98
99/// Tower [`Service`] produced by [`JwtLayer`]. See that type for behavior details.
100pub struct JwtMiddleware<Svc> {
101    inner: Svc,
102    decoder: JwtDecoder,
103    sources: Arc<[Arc<dyn TokenSource>]>,
104    service: Option<JwtSessionService>,
105}
106
107impl<Svc: Clone> Clone for JwtMiddleware<Svc> {
108    fn clone(&self) -> Self {
109        Self {
110            inner: self.inner.clone(),
111            decoder: self.decoder.clone(),
112            sources: self.sources.clone(),
113            service: self.service.clone(),
114        }
115    }
116}
117
118impl<Svc> Service<Request<Body>> for JwtMiddleware<Svc>
119where
120    Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
121    Svc::Future: Send + 'static,
122    Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
123{
124    type Response = http::Response<Body>;
125    type Error = Svc::Error;
126    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
127
128    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
129        self.inner.poll_ready(cx)
130    }
131
132    fn call(&mut self, request: Request<Body>) -> Self::Future {
133        let decoder = self.decoder.clone();
134        let sources = self.sources.clone();
135        let service = self.service.clone();
136        let mut inner = self.inner.clone();
137        std::mem::swap(&mut self.inner, &mut inner);
138
139        Box::pin(async move {
140            let (mut parts, body) = request.into_parts();
141
142            // Try each token source in order
143            let token = sources.iter().find_map(|s| s.extract(&parts));
144            let token = match token {
145                Some(t) => t,
146                None => {
147                    let err = Error::unauthorized("unauthorized")
148                        .chain(JwtError::MissingToken)
149                        .with_code(JwtError::MissingToken.code());
150                    return Ok(err.into_response());
151                }
152            };
153
154            // Decode and validate (sync)
155            let claims: Claims = match decoder.decode(&token) {
156                Ok(c) => c,
157                Err(e) => return Ok(e.into_response()),
158            };
159
160            // Stateful validation: when backed by a JwtSessionService and
161            // stateful_validation is enabled, hash the jti claim, load the session
162            // row, and insert the transport-agnostic Session into extensions.
163            // Returns 401 when the row is absent; propagates DB errors as 5xx.
164            if let Some(svc) = service {
165                // Fix #1: reject non-access audience tokens before any DB lookup.
166                if claims.aud.as_deref() != Some("access") {
167                    let err = Error::unauthorized("unauthorized").with_code("auth:aud_mismatch");
168                    return Ok(err.into_response());
169                }
170
171                // Fix #2: honor stateful_validation flag.
172                if svc.config().stateful_validation {
173                    let jti = match claims.jti.as_deref() {
174                        Some(j) => j,
175                        None => {
176                            let err = Error::unauthorized("unauthorized")
177                                .with_code("auth:session_not_found");
178                            return Ok(err.into_response());
179                        }
180                    };
181
182                    let session_token = match SessionToken::from_raw(jti) {
183                        Some(t) => t,
184                        None => {
185                            let err = Error::unauthorized("unauthorized")
186                                .with_code("auth:session_not_found");
187                            return Ok(err.into_response());
188                        }
189                    };
190
191                    // Fix #5: propagate DB errors as 5xx; 401 only for missing row.
192                    let lookup = svc.store().read_by_token_hash(&session_token.hash()).await;
193                    let raw = match lookup {
194                        Err(e) => return Ok(e.into_response()),
195                        Ok(None) => {
196                            let err = Error::unauthorized("unauthorized")
197                                .with_code("auth:session_not_found");
198                            return Ok(err.into_response());
199                        }
200                        Ok(Some(row)) => row,
201                    };
202
203                    parts.extensions.insert(Session::from(raw));
204                }
205            }
206
207            // Insert claims into extensions
208            parts.extensions.insert(claims);
209
210            let request = Request::from_parts(parts, body);
211            inner.call(request).await
212        })
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use http::{Response, StatusCode};
220    use std::convert::Infallible;
221    use tower::ServiceExt;
222
223    use crate::auth::session::jwt::{Claims, JwtEncoder, JwtSessionsConfig};
224
225    fn test_config() -> JwtSessionsConfig {
226        JwtSessionsConfig {
227            signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
228            ..JwtSessionsConfig::default()
229        }
230    }
231
232    fn now_secs() -> u64 {
233        std::time::SystemTime::now()
234            .duration_since(std::time::UNIX_EPOCH)
235            .unwrap()
236            .as_secs()
237    }
238
239    fn make_token(config: &JwtSessionsConfig) -> String {
240        let encoder = JwtEncoder::from_config(config);
241        let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
242        encoder.encode(&claims).unwrap()
243    }
244
245    async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
246        let has_claims = req.extensions().get::<Claims>().is_some();
247        let body = if has_claims { "ok" } else { "no-claims" };
248        Ok(Response::new(Body::from(body)))
249    }
250
251    #[tokio::test]
252    async fn valid_token_passes_through() {
253        let config = test_config();
254        let decoder = JwtDecoder::from_config(&config);
255        let token = make_token(&config);
256        let layer = JwtLayer::new(decoder);
257        let svc = layer.layer(tower::service_fn(echo_handler));
258
259        let req = Request::builder()
260            .header("Authorization", format!("Bearer {token}"))
261            .body(Body::empty())
262            .unwrap();
263        let resp = svc.oneshot(req).await.unwrap();
264        assert_eq!(resp.status(), StatusCode::OK);
265    }
266
267    #[tokio::test]
268    async fn missing_header_returns_401() {
269        let config = test_config();
270        let decoder = JwtDecoder::from_config(&config);
271        let layer = JwtLayer::new(decoder);
272        let svc = layer.layer(tower::service_fn(echo_handler));
273
274        let req = Request::builder().body(Body::empty()).unwrap();
275        let resp = svc.oneshot(req).await.unwrap();
276        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
277    }
278
279    #[tokio::test]
280    async fn expired_token_returns_401() {
281        let config = test_config();
282        let encoder = JwtEncoder::from_config(&config);
283        let decoder = JwtDecoder::from_config(&config);
284        let claims = Claims::new().with_exp(now_secs() - 10);
285        let token = encoder.encode(&claims).unwrap();
286        let layer = JwtLayer::new(decoder);
287        let svc = layer.layer(tower::service_fn(echo_handler));
288
289        let req = Request::builder()
290            .header("Authorization", format!("Bearer {token}"))
291            .body(Body::empty())
292            .unwrap();
293        let resp = svc.oneshot(req).await.unwrap();
294        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
295    }
296
297    #[tokio::test]
298    async fn tampered_token_returns_401() {
299        let config = test_config();
300        let decoder = JwtDecoder::from_config(&config);
301        let token = make_token(&config);
302        // Flip a character in the middle of the signature where all 6 bits are significant.
303        // The last character of a base64url string may have insignificant low bits,
304        // so flipping it can decode to identical bytes (making the test flaky).
305        let dot = token.rfind('.').unwrap();
306        let mid = dot + (token.len() - dot) / 2;
307        let mut bytes = token.into_bytes();
308        bytes[mid] = if bytes[mid] == b'A' { b'Z' } else { b'A' };
309        let token = String::from_utf8(bytes).unwrap();
310        let layer = JwtLayer::new(decoder);
311        let svc = layer.layer(tower::service_fn(echo_handler));
312
313        let req = Request::builder()
314            .header("Authorization", format!("Bearer {token}"))
315            .body(Body::empty())
316            .unwrap();
317        let resp = svc.oneshot(req).await.unwrap();
318        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
319    }
320
321    #[tokio::test]
322    async fn claims_inserted_into_extensions() {
323        let config = test_config();
324        let decoder = JwtDecoder::from_config(&config);
325        let token = make_token(&config);
326        let layer = JwtLayer::new(decoder);
327
328        let inner = tower::service_fn(|req: Request<Body>| async move {
329            let claims = req.extensions().get::<Claims>().unwrap();
330            assert_eq!(claims.subject(), Some("user_1"));
331            Ok::<_, Infallible>(Response::new(Body::empty()))
332        });
333
334        let svc = layer.layer(inner);
335        let req = Request::builder()
336            .header("Authorization", format!("Bearer {token}"))
337            .body(Body::empty())
338            .unwrap();
339        let resp = svc.oneshot(req).await.unwrap();
340        assert_eq!(resp.status(), StatusCode::OK);
341    }
342
343    #[tokio::test]
344    async fn custom_token_source_works() {
345        let config = test_config();
346        let decoder = JwtDecoder::from_config(&config);
347        let token = make_token(&config);
348        let layer = JwtLayer::new(decoder).with_sources(vec![Arc::new(
349            super::super::source::QuerySource("token"),
350        ) as Arc<dyn TokenSource>]);
351        let svc = layer.layer(tower::service_fn(echo_handler));
352
353        let req = Request::builder()
354            .uri(format!("/path?token={token}"))
355            .body(Body::empty())
356            .unwrap();
357        let resp = svc.oneshot(req).await.unwrap();
358        assert_eq!(resp.status(), StatusCode::OK);
359    }
360}