auth_framework/server/core/
common_jwt.rs1use crate::errors::{AuthError, Result};
7use crate::server::core::common_validation;
8use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Clone)]
15pub struct JwtConfig {
16 pub algorithm: Algorithm,
18 pub signing_key: EncodingKey,
20 pub verification_key: DecodingKey,
22 pub default_expiration: u64,
24 pub issuer: String,
26 pub audiences: Vec<String>,
28}
29
30impl JwtConfig {
31 pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
33 Self {
34 algorithm: Algorithm::HS256,
35 signing_key: EncodingKey::from_secret(secret),
36 verification_key: DecodingKey::from_secret(secret),
37 default_expiration: 3600, issuer,
39 audiences: vec![],
40 }
41 }
42
43 pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
45 let signing_key = EncodingKey::from_rsa_pem(private_key)
46 .map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
47
48 let verification_key = DecodingKey::from_rsa_pem(public_key)
49 .map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
50
51 Ok(Self {
52 algorithm: Algorithm::RS256,
53 signing_key,
54 verification_key,
55 default_expiration: 3600, issuer,
57 audiences: vec![],
58 })
59 }
60
61 pub fn with_audience(mut self, audience: String) -> Self {
63 self.audiences.push(audience);
64 self
65 }
66
67 pub fn with_expiration(mut self, expiration: u64) -> Self {
69 self.default_expiration = expiration;
70 self
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CommonJwtClaims {
77 pub iss: String,
79 pub sub: String,
81 pub aud: Vec<String>,
83 pub exp: i64,
85 pub iat: i64,
87 pub nbf: Option<i64>,
89 pub jti: Option<String>,
91 #[serde(flatten)]
93 pub custom: HashMap<String, serde_json::Value>,
94}
95
96impl CommonJwtClaims {
97 pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
99 let now = SystemTime::now()
100 .duration_since(UNIX_EPOCH)
101 .unwrap()
102 .as_secs() as i64;
103
104 Self {
105 iss: issuer,
106 sub: subject,
107 aud: audiences,
108 exp: expiration,
109 iat: now,
110 nbf: None,
111 jti: None,
112 custom: HashMap::new(),
113 }
114 }
115
116 pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
118 self.custom.insert(key, value);
119 self
120 }
121
122 pub fn with_jti(mut self, jti: String) -> Self {
124 self.jti = Some(jti);
125 self
126 }
127
128 pub fn with_nbf(mut self, nbf: i64) -> Self {
130 self.nbf = Some(nbf);
131 self
132 }
133}
134
135pub struct JwtManager {
211 config: JwtConfig,
212}
213
214impl JwtManager {
215 pub fn new(config: JwtConfig) -> Self {
217 Self { config }
218 }
219
220 pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
222 let header = Header {
223 alg: self.config.algorithm,
224 ..Default::default()
225 };
226
227 encode(&header, claims, &self.config.signing_key)
228 .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
229 }
230
231 pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
233 where
234 T: Serialize,
235 {
236 let header = Header {
237 alg: self.config.algorithm,
238 ..Default::default()
239 };
240
241 encode(&header, claims, &self.config.signing_key)
242 .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
243 }
244
245 pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
247 common_validation::jwt::validate_jwt_format(token)?;
249
250 let mut validation = Validation::new(self.config.algorithm);
251 validation.set_issuer(&[&self.config.issuer]);
252
253 if !self.config.audiences.is_empty() {
254 validation.set_audience(
255 &self
256 .config
257 .audiences
258 .iter()
259 .map(String::as_str)
260 .collect::<Vec<_>>(),
261 );
262 }
263
264 let token_data =
265 decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
266 .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
267
268 let claims_value = serde_json::to_value(&token_data.claims)
270 .map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
271
272 common_validation::jwt::validate_time_claims(&claims_value)?;
273
274 Ok(token_data.claims)
275 }
276
277 pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
279 where
280 T: for<'de> Deserialize<'de>,
281 {
282 common_validation::jwt::validate_jwt_format(token)?;
283
284 let mut validation = Validation::new(self.config.algorithm);
285 validation.set_issuer(&[&self.config.issuer]);
286
287 if !self.config.audiences.is_empty() {
288 validation.set_audience(
289 &self
290 .config
291 .audiences
292 .iter()
293 .map(String::as_str)
294 .collect::<Vec<_>>(),
295 );
296 }
297
298 let token_data = decode::<T>(token, &self.config.verification_key, &validation)
299 .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
300
301 Ok(token_data.claims)
302 }
303
304 pub fn create_access_token(
306 &self,
307 subject: String,
308 scope: Vec<String>,
309 client_id: Option<String>,
310 ) -> Result<String> {
311 let exp = SystemTime::now()
312 .duration_since(UNIX_EPOCH)
313 .unwrap()
314 .as_secs() as i64
315 + self.config.default_expiration as i64;
316
317 let mut claims = CommonJwtClaims::new(
318 self.config.issuer.clone(),
319 subject,
320 self.config.audiences.clone(),
321 exp,
322 );
323
324 claims
325 .custom
326 .insert("scope".to_string(), serde_json::json!(scope.join(" ")));
327
328 if let Some(client_id) = client_id {
329 claims.custom.insert(
330 "client_id".to_string(),
331 serde_json::Value::String(client_id),
332 );
333 }
334
335 claims.custom.insert(
336 "token_type".to_string(),
337 serde_json::Value::String("access_token".to_string()),
338 );
339
340 self.create_token(&claims)
341 }
342
343 pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
345 let exp = SystemTime::now()
347 .duration_since(UNIX_EPOCH)
348 .unwrap()
349 .as_secs() as i64
350 + (self.config.default_expiration * 24) as i64; let mut claims = CommonJwtClaims::new(
353 self.config.issuer.clone(),
354 subject,
355 self.config.audiences.clone(),
356 exp,
357 );
358
359 claims.custom.insert(
360 "client_id".to_string(),
361 serde_json::Value::String(client_id),
362 );
363 claims.custom.insert(
364 "token_type".to_string(),
365 serde_json::Value::String("refresh_token".to_string()),
366 );
367
368 self.create_token(&claims)
369 }
370
371 pub fn create_id_token(
373 &self,
374 subject: String,
375 nonce: Option<String>,
376 auth_time: Option<i64>,
377 user_info: HashMap<String, serde_json::Value>,
378 ) -> Result<String> {
379 let exp = SystemTime::now()
380 .duration_since(UNIX_EPOCH)
381 .unwrap()
382 .as_secs() as i64
383 + 300; let mut claims = CommonJwtClaims::new(
386 self.config.issuer.clone(),
387 subject,
388 self.config.audiences.clone(),
389 exp,
390 );
391
392 claims.custom.insert(
393 "token_type".to_string(),
394 serde_json::Value::String("id_token".to_string()),
395 );
396
397 if let Some(nonce) = nonce {
398 claims
399 .custom
400 .insert("nonce".to_string(), serde_json::Value::String(nonce));
401 }
402
403 if let Some(auth_time) = auth_time {
404 claims.custom.insert(
405 "auth_time".to_string(),
406 serde_json::Value::Number(auth_time.into()),
407 );
408 }
409
410 for (key, value) in user_info {
412 claims.custom.insert(key, value);
413 }
414
415 self.create_token(&claims)
416 }
417}
418
419pub mod utils {
421 use super::*;
422
423 pub fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
433 common_validation::jwt::extract_claims_unsafe(token)
434 }
435
436 pub fn is_token_expired(token: &str) -> Result<bool> {
443 let claims = extract_claims_unsafe(token)?;
444
445 let now = SystemTime::now()
446 .duration_since(UNIX_EPOCH)
447 .unwrap()
448 .as_secs() as i64;
449
450 if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
451 Ok(now >= exp)
452 } else {
453 Ok(false) }
455 }
456
457 pub fn get_token_expiration(token: &str) -> Result<Option<i64>> {
463 let claims = extract_claims_unsafe(token)?;
464 Ok(claims.get("exp").and_then(|v| v.as_i64()))
465 }
466
467 pub fn get_token_subject(token: &str) -> Result<Option<String>> {
473 let claims = extract_claims_unsafe(token)?;
474 Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
475 }
476
477 pub fn get_token_scopes(token: &str) -> Result<Vec<String>> {
483 let claims = extract_claims_unsafe(token)?;
484
485 if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
486 Ok(scope_str.split_whitespace().map(String::from).collect())
487 } else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
488 Ok(scopes_array
489 .iter()
490 .filter_map(|v| v.as_str())
491 .map(String::from)
492 .collect())
493 } else {
494 Ok(vec![])
495 }
496 }
497}
498
499