1use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4use std::fmt;
5use std::fs;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9#[derive(Debug)]
11pub enum JwtError {
12 MissingSecretKey,
14 InvalidSigningAlgorithm,
16 ExpiredToken,
18 MissingExpField,
20 WrongExpFormat,
22 Token(String),
24 Forbidden,
26 MissingAuthenticatorFunc,
28 MissingLoginValues,
30 FailedAuthentication,
32 FailedTokenCreation,
34 EmptyAuthHeader,
36 InvalidAuthHeader,
38 EmptyQueryToken,
40 EmptyCookieToken,
42 EmptyParamToken,
44 NoPrivKeyFile,
46 NoPubKeyFile,
48 InvalidPrivKey,
50 InvalidPubKey,
52}
53
54impl fmt::Display for JwtError {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 JwtError::MissingSecretKey => write!(f, "secret key is required"),
58 JwtError::InvalidSigningAlgorithm => write!(f, "invalid signing algorithm"),
59 JwtError::ExpiredToken => write!(f, "token is expired"),
60 JwtError::MissingExpField => write!(f, "missing exp field"),
61 JwtError::WrongExpFormat => write!(f, "exp must be number"),
62 JwtError::Token(msg) => write!(f, "token error: {msg}"),
63 JwtError::Forbidden => write!(f, "you don't have permission to access this resource"),
64 JwtError::MissingAuthenticatorFunc => write!(f, "authenticator func is undefined"),
65 JwtError::MissingLoginValues => write!(f, "missing Username or Password"),
66 JwtError::FailedAuthentication => write!(f, "incorrect Username or Password"),
67 JwtError::FailedTokenCreation => write!(f, "failed to create JWT Token"),
68 JwtError::EmptyAuthHeader => write!(f, "auth header is empty"),
69 JwtError::InvalidAuthHeader => write!(f, "auth header is invalid"),
70 JwtError::EmptyQueryToken => write!(f, "query token is empty"),
71 JwtError::EmptyCookieToken => write!(f, "cookie token is empty"),
72 JwtError::EmptyParamToken => write!(f, "parameter token is empty"),
73 JwtError::NoPrivKeyFile => write!(f, "private key file unreadable"),
74 JwtError::NoPubKeyFile => write!(f, "public key file unreadable"),
75 JwtError::InvalidPrivKey => write!(f, "private key invalid"),
76 JwtError::InvalidPubKey => write!(f, "public key invalid"),
77 }
78 }
79}
80
81impl std::error::Error for JwtError {}
82
83#[derive(Clone)]
85pub enum AlgorithmKind {
86 HS256(Vec<u8>),
88 HS384(Vec<u8>),
90 HS512(Vec<u8>),
92 RS256 { private_pem: Vec<u8>, public_pem: Vec<u8> },
94 RS384 { private_pem: Vec<u8>, public_pem: Vec<u8> },
96 RS512 { private_pem: Vec<u8>, public_pem: Vec<u8> },
98}
99
100impl AlgorithmKind {
101 pub fn algorithm(&self) -> Algorithm {
103 match self {
104 AlgorithmKind::HS256(_) => Algorithm::HS256,
105 AlgorithmKind::HS384(_) => Algorithm::HS384,
106 AlgorithmKind::HS512(_) => Algorithm::HS512,
107 AlgorithmKind::RS256 { .. } => Algorithm::RS256,
108 AlgorithmKind::RS384 { .. } => Algorithm::RS384,
109 AlgorithmKind::RS512 { .. } => Algorithm::RS512,
110 }
111 }
112
113 fn encoding_key(&self) -> Result<EncodingKey, JwtError> {
119 match self {
120 AlgorithmKind::HS256(k) | AlgorithmKind::HS384(k) | AlgorithmKind::HS512(k) => {
121 Ok(EncodingKey::from_secret(k))
122 }
123 AlgorithmKind::RS256 { private_pem, .. }
124 | AlgorithmKind::RS384 { private_pem, .. }
125 | AlgorithmKind::RS512 { private_pem, .. } => {
126 EncodingKey::from_rsa_pem(private_pem).map_err(|e| JwtError::Token(e.to_string()))
127 }
128 }
129 }
130
131 fn decoding_key(&self) -> Result<DecodingKey, JwtError> {
137 match self {
138 AlgorithmKind::HS256(k) | AlgorithmKind::HS384(k) | AlgorithmKind::HS512(k) => {
139 Ok(DecodingKey::from_secret(k))
140 }
141 AlgorithmKind::RS256 { public_pem, .. }
142 | AlgorithmKind::RS384 { public_pem, .. }
143 | AlgorithmKind::RS512 { public_pem, .. } => {
144 DecodingKey::from_rsa_pem(public_pem).map_err(|e| JwtError::Token(e.to_string()))
145 }
146 }
147 }
148
149 pub fn from_rsa_files(alg: Algorithm, priv_key_file: &str, pub_key_file: &str) -> Result<Self, JwtError> {
160 let private_pem = fs::read(priv_key_file)
161 .map_err(|_| JwtError::NoPrivKeyFile)?;
162 let public_pem = fs::read(pub_key_file)
163 .map_err(|_| JwtError::NoPubKeyFile)?;
164
165 EncodingKey::from_rsa_pem(&private_pem)
167 .map_err(|_| JwtError::InvalidPrivKey)?;
168 DecodingKey::from_rsa_pem(&public_pem)
169 .map_err(|_| JwtError::InvalidPubKey)?;
170
171 match alg {
172 Algorithm::RS256 => Ok(AlgorithmKind::RS256 { private_pem, public_pem }),
173 Algorithm::RS384 => Ok(AlgorithmKind::RS384 { private_pem, public_pem }),
174 Algorithm::RS512 => Ok(AlgorithmKind::RS512 { private_pem, public_pem }),
175 _ => Err(JwtError::InvalidSigningAlgorithm),
176 }
177 }
178}
179
180#[derive(Clone)]
182pub struct JwtEngineOptions {
183 pub realm: String,
185 pub alg: AlgorithmKind,
187 pub timeout: Duration,
189 pub max_refresh: Duration,
191 pub identity_key: String,
193 pub token_lookup: String,
195 pub token_head_name: String,
197 pub send_authorization: bool,
199 pub disabled_abort: bool,
201 pub priv_key_file: String,
203 pub pub_key_file: String,
205}
206
207impl Default for JwtEngineOptions {
208 fn default() -> Self {
209 Self {
210 realm: "grpc jwt".to_string(),
211 alg: AlgorithmKind::HS256(b"secret".to_vec()),
212 timeout: Duration::from_secs(3600),
213 max_refresh: Duration::from_secs(3600 * 24 * 7),
214 identity_key: "identity".to_string(),
215 token_lookup: "header:Authorization".to_string(),
216 token_head_name: "Bearer".to_string(),
217 send_authorization: false,
218 disabled_abort: false,
219 priv_key_file: String::new(),
220 pub_key_file: String::new(),
221 }
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct Claims {
228 pub exp: u64,
230 pub orig_iat: u64,
232 #[serde(flatten)]
234 pub extra: Map<String, Value>,
235}
236
237#[derive(Clone)]
239pub struct JwtEngine {
240 pub opts: Arc<JwtEngineOptions>,
242}
243
244impl JwtEngine {
245 pub fn new(mut opts: JwtEngineOptions) -> Result<Self, JwtError> {
254 if opts.token_lookup.is_empty() {
256 opts.token_lookup = "header:Authorization".to_string();
257 }
258
259 if opts.timeout.as_secs() == 0 {
260 opts.timeout = Duration::from_secs(3600);
261 }
262
263 opts.token_head_name = opts.token_head_name.trim().to_string();
264 if opts.token_head_name.is_empty() {
265 opts.token_head_name = "Bearer".to_string();
266 }
267
268 if opts.identity_key.is_empty() {
269 opts.identity_key = "identity".to_string();
270 }
271
272 if opts.realm.is_empty() {
273 opts.realm = "grpc jwt".to_string();
274 }
275
276 if Self::using_public_key_algo_static(&opts.alg) {
278 if opts.priv_key_file.is_empty() || opts.pub_key_file.is_empty() {
279 return Err(JwtError::MissingSecretKey);
280 }
281 opts.alg = Self::read_keys_static(&opts.priv_key_file, &opts.pub_key_file, &opts.alg)?;
283 } else {
284 match &opts.alg {
286 AlgorithmKind::HS256(k) | AlgorithmKind::HS384(k) | AlgorithmKind::HS512(k) => {
287 if k.is_empty() {
288 return Err(JwtError::MissingSecretKey);
289 }
290 }
291 _ => {}
292 }
293 }
294
295 Ok(Self { opts: Arc::new(opts) })
296 }
297
298 fn using_public_key_algo_static(alg: &AlgorithmKind) -> bool {
307 matches!(
308 alg,
309 AlgorithmKind::RS256 { .. } | AlgorithmKind::RS384 { .. } | AlgorithmKind::RS512 { .. }
310 )
311 }
312
313 fn read_keys_static(priv_key_file: &str, pub_key_file: &str, current_alg: &AlgorithmKind) -> Result<AlgorithmKind, JwtError> {
324 let private_pem = Self::read_private_key_static(priv_key_file)?;
325 let public_pem = Self::read_public_key_static(pub_key_file)?;
326
327 match current_alg {
328 AlgorithmKind::RS256 { .. } => Ok(AlgorithmKind::RS256 { private_pem, public_pem }),
329 AlgorithmKind::RS384 { .. } => Ok(AlgorithmKind::RS384 { private_pem, public_pem }),
330 AlgorithmKind::RS512 { .. } => Ok(AlgorithmKind::RS512 { private_pem, public_pem }),
331 _ => Err(JwtError::InvalidSigningAlgorithm),
332 }
333 }
334
335 fn read_private_key_static(file_path: &str) -> Result<Vec<u8>, JwtError> {
344 let key_data = fs::read(file_path)
345 .map_err(|_| JwtError::NoPrivKeyFile)?;
346
347 EncodingKey::from_rsa_pem(&key_data)
349 .map_err(|_| JwtError::InvalidPrivKey)?;
350
351 Ok(key_data)
352 }
353
354 fn read_public_key_static(file_path: &str) -> Result<Vec<u8>, JwtError> {
363 let key_data = fs::read(file_path)
364 .map_err(|_| JwtError::NoPubKeyFile)?;
365
366 DecodingKey::from_rsa_pem(&key_data)
368 .map_err(|_| JwtError::InvalidPubKey)?;
369
370 Ok(key_data)
371 }
372
373 pub fn using_public_key_algo(&self) -> bool {
379 Self::using_public_key_algo_static(&self.opts.alg)
380 }
381
382 pub fn now() -> u64 {
387 SystemTime::now()
388 .duration_since(UNIX_EPOCH)
389 .unwrap_or_default()
390 .as_secs()
391 }
392
393 pub fn sign_with_extra(&self, mut extra: Map<String, Value>) -> Result<(String, u64), JwtError> {
402 let now = Self::now();
403 let exp = now + self.opts.timeout.as_secs();
404 extra.insert("orig_iat".into(), Value::from(now));
405 extra.insert("exp".into(), Value::from(exp));
406
407 let claims = Claims { exp, orig_iat: now, extra };
408 let mut header = Header::new(self.opts.alg.algorithm());
409 header.typ = Some("JWT".into());
410 let token = jsonwebtoken::encode(&header, &claims, &self.opts.alg.encoding_key()?)
411 .map_err(|e| JwtError::Token(e.to_string()))?;
412 Ok((token, exp))
413 }
414
415 pub fn decode(&self, token: &str) -> Result<Claims, JwtError> {
424 let mut validation = Validation::new(self.opts.alg.algorithm());
425 validation.validate_exp = false;
426 let data = jsonwebtoken::decode::<Claims>(
427 token,
428 &self.opts.alg.decoding_key()?,
429 &validation,
430 ).map_err(|e| JwtError::Token(e.to_string()))?;
431 Ok(data.claims)
432 }
433
434 pub fn get_claims(&self, token: &str) -> Result<Map<String, Value>, JwtError> {
443 let claims = self.decode(token)?;
444 let mut m = claims.extra.clone();
445 m.insert("exp".to_string(), Value::from(claims.exp));
446 m.insert("orig_iat".to_string(), Value::from(claims.orig_iat));
447 Ok(m)
448 }
449
450 pub fn ensure_not_expired(&self, claims: &Map<String, Value>) -> Result<(), JwtError> {
459 let exp = claims
460 .get("exp")
461 .ok_or(JwtError::MissingExpField)?
462 .as_u64()
463 .ok_or(JwtError::WrongExpFormat)?;
464 if (exp as i64) < (Self::now() as i64) {
465 return Err(JwtError::ExpiredToken);
466 }
467 Ok(())
468 }
469
470 pub fn check_if_token_expire(&self, token: &str) -> Result<Map<String, Value>, JwtError> {
479 let claims = self.decode(token)?;
480 let orig_iat = claims.orig_iat;
481 let now = Self::now();
482 if (now as i64) - (orig_iat as i64) > self.opts.max_refresh.as_secs() as i64 {
483 return Err(JwtError::ExpiredToken);
484 }
485 let mut m = claims.extra.clone();
486 m.insert("orig_iat".into(), Value::from(orig_iat));
487 m.insert("exp".into(), Value::from(claims.exp));
488 Ok(m)
489 }
490}