modo/auth/session/jwt/
extractor.rs1use axum::body::to_bytes;
2use axum::extract::{FromRef, FromRequest, FromRequestParts, OptionalFromRequestParts, Request};
3use http::request::Parts;
4
5use crate::Error;
6use crate::Result;
7use crate::auth::session::Session;
8use crate::auth::session::meta::SessionMeta;
9
10use super::claims::Claims;
11use super::error::JwtError;
12
13#[derive(Debug)]
25pub struct Bearer(pub String);
26
27impl<S: Send + Sync> FromRequestParts<S> for Bearer {
28 type Rejection = Error;
29
30 async fn from_request_parts(
31 parts: &mut Parts,
32 _state: &S,
33 ) -> std::result::Result<Self, Self::Rejection> {
34 let header = parts
35 .headers
36 .get(http::header::AUTHORIZATION)
37 .and_then(|v| v.to_str().ok())
38 .ok_or_else(|| {
39 Error::unauthorized("unauthorized")
40 .chain(JwtError::MissingToken)
41 .with_code(JwtError::MissingToken.code())
42 })?;
43
44 let token = header
45 .split_once(' ')
46 .and_then(|(scheme, rest)| {
47 scheme
48 .eq_ignore_ascii_case("Bearer")
49 .then(|| rest.trim_start())
50 })
51 .ok_or_else(|| {
52 Error::unauthorized("unauthorized")
53 .chain(JwtError::MissingToken)
54 .with_code(JwtError::MissingToken.code())
55 })?;
56
57 if token.is_empty() {
58 return Err(Error::unauthorized("unauthorized")
59 .chain(JwtError::MissingToken)
60 .with_code(JwtError::MissingToken.code()));
61 }
62
63 Ok(Bearer(token.to_string()))
64 }
65}
66
67impl<S: Send + Sync> FromRequestParts<S> for Claims {
74 type Rejection = Error;
75
76 async fn from_request_parts(
77 parts: &mut Parts,
78 _state: &S,
79 ) -> std::result::Result<Self, Self::Rejection> {
80 parts
81 .extensions
82 .get::<Claims>()
83 .cloned()
84 .ok_or_else(|| Error::unauthorized("unauthorized"))
85 }
86}
87
88impl<S: Send + Sync> OptionalFromRequestParts<S> for Claims {
93 type Rejection = Error;
94
95 async fn from_request_parts(
96 parts: &mut Parts,
97 _state: &S,
98 ) -> std::result::Result<Option<Self>, Self::Rejection> {
99 Ok(parts.extensions.get::<Claims>().cloned())
100 }
101}
102
103use super::service::JwtSessionService;
104use super::source::TokenSourceConfig;
105use super::tokens::TokenPair;
106
107pub struct JwtSession {
138 service: JwtSessionService,
139 parts: Parts,
140 body_refresh: Option<String>,
141}
142
143impl<S: Send + Sync> FromRequest<S> for JwtSession
144where
145 JwtSessionService: FromRef<S>,
146{
147 type Rejection = Error;
148
149 async fn from_request(req: Request, state: &S) -> Result<Self> {
150 let service = JwtSessionService::from_ref(state);
151 let (parts, body) = req.into_parts();
152
153 let body_refresh =
154 if let TokenSourceConfig::Body { field } = &service.config().refresh_source {
155 if let Ok(bytes) = to_bytes(body, 1024 * 1024).await {
156 if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&bytes) {
157 v.get(field.as_str())
158 .and_then(|x| x.as_str())
159 .map(str::to_string)
160 } else {
161 None
162 }
163 } else {
164 None
165 }
166 } else {
167 None
168 };
169
170 Ok(Self {
171 service,
172 parts,
173 body_refresh,
174 })
175 }
176}
177
178impl JwtSession {
179 pub fn current(&self) -> Option<&Session> {
181 self.parts.extensions.get::<Session>()
182 }
183
184 pub async fn authenticate(&self, user_id: &str, meta: &SessionMeta) -> Result<TokenPair> {
188 self.service.authenticate(user_id, meta).await
189 }
190
191 pub async fn rotate(&self) -> Result<TokenPair> {
195 let token = self.find_refresh_token()?;
196 self.service.rotate(&token).await
197 }
198
199 pub async fn logout(&self) -> Result<()> {
203 let token = self.find_access_token()?;
204 self.service.logout(&token).await
205 }
206
207 pub async fn list(&self, user_id: &str) -> Result<Vec<Session>> {
209 self.service.list(user_id).await
210 }
211
212 pub async fn revoke(&self, user_id: &str, id: &str) -> Result<()> {
214 self.service.revoke(user_id, id).await
215 }
216
217 pub async fn revoke_all(&self, user_id: &str) -> Result<()> {
219 self.service.revoke_all(user_id).await
220 }
221
222 pub async fn revoke_all_except(&self, user_id: &str, keep_id: &str) -> Result<()> {
224 self.service.revoke_all_except(user_id, keep_id).await
225 }
226
227 fn find_access_token(&self) -> Result<String> {
228 match &self.service.config().access_source {
229 TokenSourceConfig::Bearer => self
230 .parts
231 .headers
232 .get(http::header::AUTHORIZATION)
233 .and_then(|v| v.to_str().ok())
234 .and_then(|s| {
235 s.split_once(' ').and_then(|(scheme, rest)| {
236 scheme
237 .eq_ignore_ascii_case("Bearer")
238 .then(|| rest.trim_start())
239 })
240 })
241 .map(str::to_string)
242 .ok_or_else(|| {
243 Error::unauthorized("unauthorized").with_code("auth:access_missing")
244 }),
245 TokenSourceConfig::Cookie { name } => {
246 let cookie_header = self
247 .parts
248 .headers
249 .get(http::header::COOKIE)
250 .and_then(|v| v.to_str().ok())
251 .unwrap_or("");
252 for cookie in cookie_header.split(';') {
253 let cookie = cookie.trim();
254 if let Some((k, v)) = cookie.split_once('=')
255 && k.trim() == name.as_str()
256 && !v.is_empty()
257 {
258 return Ok(v.trim().to_string());
259 }
260 }
261 Err(Error::unauthorized("unauthorized").with_code("auth:access_missing"))
262 }
263 TokenSourceConfig::Header { name } => self
264 .parts
265 .headers
266 .get(name.as_str())
267 .and_then(|v| v.to_str().ok())
268 .filter(|s| !s.is_empty())
269 .map(str::to_string)
270 .ok_or_else(|| {
271 Error::unauthorized("unauthorized").with_code("auth:access_missing")
272 }),
273 TokenSourceConfig::Query { name } => {
274 let query = self.parts.uri.query().unwrap_or("");
275 for pair in query.split('&') {
276 if let Some((k, v)) = pair.split_once('=')
277 && k == name.as_str()
278 && !v.is_empty()
279 {
280 return Ok(v.to_string());
281 }
282 }
283 Err(Error::unauthorized("unauthorized").with_code("auth:access_missing"))
284 }
285 TokenSourceConfig::Body { .. } => {
286 Err(Error::internal("access_source=Body is not supported"))
287 }
288 }
289 }
290
291 fn find_refresh_token(&self) -> Result<String> {
292 if let Some(t) = &self.body_refresh {
293 return Ok(t.clone());
294 }
295 match &self.service.config().refresh_source {
296 TokenSourceConfig::Body { .. } => {
297 Err(Error::bad_request("refresh token missing").with_code("auth:refresh_missing"))
298 }
299 TokenSourceConfig::Bearer => self.find_access_token(),
300 TokenSourceConfig::Cookie { name } => {
301 let cookie_header = self
302 .parts
303 .headers
304 .get(http::header::COOKIE)
305 .and_then(|v| v.to_str().ok())
306 .unwrap_or("");
307 for cookie in cookie_header.split(';') {
308 let cookie = cookie.trim();
309 if let Some((k, v)) = cookie.split_once('=')
310 && k.trim() == name.as_str()
311 && !v.is_empty()
312 {
313 return Ok(v.trim().to_string());
314 }
315 }
316 Err(Error::unauthorized("unauthorized").with_code("auth:refresh_missing"))
317 }
318 TokenSourceConfig::Header { name } => self
319 .parts
320 .headers
321 .get(name.as_str())
322 .and_then(|v| v.to_str().ok())
323 .filter(|s| !s.is_empty())
324 .map(str::to_string)
325 .ok_or_else(|| {
326 Error::unauthorized("unauthorized").with_code("auth:refresh_missing")
327 }),
328 TokenSourceConfig::Query { name } => {
329 let query = self.parts.uri.query().unwrap_or("");
330 for pair in query.split('&') {
331 if let Some((k, v)) = pair.split_once('=')
332 && k == name.as_str()
333 && !v.is_empty()
334 {
335 return Ok(v.to_string());
336 }
337 }
338 Err(Error::unauthorized("unauthorized").with_code("auth:refresh_missing"))
339 }
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[tokio::test]
349 async fn bearer_extracts_token() {
350 let (mut parts, _) = http::Request::builder()
351 .header("Authorization", "Bearer my-token")
352 .body(())
353 .unwrap()
354 .into_parts();
355 let bearer = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
356 .await
357 .unwrap();
358 assert_eq!(bearer.0, "my-token");
359 }
360
361 #[tokio::test]
362 async fn bearer_missing_header_returns_401() {
363 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
364 let err = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
365 .await
366 .unwrap_err();
367 assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
368 }
369
370 #[tokio::test]
371 async fn bearer_wrong_scheme_returns_401() {
372 let (mut parts, _) = http::Request::builder()
373 .header("Authorization", "Basic abc")
374 .body(())
375 .unwrap()
376 .into_parts();
377 let err = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
378 .await
379 .unwrap_err();
380 assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
381 }
382
383 #[tokio::test]
384 async fn claims_extract_from_extensions() {
385 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
386 let claims = Claims::new().with_sub("user_1").with_exp(9999999999);
387 parts.extensions.insert(claims.clone());
388 let extracted = <Claims as FromRequestParts<()>>::from_request_parts(&mut parts, &())
389 .await
390 .unwrap();
391 assert_eq!(extracted.sub, Some("user_1".into()));
392 }
393
394 #[tokio::test]
395 async fn claims_missing_returns_401() {
396 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
397 let err = <Claims as FromRequestParts<()>>::from_request_parts(&mut parts, &())
398 .await
399 .unwrap_err();
400 assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
401 }
402
403 #[tokio::test]
404 async fn option_claims_none_when_missing() {
405 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
406 let result =
407 <Claims as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
408 assert!(result.is_ok());
409 assert!(result.unwrap().is_none());
410 }
411
412 #[tokio::test]
413 async fn option_claims_some_when_present() {
414 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
415 parts.extensions.insert(Claims::new().with_sub("user_1"));
416 let result =
417 <Claims as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
418 assert!(result.unwrap().is_some());
419 }
420}