axum_security/jwt/
session.rs1use std::convert::Infallible;
2
3use axum::{
4 extract::{FromRequestParts, OptionalFromRequestParts},
5 http::{Extensions, StatusCode, request::Parts},
6};
7
8#[derive(Clone, Debug)]
9pub struct Jwt<T>(pub T);
10
11impl<T: Send + Sync + 'static> Jwt<T> {
12 pub fn from_extensions(extensions: &mut Extensions) -> Option<Self> {
13 extensions.remove()
14 }
15}
16
17impl<S, T> FromRequestParts<S> for Jwt<T>
18where
19 S: Send + Sync,
20 T: Send + Sync + 'static,
21{
22 type Rejection = StatusCode;
23
24 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
25 if let Some(session) = <Jwt<T>>::from_extensions(&mut parts.extensions) {
26 Ok(session)
27 } else {
28 Err(StatusCode::UNAUTHORIZED)
29 }
30 }
31}
32
33impl<S, T> OptionalFromRequestParts<S> for Jwt<T>
34where
35 S: Send + Sync,
36 T: Send + Sync + 'static,
37{
38 type Rejection = Infallible;
39
40 async fn from_request_parts(
41 parts: &mut Parts,
42 _state: &S,
43 ) -> Result<Option<Self>, Self::Rejection> {
44 Ok(<Jwt<T>>::from_extensions(&mut parts.extensions))
45 }
46}
47
48#[cfg(test)]
49mod extract_jwt {
50 use axum::{
51 extract::FromRequestParts,
52 http::{Request, StatusCode},
53 };
54
55 use crate::jwt::Jwt;
56
57 #[tokio::test]
58 async fn extract() {
59 let jwt = Jwt(1i32);
60
61 let (mut parts, _) = Request::builder()
62 .extension(jwt.clone())
63 .body(())
64 .unwrap()
65 .into_parts();
66
67 let extracted_jwt = Jwt::<i32>::from_request_parts(&mut parts, &())
68 .await
69 .unwrap();
70
71 assert!(jwt.0 == extracted_jwt.0);
72 }
73
74 #[tokio::test]
75 async fn extract_rejection() {
76 let (mut parts, _) = Request::builder().body(()).unwrap().into_parts();
77
78 let rejection = Jwt::<i32>::from_request_parts(&mut parts, &())
79 .await
80 .unwrap_err();
81
82 assert!(rejection == StatusCode::UNAUTHORIZED);
83 }
84}