1use std::marker::PhantomData;
32
33use jsonwebtoken::{decode, DecodingKey, TokenData, Validation};
34use serde::de::DeserializeOwned;
35
36use super::{AuthError, Authenticator};
37
38#[derive(Debug, Clone)]
40pub enum JwtAlgorithm {
41 HS256(String),
43 HS384(String),
45 HS512(String),
47 RS256(String),
49 RS384(String),
51 RS512(String),
53 EdDSA(String),
55}
56
57impl JwtAlgorithm {
58 fn jsonwebtoken_algorithm(&self) -> jsonwebtoken::Algorithm {
59 match self {
60 JwtAlgorithm::HS256(_) => jsonwebtoken::Algorithm::HS256,
61 JwtAlgorithm::HS384(_) => jsonwebtoken::Algorithm::HS384,
62 JwtAlgorithm::HS512(_) => jsonwebtoken::Algorithm::HS512,
63 JwtAlgorithm::RS256(_) => jsonwebtoken::Algorithm::RS256,
64 JwtAlgorithm::RS384(_) => jsonwebtoken::Algorithm::RS384,
65 JwtAlgorithm::RS512(_) => jsonwebtoken::Algorithm::RS512,
66 JwtAlgorithm::EdDSA(_) => jsonwebtoken::Algorithm::EdDSA,
67 }
68 }
69
70 fn decoding_key(&self) -> Result<DecodingKey, AuthError> {
71 match self {
72 JwtAlgorithm::HS256(secret)
73 | JwtAlgorithm::HS384(secret)
74 | JwtAlgorithm::HS512(secret) => Ok(DecodingKey::from_secret(secret.as_bytes())),
75 JwtAlgorithm::RS256(pem) | JwtAlgorithm::RS384(pem) | JwtAlgorithm::RS512(pem) => {
76 DecodingKey::from_rsa_pem(pem.as_bytes())
77 .map_err(|e| AuthError::Internal(format!("Invalid RSA key: {}", e)))
78 }
79 JwtAlgorithm::EdDSA(pem) => DecodingKey::from_ed_pem(pem.as_bytes())
80 .map_err(|e| AuthError::Internal(format!("Invalid EdDSA key: {}", e))),
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct JwtConfig {
88 pub algorithm: JwtAlgorithm,
90 pub issuer: Option<String>,
92 pub audience: Option<String>,
94 pub leeway_seconds: u64,
96 pub validate_exp: bool,
98 pub validate_nbf: bool,
100}
101
102impl JwtConfig {
103 pub fn hs256(secret: impl Into<String>) -> Self {
113 Self {
114 algorithm: JwtAlgorithm::HS256(secret.into()),
115 issuer: None,
116 audience: None,
117 leeway_seconds: 0,
118 validate_exp: true,
119 validate_nbf: true,
120 }
121 }
122
123 pub fn hs384(secret: impl Into<String>) -> Self {
125 Self {
126 algorithm: JwtAlgorithm::HS384(secret.into()),
127 issuer: None,
128 audience: None,
129 leeway_seconds: 0,
130 validate_exp: true,
131 validate_nbf: true,
132 }
133 }
134
135 pub fn hs512(secret: impl Into<String>) -> Self {
137 Self {
138 algorithm: JwtAlgorithm::HS512(secret.into()),
139 issuer: None,
140 audience: None,
141 leeway_seconds: 0,
142 validate_exp: true,
143 validate_nbf: true,
144 }
145 }
146
147 pub fn rs256(public_key_pem: impl Into<String>) -> Self {
152 Self {
153 algorithm: JwtAlgorithm::RS256(public_key_pem.into()),
154 issuer: None,
155 audience: None,
156 leeway_seconds: 0,
157 validate_exp: true,
158 validate_nbf: true,
159 }
160 }
161
162 pub fn rs384(public_key_pem: impl Into<String>) -> Self {
164 Self {
165 algorithm: JwtAlgorithm::RS384(public_key_pem.into()),
166 issuer: None,
167 audience: None,
168 leeway_seconds: 0,
169 validate_exp: true,
170 validate_nbf: true,
171 }
172 }
173
174 pub fn rs512(public_key_pem: impl Into<String>) -> Self {
176 Self {
177 algorithm: JwtAlgorithm::RS512(public_key_pem.into()),
178 issuer: None,
179 audience: None,
180 leeway_seconds: 0,
181 validate_exp: true,
182 validate_nbf: true,
183 }
184 }
185
186 pub fn eddsa(public_key_pem: impl Into<String>) -> Self {
188 Self {
189 algorithm: JwtAlgorithm::EdDSA(public_key_pem.into()),
190 issuer: None,
191 audience: None,
192 leeway_seconds: 0,
193 validate_exp: true,
194 validate_nbf: true,
195 }
196 }
197
198 pub fn from_env() -> Option<Self> {
207 let config = if let Ok(public_key) = std::env::var("JWT_PUBLIC_KEY") {
208 Self::rs256(public_key)
209 } else if let Ok(secret) = std::env::var("JWT_SECRET") {
210 Self::hs256(secret)
211 } else {
212 return None;
213 };
214
215 let config = if let Ok(issuer) = std::env::var("JWT_ISSUER") {
216 config.with_issuer(issuer)
217 } else {
218 config
219 };
220
221 let config = if let Ok(audience) = std::env::var("JWT_AUDIENCE") {
222 config.with_audience(audience)
223 } else {
224 config
225 };
226
227 let config = if let Ok(leeway) = std::env::var("JWT_LEEWAY") {
228 if let Ok(seconds) = leeway.parse() {
229 config.with_leeway(seconds)
230 } else {
231 config
232 }
233 } else {
234 config.with_leeway(60) };
236
237 Some(config)
238 }
239
240 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
242 self.issuer = Some(issuer.into());
243 self
244 }
245
246 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
248 self.audience = Some(audience.into());
249 self
250 }
251
252 pub fn with_leeway(mut self, seconds: u64) -> Self {
254 self.leeway_seconds = seconds;
255 self
256 }
257
258 pub fn without_exp_validation(mut self) -> Self {
260 self.validate_exp = false;
261 self
262 }
263
264 pub fn without_nbf_validation(mut self) -> Self {
266 self.validate_nbf = false;
267 self
268 }
269}
270
271pub struct JwtValidator<C> {
295 config: JwtConfig,
296 decoding_key: DecodingKey,
297 validation: Validation,
298 _phantom: PhantomData<C>,
299}
300
301impl<C> JwtValidator<C> {
302 pub fn new(config: JwtConfig) -> Self {
309 Self::try_new(config).expect("Invalid JWT configuration")
310 }
311
312 pub fn try_new(config: JwtConfig) -> Result<Self, AuthError> {
316 let decoding_key = config.algorithm.decoding_key()?;
317
318 let mut validation = Validation::new(config.algorithm.jsonwebtoken_algorithm());
319 validation.leeway = config.leeway_seconds;
320 validation.validate_exp = config.validate_exp;
321 validation.validate_nbf = config.validate_nbf;
322
323 if let Some(ref iss) = config.issuer {
324 validation.set_issuer(&[iss]);
325 }
326
327 if let Some(ref aud) = config.audience {
328 validation.set_audience(&[aud]);
329 } else {
330 validation.validate_aud = false;
333 }
334
335 Ok(Self {
336 config,
337 decoding_key,
338 validation,
339 _phantom: PhantomData,
340 })
341 }
342
343 pub fn config(&self) -> &JwtConfig {
345 &self.config
346 }
347}
348
349impl<C: Clone + Send + Sync + DeserializeOwned + 'static> JwtValidator<C> {
350 pub fn validate(&self, token: &str) -> Result<C, AuthError> {
352 let token_data: TokenData<C> = decode(token, &self.decoding_key, &self.validation)
353 .map_err(|e| {
354 use jsonwebtoken::errors::ErrorKind;
355 match e.kind() {
356 ErrorKind::ExpiredSignature => AuthError::TokenExpired,
357 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
358 ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
359 ErrorKind::InvalidAudience => AuthError::InvalidAudience,
360 ErrorKind::InvalidToken => AuthError::InvalidToken("malformed token".into()),
361 ErrorKind::InvalidAlgorithm => {
362 AuthError::InvalidToken("wrong algorithm".into())
363 }
364 _ => AuthError::InvalidToken(e.to_string()),
365 }
366 })?;
367
368 Ok(token_data.claims)
369 }
370}
371
372#[async_trait::async_trait]
373impl<C: Clone + Send + Sync + DeserializeOwned + 'static> Authenticator for JwtValidator<C> {
374 type Claims = C;
375
376 async fn authenticate(&self, token: &str) -> Result<Self::Claims, AuthError> {
377 self.validate(token)
379 }
380}
381
382#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
386pub struct StandardClaims {
387 pub sub: String,
389 #[serde(default)]
391 pub exp: Option<i64>,
392 #[serde(default)]
394 pub iat: Option<i64>,
395 #[serde(default)]
397 pub nbf: Option<i64>,
398 #[serde(default)]
400 pub iss: Option<String>,
401 #[serde(default)]
403 pub aud: Option<String>,
404}
405
406impl super::HasSubject for StandardClaims {
407 fn subject(&self) -> &str {
408 &self.sub
409 }
410}
411
412impl super::HasExpiration for StandardClaims {
413 fn expiration(&self) -> Option<i64> {
414 self.exp
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 const TEST_SECRET: &str = "test-secret-that-is-long-enough-for-hs256";
425
426 fn create_test_token(claims: &impl serde::Serialize) -> String {
427 use jsonwebtoken::{encode, EncodingKey, Header};
428 encode(
429 &Header::default(),
430 claims,
431 &EncodingKey::from_secret(TEST_SECRET.as_bytes()),
432 )
433 .unwrap()
434 }
435
436 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
437 struct TestClaims {
438 sub: String,
439 exp: i64,
440 }
441
442 #[test]
443 fn test_jwt_config_hs256() {
444 let config = JwtConfig::hs256("secret");
445 assert!(matches!(config.algorithm, JwtAlgorithm::HS256(_)));
446 assert!(config.issuer.is_none());
447 assert!(config.audience.is_none());
448 }
449
450 #[test]
451 fn test_jwt_config_builder() {
452 let config = JwtConfig::hs256("secret")
453 .with_issuer("my-app")
454 .with_audience("my-audience")
455 .with_leeway(120);
456
457 assert_eq!(config.issuer, Some("my-app".to_string()));
458 assert_eq!(config.audience, Some("my-audience".to_string()));
459 assert_eq!(config.leeway_seconds, 120);
460 }
461
462 #[test]
463 fn test_jwt_validator_valid_token() {
464 let claims = TestClaims {
465 sub: "user123".to_string(),
466 exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
467 };
468 let token = create_test_token(&claims);
469
470 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
471
472 let result = validator.validate(&token);
473 assert!(result.is_ok());
474
475 let decoded = result.unwrap();
476 assert_eq!(decoded.sub, "user123");
477 }
478
479 #[test]
480 fn test_jwt_validator_expired_token() {
481 let claims = TestClaims {
482 sub: "user123".to_string(),
483 exp: 0, };
485 let token = create_test_token(&claims);
486
487 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
488
489 let result = validator.validate(&token);
490 assert!(matches!(result, Err(AuthError::TokenExpired)));
491 }
492
493 #[test]
494 fn test_jwt_validator_wrong_secret() {
495 let claims = TestClaims {
496 sub: "user123".to_string(),
497 exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
498 };
499 let token = create_test_token(&claims);
500
501 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256("wrong-secret"));
502
503 let result = validator.validate(&token);
504 assert!(matches!(result, Err(AuthError::InvalidSignature)));
505 }
506
507 #[test]
508 fn test_jwt_validator_invalid_token() {
509 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
510
511 let result = validator.validate("not-a-jwt");
512 assert!(matches!(result, Err(AuthError::InvalidToken(_))));
513 }
514
515 #[tokio::test]
516 async fn test_jwt_validator_authenticator_trait() {
517 let claims = TestClaims {
518 sub: "user123".to_string(),
519 exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
520 };
521 let token = create_test_token(&claims);
522
523 let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
524
525 let result = validator.authenticate(&token).await;
527 assert!(result.is_ok());
528 assert_eq!(result.unwrap().sub, "user123");
529 }
530
531 #[test]
532 fn test_standard_claims() {
533 let claims = StandardClaims {
534 sub: "user123".to_string(),
535 exp: Some(1234567890),
536 iat: Some(1234567800),
537 nbf: None,
538 iss: Some("my-app".to_string()),
539 aud: None,
540 };
541
542 assert_eq!(claims.sub, "user123");
543 assert_eq!(claims.exp, Some(1234567890));
544
545 use super::super::HasSubject;
547 assert_eq!(claims.subject(), "user123");
548
549 use super::super::HasExpiration;
551 assert_eq!(claims.expiration(), Some(1234567890));
552 }
553}