grpc_jwt_tonic/
jwt_engine.rs

1use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4use std::fmt;
5use std::fs;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9/// JWT-related errors that can occur during token operations.
10#[derive(Debug)]
11pub enum JwtError {
12    /// Secret key is missing or empty
13    MissingSecretKey,
14    /// The specified signing algorithm is not supported
15    InvalidSigningAlgorithm,
16    /// The token has expired
17    ExpiredToken,
18    /// The `exp` field is missing from token claims
19    MissingExpField,
20    /// The `exp` field is not in the correct format (should be a number)
21    WrongExpFormat,
22    /// Generic token error with custom message
23    Token(String),
24    /// User lacks permission to access the resource
25    Forbidden,
26    /// Authenticator function is not defined
27    MissingAuthenticatorFunc,
28    /// Username or password is missing from login request
29    MissingLoginValues,
30    /// Authentication failed due to incorrect credentials
31    FailedAuthentication,
32    /// Failed to create JWT token
33    FailedTokenCreation,
34    /// Authorization header is empty
35    EmptyAuthHeader,
36    /// Authorization header format is invalid
37    InvalidAuthHeader,
38    /// Query parameter token is empty
39    EmptyQueryToken,
40    /// Cookie token is empty
41    EmptyCookieToken,
42    /// Parameter token is empty
43    EmptyParamToken,
44    /// Private key file cannot be read
45    NoPrivKeyFile,
46    /// Public key file cannot be read
47    NoPubKeyFile,
48    /// Private key is invalid or malformed
49    InvalidPrivKey,
50    /// Public key is invalid or malformed
51    InvalidPubKey,
52}
53
54impl fmt::Display for JwtError {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            JwtError::MissingSecretKey => write!(f, "secret key is required"),
58            JwtError::InvalidSigningAlgorithm => write!(f, "invalid signing algorithm"),
59            JwtError::ExpiredToken => write!(f, "token is expired"),
60            JwtError::MissingExpField => write!(f, "missing exp field"),
61            JwtError::WrongExpFormat => write!(f, "exp must be number"),
62            JwtError::Token(msg) => write!(f, "token error: {msg}"),
63            JwtError::Forbidden => write!(f, "you don't have permission to access this resource"),
64            JwtError::MissingAuthenticatorFunc => write!(f, "authenticator func is undefined"),
65            JwtError::MissingLoginValues => write!(f, "missing Username or Password"),
66            JwtError::FailedAuthentication => write!(f, "incorrect Username or Password"),
67            JwtError::FailedTokenCreation => write!(f, "failed to create JWT Token"),
68            JwtError::EmptyAuthHeader => write!(f, "auth header is empty"),
69            JwtError::InvalidAuthHeader => write!(f, "auth header is invalid"),
70            JwtError::EmptyQueryToken => write!(f, "query token is empty"),
71            JwtError::EmptyCookieToken => write!(f, "cookie token is empty"),
72            JwtError::EmptyParamToken => write!(f, "parameter token is empty"),
73            JwtError::NoPrivKeyFile => write!(f, "private key file unreadable"),
74            JwtError::NoPubKeyFile => write!(f, "public key file unreadable"),
75            JwtError::InvalidPrivKey => write!(f, "private key invalid"),
76            JwtError::InvalidPubKey => write!(f, "public key invalid"),
77        }
78    }
79}
80
81impl std::error::Error for JwtError {}
82
83/// Supported JWT signing algorithms with their key material.
84#[derive(Clone)]
85pub enum AlgorithmKind {
86    /// HMAC with SHA-256
87    HS256(Vec<u8>),
88    /// HMAC with SHA-384
89    HS384(Vec<u8>),
90    /// HMAC with SHA-512
91    HS512(Vec<u8>),
92    /// RSA with SHA-256
93    RS256 { private_pem: Vec<u8>, public_pem: Vec<u8> },
94    /// RSA with SHA-384
95    RS384 { private_pem: Vec<u8>, public_pem: Vec<u8> },
96    /// RSA with SHA-512
97    RS512 { private_pem: Vec<u8>, public_pem: Vec<u8> },
98}
99
100impl AlgorithmKind {
101    /// Returns the jsonwebtoken Algorithm enum variant for this algorithm.
102    pub fn algorithm(&self) -> Algorithm {
103        match self {
104            AlgorithmKind::HS256(_) => Algorithm::HS256,
105            AlgorithmKind::HS384(_) => Algorithm::HS384,
106            AlgorithmKind::HS512(_) => Algorithm::HS512,
107            AlgorithmKind::RS256 { .. } => Algorithm::RS256,
108            AlgorithmKind::RS384 { .. } => Algorithm::RS384,
109            AlgorithmKind::RS512 { .. } => Algorithm::RS512,
110        }
111    }
112
113    /// Creates an EncodingKey for token signing.
114    /// 
115    /// # Returns
116    /// * `Ok(EncodingKey)` - Key for signing tokens
117    /// * `Err(JwtError)` - If key creation fails
118    fn encoding_key(&self) -> Result<EncodingKey, JwtError> {
119        match self {
120            AlgorithmKind::HS256(k) | AlgorithmKind::HS384(k) | AlgorithmKind::HS512(k) => {
121                Ok(EncodingKey::from_secret(k))
122            }
123            AlgorithmKind::RS256 { private_pem, .. }
124            | AlgorithmKind::RS384 { private_pem, .. }
125            | AlgorithmKind::RS512 { private_pem, .. } => {
126                EncodingKey::from_rsa_pem(private_pem).map_err(|e| JwtError::Token(e.to_string()))
127            }
128        }
129    }
130
131    /// Creates a DecodingKey for token verification.
132    /// 
133    /// # Returns
134    /// * `Ok(DecodingKey)` - Key for verifying tokens
135    /// * `Err(JwtError)` - If key creation fails
136    fn decoding_key(&self) -> Result<DecodingKey, JwtError> {
137        match self {
138            AlgorithmKind::HS256(k) | AlgorithmKind::HS384(k) | AlgorithmKind::HS512(k) => {
139                Ok(DecodingKey::from_secret(k))
140            }
141            AlgorithmKind::RS256 { public_pem, .. }
142            | AlgorithmKind::RS384 { public_pem, .. }
143            | AlgorithmKind::RS512 { public_pem, .. } => {
144                DecodingKey::from_rsa_pem(public_pem).map_err(|e| JwtError::Token(e.to_string()))
145            }
146        }
147    }
148
149    /// Creates an RSA algorithm variant by loading keys from files.
150    /// 
151    /// # Arguments
152    /// * `alg` - The RSA algorithm to use (RS256, RS384, or RS512)
153    /// * `priv_key_file` - Path to the private key file
154    /// * `pub_key_file` - Path to the public key file
155    /// 
156    /// # Returns
157    /// * `Ok(AlgorithmKind)` - The algorithm with loaded keys
158    /// * `Err(JwtError)` - If files cannot be read or keys are invalid
159    pub fn from_rsa_files(alg: Algorithm, priv_key_file: &str, pub_key_file: &str) -> Result<Self, JwtError> {
160        let private_pem = fs::read(priv_key_file)
161            .map_err(|_| JwtError::NoPrivKeyFile)?;
162        let public_pem = fs::read(pub_key_file)
163            .map_err(|_| JwtError::NoPubKeyFile)?;
164
165        // Validate keys
166        EncodingKey::from_rsa_pem(&private_pem)
167            .map_err(|_| JwtError::InvalidPrivKey)?;
168        DecodingKey::from_rsa_pem(&public_pem)
169            .map_err(|_| JwtError::InvalidPubKey)?;
170
171        match alg {
172            Algorithm::RS256 => Ok(AlgorithmKind::RS256 { private_pem, public_pem }),
173            Algorithm::RS384 => Ok(AlgorithmKind::RS384 { private_pem, public_pem }),
174            Algorithm::RS512 => Ok(AlgorithmKind::RS512 { private_pem, public_pem }),
175            _ => Err(JwtError::InvalidSigningAlgorithm),
176        }
177    }
178}
179
180/// Configuration options for the JWT engine.
181#[derive(Clone)]
182pub struct JwtEngineOptions {
183    /// JWT realm identifier (default: "grpc jwt")
184    pub realm: String,
185    /// Signing algorithm with key material
186    pub alg: AlgorithmKind,
187    /// Token expiration duration (default: 1 hour)
188    pub timeout: Duration,
189    /// Maximum time allowed for token refresh (default: 7 days)
190    pub max_refresh: Duration,
191    /// Key name for identity in JWT claims (default: "identity")
192    pub identity_key: String,
193    /// Where to find the token (default: "header:Authorization")
194    pub token_lookup: String,
195    /// Token prefix in header (default: "Bearer")
196    pub token_head_name: String,
197    /// Whether to send authorization in response headers
198    pub send_authorization: bool,
199    /// Whether to disable abort on authentication failure (not handled in this version)
200    pub disabled_abort: bool,
201    /// Path to private key file for RSA algorithms
202    pub priv_key_file: String,
203    /// Path to public key file for RSA algorithms
204    pub pub_key_file: String,
205}
206
207impl Default for JwtEngineOptions {
208    fn default() -> Self {
209        Self {
210            realm: "grpc jwt".to_string(),
211            alg: AlgorithmKind::HS256(b"secret".to_vec()),
212            timeout: Duration::from_secs(3600),
213            max_refresh: Duration::from_secs(3600 * 24 * 7),
214            identity_key: "identity".to_string(),
215            token_lookup: "header:Authorization".to_string(),
216            token_head_name: "Bearer".to_string(),
217            send_authorization: false,
218            disabled_abort: false,
219            priv_key_file: String::new(),
220            pub_key_file: String::new(),
221        }
222    }
223}
224
225/// JWT claims structure containing expiration and custom data.
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct Claims {
228    /// Expiration timestamp (seconds since Unix epoch)
229    pub exp: u64,
230    /// Original issued at timestamp (seconds since Unix epoch)
231    pub orig_iat: u64,
232    /// Additional custom claims
233    #[serde(flatten)]
234    pub extra: Map<String, Value>,
235}
236
237/// JWT engine for creating and validating JWT tokens.
238#[derive(Clone)]
239pub struct JwtEngine {
240    /// Engine configuration options
241    pub opts: Arc<JwtEngineOptions>,
242}
243
244impl JwtEngine {
245    /// Creates a new JWT engine with the given options.
246    /// 
247    /// # Arguments
248    /// * `opts` - Configuration options for the engine
249    /// 
250    /// # Returns
251    /// * `Ok(JwtEngine)` - Successfully created engine
252    /// * `Err(JwtError)` - If configuration is invalid or keys cannot be loaded
253    pub fn new(mut opts: JwtEngineOptions) -> Result<Self, JwtError> {
254        // Apply defaults
255        if opts.token_lookup.is_empty() {
256            opts.token_lookup = "header:Authorization".to_string();
257        }
258        
259        if opts.timeout.as_secs() == 0 {
260            opts.timeout = Duration::from_secs(3600);
261        }
262        
263        opts.token_head_name = opts.token_head_name.trim().to_string();
264        if opts.token_head_name.is_empty() {
265            opts.token_head_name = "Bearer".to_string();
266        }
267        
268        if opts.identity_key.is_empty() {
269            opts.identity_key = "identity".to_string();
270        }
271        
272        if opts.realm.is_empty() {
273            opts.realm = "grpc jwt".to_string();
274        }
275
276        // Read keys if using public key algorithms
277        if Self::using_public_key_algo_static(&opts.alg) {
278            if opts.priv_key_file.is_empty() || opts.pub_key_file.is_empty() {
279                return Err(JwtError::MissingSecretKey);
280            }
281            // Update the algorithm with loaded keys
282            opts.alg = Self::read_keys_static(&opts.priv_key_file, &opts.pub_key_file, &opts.alg)?;
283        } else {
284            // For HMAC algorithms, ensure we have a key
285            match &opts.alg {
286                AlgorithmKind::HS256(k) | AlgorithmKind::HS384(k) | AlgorithmKind::HS512(k) => {
287                    if k.is_empty() {
288                        return Err(JwtError::MissingSecretKey);
289                    }
290                }
291                _ => {}
292            }
293        }
294
295        Ok(Self { opts: Arc::new(opts) })
296    }
297
298    /// Checks if the algorithm requires public/private key pairs (RSA algorithms).
299    /// 
300    /// # Arguments
301    /// * `alg` - The algorithm to check
302    /// 
303    /// # Returns
304    /// * `true` - If algorithm uses RSA (requires key files)
305    /// * `false` - If algorithm uses HMAC (uses secret)
306    fn using_public_key_algo_static(alg: &AlgorithmKind) -> bool {
307        matches!(
308            alg,
309            AlgorithmKind::RS256 { .. } | AlgorithmKind::RS384 { .. } | AlgorithmKind::RS512 { .. }
310        )
311    }
312    
313    /// Loads RSA keys from files and creates the appropriate algorithm variant.
314    /// 
315    /// # Arguments
316    /// * `priv_key_file` - Path to private key file
317    /// * `pub_key_file` - Path to public key file
318    /// * `current_alg` - Current algorithm to determine which RSA variant to create
319    /// 
320    /// # Returns
321    /// * `Ok(AlgorithmKind)` - Algorithm with loaded keys
322    /// * `Err(JwtError)` - If files cannot be read or keys are invalid
323    fn read_keys_static(priv_key_file: &str, pub_key_file: &str, current_alg: &AlgorithmKind) -> Result<AlgorithmKind, JwtError> {
324        let private_pem = Self::read_private_key_static(priv_key_file)?;
325        let public_pem = Self::read_public_key_static(pub_key_file)?;
326        
327        match current_alg {
328            AlgorithmKind::RS256 { .. } => Ok(AlgorithmKind::RS256 { private_pem, public_pem }),
329            AlgorithmKind::RS384 { .. } => Ok(AlgorithmKind::RS384 { private_pem, public_pem }),
330            AlgorithmKind::RS512 { .. } => Ok(AlgorithmKind::RS512 { private_pem, public_pem }),
331            _ => Err(JwtError::InvalidSigningAlgorithm),
332        }
333    }
334
335    /// Reads and validates a private key from file.
336    /// 
337    /// # Arguments
338    /// * `file_path` - Path to the private key file
339    /// 
340    /// # Returns
341    /// * `Ok(Vec<u8>)` - The private key bytes
342    /// * `Err(JwtError)` - If file cannot be read or key is invalid
343    fn read_private_key_static(file_path: &str) -> Result<Vec<u8>, JwtError> {
344        let key_data = fs::read(file_path)
345            .map_err(|_| JwtError::NoPrivKeyFile)?;
346        
347        // Validate the private key
348        EncodingKey::from_rsa_pem(&key_data)
349            .map_err(|_| JwtError::InvalidPrivKey)?;
350            
351        Ok(key_data)
352    }
353
354    /// Reads and validates a public key from file.
355    /// 
356    /// # Arguments
357    /// * `file_path` - Path to the public key file
358    /// 
359    /// # Returns
360    /// * `Ok(Vec<u8>)` - The public key bytes
361    /// * `Err(JwtError)` - If file cannot be read or key is invalid
362    fn read_public_key_static(file_path: &str) -> Result<Vec<u8>, JwtError> {
363        let key_data = fs::read(file_path)
364            .map_err(|_| JwtError::NoPubKeyFile)?;
365        
366        // Validate the public key
367        DecodingKey::from_rsa_pem(&key_data)
368            .map_err(|_| JwtError::InvalidPubKey)?;
369            
370        Ok(key_data)
371    }
372
373    /// Checks if this engine uses a public key algorithm.
374    /// 
375    /// # Returns
376    /// * `true` - If using RSA algorithm
377    /// * `false` - If using HMAC algorithm
378    pub fn using_public_key_algo(&self) -> bool {
379        Self::using_public_key_algo_static(&self.opts.alg)
380    }
381
382    /// Returns the current Unix timestamp in seconds.
383    /// 
384    /// # Returns
385    /// * `u64` - Current timestamp in seconds since Unix epoch
386    pub fn now() -> u64 {
387        SystemTime::now()
388            .duration_since(UNIX_EPOCH)
389            .unwrap_or_default()
390            .as_secs()
391    }
392
393    /// Signs a JWT token with additional claims.
394    /// 
395    /// # Arguments
396    /// * `extra` - Additional claims to include in the token
397    /// 
398    /// # Returns
399    /// * `Ok((String, u64))` - Tuple of (JWT token, expiration timestamp)
400    /// * `Err(JwtError)` - If token creation fails
401    pub fn sign_with_extra(&self, mut extra: Map<String, Value>) -> Result<(String, u64), JwtError> {
402        let now = Self::now();
403        let exp = now + self.opts.timeout.as_secs();
404        extra.insert("orig_iat".into(), Value::from(now));
405        extra.insert("exp".into(), Value::from(exp));
406
407        let claims = Claims { exp, orig_iat: now, extra };
408        let mut header = Header::new(self.opts.alg.algorithm());
409        header.typ = Some("JWT".into());
410        let token = jsonwebtoken::encode(&header, &claims, &self.opts.alg.encoding_key()?)
411            .map_err(|e| JwtError::Token(e.to_string()))?;
412        Ok((token, exp))
413    }
414
415    /// Decodes a JWT token without validating expiration.
416    /// 
417    /// # Arguments
418    /// * `token` - The JWT token string to decode
419    /// 
420    /// # Returns
421    /// * `Ok(Claims)` - The decoded claims
422    /// * `Err(JwtError)` - If token is malformed or signature is invalid
423    pub fn decode(&self, token: &str) -> Result<Claims, JwtError> {
424        let mut validation = Validation::new(self.opts.alg.algorithm());
425        validation.validate_exp = false;
426        let data = jsonwebtoken::decode::<Claims>(
427            token,
428            &self.opts.alg.decoding_key()?,
429            &validation,
430        ).map_err(|e| JwtError::Token(e.to_string()))?;
431        Ok(data.claims)
432    }
433
434    /// Extracts all claims from a JWT token as a Map.
435    /// 
436    /// # Arguments
437    /// * `token` - The JWT token string to parse
438    /// 
439    /// # Returns
440    /// * `Ok(Map<String, Value>)` - All claims including exp and orig_iat
441    /// * `Err(JwtError)` - If token cannot be decoded
442    pub fn get_claims(&self, token: &str) -> Result<Map<String, Value>, JwtError> {
443        let claims = self.decode(token)?;
444        let mut m = claims.extra.clone();
445        m.insert("exp".to_string(), Value::from(claims.exp));
446        m.insert("orig_iat".to_string(), Value::from(claims.orig_iat));
447        Ok(m)
448    }
449
450    /// Validates that a token has not expired.
451    /// 
452    /// # Arguments
453    /// * `claims` - The claims map to check for expiration
454    /// 
455    /// # Returns
456    /// * `Ok(())` - If token is still valid
457    /// * `Err(JwtError)` - If token is expired or exp field is invalid
458    pub fn ensure_not_expired(&self, claims: &Map<String, Value>) -> Result<(), JwtError> {
459        let exp = claims
460            .get("exp")
461            .ok_or(JwtError::MissingExpField)?
462            .as_u64()
463            .ok_or(JwtError::WrongExpFormat)?;
464        if (exp as i64) < (Self::now() as i64) {
465            return Err(JwtError::ExpiredToken);
466        }
467        Ok(())
468    }
469
470    /// Checks if a token is within the refresh window and returns updated claims.
471    /// 
472    /// # Arguments
473    /// * `token` - The JWT token string to check
474    /// 
475    /// # Returns
476    /// * `Ok(Map<String, Value>)` - Claims if token can be refreshed
477    /// * `Err(JwtError)` - If token is beyond refresh window or invalid
478    pub fn check_if_token_expire(&self, token: &str) -> Result<Map<String, Value>, JwtError> {
479        let claims = self.decode(token)?;
480        let orig_iat = claims.orig_iat;
481        let now = Self::now();
482        if (now as i64) - (orig_iat as i64) > self.opts.max_refresh.as_secs() as i64 {
483            return Err(JwtError::ExpiredToken);
484        }
485        let mut m = claims.extra.clone();
486        m.insert("orig_iat".into(), Value::from(orig_iat));
487        m.insert("exp".into(), Value::from(claims.exp));
488        Ok(m)
489    }
490}