Skip to main content

axum_webtools/security/
jwt.rs

1use axum::{
2    extract::Request, http::StatusCode, response::IntoResponse, response::Response, Json,
3    RequestExt,
4};
5use chrono::TimeDelta;
6use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
7
8use axum::extract::FromRequestParts;
9use axum::http::request::Parts;
10use futures_util::future::BoxFuture;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::fmt::Display;
14use std::task::{Context, Poll};
15
16pub trait JwtToken: Send + Sync {
17    fn subject(&self) -> String;
18}
19
20fn get_jwt_secret() -> String {
21    std::env::var("JWT_SECRET").expect("JWT_SECRET must be set")
22}
23
24fn get_jwt_issuer() -> String {
25    std::env::var("JWT_ISSUER").expect("JWT_ISSUER must be set")
26}
27
28fn get_jwt_audience() -> String {
29    std::env::var("JWT_AUDIENCE").expect("JWT_AUDIENCE must be set")
30}
31
32pub fn parse_jwt_token(token: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
33    let jwt_issuer = get_jwt_issuer();
34    let jwt_audience = get_jwt_audience();
35    let jwt_secret = get_jwt_secret();
36    let decode_key = DecodingKey::from_secret(jwt_secret.as_bytes());
37
38    let mut validation = Validation::new(Algorithm::HS256);
39    validation.set_audience(&[jwt_audience]);
40    validation.set_issuer(&[jwt_issuer]);
41    let token_data = decode::<Claims>(token, &decode_key, &validation)?;
42    Ok(token_data.claims)
43}
44
45pub struct CreateJwtResult {
46    pub access_token: String,
47    pub expires_in: u64,
48    pub scopes: Vec<String>,
49}
50
51pub fn create_jwt_token(subject: impl Into<String>, scopes: Vec<String>) -> CreateJwtResult {
52    let now = chrono::Utc::now();
53    let expires_at = TimeDelta::try_days(7)
54        .map(|d| now + d)
55        .expect("Failed to calculate expiration date");
56    let issued_at = now.timestamp() as u64;
57    let exp = expires_at.timestamp() as u64;
58    let iss = get_jwt_issuer();
59    let aud = get_jwt_audience();
60    let sub = subject.into();
61
62    let claims = Claims {
63        iss,
64        sub,
65        issued_at,
66        exp,
67        aud,
68        scopes: scopes.clone(),
69    };
70
71    let jwt_secret = get_jwt_secret();
72    let encode_key = EncodingKey::from_secret(jwt_secret.as_bytes());
73    let access_token = encode(&Header::default(), &claims, &encode_key).unwrap();
74    CreateJwtResult {
75        access_token,
76        expires_in: exp,
77        scopes,
78    }
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82pub struct Claims {
83    pub sub: String,
84    pub aud: String,
85    pub iss: String,
86    pub issued_at: u64,
87    pub exp: u64,
88    pub scopes: Vec<String>,
89}
90
91impl Claims {
92    pub fn has_scopes(&self, expected_scopes: &[String]) -> bool {
93        expected_scopes
94            .iter()
95            .all(|scope| self.scopes.contains(&scope.to_string()))
96    }
97}
98
99impl Display for Claims {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        write!(f, "Email: {}", self.sub)
102    }
103}
104
105#[derive(Clone)]
106pub struct RequireScopeLayer {
107    required_scopes: Vec<String>,
108}
109
110impl RequireScopeLayer {
111    pub fn new() -> Self {
112        Self {
113            required_scopes: Vec::new(),
114        }
115    }
116
117    pub fn with(mut self, require_scope: Vec<&str>) -> Self {
118        self.required_scopes = require_scope.iter().map(|s| s.to_string()).collect();
119        self
120    }
121}
122
123impl Default for RequireScopeLayer {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl<S> Layer<S> for RequireScopeLayer {
130    type Service = RequireScopeMiddleware<S>;
131
132    fn layer(&self, inner: S) -> Self::Service {
133        RequireScopeMiddleware {
134            inner,
135            required_scopes: self.required_scopes.clone(),
136        }
137    }
138}
139
140#[derive(Clone)]
141pub struct RequireScopeMiddleware<S> {
142    inner: S,
143    required_scopes: Vec<String>,
144}
145
146impl<S> Service<Request> for RequireScopeMiddleware<S>
147where
148    S: Service<Request, Response = Response> + Clone + Send + 'static,
149    S::Future: Send + 'static,
150{
151    type Response = S::Response;
152    type Error = S::Error;
153    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
154
155    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        self.inner.poll_ready(cx)
157    }
158
159    fn call(&mut self, mut request: Request) -> Self::Future {
160        let required_scopes = self.required_scopes.clone();
161        let mut inner = self.inner.clone();
162
163        Box::pin(async move {
164            match request.extract_parts::<Claims>().await {
165                Ok(claims) => {
166                    if claims.has_scopes(&required_scopes) {
167                        return inner.call(request).await;
168                    }
169                    let response = AuthError::NotSufficientScopes.into_response();
170                    Ok(response)
171                }
172                Err(_) => {
173                    let response = AuthError::InvalidToken.into_response();
174                    Ok(response)
175                }
176            }
177        })
178    }
179}
180
181#[cfg(not(any(test, feature = "mock_jwt")))]
182use axum::RequestPartsExt;
183#[cfg(not(any(test, feature = "mock_jwt")))]
184use axum_extra::{
185    headers::{authorization::Bearer, Authorization},
186    TypedHeader,
187};
188use derive_more::Display;
189use thiserror::Error;
190use tower::{Layer, Service};
191
192#[cfg(not(any(test, feature = "mock_jwt")))]
193impl<S> FromRequestParts<S> for Claims
194where
195    S: Send + Sync,
196{
197    type Rejection = AuthError;
198
199    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
200        let TypedHeader(Authorization(bearer)) = parts
201            .extract::<TypedHeader<Authorization<Bearer>>>()
202            .await
203            .map_err(|_| AuthError::InvalidToken)?;
204        let claims = parse_jwt_token(bearer.token()).map_err(|_| AuthError::InvalidToken)?;
205        Ok(claims)
206    }
207}
208
209#[cfg(any(test, feature = "mock_jwt"))]
210impl<S> FromRequestParts<S> for Claims
211where
212    S: Send + Sync,
213{
214    type Rejection = AuthError;
215
216    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
217        let sub = parts
218            .headers
219            .get("X-Claims-Subject")
220            .unwrap()
221            .to_str()
222            .unwrap();
223        let iss = parts
224            .headers
225            .get("X-Claims-Issuer")
226            .unwrap()
227            .to_str()
228            .unwrap();
229        let aud = parts
230            .headers
231            .get("X-Claims-Audience")
232            .unwrap()
233            .to_str()
234            .unwrap();
235        let issued_at = parts
236            .headers
237            .get("X-Claims-Issued-At")
238            .unwrap()
239            .to_str()
240            .unwrap();
241        let exp = parts
242            .headers
243            .get("X-Claims-Expiration")
244            .unwrap()
245            .to_str()
246            .unwrap();
247        let scopes = parts
248            .headers
249            .get("X-Claims-Scopes")
250            .unwrap()
251            .to_str()
252            .unwrap()
253            .split(',')
254            .map(|s| s.to_string())
255            .collect();
256
257        let sub = sub.to_string();
258        let iss = iss.to_string();
259        let aud = aud.to_string();
260        let issued_at = issued_at.parse().unwrap();
261        let exp = exp.parse().unwrap();
262
263        Ok(Claims {
264            sub,
265            aud,
266            iss,
267            issued_at,
268            exp,
269            scopes,
270        })
271    }
272}
273
274impl IntoResponse for AuthError {
275    fn into_response(self) -> axum::response::Response {
276        let (status, error_message) = match self {
277            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
278            AuthError::NotSufficientScopes => (StatusCode::FORBIDDEN, "Not sufficient scopes"),
279        };
280        let body = Json(json!({
281            "error": error_message,
282        }));
283        (status, body).into_response()
284    }
285}
286
287#[derive(Debug, Error, Display)]
288pub enum AuthError {
289    InvalidToken,
290    NotSufficientScopes,
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use fake::faker::internet::en::FreeEmail;
297    use fake::Fake;
298
299    fn setup() {
300        std::env::set_var("JWT_SECRET", "secret");
301        std::env::set_var("JWT_ISSUER", "issuer");
302        std::env::set_var("JWT_AUDIENCE", "audience");
303    }
304
305    #[test]
306    fn test_create_token() {
307        setup();
308        let email: String = FreeEmail().fake();
309        let jwt_token = create_jwt_token(email.clone(), vec!["customers:read".to_string()]);
310        assert_eq!(jwt_token.scopes, vec!["customers:read"]);
311        let now_plus_5_days =
312            (chrono::Utc::now() + chrono::Duration::days(7)) - chrono::Duration::seconds(30);
313        assert!(jwt_token.expires_in > now_plus_5_days.timestamp() as u64);
314
315        let claims = parse_jwt_token(&jwt_token.access_token).unwrap();
316        assert_eq!(vec!["customers:read".to_string()], claims.scopes);
317        assert_eq!(email, claims.sub);
318    }
319
320    #[test]
321    fn test_invalid_token() {
322        setup();
323        let email: String = FreeEmail().fake();
324        let mut token = create_jwt_token(email, vec![]);
325        token.access_token.push('a');
326        let claims = parse_jwt_token(&token.access_token);
327        assert!(claims.is_err());
328    }
329}