use crate::errors::{AuthError, Result};
use crate::server::core::common_validation;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Clone)]
pub struct JwtConfig {
pub algorithm: Algorithm,
pub signing_key: EncodingKey,
pub verification_key: DecodingKey,
pub default_expiration: u64,
pub issuer: String,
pub audiences: Vec<String>,
}
impl JwtConfig {
pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
Self {
algorithm: Algorithm::HS256,
signing_key: EncodingKey::from_secret(secret),
verification_key: DecodingKey::from_secret(secret),
default_expiration: 3600, issuer,
audiences: vec![],
}
}
pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
let signing_key = EncodingKey::from_rsa_pem(private_key)
.map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
let verification_key = DecodingKey::from_rsa_pem(public_key)
.map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
Ok(Self {
algorithm: Algorithm::RS256,
signing_key,
verification_key,
default_expiration: 3600, issuer,
audiences: vec![],
})
}
pub fn with_audience(mut self, audience: String) -> Self {
self.audiences.push(audience);
self
}
pub fn with_expiration(mut self, expiration: u64) -> Self {
self.default_expiration = expiration;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommonJwtClaims {
pub iss: String,
pub sub: String,
pub aud: Vec<String>,
pub exp: i64,
pub iat: i64,
pub nbf: Option<i64>,
pub jti: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, serde_json::Value>,
}
impl CommonJwtClaims {
pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
Self {
iss: issuer,
sub: subject,
aud: audiences,
exp: expiration,
iat: now,
nbf: None,
jti: None,
custom: HashMap::new(),
}
}
pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
self.custom.insert(key, value);
self
}
pub fn with_jti(mut self, jti: String) -> Self {
self.jti = Some(jti);
self
}
pub fn with_nbf(mut self, nbf: i64) -> Self {
self.nbf = Some(nbf);
self
}
}
pub struct JwtManager {
config: JwtConfig,
}
impl JwtManager {
pub fn new(config: JwtConfig) -> Self {
Self { config }
}
pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
let header = Header {
alg: self.config.algorithm,
..Default::default()
};
encode(&header, claims, &self.config.signing_key)
.map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
}
pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
where
T: Serialize,
{
let header = Header {
alg: self.config.algorithm,
..Default::default()
};
encode(&header, claims, &self.config.signing_key)
.map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
}
pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
common_validation::jwt::validate_jwt_format(token)?;
let mut validation = Validation::new(self.config.algorithm);
validation.set_issuer(&[&self.config.issuer]);
if !self.config.audiences.is_empty() {
validation.set_audience(
&self
.config
.audiences
.iter()
.map(String::as_str)
.collect::<Vec<_>>(),
);
}
let token_data =
decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
.map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
let claims_value = serde_json::to_value(&token_data.claims)
.map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
common_validation::jwt::validate_time_claims(&claims_value)?;
Ok(token_data.claims)
}
pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
common_validation::jwt::validate_jwt_format(token)?;
let mut validation = Validation::new(self.config.algorithm);
validation.set_issuer(&[&self.config.issuer]);
if !self.config.audiences.is_empty() {
validation.set_audience(
&self
.config
.audiences
.iter()
.map(String::as_str)
.collect::<Vec<_>>(),
);
}
let token_data = decode::<T>(token, &self.config.verification_key, &validation)
.map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
Ok(token_data.claims)
}
pub fn create_access_token(
&self,
subject: String,
scope: Vec<String>,
client_id: Option<String>,
) -> Result<String> {
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
+ self.config.default_expiration as i64;
let mut claims = CommonJwtClaims::new(
self.config.issuer.clone(),
subject,
self.config.audiences.clone(),
exp,
);
claims
.custom
.insert("scope".to_string(), serde_json::json!(scope.join(" ")));
if let Some(client_id) = client_id {
claims.custom.insert(
"client_id".to_string(),
serde_json::Value::String(client_id),
);
}
claims.custom.insert(
"token_type".to_string(),
serde_json::Value::String("access_token".to_string()),
);
self.create_token(&claims)
}
pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
+ (self.config.default_expiration * 24) as i64;
let mut claims = CommonJwtClaims::new(
self.config.issuer.clone(),
subject,
self.config.audiences.clone(),
exp,
);
claims.custom.insert(
"client_id".to_string(),
serde_json::Value::String(client_id),
);
claims.custom.insert(
"token_type".to_string(),
serde_json::Value::String("refresh_token".to_string()),
);
self.create_token(&claims)
}
pub fn create_id_token(
&self,
subject: String,
nonce: Option<String>,
auth_time: Option<i64>,
user_info: HashMap<String, serde_json::Value>,
) -> Result<String> {
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
+ 300;
let mut claims = CommonJwtClaims::new(
self.config.issuer.clone(),
subject,
self.config.audiences.clone(),
exp,
);
claims.custom.insert(
"token_type".to_string(),
serde_json::Value::String("id_token".to_string()),
);
if let Some(nonce) = nonce {
claims
.custom
.insert("nonce".to_string(), serde_json::Value::String(nonce));
}
if let Some(auth_time) = auth_time {
claims.custom.insert(
"auth_time".to_string(),
serde_json::Value::Number(auth_time.into()),
);
}
for (key, value) in user_info {
claims.custom.insert(key, value);
}
self.create_token(&claims)
}
}
pub mod utils {
use super::*;
pub fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
common_validation::jwt::extract_claims_unsafe(token)
}
pub fn is_token_expired(token: &str) -> Result<bool> {
let claims = extract_claims_unsafe(token)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
Ok(now >= exp)
} else {
Ok(false) }
}
pub fn get_token_expiration(token: &str) -> Result<Option<i64>> {
let claims = extract_claims_unsafe(token)?;
Ok(claims.get("exp").and_then(|v| v.as_i64()))
}
pub fn get_token_subject(token: &str) -> Result<Option<String>> {
let claims = extract_claims_unsafe(token)?;
Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
}
pub fn get_token_scopes(token: &str) -> Result<Vec<String>> {
let claims = extract_claims_unsafe(token)?;
if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
Ok(scope_str.split_whitespace().map(String::from).collect())
} else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
Ok(scopes_array
.iter()
.filter_map(|v| v.as_str())
.map(String::from)
.collect())
} else {
Ok(vec![])
}
}
}