axum_webtools/security/
jwt.rs1use axum::{
2 extract::FromRequestParts,
3 http::{request::Parts, StatusCode},
4 response::{IntoResponse, Response},
5 Json, RequestPartsExt,
6};
7
8use chrono::TimeDelta;
9use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
10
11use axum_extra::{
12 headers::{authorization::Bearer, Authorization},
13 TypedHeader,
14};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::fmt::Display;
18
19pub trait JwtToken: Send + Sync {
20 fn subject(&self) -> String;
21}
22
23fn get_jwt_secret() -> String {
24 std::env::var("JWT_SECRET").expect("JWT_SECRET must be set")
25}
26
27fn get_jwt_issuer() -> String {
28 std::env::var("JWT_ISSUER").expect("JWT_ISSUER must be set")
29}
30
31fn get_jwt_audience() -> String {
32 std::env::var("JWT_AUDIENCE").expect("JWT_AUDIENCE must be set")
33}
34
35pub fn parse_jwt_token(token: impl Into<String>) -> Result<Claims, jsonwebtoken::errors::Error> {
36 let token = token.into();
37 let jwt_issuer = get_jwt_issuer();
38 let jwt_audience = get_jwt_audience();
39 let jwt_secret = get_jwt_secret();
40 let decode_key = DecodingKey::from_secret(jwt_secret.as_bytes());
41
42 let mut validation = Validation::new(Algorithm::HS256);
43 validation.set_audience(&[jwt_audience]);
44 validation.set_issuer(&[jwt_issuer]);
45 let token_data = decode::<Claims>(token.as_str(), &decode_key, &validation)?;
46 Ok(token_data.claims)
47}
48
49pub fn create_jwt_token(subject: impl Into<String>) -> String {
50 let now = chrono::Utc::now();
51 let expires_at = TimeDelta::try_days(7)
52 .map(|d| now + d)
53 .expect("Failed to calculate expiration date");
54 let issued_at = now.timestamp() as u64;
55 let exp = expires_at.timestamp() as u64;
56 let iss = get_jwt_issuer();
57 let aud = get_jwt_audience();
58 let sub = subject.into();
59
60 let claims = Claims {
61 iss,
62 sub,
63 issued_at,
64 exp,
65 aud,
66 };
67
68 let jwt_secret = get_jwt_secret();
69 let encode_key = EncodingKey::from_secret(jwt_secret.as_bytes());
70 encode(&Header::default(), &claims, &encode_key).unwrap()
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74pub struct Claims {
75 pub sub: String,
76 pub aud: String,
77 pub iss: String,
78 pub issued_at: u64,
79 pub exp: u64,
80}
81
82impl Display for Claims {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "Email: {}", self.sub)
85 }
86}
87
88#[cfg(not(test))]
89impl<S> FromRequestParts<S> for Claims
90where
91 S: Send + Sync,
92{
93 type Rejection = AuthError;
94
95 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
96 let TypedHeader(Authorization(bearer)) = parts
97 .extract::<TypedHeader<Authorization<Bearer>>>()
98 .await
99 .map_err(|_| AuthError::InvalidToken)?;
100
101 let claims = parse_jwt_token(bearer.token()).map_err(|_| AuthError::InvalidToken)?;
102
103 Ok(claims)
104 }
105}
106
107#[cfg(test)]
108impl<S> FromRequestParts<S> for Claims
109where
110 S: Send + Sync,
111{
112 type Rejection = AuthError;
113
114 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
115 let sub = parts
116 .headers
117 .get("X-Claims-Subject")
118 .unwrap()
119 .to_str()
120 .unwrap();
121 let sub = sub.to_string();
122 Ok(Claims {
123 sub,
124 aud: "audience".to_string(),
125 iss: "issuer".to_string(),
126 issued_at: chrono::Utc::now().timestamp() as u64,
127 exp: (chrono::Utc::now() + chrono::Duration::days(7)).timestamp() as u64,
128 })
129 }
130}
131
132impl IntoResponse for AuthError {
133 fn into_response(self) -> Response {
134 let (status, error_message) = match self {
135 AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
136 };
137 let body = Json(json!({
138 "error": error_message,
139 }));
140 (status, body).into_response()
141 }
142}
143
144#[derive(Debug)]
145pub enum AuthError {
146 InvalidToken,
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use fake::faker::internet::en::FreeEmail;
153 use fake::Fake;
154
155 fn setup() {
156 std::env::set_var("JWT_SECRET", "secret");
157 std::env::set_var("JWT_ISSUER", "issuer");
158 std::env::set_var("JWT_AUDIENCE", "audience");
159 }
160
161 #[test]
162 fn test_create_token() {
163 setup();
164 let email: String = FreeEmail().fake();
165 let token = create_jwt_token(email.clone());
166 let claims = parse_jwt_token(token).unwrap();
167 assert_eq!(email, claims.sub);
168 }
169
170 #[test]
171 fn test_invalid_token() {
172 setup();
173 let email: String = FreeEmail().fake();
174 let mut token = create_jwt_token(email);
175 token.push_str("a");
176 let claims = parse_jwt_token(token);
177 assert!(claims.is_err());
178 }
179}