Skip to main content

fr_rust/jwt/
jwt.rs

1use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation, Algorithm};
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use std::sync::Arc;
5use std::time::{SystemTime, UNIX_EPOCH};
6use thiserror::Error;
7use dashmap::DashMap;
8use uuid::Uuid;
9use chrono::Duration;
10
11// ============ Error Types ============
12#[derive(Debug, Error, Clone)]
13pub enum JwtError {
14    #[error("Invalid token: {0}")]
15    InvalidToken(String),
16    #[error("Token expired")]
17    TokenExpired,
18    #[error("Invalid signature")]
19    InvalidSignature,
20    #[error("Token revoked")]
21    TokenRevoked,
22    #[error("Invalid issuer")]
23    InvalidIssuer,
24    #[error("Invalid audience")]
25    InvalidAudience,
26    #[error("Missing required claim: {0}")]
27    MissingClaim(String),
28    #[error("Key generation error: {0}")]
29    KeyError(String),
30}
31
32impl From<jsonwebtoken::errors::Error> for JwtError {
33    fn from(err: jsonwebtoken::errors::Error) -> Self {
34        match err.kind() {
35            jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::TokenExpired,
36            jsonwebtoken::errors::ErrorKind::InvalidSignature => JwtError::InvalidSignature,
37            jsonwebtoken::errors::ErrorKind::InvalidIssuer => JwtError::InvalidIssuer,
38            jsonwebtoken::errors::ErrorKind::InvalidAudience => JwtError::InvalidAudience,
39            _ => JwtError::InvalidToken(err.to_string()),
40        }
41    }
42}
43
44// ============ Claims ============
45#[derive(Debug, Serialize, Deserialize, Clone)]
46pub struct Claims {
47    pub sub: String,
48    pub exp: usize,
49    pub iat: usize,
50    pub jti: String,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub iss: Option<String>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub aud: Option<String>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub nbf: Option<usize>,
57    #[serde(flatten)]
58    pub custom: serde_json::Map<String, serde_json::Value>,
59}
60
61impl Claims {
62    #[inline]
63    pub fn new(sub: impl Into<String>) -> Self {
64        let now = Self::now();
65        Self {
66            sub: sub.into(),
67            exp: now + 900,
68            iat: now,
69            jti: Uuid::now_v7().to_string(),
70            iss: None,
71            aud: None,
72            nbf: None,
73            custom: serde_json::Map::new(),
74        }
75    }
76
77    #[inline]
78    pub fn now() -> usize {
79        SystemTime::now()
80            .duration_since(UNIX_EPOCH)
81            .unwrap_or_default()
82            .as_secs() as usize
83    }
84
85    #[inline]
86    pub fn with_expiration(mut self, seconds: u64) -> Self {
87        self.exp = Self::now() + seconds as usize;
88        self
89    }
90
91    #[inline]
92    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
93        self.iss = Some(issuer.into());
94        self
95    }
96
97    #[inline]
98    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
99        self.aud = Some(audience.into());
100        self
101    }
102
103    #[inline]
104    pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
105        self.custom.insert(key.into(), value);
106        self
107    }
108
109    #[inline]
110    pub fn is_expired(&self) -> bool {
111        Self::now() > self.exp
112    }
113
114    #[inline]
115    pub fn remaining_time(&self) -> Option<Duration> {
116        if self.is_expired() {
117            return None;
118        }
119        let remaining = self.exp - Self::now();
120        Some(Duration::seconds(remaining as i64))
121    }
122}
123
124// ============ Token Types ============
125#[derive(Debug, Clone, Copy)]
126pub enum TokenType {
127    Access,
128    Refresh,
129    Reset,
130    Verify,
131    Custom(u64),
132}
133
134impl TokenType {
135    #[inline]
136    pub const fn duration_seconds(&self) -> u64 {
137        match self {
138            TokenType::Access => 900,        // 15 minutes
139            TokenType::Refresh => 604800,    // 7 days
140            TokenType::Reset => 3600,        // 1 hour
141            TokenType::Verify => 86400,      // 24 hours
142            TokenType::Custom(secs) => *secs,
143        }
144    }
145}
146
147// ============ Blacklist with Sharded Storage ============
148#[derive(Clone)]
149pub struct TokenBlacklist {
150    store: Arc<DashMap<String, usize>>,
151    cleanup_interval: tokio::time::Duration,
152}
153
154impl TokenBlacklist {
155    pub fn new(cleanup_interval_seconds: u64) -> Self {
156        let blacklist = Self {
157            store: Arc::new(DashMap::with_capacity(10000)),
158            cleanup_interval: tokio::time::Duration::from_secs(cleanup_interval_seconds),
159        };
160
161        let store = blacklist.store.clone();
162        let interval = blacklist.cleanup_interval;
163        tokio::spawn(async move {
164            let mut interval = tokio::time::interval(interval);
165            loop {
166                interval.tick().await;
167                let now = Claims::now();
168                store.retain(|_, &mut exp| exp > now);
169            }
170        });
171
172        blacklist
173    }
174
175    #[inline]
176    pub fn revoke(&self, jti: &str, exp: usize) {
177        self.store.insert(jti.to_string(), exp);
178    }
179
180    #[inline]
181    pub fn is_revoked(&self, jti: &str) -> bool {
182        if let Some(entry) = self.store.get_mut(jti) {
183            let exp = *entry;
184            if exp > Claims::now() {
185                return true;
186            }
187            drop(entry);
188            self.store.remove(jti);
189        }
190        false
191    }
192
193    #[inline]
194    pub fn len(&self) -> usize {
195        self.store.len()
196    }
197
198    #[inline]
199    pub fn is_empty(&self) -> bool {
200        self.store.is_empty()
201    }
202}
203
204impl Default for TokenBlacklist {
205    fn default() -> Self {
206        Self::new(300)
207    }
208}
209
210// ============ Main JWT Service ============
211#[derive(Clone)]
212pub struct JwtService {
213    encoding_key: Arc<EncodingKey>,
214    decoding_key: Arc<DecodingKey>,
215    algorithm: Algorithm,
216    validation: Arc<Validation>,
217    blacklist: Option<TokenBlacklist>,
218    issuer: Option<String>,
219    audience: Option<String>,
220}
221
222impl JwtService {
223    // ===== Factory Methods =====
224
225    pub fn new_hs256(secret: impl AsRef<[u8]>) -> Self {
226        let secret = secret.as_ref();
227        let mut validation = Validation::new(Algorithm::HS256);
228        validation.validate_exp = true;
229        validation.required_spec_claims = HashSet::from([
230            "exp".to_string(),
231            "iat".to_string(),
232            "jti".to_string(),
233        ]);
234
235        Self {
236            encoding_key: Arc::new(EncodingKey::from_secret(secret)),
237            decoding_key: Arc::new(DecodingKey::from_secret(secret)),
238            algorithm: Algorithm::HS256,
239            validation: Arc::new(validation),
240            blacklist: None,
241            issuer: None,
242            audience: None,
243        }
244    }
245
246    pub fn new_rs256(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
247        let mut validation = Validation::new(Algorithm::RS256);
248        validation.validate_exp = true;
249        validation.required_spec_claims = HashSet::from([
250            "exp".to_string(),
251            "iat".to_string(),
252            "jti".to_string(),
253        ]);
254
255        Ok(Self {
256            encoding_key: Arc::new(EncodingKey::from_rsa_pem(private_key.as_ref())
257                .map_err(|e| JwtError::KeyError(e.to_string()))?),
258            decoding_key: Arc::new(DecodingKey::from_rsa_pem(public_key.as_ref())
259                .map_err(|e| JwtError::KeyError(e.to_string()))?),
260            algorithm: Algorithm::RS256,
261            validation: Arc::new(validation),
262            blacklist: None,
263            issuer: None,
264            audience: None,
265        })
266    }
267
268    pub fn new_rs384(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
269        let mut validation = Validation::new(Algorithm::RS384);
270        validation.validate_exp = true;
271        validation.required_spec_claims = HashSet::from([
272            "exp".to_string(),
273            "iat".to_string(),
274            "jti".to_string(),
275        ]);
276
277        Ok(Self {
278            encoding_key: Arc::new(EncodingKey::from_rsa_pem(private_key.as_ref())
279                .map_err(|e| JwtError::KeyError(e.to_string()))?),
280            decoding_key: Arc::new(DecodingKey::from_rsa_pem(public_key.as_ref())
281                .map_err(|e| JwtError::KeyError(e.to_string()))?),
282            algorithm: Algorithm::RS384,
283            validation: Arc::new(validation),
284            blacklist: None,
285            issuer: None,
286            audience: None,
287        })
288    }
289
290    pub fn new_ecdsa_p256(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
291        let mut validation = Validation::new(Algorithm::ES256);
292        validation.validate_exp = true;
293        validation.required_spec_claims = HashSet::from([
294            "exp".to_string(),
295            "iat".to_string(),
296            "jti".to_string(),
297        ]);
298
299        Ok(Self {
300            encoding_key: Arc::new(EncodingKey::from_ec_pem(private_key.as_ref())
301                .map_err(|e| JwtError::KeyError(e.to_string()))?),
302            decoding_key: Arc::new(DecodingKey::from_ec_pem(public_key.as_ref())
303                .map_err(|e| JwtError::KeyError(e.to_string()))?),
304            algorithm: Algorithm::ES256,
305            validation: Arc::new(validation),
306            blacklist: None,
307            issuer: None,
308            audience: None,
309        })
310    }
311
312    pub fn new_ed25519(private_key: impl AsRef<[u8]>, public_key: impl AsRef<[u8]>) -> Result<Self, JwtError> {
313        let mut validation = Validation::new(Algorithm::EdDSA);
314        validation.validate_exp = true;
315        validation.required_spec_claims = HashSet::from([
316            "exp".to_string(),
317            "iat".to_string(),
318            "jti".to_string(),
319        ]);
320
321        Ok(Self {
322            encoding_key: Arc::new(EncodingKey::from_ed_pem(private_key.as_ref())
323                .map_err(|e| JwtError::KeyError(e.to_string()))?),
324            decoding_key: Arc::new(DecodingKey::from_ed_pem(public_key.as_ref())
325                .map_err(|e| JwtError::KeyError(e.to_string()))?),
326            algorithm: Algorithm::EdDSA,
327            validation: Arc::new(validation),
328            blacklist: None,
329            issuer: None,
330            audience: None,
331        })
332    }
333
334    // ===== Configuration =====
335
336    #[inline]
337    pub fn with_blacklist(mut self, blacklist: TokenBlacklist) -> Self {
338        self.blacklist = Some(blacklist);
339        self
340    }
341
342    #[inline]
343    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
344        let issuer = issuer.into();
345        self.issuer = Some(issuer.clone());
346        let validation = Arc::make_mut(&mut self.validation);
347        validation.set_issuer(&[issuer]);
348        self
349    }
350
351    #[inline]
352    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
353        let audience = audience.into();
354        self.audience = Some(audience.clone());
355        let validation = Arc::make_mut(&mut self.validation);
356        validation.set_audience(&[audience]);
357        self
358    }
359
360    #[inline]
361    pub fn with_leeway(mut self, seconds: u64) -> Self {
362        let validation = Arc::make_mut(&mut self.validation);
363        validation.leeway = seconds;
364        self
365    }
366
367    #[inline]
368    pub fn disable_exp_validation(mut self) -> Self {
369        let validation = Arc::make_mut(&mut self.validation);
370        validation.validate_exp = false;
371        self
372    }
373
374    // ===== Token Generation =====
375
376    #[inline]
377    pub fn generate(&self, sub: impl Into<String>, token_type: TokenType) -> Result<String, JwtError> {
378        let mut claims = Claims::new(sub);
379        let duration = token_type.duration_seconds();
380        claims.exp = Claims::now() + duration as usize;
381
382        if let Some(ref iss) = self.issuer {
383            claims.iss = Some(iss.clone());
384        }
385        if let Some(ref aud) = self.audience {
386            claims.aud = Some(aud.clone());
387        }
388
389        let header = Header::new(self.algorithm);
390        Ok(encode(&header, &claims, &self.encoding_key)?)
391    }
392
393    /// Generate a token with a custom expiration timestamp (in seconds since epoch)
394    #[inline]
395    pub fn generate_exp_token(&self, sub: impl Into<String>, exp: usize) -> Result<String, JwtError> {
396        let mut claims = Claims::new(sub);
397        claims.exp = exp;
398
399        if let Some(ref iss) = self.issuer {
400            claims.iss = Some(iss.clone());
401        }
402        if let Some(ref aud) = self.audience {
403            claims.aud = Some(aud.clone());
404        }
405
406        let header = Header::new(self.algorithm);
407        Ok(encode(&header, &claims, &self.encoding_key)?)
408    }
409
410    #[inline]
411    pub fn generate_with_claims(&self, mut claims: Claims, token_type: TokenType) -> Result<String, JwtError> {
412        let duration = token_type.duration_seconds();
413        claims.exp = Claims::now() + duration as usize;
414        claims.iat = Claims::now();
415        claims.jti = Uuid::now_v7().to_string();
416
417        let header = Header::new(self.algorithm);
418        Ok(encode(&header, &claims, &self.encoding_key)?)
419    }
420
421    #[inline]
422    pub fn generate_pair(&self, sub: impl Into<String>) -> Result<(String, String), JwtError> {
423        let sub_str = sub.into();
424        let access = self.generate(sub_str.clone(), TokenType::Access)?;
425        let refresh = self.generate(sub_str, TokenType::Refresh)?;
426        Ok((access, refresh))
427    }
428
429    #[inline]
430    pub fn generate_access_refresh_with_claims(&self, claims: Claims) -> Result<(String, String), JwtError> {
431        let access_claims = claims.clone();
432        let refresh_claims = claims;
433
434        let access = self.generate_with_claims(access_claims, TokenType::Access)?;
435        let refresh = self.generate_with_claims(refresh_claims, TokenType::Refresh)?;
436
437        Ok((access, refresh))
438    }
439
440    // ===== Token Verification =====
441
442    #[inline]
443    pub fn verify(&self, token: &str) -> Result<Claims, JwtError> {
444        let token_data = decode::<Claims>(
445            token,
446            &self.decoding_key,
447            &self.validation,
448        )?;
449
450        let claims = token_data.claims;
451
452        if let Some(ref blacklist) = self.blacklist {
453            if blacklist.is_revoked(&claims.jti) {
454                return Err(JwtError::TokenRevoked);
455            }
456        }
457
458        Ok(claims)
459    }
460
461    #[inline]
462    pub fn verify_token(&self, token: &str) -> bool {
463        self.verify(token).is_ok()
464    }
465
466    #[inline]
467    pub fn verify_without_expiry(&self, token: &str) -> Result<Claims, JwtError> {
468        // Create a new Validation with the same algorithm and settings, but with validate_exp = false
469        let mut validation = Validation::new(self.algorithm);
470        validation.validate_exp = false;
471        // Copy other public settings from the stored validation
472        validation.leeway = self.validation.leeway;
473        validation.required_spec_claims = self.validation.required_spec_claims.clone();
474        // Copy issuer and audience if set
475        if let Some(ref iss) = self.issuer {
476            validation.set_issuer(&[iss.clone()]);
477        }
478        if let Some(ref aud) = self.audience {
479            validation.set_audience(&[aud.clone()]);
480        }
481
482        let token_data = decode::<Claims>(
483            token,
484            &self.decoding_key,
485            &validation,
486        )?;
487
488        Ok(token_data.claims)
489    }
490
491    // ===== Refresh & Revoke =====
492
493    #[inline]
494    pub fn refresh_access(&self, refresh_token: &str) -> Result<String, JwtError> {
495        let claims = self.verify(refresh_token)?;
496
497        if claims.is_expired() {
498            return Err(JwtError::TokenExpired);
499        }
500
501        let new_claims = Claims::new(claims.sub);
502        self.generate_with_claims(new_claims, TokenType::Access)
503    }
504
505    #[inline]
506    pub fn revoke_token(&self, token: &str) -> Result<(), JwtError> {
507        let claims = self.verify(token)?;
508
509        if let Some(ref blacklist) = self.blacklist {
510            blacklist.revoke(&claims.jti, claims.exp);
511            Ok(())
512        } else {
513            Err(JwtError::InvalidToken("Blacklist not configured".to_string()))
514        }
515    }
516
517    #[inline]
518    pub fn revoke_by_jti(&self, jti: &str, exp: usize) -> Result<(), JwtError> {
519        if let Some(ref blacklist) = self.blacklist {
520            blacklist.revoke(jti, exp);
521            Ok(())
522        } else {
523            Err(JwtError::InvalidToken("Blacklist not configured".to_string()))
524        }
525    }
526
527    #[inline]
528    pub fn is_revoked(&self, jti: &str) -> bool {
529        self.blacklist
530            .as_ref()
531            .map(|b| b.is_revoked(jti))
532            .unwrap_or(false)
533    }
534
535    // ===== Utilities =====
536
537    /// Extract claims without validation (for debugging only)
538    #[inline]
539    pub fn peek_claims(&self, token: &str) -> Option<Claims> {
540        // Create a minimal Validation with no checks
541        let mut validation = Validation::default();
542        validation.validate_exp = false;
543        validation.validate_nbf = false;
544        validation.validate_aud = false;
545        // Issuer and subject validation are skipped by not setting `iss` or `sub`.
546
547        decode::<Claims>(token, &self.decoding_key, &validation)
548            .ok()
549            .map(|data| data.claims)
550    }
551
552    #[inline]
553    pub fn extract_subject(&self, token: &str) -> Option<String> {
554        self.peek_claims(token).map(|c| c.sub)
555    }
556
557    #[inline]
558    pub fn get_token_expiry(&self, token: &str) -> Option<usize> {
559        self.peek_claims(token).map(|c| c.exp)
560    }
561
562    #[inline]
563    pub fn get_token_jti(&self, token: &str) -> Option<String> {
564        self.peek_claims(token).map(|c| c.jti)
565    }
566
567    #[inline]
568    pub fn get_token_issuer(&self, token: &str) -> Option<String> {
569        self.peek_claims(token).and_then(|c| c.iss)
570    }
571
572    #[inline]
573    pub fn get_token_audience(&self, token: &str) -> Option<String> {
574        self.peek_claims(token).and_then(|c| c.aud)
575    }
576}