1use std::marker::PhantomData;
32
33use jsonwebtoken::{decode, DecodingKey, TokenData, Validation};
34use serde::de::DeserializeOwned;
35
36use super::{AuthError, Authenticator};
37
38#[derive(Debug, Clone)]
40pub enum JwtAlgorithm {
41 HS256(String),
43 HS384(String),
45 HS512(String),
47 RS256(String),
49 RS384(String),
51 RS512(String),
53 EdDSA(String),
55}
56
57impl JwtAlgorithm {
58 fn jsonwebtoken_algorithm(&self) -> jsonwebtoken::Algorithm {
59 match self {
60 JwtAlgorithm::HS256(_) => jsonwebtoken::Algorithm::HS256,
61 JwtAlgorithm::HS384(_) => jsonwebtoken::Algorithm::HS384,
62 JwtAlgorithm::HS512(_) => jsonwebtoken::Algorithm::HS512,
63 JwtAlgorithm::RS256(_) => jsonwebtoken::Algorithm::RS256,
64 JwtAlgorithm::RS384(_) => jsonwebtoken::Algorithm::RS384,
65 JwtAlgorithm::RS512(_) => jsonwebtoken::Algorithm::RS512,
66 JwtAlgorithm::EdDSA(_) => jsonwebtoken::Algorithm::EdDSA,
67 }
68 }
69
70 fn decoding_key(&self) -> Result<DecodingKey, AuthError> {
71 match self {
72 JwtAlgorithm::HS256(secret)
73 | JwtAlgorithm::HS384(secret)
74 | JwtAlgorithm::HS512(secret) => Ok(DecodingKey::from_secret(secret.as_bytes())),
75 JwtAlgorithm::RS256(pem)
76 | JwtAlgorithm::RS384(pem)
77 | JwtAlgorithm::RS512(pem) => DecodingKey::from_rsa_pem(pem.as_bytes())
78 .map_err(|e| AuthError::Internal(format!("Invalid RSA key: {}", e))),
79 JwtAlgorithm::EdDSA(pem) => DecodingKey::from_ed_pem(pem.as_bytes())
80 .map_err(|e| AuthError::Internal(format!("Invalid EdDSA key: {}", e))),
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct JwtConfig {
88 pub algorithm: JwtAlgorithm,
90 pub issuer: Option<String>,
92 pub audience: Option<String>,
94 pub leeway_seconds: u64,
96 pub validate_exp: bool,
98 pub validate_nbf: bool,
100}
101
102impl JwtConfig {
103 pub fn hs256(secret: impl Into<String>) -> Self {
113 Self {
114 algorithm: JwtAlgorithm::HS256(secret.into()),
115 issuer: None,
116 audience: None,
117 leeway_seconds: 0,
118 validate_exp: true,
119 validate_nbf: true,
120 }
121 }
122
123 pub fn hs384(secret: impl Into<String>) -> Self {
125 Self {
126 algorithm: JwtAlgorithm::HS384(secret.into()),
127 issuer: None,
128 audience: None,
129 leeway_seconds: 0,
130 validate_exp: true,
131 validate_nbf: true,
132 }
133 }
134
135 pub fn hs512(secret: impl Into<String>) -> Self {
137 Self {
138 algorithm: JwtAlgorithm::HS512(secret.into()),
139 issuer: None,
140 audience: None,
141 leeway_seconds: 0,
142 validate_exp: true,
143 validate_nbf: true,
144 }
145 }
146
147 pub fn rs256(public_key_pem: impl Into<String>) -> Self {
152 Self {
153 algorithm: JwtAlgorithm::RS256(public_key_pem.into()),
154 issuer: None,
155 audience: None,
156 leeway_seconds: 0,
157 validate_exp: true,
158 validate_nbf: true,
159 }
160 }
161
162 pub fn rs384(public_key_pem: impl Into<String>) -> Self {
164 Self {
165 algorithm: JwtAlgorithm::RS384(public_key_pem.into()),
166 issuer: None,
167 audience: None,
168 leeway_seconds: 0,
169 validate_exp: true,
170 validate_nbf: true,
171 }
172 }
173
174 pub fn rs512(public_key_pem: impl Into<String>) -> Self {
176 Self {
177 algorithm: JwtAlgorithm::RS512(public_key_pem.into()),
178 issuer: None,
179 audience: None,
180 leeway_seconds: 0,
181 validate_exp: true,
182 validate_nbf: true,
183 }
184 }
185
186 pub fn eddsa(public_key_pem: impl Into<String>) -> Self {
188 Self {
189 algorithm: JwtAlgorithm::EdDSA(public_key_pem.into()),
190 issuer: None,
191 audience: None,
192 leeway_seconds: 0,
193 validate_exp: true,
194 validate_nbf: true,
195 }
196 }
197
198 pub fn from_env() -> Option<Self> {
207 let config = if let Ok(public_key) = std::env::var("JWT_PUBLIC_KEY") {
208 Self::rs256(public_key)
209 } else if let Ok(secret) = std::env::var("JWT_SECRET") {
210 Self::hs256(secret)
211 } else {
212 return None;
213 };
214
215 let config = if let Ok(issuer) = std::env::var("JWT_ISSUER") {
216 config.with_issuer(issuer)
217 } else {
218 config
219 };
220
221 let config = if let Ok(audience) = std::env::var("JWT_AUDIENCE") {
222 config.with_audience(audience)
223 } else {
224 config
225 };
226
227 let config = if let Ok(leeway) = std::env::var("JWT_LEEWAY") {
228 if let Ok(seconds) = leeway.parse() {
229 config.with_leeway(seconds)
230 } else {
231 config
232 }
233 } else {
234 config.with_leeway(60) };
236
237 Some(config)
238 }
239
240 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
242 self.issuer = Some(issuer.into());
243 self
244 }
245
246 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
248 self.audience = Some(audience.into());
249 self
250 }
251
252 pub fn with_leeway(mut self, seconds: u64) -> Self {
254 self.leeway_seconds = seconds;
255 self
256 }
257
258 pub fn without_exp_validation(mut self) -> Self {
260 self.validate_exp = false;
261 self
262 }
263
264 pub fn without_nbf_validation(mut self) -> Self {
266 self.validate_nbf = false;
267 self
268 }
269}
270
271pub struct JwtValidator<C> {
295 config: JwtConfig,
296 decoding_key: DecodingKey,
297 validation: Validation,
298 _phantom: PhantomData<C>,
299}
300
301impl<C> JwtValidator<C> {
302 pub fn new(config: JwtConfig) -> Self {
309 Self::try_new(config).expect("Invalid JWT configuration")
310 }
311
312 pub fn try_new(config: JwtConfig) -> Result<Self, AuthError> {
316 let decoding_key = config.algorithm.decoding_key()?;
317
318 let mut validation = Validation::new(config.algorithm.jsonwebtoken_algorithm());
319 validation.leeway = config.leeway_seconds;
320 validation.validate_exp = config.validate_exp;
321 validation.validate_nbf = config.validate_nbf;
322
323 if let Some(ref iss) = config.issuer {
324 validation.set_issuer(&[iss]);
325 }
326
327 if let Some(ref aud) = config.audience {
328 validation.set_audience(&[aud]);
329 } else {
330 validation.validate_aud = false;
333 }
334
335 Ok(Self {
336 config,
337 decoding_key,
338 validation,
339 _phantom: PhantomData,
340 })
341 }
342
343 pub fn config(&self) -> &JwtConfig {
345 &self.config
346 }
347}
348
349impl<C: Clone + Send + Sync + DeserializeOwned + 'static> JwtValidator<C> {
350 pub fn validate(&self, token: &str) -> Result<C, AuthError> {
352 let token_data: TokenData<C> =
353 decode(token, &self.decoding_key, &self.validation).map_err(|e| {
354 use jsonwebtoken::errors::ErrorKind;
355 match e.kind() {
356 ErrorKind::ExpiredSignature => AuthError::TokenExpired,
357 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
358 ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
359 ErrorKind::InvalidAudience => AuthError::InvalidAudience,
360 ErrorKind::InvalidToken => AuthError::InvalidToken("malformed token".into()),
361 ErrorKind::InvalidAlgorithm => {
362 AuthError::InvalidToken("wrong algorithm".into())
363 }
364 _ => AuthError::InvalidToken(e.to_string()),
365 }
366 })?;
367
368 Ok(token_data.claims)
369 }
370}
371
372#[async_trait::async_trait]
373impl<C: Clone + Send + Sync + DeserializeOwned + 'static> Authenticator for JwtValidator<C> {
374 type Claims = C;
375
376 async fn authenticate(&self, token: &str) -> Result<Self::Claims, AuthError> {
377 self.validate(token)
379 }
380}
381
382#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
386pub struct StandardClaims {
387 pub sub: String,
389 #[serde(default)]
391 pub exp: Option<i64>,
392 #[serde(default)]
394 pub iat: Option<i64>,
395 #[serde(default)]
397 pub nbf: Option<i64>,
398 #[serde(default)]
400 pub iss: Option<String>,
401 #[serde(default)]
403 pub aud: Option<String>,
404}
405
406impl super::HasSubject for StandardClaims {
407 fn subject(&self) -> &str {
408 &self.sub
409 }
410}
411
412impl super::HasExpiration for StandardClaims {
413 fn expiration(&self) -> Option<i64> {
414 self.exp
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 const TEST_SECRET: &str = "test-secret-that-is-long-enough-for-hs256";
424
425 fn create_test_token(claims: &impl serde::Serialize) -> String {
426 use jsonwebtoken::{encode, EncodingKey, Header};
427 encode(
428 &Header::default(),
429 claims,
430 &EncodingKey::from_secret(TEST_SECRET.as_bytes()),
431 )
432 .unwrap()
433 }
434
435 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
436 struct TestClaims {
437 sub: String,
438 exp: i64,
439 }
440
441 #[test]
442 fn test_jwt_config_hs256() {
443 let config = JwtConfig::hs256("secret");
444 assert!(matches!(config.algorithm, JwtAlgorithm::HS256(_)));
445 assert!(config.issuer.is_none());
446 assert!(config.audience.is_none());
447 }
448
449 #[test]
450 fn test_jwt_config_builder() {
451 let config = JwtConfig::hs256("secret")
452 .with_issuer("my-app")
453 .with_audience("my-audience")
454 .with_leeway(120);
455
456 assert_eq!(config.issuer, Some("my-app".to_string()));
457 assert_eq!(config.audience, Some("my-audience".to_string()));
458 assert_eq!(config.leeway_seconds, 120);
459 }
460
461 #[test]
462 fn test_jwt_validator_valid_token() {
463 let claims = TestClaims {
464 sub: "user123".to_string(),
465 exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
466 };
467 let token = create_test_token(&claims);
468
469 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
470
471 let result = validator.validate(&token);
472 assert!(result.is_ok());
473
474 let decoded = result.unwrap();
475 assert_eq!(decoded.sub, "user123");
476 }
477
478 #[test]
479 fn test_jwt_validator_expired_token() {
480 let claims = TestClaims {
481 sub: "user123".to_string(),
482 exp: 0, };
484 let token = create_test_token(&claims);
485
486 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
487
488 let result = validator.validate(&token);
489 assert!(matches!(result, Err(AuthError::TokenExpired)));
490 }
491
492 #[test]
493 fn test_jwt_validator_wrong_secret() {
494 let claims = TestClaims {
495 sub: "user123".to_string(),
496 exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
497 };
498 let token = create_test_token(&claims);
499
500 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256("wrong-secret"));
501
502 let result = validator.validate(&token);
503 assert!(matches!(result, Err(AuthError::InvalidSignature)));
504 }
505
506 #[test]
507 fn test_jwt_validator_invalid_token() {
508 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
509
510 let result = validator.validate("not-a-jwt");
511 assert!(matches!(result, Err(AuthError::InvalidToken(_))));
512 }
513
514 #[tokio::test]
515 async fn test_jwt_validator_authenticator_trait() {
516 let claims = TestClaims {
517 sub: "user123".to_string(),
518 exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
519 };
520 let token = create_test_token(&claims);
521
522 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
523
524 let result = validator.authenticate(&token).await;
526 assert!(result.is_ok());
527 assert_eq!(result.unwrap().sub, "user123");
528 }
529
530 #[test]
531 fn test_standard_claims() {
532 let claims = StandardClaims {
533 sub: "user123".to_string(),
534 exp: Some(1234567890),
535 iat: Some(1234567800),
536 nbf: None,
537 iss: Some("my-app".to_string()),
538 aud: None,
539 };
540
541 assert_eq!(claims.sub, "user123");
542 assert_eq!(claims.exp, Some(1234567890));
543
544 use super::super::HasSubject;
546 assert_eq!(claims.subject(), "user123");
547
548 use super::super::HasExpiration;
550 assert_eq!(claims.expiration(), Some(1234567890));
551 }
552}