Skip to main content

modo/auth/session/jwt/
extractor.rs

1use axum::body::to_bytes;
2use axum::extract::{FromRef, FromRequest, FromRequestParts, OptionalFromRequestParts, Request};
3use http::request::Parts;
4
5use crate::Error;
6use crate::Result;
7use crate::auth::session::Session;
8use crate::auth::session::meta::SessionMeta;
9
10use super::claims::Claims;
11use super::error::JwtError;
12
13/// Standalone extractor for the raw Bearer token string.
14///
15/// Reads the `Authorization` header and strips the `Bearer` scheme prefix
16/// (case-insensitive per RFC 7235). Use this when you need the raw token
17/// string (e.g., to forward it or pass it to a revocation endpoint).
18///
19/// This extractor is independent of `JwtLayer` — it does not decode or validate
20/// the token.
21///
22/// Returns `401 Unauthorized` with `jwt:missing_token` when the header is absent,
23/// uses a scheme other than `Bearer`, or contains an empty token value.
24#[derive(Debug)]
25pub struct Bearer(pub String);
26
27impl<S: Send + Sync> FromRequestParts<S> for Bearer {
28    type Rejection = Error;
29
30    async fn from_request_parts(
31        parts: &mut Parts,
32        _state: &S,
33    ) -> std::result::Result<Self, Self::Rejection> {
34        let header = parts
35            .headers
36            .get(http::header::AUTHORIZATION)
37            .and_then(|v| v.to_str().ok())
38            .ok_or_else(|| {
39                Error::unauthorized("unauthorized")
40                    .chain(JwtError::MissingToken)
41                    .with_code(JwtError::MissingToken.code())
42            })?;
43
44        let token = header
45            .split_once(' ')
46            .and_then(|(scheme, rest)| {
47                scheme
48                    .eq_ignore_ascii_case("Bearer")
49                    .then(|| rest.trim_start())
50            })
51            .ok_or_else(|| {
52                Error::unauthorized("unauthorized")
53                    .chain(JwtError::MissingToken)
54                    .with_code(JwtError::MissingToken.code())
55            })?;
56
57        if token.is_empty() {
58            return Err(Error::unauthorized("unauthorized")
59                .chain(JwtError::MissingToken)
60                .with_code(JwtError::MissingToken.code()));
61        }
62
63        Ok(Bearer(token.to_string()))
64    }
65}
66
67/// Extracts [`Claims`] from request extensions.
68///
69/// [`JwtLayer`](super::middleware::JwtLayer) must be applied to the route — the
70/// middleware decodes the token and inserts `Claims` into extensions before the
71/// handler is called. Returns `401 Unauthorized` when claims are not present
72/// in extensions.
73impl<S: Send + Sync> FromRequestParts<S> for Claims {
74    type Rejection = Error;
75
76    async fn from_request_parts(
77        parts: &mut Parts,
78        _state: &S,
79    ) -> std::result::Result<Self, Self::Rejection> {
80        parts
81            .extensions
82            .get::<Claims>()
83            .cloned()
84            .ok_or_else(|| Error::unauthorized("unauthorized"))
85    }
86}
87
88/// Optionally extracts [`Claims`] from request extensions.
89///
90/// Returns `Ok(None)` when `JwtLayer` is not applied or the token is missing/invalid,
91/// allowing routes to serve both authenticated and unauthenticated users.
92impl<S: Send + Sync> OptionalFromRequestParts<S> for Claims {
93    type Rejection = Error;
94
95    async fn from_request_parts(
96        parts: &mut Parts,
97        _state: &S,
98    ) -> std::result::Result<Option<Self>, Self::Rejection> {
99        Ok(parts.extensions.get::<Claims>().cloned())
100    }
101}
102
103use super::service::JwtSessionService;
104use super::source::TokenSourceConfig;
105use super::tokens::TokenPair;
106
107/// Request-scoped JWT session manager.
108///
109/// `JwtSession` is an axum [`FromRequest`] extractor that captures the
110/// `JwtSessionService` from router state and pre-reads any tokens it needs
111/// (including the body when `refresh_source = Body { field }`).
112///
113/// Handlers use it to call [`rotate`](JwtSession::rotate) or
114/// [`logout`](JwtSession::logout) without manually fishing tokens out of the
115/// request.
116///
117/// # Trade-off
118///
119/// Because this extractor may consume the request body (when the refresh
120/// source is `Body { field }`), handlers that also need a typed body extractor
121/// (e.g., a login handler that parses `LoginReq`) **cannot** combine
122/// `JwtSession` with another body extractor. Those handlers should inject
123/// [`State<JwtSessionService>`](axum::extract::State) directly instead.
124///
125/// # Example
126///
127/// ```rust,ignore
128/// async fn refresh(jwt: JwtSession) -> Result<Json<TokenPair>> {
129///     Ok(Json(jwt.rotate().await?))
130/// }
131///
132/// async fn logout(jwt: JwtSession) -> Result<StatusCode> {
133///     jwt.logout().await?;
134///     Ok(StatusCode::NO_CONTENT)
135/// }
136/// ```
137pub struct JwtSession {
138    service: JwtSessionService,
139    parts: Parts,
140    body_refresh: Option<String>,
141}
142
143impl<S: Send + Sync> FromRequest<S> for JwtSession
144where
145    JwtSessionService: FromRef<S>,
146{
147    type Rejection = Error;
148
149    async fn from_request(req: Request, state: &S) -> Result<Self> {
150        let service = JwtSessionService::from_ref(state);
151        let (parts, body) = req.into_parts();
152
153        let body_refresh =
154            if let TokenSourceConfig::Body { field } = &service.config().refresh_source {
155                if let Ok(bytes) = to_bytes(body, 1024 * 1024).await {
156                    if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&bytes) {
157                        v.get(field.as_str())
158                            .and_then(|x| x.as_str())
159                            .map(str::to_string)
160                    } else {
161                        None
162                    }
163                } else {
164                    None
165                }
166            } else {
167                None
168            };
169
170        Ok(Self {
171            service,
172            parts,
173            body_refresh,
174        })
175    }
176}
177
178impl JwtSession {
179    /// Returns the [`Session`] injected by `JwtLayer`, if present.
180    pub fn current(&self) -> Option<&Session> {
181        self.parts.extensions.get::<Session>()
182    }
183
184    /// Authenticate a user and issue a new [`TokenPair`].
185    ///
186    /// Delegates directly to [`JwtSessionService::authenticate`].
187    pub async fn authenticate(&self, user_id: &str, meta: &SessionMeta) -> Result<TokenPair> {
188        self.service.authenticate(user_id, meta).await
189    }
190
191    /// Rotate the refresh token and return a fresh [`TokenPair`].
192    ///
193    /// Finds the refresh token according to `refresh_source` in the config.
194    pub async fn rotate(&self) -> Result<TokenPair> {
195        let token = self.find_refresh_token()?;
196        self.service.rotate(&token).await
197    }
198
199    /// Revoke the session associated with the current access token.
200    ///
201    /// Finds the access token according to `access_source` in the config.
202    pub async fn logout(&self) -> Result<()> {
203        let token = self.find_access_token()?;
204        self.service.logout(&token).await
205    }
206
207    /// List all active sessions for the given user.
208    pub async fn list(&self, user_id: &str) -> Result<Vec<Session>> {
209        self.service.list(user_id).await
210    }
211
212    /// Revoke a specific session by its ULID identifier.
213    pub async fn revoke(&self, user_id: &str, id: &str) -> Result<()> {
214        self.service.revoke(user_id, id).await
215    }
216
217    /// Revoke all sessions for the given user.
218    pub async fn revoke_all(&self, user_id: &str) -> Result<()> {
219        self.service.revoke_all(user_id).await
220    }
221
222    /// Revoke all sessions for the given user except the session with `keep_id`.
223    pub async fn revoke_all_except(&self, user_id: &str, keep_id: &str) -> Result<()> {
224        self.service.revoke_all_except(user_id, keep_id).await
225    }
226
227    fn find_access_token(&self) -> Result<String> {
228        match &self.service.config().access_source {
229            TokenSourceConfig::Bearer => self
230                .parts
231                .headers
232                .get(http::header::AUTHORIZATION)
233                .and_then(|v| v.to_str().ok())
234                .and_then(|s| {
235                    s.split_once(' ').and_then(|(scheme, rest)| {
236                        scheme
237                            .eq_ignore_ascii_case("Bearer")
238                            .then(|| rest.trim_start())
239                    })
240                })
241                .map(str::to_string)
242                .ok_or_else(|| {
243                    Error::unauthorized("unauthorized").with_code("auth:access_missing")
244                }),
245            TokenSourceConfig::Cookie { name } => {
246                let cookie_header = self
247                    .parts
248                    .headers
249                    .get(http::header::COOKIE)
250                    .and_then(|v| v.to_str().ok())
251                    .unwrap_or("");
252                for cookie in cookie_header.split(';') {
253                    let cookie = cookie.trim();
254                    if let Some((k, v)) = cookie.split_once('=')
255                        && k.trim() == name.as_str()
256                        && !v.is_empty()
257                    {
258                        return Ok(v.trim().to_string());
259                    }
260                }
261                Err(Error::unauthorized("unauthorized").with_code("auth:access_missing"))
262            }
263            TokenSourceConfig::Header { name } => self
264                .parts
265                .headers
266                .get(name.as_str())
267                .and_then(|v| v.to_str().ok())
268                .filter(|s| !s.is_empty())
269                .map(str::to_string)
270                .ok_or_else(|| {
271                    Error::unauthorized("unauthorized").with_code("auth:access_missing")
272                }),
273            TokenSourceConfig::Query { name } => {
274                let query = self.parts.uri.query().unwrap_or("");
275                for pair in query.split('&') {
276                    if let Some((k, v)) = pair.split_once('=')
277                        && k == name.as_str()
278                        && !v.is_empty()
279                    {
280                        return Ok(v.to_string());
281                    }
282                }
283                Err(Error::unauthorized("unauthorized").with_code("auth:access_missing"))
284            }
285            TokenSourceConfig::Body { .. } => {
286                Err(Error::internal("access_source=Body is not supported"))
287            }
288        }
289    }
290
291    fn find_refresh_token(&self) -> Result<String> {
292        if let Some(t) = &self.body_refresh {
293            return Ok(t.clone());
294        }
295        match &self.service.config().refresh_source {
296            TokenSourceConfig::Body { .. } => {
297                Err(Error::bad_request("refresh token missing").with_code("auth:refresh_missing"))
298            }
299            TokenSourceConfig::Bearer => self.find_access_token(),
300            TokenSourceConfig::Cookie { name } => {
301                let cookie_header = self
302                    .parts
303                    .headers
304                    .get(http::header::COOKIE)
305                    .and_then(|v| v.to_str().ok())
306                    .unwrap_or("");
307                for cookie in cookie_header.split(';') {
308                    let cookie = cookie.trim();
309                    if let Some((k, v)) = cookie.split_once('=')
310                        && k.trim() == name.as_str()
311                        && !v.is_empty()
312                    {
313                        return Ok(v.trim().to_string());
314                    }
315                }
316                Err(Error::unauthorized("unauthorized").with_code("auth:refresh_missing"))
317            }
318            TokenSourceConfig::Header { name } => self
319                .parts
320                .headers
321                .get(name.as_str())
322                .and_then(|v| v.to_str().ok())
323                .filter(|s| !s.is_empty())
324                .map(str::to_string)
325                .ok_or_else(|| {
326                    Error::unauthorized("unauthorized").with_code("auth:refresh_missing")
327                }),
328            TokenSourceConfig::Query { name } => {
329                let query = self.parts.uri.query().unwrap_or("");
330                for pair in query.split('&') {
331                    if let Some((k, v)) = pair.split_once('=')
332                        && k == name.as_str()
333                        && !v.is_empty()
334                    {
335                        return Ok(v.to_string());
336                    }
337                }
338                Err(Error::unauthorized("unauthorized").with_code("auth:refresh_missing"))
339            }
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[tokio::test]
349    async fn bearer_extracts_token() {
350        let (mut parts, _) = http::Request::builder()
351            .header("Authorization", "Bearer my-token")
352            .body(())
353            .unwrap()
354            .into_parts();
355        let bearer = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
356            .await
357            .unwrap();
358        assert_eq!(bearer.0, "my-token");
359    }
360
361    #[tokio::test]
362    async fn bearer_missing_header_returns_401() {
363        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
364        let err = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
365            .await
366            .unwrap_err();
367        assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
368    }
369
370    #[tokio::test]
371    async fn bearer_wrong_scheme_returns_401() {
372        let (mut parts, _) = http::Request::builder()
373            .header("Authorization", "Basic abc")
374            .body(())
375            .unwrap()
376            .into_parts();
377        let err = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
378            .await
379            .unwrap_err();
380        assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
381    }
382
383    #[tokio::test]
384    async fn claims_extract_from_extensions() {
385        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
386        let claims = Claims::new().with_sub("user_1").with_exp(9999999999);
387        parts.extensions.insert(claims.clone());
388        let extracted = <Claims as FromRequestParts<()>>::from_request_parts(&mut parts, &())
389            .await
390            .unwrap();
391        assert_eq!(extracted.sub, Some("user_1".into()));
392    }
393
394    #[tokio::test]
395    async fn claims_missing_returns_401() {
396        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
397        let err = <Claims as FromRequestParts<()>>::from_request_parts(&mut parts, &())
398            .await
399            .unwrap_err();
400        assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
401    }
402
403    #[tokio::test]
404    async fn option_claims_none_when_missing() {
405        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
406        let result =
407            <Claims as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
408        assert!(result.is_ok());
409        assert!(result.unwrap().is_none());
410    }
411
412    #[tokio::test]
413    async fn option_claims_some_when_present() {
414        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
415        parts.extensions.insert(Claims::new().with_sub("user_1"));
416        let result =
417            <Claims as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
418        assert!(result.unwrap().is_some());
419    }
420}