axum_webtools/security/
jwt.rs1use axum::{
2 extract::FromRequestParts,
3 http::{request::Parts, StatusCode},
4 response::{IntoResponse, Response},
5 Json, RequestPartsExt,
6};
7
8use chrono::TimeDelta;
9use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
10
11use axum_extra::{
12 headers::{authorization::Bearer, Authorization},
13 TypedHeader,
14};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::fmt::Display;
18
19fn get_jwt_secret() -> String {
20 std::env::var("JWT_SECRET").expect("JWT_SECRET must be set")
21}
22
23fn get_jwt_issuer() -> String {
24 std::env::var("JWT_ISSUER").expect("JWT_ISSUER must be set")
25}
26
27fn get_jwt_audience() -> String {
28 std::env::var("JWT_AUDIENCE").expect("JWT_AUDIENCE must be set")
29}
30
31pub fn parse_jwt_token(token: impl Into<String>) -> Result<Claims, jsonwebtoken::errors::Error> {
32 let token = token.into();
33 let jwt_issuer = get_jwt_issuer();
34 let jwt_audience = get_jwt_audience();
35 let jwt_secret = get_jwt_secret();
36 let decode_key = DecodingKey::from_secret(jwt_secret.as_bytes());
37
38 let mut validation = Validation::new(Algorithm::HS256);
39 validation.set_audience(&[jwt_audience]);
40 validation.set_issuer(&[jwt_issuer]);
41 let token_data = decode::<Claims>(token.as_str(), &decode_key, &validation)?;
42 Ok(token_data.claims)
43}
44
45pub fn create_jwt_token(subject: impl Into<String>) -> String {
46 let now = chrono::Utc::now();
47 let expires_at = TimeDelta::try_days(7)
48 .map(|d| now + d)
49 .expect("Failed to calculate expiration date");
50 let issued_at = now.timestamp() as u64;
51 let exp = expires_at.timestamp() as u64;
52 let iss = get_jwt_issuer();
53 let aud = get_jwt_audience();
54 let sub = subject.into();
55
56 let claims = Claims {
57 iss,
58 sub,
59 issued_at,
60 exp,
61 aud,
62 };
63
64 let jwt_secret = get_jwt_secret();
65 let encode_key = EncodingKey::from_secret(jwt_secret.as_bytes());
66 encode(&Header::default(), &claims, &encode_key).unwrap()
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70pub struct Claims {
71 pub sub: String,
72 pub aud: String,
73 pub iss: String,
74 pub issued_at: u64,
75 pub exp: u64,
76}
77
78impl Display for Claims {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "Email: {}", self.sub)
81 }
82}
83
84impl<S> FromRequestParts<S> for Claims
85where
86 S: Send + Sync,
87{
88 type Rejection = AuthError;
89
90 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
91 let TypedHeader(Authorization(bearer)) = parts
92 .extract::<TypedHeader<Authorization<Bearer>>>()
93 .await
94 .map_err(|_| AuthError::InvalidToken)?;
95
96 let claims = parse_jwt_token(bearer.token()).map_err(|_| AuthError::InvalidToken)?;
97
98 Ok(claims)
99 }
100}
101
102impl IntoResponse for AuthError {
103 fn into_response(self) -> Response {
104 let (status, error_message) = match self {
105 AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
106 };
107 let body = Json(json!({
108 "error": error_message,
109 }));
110 (status, body).into_response()
111 }
112}
113
114#[derive(Debug)]
115pub enum AuthError {
116 InvalidToken,
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use fake::faker::internet::en::FreeEmail;
123 use fake::Fake;
124
125 fn setup() {
126 std::env::set_var("JWT_SECRET", "secret");
127 std::env::set_var("JWT_ISSUER", "issuer");
128 std::env::set_var("JWT_AUDIENCE", "audience");
129 }
130
131 #[test]
132 fn test_create_token() {
133 setup();
134 let email: String = FreeEmail().fake();
135 let token = create_jwt_token(email.clone());
136 let claims = parse_jwt_token(token).unwrap();
137 assert_eq!(email, claims.sub);
138 }
139
140 #[test]
141 fn test_invalid_token() {
142 setup();
143 let email: String = FreeEmail().fake();
144 let mut token = create_jwt_token(email);
145 token.push_str("a");
146 let claims = parse_jwt_token(token);
147 assert!(claims.is_err());
148 }
149}