use crate::errors::{AuthError, Result};
use crate::server::core::common_validation;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Clone)]
pub struct JwtConfig {
pub algorithm: Algorithm,
pub signing_key: EncodingKey,
pub verification_key: DecodingKey,
pub default_expiration: u64,
pub issuer: String,
pub audiences: Vec<String>,
}
impl JwtConfig {
pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
Self {
algorithm: Algorithm::HS256,
signing_key: EncodingKey::from_secret(secret),
verification_key: DecodingKey::from_secret(secret),
default_expiration: 3600, issuer,
audiences: vec![],
}
}
pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
let signing_key = EncodingKey::from_rsa_pem(private_key)
.map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
let verification_key = DecodingKey::from_rsa_pem(public_key)
.map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
Ok(Self {
algorithm: Algorithm::RS256,
signing_key,
verification_key,
default_expiration: 3600, issuer,
audiences: vec![],
})
}
pub fn with_audience(mut self, audience: String) -> Self {
self.audiences.push(audience);
self
}
pub fn with_expiration(mut self, expiration: u64) -> Self {
self.default_expiration = expiration;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommonJwtClaims {
pub iss: String,
pub sub: String,
pub aud: Vec<String>,
pub exp: i64,
pub iat: i64,
pub nbf: Option<i64>,
pub jti: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, serde_json::Value>,
}
impl CommonJwtClaims {
pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
Self {
iss: issuer,
sub: subject,
aud: audiences,
exp: expiration,
iat: now,
nbf: None,
jti: None,
custom: HashMap::new(),
}
}
pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
self.custom.insert(key, value);
self
}
pub fn with_jti(mut self, jti: String) -> Self {
self.jti = Some(jti);
self
}
pub fn with_nbf(mut self, nbf: i64) -> Self {
self.nbf = Some(nbf);
self
}
}
pub struct JwtManager {
config: JwtConfig,
}
impl JwtManager {
pub fn new(config: JwtConfig) -> Self {
Self { config }
}
pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
let header = Header {
alg: self.config.algorithm,
..Default::default()
};
encode(&header, claims, &self.config.signing_key)
.map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
}
pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
where
T: Serialize,
{
let header = Header {
alg: self.config.algorithm,
..Default::default()
};
encode(&header, claims, &self.config.signing_key)
.map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
}
pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
common_validation::jwt::validate_jwt_format(token)?;
let mut validation = Validation::new(self.config.algorithm);
validation.set_issuer(&[&self.config.issuer]);
if !self.config.audiences.is_empty() {
validation.set_audience(
&self
.config
.audiences
.iter()
.map(String::as_str)
.collect::<Vec<_>>(),
);
}
let token_data =
decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
.map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
let claims_value = serde_json::to_value(&token_data.claims)
.map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
common_validation::jwt::validate_time_claims(&claims_value)?;
Ok(token_data.claims)
}
pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
common_validation::jwt::validate_jwt_format(token)?;
let mut validation = Validation::new(self.config.algorithm);
validation.set_issuer(&[&self.config.issuer]);
if !self.config.audiences.is_empty() {
validation.set_audience(
&self
.config
.audiences
.iter()
.map(String::as_str)
.collect::<Vec<_>>(),
);
}
let token_data = decode::<T>(token, &self.config.verification_key, &validation)
.map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
Ok(token_data.claims)
}
pub fn create_access_token(
&self,
subject: String,
scope: Vec<String>,
client_id: Option<String>,
) -> Result<String> {
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
+ self.config.default_expiration as i64;
let mut claims = CommonJwtClaims::new(
self.config.issuer.clone(),
subject,
self.config.audiences.clone(),
exp,
);
claims
.custom
.insert("scope".to_string(), serde_json::json!(scope.join(" ")));
if let Some(client_id) = client_id {
claims.custom.insert(
"client_id".to_string(),
serde_json::Value::String(client_id),
);
}
claims.custom.insert(
"token_type".to_string(),
serde_json::Value::String("access_token".to_string()),
);
self.create_token(&claims)
}
pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
+ (self.config.default_expiration * 24) as i64;
let mut claims = CommonJwtClaims::new(
self.config.issuer.clone(),
subject,
self.config.audiences.clone(),
exp,
);
claims.custom.insert(
"client_id".to_string(),
serde_json::Value::String(client_id),
);
claims.custom.insert(
"token_type".to_string(),
serde_json::Value::String("refresh_token".to_string()),
);
self.create_token(&claims)
}
pub fn create_id_token(
&self,
subject: String,
nonce: Option<String>,
auth_time: Option<i64>,
user_info: HashMap<String, serde_json::Value>,
) -> Result<String> {
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
+ 300;
let mut claims = CommonJwtClaims::new(
self.config.issuer.clone(),
subject,
self.config.audiences.clone(),
exp,
);
claims.custom.insert(
"token_type".to_string(),
serde_json::Value::String("id_token".to_string()),
);
if let Some(nonce) = nonce {
claims
.custom
.insert("nonce".to_string(), serde_json::Value::String(nonce));
}
if let Some(auth_time) = auth_time {
claims.custom.insert(
"auth_time".to_string(),
serde_json::Value::Number(auth_time.into()),
);
}
for (key, value) in user_info {
claims.custom.insert(key, value);
}
self.create_token(&claims)
}
}
pub(crate) mod utils {
use super::*;
#[allow(dead_code)]
pub(crate) fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
common_validation::jwt::extract_claims_unsafe(token)
}
#[allow(dead_code)]
pub(crate) fn is_token_expired(token: &str) -> Result<bool> {
let claims = extract_claims_unsafe(token)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
Ok(now >= exp)
} else {
Ok(false) }
}
#[allow(dead_code)]
pub(crate) fn get_token_expiration(token: &str) -> Result<Option<i64>> {
let claims = extract_claims_unsafe(token)?;
Ok(claims.get("exp").and_then(|v| v.as_i64()))
}
#[allow(dead_code)]
pub(crate) fn get_token_subject(token: &str) -> Result<Option<String>> {
let claims = extract_claims_unsafe(token)?;
Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
}
#[allow(dead_code)]
pub(crate) fn get_token_scopes(token: &str) -> Result<Vec<String>> {
let claims = extract_claims_unsafe(token)?;
if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
Ok(scope_str.split_whitespace().map(String::from).collect())
} else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
Ok(scopes_array
.iter()
.filter_map(|v| v.as_str())
.map(String::from)
.collect())
} else {
Ok(vec![])
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_manager() -> JwtManager {
let config = JwtConfig::with_symmetric_key(
b"a-test-secret-key-with-enough-bytes-for-hmac",
"https://test-issuer.example.com".into(),
);
JwtManager::new(config)
}
#[test]
fn test_jwt_config_symmetric() {
let config = JwtConfig::with_symmetric_key(b"secret", "iss".into());
assert_eq!(config.issuer, "iss");
assert_eq!(config.default_expiration, 3600);
}
#[test]
fn test_jwt_config_with_audience() {
let config =
JwtConfig::with_symmetric_key(b"secret", "iss".into()).with_audience("aud1".into());
assert_eq!(config.audiences, vec!["aud1"]);
}
#[test]
fn test_jwt_config_with_expiration() {
let config = JwtConfig::with_symmetric_key(b"secret", "iss".into()).with_expiration(7200);
assert_eq!(config.default_expiration, 7200);
}
#[test]
fn test_claims_new() {
let claims = CommonJwtClaims::new(
"issuer".into(),
"subject".into(),
vec!["aud".into()],
9999999999,
);
assert_eq!(claims.iss, "issuer");
assert_eq!(claims.sub, "subject");
assert!(claims.iat > 0);
}
#[test]
fn test_claims_with_custom_claim() {
let claims = CommonJwtClaims::new("iss".into(), "sub".into(), vec![], 9999999999)
.with_custom_claim("role".to_string(), serde_json::json!("admin"));
assert_eq!(claims.custom.get("role").unwrap(), "admin");
}
#[test]
fn test_claims_with_jti() {
let claims = CommonJwtClaims::new("iss".into(), "sub".into(), vec![], 9999999999)
.with_jti("test-jti-value".into());
assert!(claims.jti.is_some());
}
#[test]
fn test_create_and_verify_token() {
let mgr = make_manager();
let claims = CommonJwtClaims::new(
"https://test-issuer.example.com".into(),
"user_123".into(),
vec![],
(SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600) as i64,
);
let token = mgr.create_token(&claims).unwrap();
let verified = mgr.verify_token(&token).unwrap();
assert_eq!(verified.sub, "user_123");
}
#[test]
fn test_verify_invalid_token() {
let mgr = make_manager();
assert!(mgr.verify_token("not.a.valid.jwt").is_err());
}
#[test]
fn test_verify_wrong_key() {
let mgr1 = make_manager();
let mgr2 = JwtManager::new(JwtConfig::with_symmetric_key(
b"different-key-entirely-for-testing",
"https://test-issuer.example.com".into(),
));
let claims = CommonJwtClaims::new(
"https://test-issuer.example.com".into(),
"user".into(),
vec![],
(SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600) as i64,
);
let token = mgr1.create_token(&claims).unwrap();
assert!(mgr2.verify_token(&token).is_err());
}
#[test]
fn test_create_access_token() {
let mgr = make_manager();
let token = mgr
.create_access_token(
"user_1".into(),
vec!["read".into()],
Some("client_1".into()),
)
.unwrap();
let claims = mgr.verify_token(&token).unwrap();
assert_eq!(claims.sub, "user_1");
assert!(claims.custom.contains_key("scope"));
}
#[test]
fn test_create_refresh_token() {
let mgr = make_manager();
let token = mgr
.create_refresh_token("user_2".into(), "client_2".into())
.unwrap();
let claims = mgr.verify_token(&token).unwrap();
assert_eq!(claims.sub, "user_2");
assert_eq!(
claims.custom.get("token_type").unwrap(),
&serde_json::json!("refresh_token")
);
}
#[test]
fn test_create_id_token() {
let mgr = make_manager();
let user_info = HashMap::from([
("name".into(), serde_json::json!("Test User")),
("email".into(), serde_json::json!("test@example.com")),
]);
let token = mgr
.create_id_token("user_3".into(), Some("nonce_123".into()), None, user_info)
.unwrap();
let claims = mgr.verify_token(&token).unwrap();
assert_eq!(claims.sub, "user_3");
assert_eq!(claims.custom.get("nonce").unwrap(), "nonce_123");
assert_eq!(
claims.custom.get("token_type").unwrap(),
&serde_json::json!("id_token")
);
}
#[test]
fn test_extract_claims_unsafe_works() {
let mgr = make_manager();
let claims = CommonJwtClaims::new(
"https://test-issuer.example.com".into(),
"peek_user".into(),
vec![],
(SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600) as i64,
);
let token = mgr.create_token(&claims).unwrap();
let extracted = utils::extract_claims_unsafe(&token).unwrap();
assert_eq!(extracted["sub"], "peek_user");
}
#[test]
fn test_is_token_expired_not_expired() {
let mgr = make_manager();
let token = mgr
.create_access_token("user".into(), vec![], None)
.unwrap();
assert!(!utils::is_token_expired(&token).unwrap());
}
}