auth_framework/server/core/
common_jwt.rs1use crate::errors::{AuthError, Result};
7use crate::server::core::common_validation;
8use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Clone)]
15pub struct JwtConfig {
16 pub algorithm: Algorithm,
18 pub signing_key: EncodingKey,
20 pub verification_key: DecodingKey,
22 pub default_expiration: u64,
24 pub issuer: String,
26 pub audiences: Vec<String>,
28}
29
30impl JwtConfig {
31 pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
33 Self {
34 algorithm: Algorithm::HS256,
35 signing_key: EncodingKey::from_secret(secret),
36 verification_key: DecodingKey::from_secret(secret),
37 default_expiration: 3600, issuer,
39 audiences: vec![],
40 }
41 }
42
43 pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
45 let signing_key = EncodingKey::from_rsa_pem(private_key)
46 .map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
47
48 let verification_key = DecodingKey::from_rsa_pem(public_key)
49 .map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
50
51 Ok(Self {
52 algorithm: Algorithm::RS256,
53 signing_key,
54 verification_key,
55 default_expiration: 3600, issuer,
57 audiences: vec![],
58 })
59 }
60
61 pub fn with_audience(mut self, audience: String) -> Self {
63 self.audiences.push(audience);
64 self
65 }
66
67 pub fn with_expiration(mut self, expiration: u64) -> Self {
69 self.default_expiration = expiration;
70 self
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CommonJwtClaims {
77 pub iss: String,
79 pub sub: String,
81 pub aud: Vec<String>,
83 pub exp: i64,
85 pub iat: i64,
87 pub nbf: Option<i64>,
89 pub jti: Option<String>,
91 #[serde(flatten)]
93 pub custom: HashMap<String, serde_json::Value>,
94}
95
96impl CommonJwtClaims {
97 pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
99 let now = SystemTime::now()
100 .duration_since(UNIX_EPOCH)
101 .unwrap()
102 .as_secs() as i64;
103
104 Self {
105 iss: issuer,
106 sub: subject,
107 aud: audiences,
108 exp: expiration,
109 iat: now,
110 nbf: None,
111 jti: None,
112 custom: HashMap::new(),
113 }
114 }
115
116 pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
118 self.custom.insert(key, value);
119 self
120 }
121
122 pub fn with_jti(mut self, jti: String) -> Self {
124 self.jti = Some(jti);
125 self
126 }
127
128 pub fn with_nbf(mut self, nbf: i64) -> Self {
130 self.nbf = Some(nbf);
131 self
132 }
133}
134
135pub struct JwtManager {
220 config: JwtConfig,
221}
222
223impl JwtManager {
224 pub fn new(config: JwtConfig) -> Self {
226 Self { config }
227 }
228
229 pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
231 let header = Header {
232 alg: self.config.algorithm,
233 ..Default::default()
234 };
235
236 encode(&header, claims, &self.config.signing_key)
237 .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
238 }
239
240 pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
242 where
243 T: Serialize,
244 {
245 let header = Header {
246 alg: self.config.algorithm,
247 ..Default::default()
248 };
249
250 encode(&header, claims, &self.config.signing_key)
251 .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
252 }
253
254 pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
256 common_validation::jwt::validate_jwt_format(token)?;
258
259 let mut validation = Validation::new(self.config.algorithm);
260 validation.set_issuer(&[&self.config.issuer]);
261
262 if !self.config.audiences.is_empty() {
263 validation.set_audience(
264 &self
265 .config
266 .audiences
267 .iter()
268 .map(String::as_str)
269 .collect::<Vec<_>>(),
270 );
271 }
272
273 let token_data =
274 decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
275 .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
276
277 let claims_value = serde_json::to_value(&token_data.claims)
279 .map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
280
281 common_validation::jwt::validate_time_claims(&claims_value)?;
282
283 Ok(token_data.claims)
284 }
285
286 pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
288 where
289 T: for<'de> Deserialize<'de>,
290 {
291 common_validation::jwt::validate_jwt_format(token)?;
292
293 let mut validation = Validation::new(self.config.algorithm);
294 validation.set_issuer(&[&self.config.issuer]);
295
296 if !self.config.audiences.is_empty() {
297 validation.set_audience(
298 &self
299 .config
300 .audiences
301 .iter()
302 .map(String::as_str)
303 .collect::<Vec<_>>(),
304 );
305 }
306
307 let token_data = decode::<T>(token, &self.config.verification_key, &validation)
308 .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
309
310 Ok(token_data.claims)
311 }
312
313 pub fn create_access_token(
315 &self,
316 subject: String,
317 scope: Vec<String>,
318 client_id: Option<String>,
319 ) -> Result<String> {
320 let exp = SystemTime::now()
321 .duration_since(UNIX_EPOCH)
322 .unwrap()
323 .as_secs() as i64
324 + self.config.default_expiration as i64;
325
326 let mut claims = CommonJwtClaims::new(
327 self.config.issuer.clone(),
328 subject,
329 self.config.audiences.clone(),
330 exp,
331 );
332
333 claims
334 .custom
335 .insert("scope".to_string(), serde_json::json!(scope.join(" ")));
336
337 if let Some(client_id) = client_id {
338 claims.custom.insert(
339 "client_id".to_string(),
340 serde_json::Value::String(client_id),
341 );
342 }
343
344 claims.custom.insert(
345 "token_type".to_string(),
346 serde_json::Value::String("access_token".to_string()),
347 );
348
349 self.create_token(&claims)
350 }
351
352 pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
354 let exp = SystemTime::now()
356 .duration_since(UNIX_EPOCH)
357 .unwrap()
358 .as_secs() as i64
359 + (self.config.default_expiration * 24) as i64; let mut claims = CommonJwtClaims::new(
362 self.config.issuer.clone(),
363 subject,
364 self.config.audiences.clone(),
365 exp,
366 );
367
368 claims.custom.insert(
369 "client_id".to_string(),
370 serde_json::Value::String(client_id),
371 );
372 claims.custom.insert(
373 "token_type".to_string(),
374 serde_json::Value::String("refresh_token".to_string()),
375 );
376
377 self.create_token(&claims)
378 }
379
380 pub fn create_id_token(
382 &self,
383 subject: String,
384 nonce: Option<String>,
385 auth_time: Option<i64>,
386 user_info: HashMap<String, serde_json::Value>,
387 ) -> Result<String> {
388 let exp = SystemTime::now()
389 .duration_since(UNIX_EPOCH)
390 .unwrap()
391 .as_secs() as i64
392 + 300; let mut claims = CommonJwtClaims::new(
395 self.config.issuer.clone(),
396 subject,
397 self.config.audiences.clone(),
398 exp,
399 );
400
401 claims.custom.insert(
402 "token_type".to_string(),
403 serde_json::Value::String("id_token".to_string()),
404 );
405
406 if let Some(nonce) = nonce {
407 claims
408 .custom
409 .insert("nonce".to_string(), serde_json::Value::String(nonce));
410 }
411
412 if let Some(auth_time) = auth_time {
413 claims.custom.insert(
414 "auth_time".to_string(),
415 serde_json::Value::Number(auth_time.into()),
416 );
417 }
418
419 for (key, value) in user_info {
421 claims.custom.insert(key, value);
422 }
423
424 self.create_token(&claims)
425 }
426}
427
428pub mod utils {
430 use super::*;
431
432 pub fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
442 common_validation::jwt::extract_claims_unsafe(token)
443 }
444
445 pub fn is_token_expired(token: &str) -> Result<bool> {
452 let claims = extract_claims_unsafe(token)?;
453
454 let now = SystemTime::now()
455 .duration_since(UNIX_EPOCH)
456 .unwrap()
457 .as_secs() as i64;
458
459 if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
460 Ok(now >= exp)
461 } else {
462 Ok(false) }
464 }
465
466 pub fn get_token_expiration(token: &str) -> Result<Option<i64>> {
472 let claims = extract_claims_unsafe(token)?;
473 Ok(claims.get("exp").and_then(|v| v.as_i64()))
474 }
475
476 pub fn get_token_subject(token: &str) -> Result<Option<String>> {
482 let claims = extract_claims_unsafe(token)?;
483 Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
484 }
485
486 pub fn get_token_scopes(token: &str) -> Result<Vec<String>> {
492 let claims = extract_claims_unsafe(token)?;
493
494 if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
495 Ok(scope_str.split_whitespace().map(String::from).collect())
496 } else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
497 Ok(scopes_array
498 .iter()
499 .filter_map(|v| v.as_str())
500 .map(String::from)
501 .collect())
502 } else {
503 Ok(vec![])
504 }
505 }
506}