1use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
28use serde::{Deserialize, Serialize};
29use std::fmt;
30use std::sync::Arc;
31
32use crate::revocation::TokenBlacklist;
33
34#[derive(Debug)]
38pub enum AuthError {
39 InvalidToken(String),
42 Expired,
44 Revoked,
46 EncodingKeyMissing,
48 Encode(String),
50 Internal(String),
52}
53
54impl fmt::Display for AuthError {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 Self::InvalidToken(e) => write!(f, "invalid token: {e}"),
58 Self::Expired => f.write_str("token expired"),
59 Self::Revoked => f.write_str("token revoked"),
60 Self::EncodingKeyMissing => f.write_str("no encoding key configured"),
61 Self::Encode(e) => write!(f, "encoding failed: {e}"),
62 Self::Internal(e) => write!(f, "internal error: {e}"),
63 }
64 }
65}
66
67impl std::error::Error for AuthError {}
68
69pub trait HasJti {
92 fn jti(&self) -> Option<&str> {
94 None
95 }
96}
97
98pub struct JwtConfig {
104 pub decoding_key: DecodingKey,
105 pub encoding_key: Option<EncodingKey>,
106 pub validation: Validation,
107}
108
109#[derive(Clone)]
124pub struct JwtManager {
125 config: Arc<JwtConfig>,
126 blacklist: Option<TokenBlacklist>,
127}
128
129impl JwtManager {
130 pub fn new(secret: &[u8]) -> Self {
132 let mut validation = Validation::new(Algorithm::HS256);
133 validation.leeway = 60; Self {
135 config: Arc::new(JwtConfig {
136 decoding_key: DecodingKey::from_secret(secret),
137 encoding_key: Some(EncodingKey::from_secret(secret)),
138 validation,
139 }),
140 blacklist: None,
141 }
142 }
143
144 pub fn verify_only(secret: &[u8]) -> Self {
146 let mut validation = Validation::new(Algorithm::HS256);
147 validation.leeway = 60;
148 Self {
149 config: Arc::new(JwtConfig {
150 decoding_key: DecodingKey::from_secret(secret),
151 encoding_key: None,
152 validation,
153 }),
154 blacklist: None,
155 }
156 }
157
158 pub fn from_rsa_pem(private_key_pem: &[u8], public_key_pem: &[u8]) -> Result<Self, AuthError> {
163 let encoding_key = EncodingKey::from_rsa_pem(private_key_pem)
164 .map_err(|e| AuthError::Internal(format!("RSA private key: {e}")))?;
165 let decoding_key = DecodingKey::from_rsa_pem(public_key_pem)
166 .map_err(|e| AuthError::Internal(format!("RSA public key: {e}")))?;
167 Ok(Self {
168 config: Arc::new(JwtConfig {
169 encoding_key: Some(encoding_key),
170 decoding_key,
171 validation: Validation::new(Algorithm::RS256),
172 }),
173 blacklist: None,
174 })
175 }
176
177 pub fn from_rsa_public_pem(public_key_pem: &[u8]) -> Result<Self, AuthError> {
183 let decoding_key = DecodingKey::from_rsa_pem(public_key_pem)
184 .map_err(|e| AuthError::Internal(format!("RSA public key: {e}")))?;
185 Ok(Self {
186 config: Arc::new(JwtConfig {
187 encoding_key: None,
188 decoding_key,
189 validation: Validation::new(Algorithm::RS256),
190 }),
191 blacklist: None,
192 })
193 }
194
195 pub fn from_ec_pem(private_key_pem: &[u8], public_key_pem: &[u8]) -> Result<Self, AuthError> {
200 let encoding_key = EncodingKey::from_ec_pem(private_key_pem)
201 .map_err(|e| AuthError::Internal(format!("EC private key: {e}")))?;
202 let decoding_key = DecodingKey::from_ec_pem(public_key_pem)
203 .map_err(|e| AuthError::Internal(format!("EC public key: {e}")))?;
204 Ok(Self {
205 config: Arc::new(JwtConfig {
206 encoding_key: Some(encoding_key),
207 decoding_key,
208 validation: Validation::new(Algorithm::ES256),
209 }),
210 blacklist: None,
211 })
212 }
213
214 pub fn from_ec_public_pem(public_key_pem: &[u8]) -> Result<Self, AuthError> {
220 let decoding_key = DecodingKey::from_ec_pem(public_key_pem)
221 .map_err(|e| AuthError::Internal(format!("EC public key: {e}")))?;
222 Ok(Self {
223 config: Arc::new(JwtConfig {
224 encoding_key: None,
225 decoding_key,
226 validation: Validation::new(Algorithm::ES256),
227 }),
228 blacklist: None,
229 })
230 }
231
232 pub fn with_config(config: JwtConfig) -> Self {
234 Self {
235 config: Arc::new(config),
236 blacklist: None,
237 }
238 }
239
240 pub fn with_blacklist(mut self, blacklist: TokenBlacklist) -> Self {
248 self.blacklist = Some(blacklist);
249 self
250 }
251
252 pub fn decode<T>(&self, token: &str) -> Result<T, AuthError>
262 where
263 T: for<'de> Deserialize<'de> + HasJti,
264 {
265 let token_data = decode::<T>(token, &self.config.decoding_key, &self.config.validation)
266 .map_err(|e| match e.kind() {
267 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::Expired,
268 _ => AuthError::InvalidToken(e.to_string()),
269 })?;
270
271 if let Some(bl) = &self.blacklist
272 && let Some(jti) = token_data.claims.jti()
273 && bl.is_revoked(jti)
274 {
275 return Err(AuthError::Revoked);
276 }
277
278 Ok(token_data.claims)
279 }
280
281 pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String, AuthError> {
287 let key = self
288 .config
289 .encoding_key
290 .as_ref()
291 .ok_or(AuthError::EncodingKeyMissing)?;
292 encode(&Header::default(), claims, key).map_err(|e| AuthError::Encode(e.to_string()))
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[derive(Debug, PartialEq, Serialize, Deserialize)]
301 struct TestClaims {
302 sub: String,
303 exp: u64,
304 }
305
306 impl HasJti for TestClaims {}
308
309 fn far_future_exp() -> u64 {
310 253_370_764_800_u64
312 }
313
314 #[test]
315 fn test_encode_decode_roundtrip() {
316 let mgr = JwtManager::new(b"test-secret-key");
317 let claims = TestClaims {
318 sub: "user-42".to_string(),
319 exp: far_future_exp(),
320 };
321 let token = mgr.encode(&claims).expect("encode should succeed");
322 assert!(!token.is_empty());
323 let decoded: TestClaims = mgr.decode(&token).expect("decode should succeed");
324 assert_eq!(decoded, claims);
325 }
326
327 #[test]
328 fn test_decode_wrong_secret_fails() {
329 let mgr_sign = JwtManager::new(b"correct-secret");
330 let mgr_verify = JwtManager::new(b"wrong-secret");
331 let claims = TestClaims {
332 sub: "user-1".to_string(),
333 exp: far_future_exp(),
334 };
335 let token = mgr_sign.encode(&claims).expect("encode must succeed");
336 let result: Result<TestClaims, _> = mgr_verify.decode(&token);
337 assert!(result.is_err(), "decode with wrong secret should fail");
338 }
339
340 #[test]
341 fn test_decode_invalid_token_fails() {
342 let mgr = JwtManager::new(b"any-secret");
343 let result: Result<TestClaims, _> = mgr.decode("not.a.jwt");
344 assert!(result.is_err(), "decode of garbage should fail");
345 }
346
347 #[test]
348 fn test_decode_mangled_token_fails() {
349 let mgr = JwtManager::new(b"secret");
350 let claims = TestClaims {
351 sub: "u".to_string(),
352 exp: far_future_exp(),
353 };
354 let mut token = mgr.encode(&claims).expect("encode ok");
355 let last = token.pop().unwrap();
357 token.push(if last == 'A' { 'B' } else { 'A' });
358 let result: Result<TestClaims, _> = mgr.decode(&token);
359 assert!(result.is_err(), "mangled token should fail");
360 }
361
362 #[test]
363 fn test_clone_shares_key() {
364 let mgr1 = JwtManager::new(b"shared-key");
365 let mgr2 = mgr1.clone();
366 let claims = TestClaims {
367 sub: "u".to_string(),
368 exp: far_future_exp(),
369 };
370 let token = mgr1.encode(&claims).unwrap();
371 let decoded: TestClaims = mgr2.decode(&token).expect("clone should decode");
373 assert_eq!(decoded.sub, "u");
374 }
375
376 #[test]
377 fn test_encode_without_key_returns_error() {
378 let config = JwtConfig {
379 decoding_key: DecodingKey::from_secret(b"secret"),
380 encoding_key: None,
381 validation: Validation::new(Algorithm::HS256),
382 };
383 let mgr = JwtManager::with_config(config);
384 let claims = TestClaims {
385 sub: "x".to_string(),
386 exp: far_future_exp(),
387 };
388 let result = mgr.encode(&claims);
389 assert!(
390 matches!(result, Err(AuthError::EncodingKeyMissing)),
391 "expected EncodingKeyMissing, got {result:?}"
392 );
393 }
394
395 #[test]
396 fn test_revoked_token_rejected() {
397 use crate::revocation::TokenBlacklist;
398
399 #[derive(Debug, PartialEq, Serialize, Deserialize)]
400 struct ClaimsWithJti {
401 sub: String,
402 jti: String,
403 exp: u64,
404 }
405 impl HasJti for ClaimsWithJti {
406 fn jti(&self) -> Option<&str> {
407 Some(&self.jti)
408 }
409 }
410
411 let blacklist = TokenBlacklist::new();
412 let mgr = JwtManager::new(b"s").with_blacklist(blacklist.clone());
413 let claims = ClaimsWithJti {
414 sub: "u".into(),
415 jti: "unique-jti-1".into(),
416 exp: far_future_exp(),
417 };
418 let token = mgr.encode(&claims).unwrap();
419
420 mgr.decode::<ClaimsWithJti>(&token)
422 .expect("should be valid before revocation");
423
424 blacklist.revoke("unique-jti-1".into(), None);
426 let result = mgr.decode::<ClaimsWithJti>(&token);
427 assert!(
428 matches!(result, Err(AuthError::Revoked)),
429 "revoked token should be rejected, got {result:?}"
430 );
431 }
432
433 #[test]
434 fn test_expired_token_returns_expired_error() {
435 let mgr = JwtManager::new(b"secret");
436 let claims = serde_json::json!({ "sub": "u", "exp": 1_u64 });
438 let token = encode(
439 &Header::default(),
440 &claims,
441 &EncodingKey::from_secret(b"secret"),
442 )
443 .unwrap();
444 let result: Result<TestClaims, _> = mgr.decode(&token);
445 assert!(
446 matches!(result, Err(AuthError::Expired)),
447 "expected Expired, got {result:?}"
448 );
449 }
450}