1use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation, Algorithm};
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use std::sync::Arc;
5use std::time::{SystemTime, UNIX_EPOCH};
6use thiserror::Error;
7use dashmap::DashMap;
8use uuid::Uuid;
9use chrono::Duration;
10
11#[derive(Debug, Error, Clone)]
13pub enum JwtError {
14 #[error("Invalid token: {0}")]
15 InvalidToken(String),
16 #[error("Token expired")]
17 TokenExpired,
18 #[error("Invalid signature")]
19 InvalidSignature,
20 #[error("Token revoked")]
21 TokenRevoked,
22 #[error("Invalid issuer")]
23 InvalidIssuer,
24 #[error("Invalid audience")]
25 InvalidAudience,
26 #[error("Missing required claim: {0}")]
27 MissingClaim(String),
28 #[error("Key generation error: {0}")]
29 KeyError(String),
30}
31
32impl From<jsonwebtoken::errors::Error> for JwtError {
33 fn from(err: jsonwebtoken::errors::Error) -> Self {
34 match err.kind() {
35 jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::TokenExpired,
36 jsonwebtoken::errors::ErrorKind::InvalidSignature => JwtError::InvalidSignature,
37 jsonwebtoken::errors::ErrorKind::InvalidIssuer => JwtError::InvalidIssuer,
38 jsonwebtoken::errors::ErrorKind::InvalidAudience => JwtError::InvalidAudience,
39 _ => JwtError::InvalidToken(err.to_string()),
40 }
41 }
42}
43
44#[derive(Debug, Serialize, Deserialize, Clone)]
46pub struct Claims {
47 pub sub: String,
48 pub exp: usize,
49 pub iat: usize,
50 pub jti: String,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub iss: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub aud: Option<String>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub nbf: Option<usize>,
57 #[serde(flatten)]
58 pub custom: serde_json::Map<String, serde_json::Value>,
59}
60
61impl Claims {
62 #[inline]
63 pub fn new(sub: impl Into<String>) -> Self {
64 let now = Self::now();
65 Self {
66 sub: sub.into(),
67 exp: now + 900,
68 iat: now,
69 jti: Uuid::now_v7().to_string(),
70 iss: None,
71 aud: None,
72 nbf: None,
73 custom: serde_json::Map::new(),
74 }
75 }
76
77 #[inline]
78 pub fn now() -> usize {
79 SystemTime::now()
80 .duration_since(UNIX_EPOCH)
81 .unwrap_or_default()
82 .as_secs() as usize
83 }
84
85 #[inline]
86 pub fn with_expiration(mut self, seconds: u64) -> Self {
87 self.exp = Self::now() + seconds as usize;
88 self
89 }
90
91 #[inline]
92 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
93 self.iss = Some(issuer.into());
94 self
95 }
96
97 #[inline]
98 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
99 self.aud = Some(audience.into());
100 self
101 }
102
103 #[inline]
104 pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
105 self.custom.insert(key.into(), value);
106 self
107 }
108
109 #[inline]
110 pub fn is_expired(&self) -> bool {
111 Self::now() > self.exp
112 }
113
114 #[inline]
115 pub fn remaining_time(&self) -> Option<Duration> {
116 if self.is_expired() {
117 return None;
118 }
119 let remaining = self.exp - Self::now();
120 Some(Duration::seconds(remaining as i64))
121 }
122}
123
124#[derive(Debug, Clone, Copy)]
126pub enum TokenType {
127 Access,
128 Refresh,
129 Reset,
130 Verify,
131 Custom(u64),
132}
133
134impl TokenType {
135 #[inline]
136 pub const fn duration_seconds(&self) -> u64 {
137 match self {
138 TokenType::Access => 900, TokenType::Refresh => 604800, TokenType::Reset => 3600, TokenType::Verify => 86400, TokenType::Custom(secs) => *secs,
143 }
144 }
145}
146
147#[derive(Clone)]
149pub struct TokenBlacklist {
150 store: Arc<DashMap<String, usize>>,
151 cleanup_interval: tokio::time::Duration,
152}
153
154impl TokenBlacklist {
155 pub fn new(cleanup_interval_seconds: u64) -> Self {
156 let blacklist = Self {
157 store: Arc::new(DashMap::with_capacity(10000)),
158 cleanup_interval: tokio::time::Duration::from_secs(cleanup_interval_seconds),
159 };
160
161 let store = blacklist.store.clone();
162 let interval = blacklist.cleanup_interval;
163 tokio::spawn(async move {
164 let mut interval = tokio::time::interval(interval);
165 loop {
166 interval.tick().await;
167 let now = Claims::now();
168 store.retain(|_, &mut exp| exp > now);
169 }
170 });
171
172 blacklist
173 }
174
175 #[inline]
176 pub fn revoke(&self, jti: &str, exp: usize) {
177 self.store.insert(jti.to_string(), exp);
178 }
179
180 #[inline]
181 pub fn is_revoked(&self, jti: &str) -> bool {
182 if let Some(entry) = self.store.get_mut(jti) {
183 let exp = *entry;
184 if exp > Claims::now() {
185 return true;
186 }
187 drop(entry);
188 self.store.remove(jti);
189 }
190 false
191 }
192
193 #[inline]
194 pub fn len(&self) -> usize {
195 self.store.len()
196 }
197
198 #[inline]
199 pub fn is_empty(&self) -> bool {
200 self.store.is_empty()
201 }
202}
203
204impl Default for TokenBlacklist {
205 fn default() -> Self {
206 Self::new(300)
207 }
208}
209
210#[derive(Clone)]
212pub struct JwtService {
213 encoding_key: Arc<EncodingKey>,
214 decoding_key: Arc<DecodingKey>,
215 algorithm: Algorithm,
216 validation: Arc<Validation>,
217 blacklist: Option<TokenBlacklist>,
218 issuer: Option<String>,
219 audience: Option<String>,
220}
221
222impl JwtService {
223 pub fn new_hs256(secret: impl AsRef<[u8]>) -> Self {
226 let secret = secret.as_ref();
227 let mut validation = Validation::new(Algorithm::HS256);
228 validation.validate_exp = true;
229 validation.required_spec_claims = HashSet::from([
230 "exp".to_string(),
231 "iat".to_string(),
232 "jti".to_string(),
233 ]);
234
235 Self {
236 encoding_key: Arc::new(EncodingKey::from_secret(secret)),
237 decoding_key: Arc::new(DecodingKey::from_secret(secret)),
238 algorithm: Algorithm::HS256,
239 validation: Arc::new(validation),
240 blacklist: None,
241 issuer: None,
242 audience: None,
243 }
244 }
245
246 pub fn new_rs256(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
247 let mut validation = Validation::new(Algorithm::RS256);
248 validation.validate_exp = true;
249 validation.required_spec_claims = HashSet::from([
250 "exp".to_string(),
251 "iat".to_string(),
252 "jti".to_string(),
253 ]);
254
255 Ok(Self {
256 encoding_key: Arc::new(EncodingKey::from_rsa_pem(private_key.as_ref())
257 .map_err(|e| JwtError::KeyError(e.to_string()))?),
258 decoding_key: Arc::new(DecodingKey::from_rsa_pem(public_key.as_ref())
259 .map_err(|e| JwtError::KeyError(e.to_string()))?),
260 algorithm: Algorithm::RS256,
261 validation: Arc::new(validation),
262 blacklist: None,
263 issuer: None,
264 audience: None,
265 })
266 }
267
268 pub fn new_rs384(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
269 let mut validation = Validation::new(Algorithm::RS384);
270 validation.validate_exp = true;
271 validation.required_spec_claims = HashSet::from([
272 "exp".to_string(),
273 "iat".to_string(),
274 "jti".to_string(),
275 ]);
276
277 Ok(Self {
278 encoding_key: Arc::new(EncodingKey::from_rsa_pem(private_key.as_ref())
279 .map_err(|e| JwtError::KeyError(e.to_string()))?),
280 decoding_key: Arc::new(DecodingKey::from_rsa_pem(public_key.as_ref())
281 .map_err(|e| JwtError::KeyError(e.to_string()))?),
282 algorithm: Algorithm::RS384,
283 validation: Arc::new(validation),
284 blacklist: None,
285 issuer: None,
286 audience: None,
287 })
288 }
289
290 pub fn new_ecdsa_p256(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
291 let mut validation = Validation::new(Algorithm::ES256);
292 validation.validate_exp = true;
293 validation.required_spec_claims = HashSet::from([
294 "exp".to_string(),
295 "iat".to_string(),
296 "jti".to_string(),
297 ]);
298
299 Ok(Self {
300 encoding_key: Arc::new(EncodingKey::from_ec_pem(private_key.as_ref())
301 .map_err(|e| JwtError::KeyError(e.to_string()))?),
302 decoding_key: Arc::new(DecodingKey::from_ec_pem(public_key.as_ref())
303 .map_err(|e| JwtError::KeyError(e.to_string()))?),
304 algorithm: Algorithm::ES256,
305 validation: Arc::new(validation),
306 blacklist: None,
307 issuer: None,
308 audience: None,
309 })
310 }
311
312 pub fn new_ed25519(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
313 let mut validation = Validation::new(Algorithm::EdDSA);
314 validation.validate_exp = true;
315 validation.required_spec_claims = HashSet::from([
316 "exp".to_string(),
317 "iat".to_string(),
318 "jti".to_string(),
319 ]);
320
321 Ok(Self {
322 encoding_key: Arc::new(EncodingKey::from_ed_pem(private_key.as_ref())
323 .map_err(|e| JwtError::KeyError(e.to_string()))?),
324 decoding_key: Arc::new(DecodingKey::from_ed_pem(public_key.as_ref())
325 .map_err(|e| JwtError::KeyError(e.to_string()))?),
326 algorithm: Algorithm::EdDSA,
327 validation: Arc::new(validation),
328 blacklist: None,
329 issuer: None,
330 audience: None,
331 })
332 }
333
334 #[inline]
337 pub fn with_blacklist(mut self, blacklist: TokenBlacklist) -> Self {
338 self.blacklist = Some(blacklist);
339 self
340 }
341
342 #[inline]
343 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
344 let issuer = issuer.into();
345 self.issuer = Some(issuer.clone());
346 let validation = Arc::make_mut(&mut self.validation);
347 validation.set_issuer(&[issuer]);
348 self
349 }
350
351 #[inline]
352 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
353 let audience = audience.into();
354 self.audience = Some(audience.clone());
355 let validation = Arc::make_mut(&mut self.validation);
356 validation.set_audience(&[audience]);
357 self
358 }
359
360 #[inline]
361 pub fn with_leeway(mut self, seconds: u64) -> Self {
362 let validation = Arc::make_mut(&mut self.validation);
363 validation.leeway = seconds;
364 self
365 }
366
367 #[inline]
368 pub fn disable_exp_validation(mut self) -> Self {
369 let validation = Arc::make_mut(&mut self.validation);
370 validation.validate_exp = false;
371 self
372 }
373
374 #[inline]
377 pub fn generate(&self, sub: impl Into<String>, token_type: TokenType) -> Result<String, JwtError> {
378 let mut claims = Claims::new(sub);
379 let duration = token_type.duration_seconds();
380 claims.exp = Claims::now() + duration as usize;
381
382 if let Some(ref iss) = self.issuer {
383 claims.iss = Some(iss.clone());
384 }
385 if let Some(ref aud) = self.audience {
386 claims.aud = Some(aud.clone());
387 }
388
389 let header = Header::new(self.algorithm);
390 Ok(encode(&header, &claims, &self.encoding_key)?)
391 }
392
393 #[inline]
395 pub fn generate_exp_token(&self, sub: impl Into<String>, exp: usize) -> Result<String, JwtError> {
396 let mut claims = Claims::new(sub);
397 claims.exp = exp;
398
399 if let Some(ref iss) = self.issuer {
400 claims.iss = Some(iss.clone());
401 }
402 if let Some(ref aud) = self.audience {
403 claims.aud = Some(aud.clone());
404 }
405
406 let header = Header::new(self.algorithm);
407 Ok(encode(&header, &claims, &self.encoding_key)?)
408 }
409
410 #[inline]
411 pub fn generate_with_claims(&self, mut claims: Claims, token_type: TokenType) -> Result<String, JwtError> {
412 let duration = token_type.duration_seconds();
413 claims.exp = Claims::now() + duration as usize;
414 claims.iat = Claims::now();
415 claims.jti = Uuid::now_v7().to_string();
416
417 let header = Header::new(self.algorithm);
418 Ok(encode(&header, &claims, &self.encoding_key)?)
419 }
420
421 #[inline]
422 pub fn generate_pair(&self, sub: impl Into<String>) -> Result<(String, String), JwtError> {
423 let sub_str = sub.into();
424 let access = self.generate(sub_str.clone(), TokenType::Access)?;
425 let refresh = self.generate(sub_str, TokenType::Refresh)?;
426 Ok((access, refresh))
427 }
428
429 #[inline]
430 pub fn generate_access_refresh_with_claims(&self, claims: Claims) -> Result<(String, String), JwtError> {
431 let access_claims = claims.clone();
432 let refresh_claims = claims;
433
434 let access = self.generate_with_claims(access_claims, TokenType::Access)?;
435 let refresh = self.generate_with_claims(refresh_claims, TokenType::Refresh)?;
436
437 Ok((access, refresh))
438 }
439
440 #[inline]
443 pub fn verify(&self, token: &str) -> Result<Claims, JwtError> {
444 let token_data = decode::<Claims>(
445 token,
446 &self.decoding_key,
447 &self.validation,
448 )?;
449
450 let claims = token_data.claims;
451
452 if let Some(ref blacklist) = self.blacklist {
453 if blacklist.is_revoked(&claims.jti) {
454 return Err(JwtError::TokenRevoked);
455 }
456 }
457
458 Ok(claims)
459 }
460
461 #[inline]
462 pub fn verify_token(&self, token: &str) -> bool {
463 self.verify(token).is_ok()
464 }
465
466 #[inline]
467 pub fn verify_without_expiry(&self, token: &str) -> Result<Claims, JwtError> {
468 let mut validation = Validation::new(self.algorithm);
470 validation.validate_exp = false;
471 validation.leeway = self.validation.leeway;
473 validation.required_spec_claims = self.validation.required_spec_claims.clone();
474 if let Some(ref iss) = self.issuer {
476 validation.set_issuer(&[iss.clone()]);
477 }
478 if let Some(ref aud) = self.audience {
479 validation.set_audience(&[aud.clone()]);
480 }
481
482 let token_data = decode::<Claims>(
483 token,
484 &self.decoding_key,
485 &validation,
486 )?;
487
488 Ok(token_data.claims)
489 }
490
491 #[inline]
494 pub fn refresh_access(&self, refresh_token: &str) -> Result<String, JwtError> {
495 let claims = self.verify(refresh_token)?;
496
497 if claims.is_expired() {
498 return Err(JwtError::TokenExpired);
499 }
500
501 let new_claims = Claims::new(claims.sub);
502 self.generate_with_claims(new_claims, TokenType::Access)
503 }
504
505 #[inline]
506 pub fn revoke_token(&self, token: &str) -> Result<(), JwtError> {
507 let claims = self.verify(token)?;
508
509 if let Some(ref blacklist) = self.blacklist {
510 blacklist.revoke(&claims.jti, claims.exp);
511 Ok(())
512 } else {
513 Err(JwtError::InvalidToken("Blacklist not configured".to_string()))
514 }
515 }
516
517 #[inline]
518 pub fn revoke_by_jti(&self, jti: &str, exp: usize) -> Result<(), JwtError> {
519 if let Some(ref blacklist) = self.blacklist {
520 blacklist.revoke(jti, exp);
521 Ok(())
522 } else {
523 Err(JwtError::InvalidToken("Blacklist not configured".to_string()))
524 }
525 }
526
527 #[inline]
528 pub fn is_revoked(&self, jti: &str) -> bool {
529 self.blacklist
530 .as_ref()
531 .map(|b| b.is_revoked(jti))
532 .unwrap_or(false)
533 }
534
535 #[inline]
539 pub fn peek_claims(&self, token: &str) -> Option<Claims> {
540 let mut validation = Validation::default();
542 validation.validate_exp = false;
543 validation.validate_nbf = false;
544 validation.validate_aud = false;
545 decode::<Claims>(token, &self.decoding_key, &validation)
548 .ok()
549 .map(|data| data.claims)
550 }
551
552 #[inline]
553 pub fn extract_subject(&self, token: &str) -> Option<String> {
554 self.peek_claims(token).map(|c| c.sub)
555 }
556
557 #[inline]
558 pub fn get_token_expiry(&self, token: &str) -> Option<usize> {
559 self.peek_claims(token).map(|c| c.exp)
560 }
561
562 #[inline]
563 pub fn get_token_jti(&self, token: &str) -> Option<String> {
564 self.peek_claims(token).map(|c| c.jti)
565 }
566
567 #[inline]
568 pub fn get_token_issuer(&self, token: &str) -> Option<String> {
569 self.peek_claims(token).and_then(|c| c.iss)
570 }
571
572 #[inline]
573 pub fn get_token_audience(&self, token: &str) -> Option<String> {
574 self.peek_claims(token).and_then(|c| c.aud)
575 }
576}