1use axum::{
2 extract::Request, http::StatusCode, response::IntoResponse, response::Response, Json,
3 RequestExt,
4};
5use chrono::TimeDelta;
6use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
7
8use axum::extract::FromRequestParts;
9use axum::http::request::Parts;
10use futures_util::future::BoxFuture;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::fmt::Display;
14use std::task::{Context, Poll};
15
16pub trait JwtToken: Send + Sync {
17 fn subject(&self) -> String;
18}
19
20fn get_jwt_secret() -> String {
21 std::env::var("JWT_SECRET").expect("JWT_SECRET must be set")
22}
23
24fn get_jwt_issuer() -> String {
25 std::env::var("JWT_ISSUER").expect("JWT_ISSUER must be set")
26}
27
28fn get_jwt_audience() -> String {
29 std::env::var("JWT_AUDIENCE").expect("JWT_AUDIENCE must be set")
30}
31
32pub fn parse_jwt_token(token: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
33 let jwt_issuer = get_jwt_issuer();
34 let jwt_audience = get_jwt_audience();
35 let jwt_secret = get_jwt_secret();
36 let decode_key = DecodingKey::from_secret(jwt_secret.as_bytes());
37
38 let mut validation = Validation::new(Algorithm::HS256);
39 validation.set_audience(&[jwt_audience]);
40 validation.set_issuer(&[jwt_issuer]);
41 let token_data = decode::<Claims>(token, &decode_key, &validation)?;
42 Ok(token_data.claims)
43}
44
45pub struct CreateJwtResult {
46 pub access_token: String,
47 pub expires_in: u64,
48 pub scopes: Vec<String>,
49}
50
51pub fn create_jwt_token(subject: impl Into<String>, scopes: Vec<String>) -> CreateJwtResult {
52 let now = chrono::Utc::now();
53 let expires_at = TimeDelta::try_days(7)
54 .map(|d| now + d)
55 .expect("Failed to calculate expiration date");
56 let issued_at = now.timestamp() as u64;
57 let exp = expires_at.timestamp() as u64;
58 let iss = get_jwt_issuer();
59 let aud = get_jwt_audience();
60 let sub = subject.into();
61
62 let claims = Claims {
63 iss,
64 sub,
65 issued_at,
66 exp,
67 aud,
68 scopes: scopes.clone(),
69 };
70
71 let jwt_secret = get_jwt_secret();
72 let encode_key = EncodingKey::from_secret(jwt_secret.as_bytes());
73 let access_token = encode(&Header::default(), &claims, &encode_key).unwrap();
74 CreateJwtResult {
75 access_token,
76 expires_in: exp,
77 scopes,
78 }
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82pub struct Claims {
83 pub sub: String,
84 pub aud: String,
85 pub iss: String,
86 pub issued_at: u64,
87 pub exp: u64,
88 pub scopes: Vec<String>,
89}
90
91impl Claims {
92 pub fn has_scopes(&self, expected_scopes: &[String]) -> bool {
93 expected_scopes
94 .iter()
95 .all(|scope| self.scopes.contains(&scope.to_string()))
96 }
97}
98
99impl Display for Claims {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(f, "Email: {}", self.sub)
102 }
103}
104
105#[derive(Clone)]
106pub struct RequireScopeLayer {
107 required_scopes: Vec<String>,
108}
109
110impl RequireScopeLayer {
111 pub fn new() -> Self {
112 Self {
113 required_scopes: Vec::new(),
114 }
115 }
116
117 pub fn with(mut self, require_scope: Vec<&str>) -> Self {
118 self.required_scopes = require_scope.iter().map(|s| s.to_string()).collect();
119 self
120 }
121}
122
123impl Default for RequireScopeLayer {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl<S> Layer<S> for RequireScopeLayer {
130 type Service = RequireScopeMiddleware<S>;
131
132 fn layer(&self, inner: S) -> Self::Service {
133 RequireScopeMiddleware {
134 inner,
135 required_scopes: self.required_scopes.clone(),
136 }
137 }
138}
139
140#[derive(Clone)]
141pub struct RequireScopeMiddleware<S> {
142 inner: S,
143 required_scopes: Vec<String>,
144}
145
146impl<S> Service<Request> for RequireScopeMiddleware<S>
147where
148 S: Service<Request, Response = Response> + Clone + Send + 'static,
149 S::Future: Send + 'static,
150{
151 type Response = S::Response;
152 type Error = S::Error;
153 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
154
155 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156 self.inner.poll_ready(cx)
157 }
158
159 fn call(&mut self, mut request: Request) -> Self::Future {
160 let required_scopes = self.required_scopes.clone();
161 let mut inner = self.inner.clone();
162
163 Box::pin(async move {
164 match request.extract_parts::<Claims>().await {
165 Ok(claims) => {
166 if claims.has_scopes(&required_scopes) {
167 return inner.call(request).await;
168 }
169 let response = AuthError::NotSufficientScopes.into_response();
170 Ok(response)
171 }
172 Err(_) => {
173 let response = AuthError::InvalidToken.into_response();
174 Ok(response)
175 }
176 }
177 })
178 }
179}
180
181#[cfg(not(any(test, feature = "mock_jwt")))]
182use axum::RequestPartsExt;
183#[cfg(not(any(test, feature = "mock_jwt")))]
184use axum_extra::{
185 headers::{authorization::Bearer, Authorization},
186 TypedHeader,
187};
188use derive_more::Display;
189use thiserror::Error;
190use tower::{Layer, Service};
191
192#[cfg(not(any(test, feature = "mock_jwt")))]
193impl<S> FromRequestParts<S> for Claims
194where
195 S: Send + Sync,
196{
197 type Rejection = AuthError;
198
199 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
200 let TypedHeader(Authorization(bearer)) = parts
201 .extract::<TypedHeader<Authorization<Bearer>>>()
202 .await
203 .map_err(|_| AuthError::InvalidToken)?;
204 let claims = parse_jwt_token(bearer.token()).map_err(|_| AuthError::InvalidToken)?;
205 Ok(claims)
206 }
207}
208
209#[cfg(any(test, feature = "mock_jwt"))]
210impl<S> FromRequestParts<S> for Claims
211where
212 S: Send + Sync,
213{
214 type Rejection = AuthError;
215
216 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
217 let sub = parts
218 .headers
219 .get("X-Claims-Subject")
220 .unwrap()
221 .to_str()
222 .unwrap();
223 let iss = parts
224 .headers
225 .get("X-Claims-Issuer")
226 .unwrap()
227 .to_str()
228 .unwrap();
229 let aud = parts
230 .headers
231 .get("X-Claims-Audience")
232 .unwrap()
233 .to_str()
234 .unwrap();
235 let issued_at = parts
236 .headers
237 .get("X-Claims-Issued-At")
238 .unwrap()
239 .to_str()
240 .unwrap();
241 let exp = parts
242 .headers
243 .get("X-Claims-Expiration")
244 .unwrap()
245 .to_str()
246 .unwrap();
247 let scopes = parts
248 .headers
249 .get("X-Claims-Scopes")
250 .unwrap()
251 .to_str()
252 .unwrap()
253 .split(',')
254 .map(|s| s.to_string())
255 .collect();
256
257 let sub = sub.to_string();
258 let iss = iss.to_string();
259 let aud = aud.to_string();
260 let issued_at = issued_at.parse().unwrap();
261 let exp = exp.parse().unwrap();
262
263 Ok(Claims {
264 sub,
265 aud,
266 iss,
267 issued_at,
268 exp,
269 scopes,
270 })
271 }
272}
273
274impl IntoResponse for AuthError {
275 fn into_response(self) -> axum::response::Response {
276 let (status, error_message) = match self {
277 AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
278 AuthError::NotSufficientScopes => (StatusCode::FORBIDDEN, "Not sufficient scopes"),
279 };
280 let body = Json(json!({
281 "error": error_message,
282 }));
283 (status, body).into_response()
284 }
285}
286
287#[derive(Debug, Error, Display)]
288pub enum AuthError {
289 InvalidToken,
290 NotSufficientScopes,
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use fake::faker::internet::en::FreeEmail;
297 use fake::Fake;
298
299 fn setup() {
300 std::env::set_var("JWT_SECRET", "secret");
301 std::env::set_var("JWT_ISSUER", "issuer");
302 std::env::set_var("JWT_AUDIENCE", "audience");
303 }
304
305 #[test]
306 fn test_create_token() {
307 setup();
308 let email: String = FreeEmail().fake();
309 let jwt_token = create_jwt_token(email.clone(), vec!["customers:read".to_string()]);
310 assert_eq!(jwt_token.scopes, vec!["customers:read"]);
311 let now_plus_5_days =
312 (chrono::Utc::now() + chrono::Duration::days(7)) - chrono::Duration::seconds(30);
313 assert!(jwt_token.expires_in > now_plus_5_days.timestamp() as u64);
314
315 let claims = parse_jwt_token(&jwt_token.access_token).unwrap();
316 assert_eq!(vec!["customers:read".to_string()], claims.scopes);
317 assert_eq!(email, claims.sub);
318 }
319
320 #[test]
321 fn test_invalid_token() {
322 setup();
323 let email: String = FreeEmail().fake();
324 let mut token = create_jwt_token(email, vec![]);
325 token.access_token.push('a');
326 let claims = parse_jwt_token(&token.access_token);
327 assert!(claims.is_err());
328 }
329}