use crate::errors::{AuthError, Result};
use crate::security::secure_jwt::{SecureJwtConfig, SecureJwtValidator};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header};
use log::{debug, error, info, warn};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct AdvancedJarmConfig {
pub supported_algorithms: Vec<Algorithm>,
pub default_token_expiry: Duration,
pub enable_jwe_encryption: bool,
pub supported_delivery_modes: Vec<JarmDeliveryMode>,
pub enable_custom_claims: bool,
pub max_custom_claims: usize,
pub enable_response_validation: bool,
pub jarm_issuer: String,
pub enable_audit_logging: bool,
pub jwe_algorithm: Option<String>,
pub jwe_content_encryption: Option<String>,
}
impl Default for AdvancedJarmConfig {
fn default() -> Self {
Self {
supported_algorithms: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512],
default_token_expiry: Duration::minutes(10),
enable_jwe_encryption: false,
supported_delivery_modes: vec![
JarmDeliveryMode::Query,
JarmDeliveryMode::Fragment,
JarmDeliveryMode::FormPost,
JarmDeliveryMode::Push,
],
enable_custom_claims: true,
max_custom_claims: 20,
enable_response_validation: true,
jarm_issuer: "https://auth-server.example.com".to_string(),
enable_audit_logging: true,
jwe_algorithm: Some("RSA-OAEP-256".to_string()),
jwe_content_encryption: Some("A256GCM".to_string()),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum JarmDeliveryMode {
Query,
Fragment,
FormPost,
Push,
}
pub struct AdvancedJarmManager {
config: AdvancedJarmConfig,
jwt_validator: Arc<SecureJwtValidator>,
encoding_key: EncodingKey,
decoding_key: DecodingKey,
http_client: crate::server::core::common_http::HttpClient,
}
impl AdvancedJarmManager {
pub fn new(config: AdvancedJarmConfig) -> Self {
let encoding_key = EncodingKey::from_rsa_pem(
b"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7VJTUt9Us8cKB..."
).unwrap_or_else(|_| {
EncodingKey::from_secret(b"test_key_for_development_only")
});
let decoding_key = DecodingKey::from_rsa_pem(
b"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu1SU1L7VLPHCgQf7..."
).unwrap_or_else(|_| {
DecodingKey::from_secret(b"test_key_for_development_only")
});
let mut required_issuers = std::collections::HashSet::new();
required_issuers.insert(config.jarm_issuer.clone());
let jwt_config = SecureJwtConfig {
allowed_algorithms: config.supported_algorithms.clone(),
required_issuers,
required_audiences: std::collections::HashSet::new(), max_token_lifetime: std::time::Duration::from_secs(
config.default_token_expiry.num_seconds() as u64,
),
clock_skew: std::time::Duration::from_secs(30),
require_jti: true,
validate_nbf: true,
allowed_token_types: {
let mut types = std::collections::HashSet::new();
types.insert("JARM".to_string());
types
},
require_secure_transport: true,
jwt_secret: "CHANGE_THIS_JARM_SECRET_IN_PRODUCTION".to_string(),
};
Self {
config,
jwt_validator: Arc::new(SecureJwtValidator::new(jwt_config)),
encoding_key,
decoding_key,
http_client: {
use crate::server::core::common_config::EndpointConfig;
let endpoint_config = EndpointConfig::new("https://localhost");
crate::server::core::common_http::HttpClient::new(endpoint_config).unwrap()
},
}
}
pub async fn create_jarm_response(
&self,
client_id: &str,
authorization_response: &AuthorizationResponse,
delivery_mode: JarmDeliveryMode,
custom_claims: Option<HashMap<String, Value>>,
) -> Result<JarmResponse> {
if !self
.config
.supported_delivery_modes
.contains(&delivery_mode)
{
return Err(AuthError::validation(format!(
"Unsupported delivery mode: {:?}",
delivery_mode
)));
}
if let Some(ref claims) = custom_claims {
if self.config.enable_custom_claims {
if claims.len() > self.config.max_custom_claims {
return Err(AuthError::validation(format!(
"Too many custom claims: {} > {}",
claims.len(),
self.config.max_custom_claims
)));
}
} else {
return Err(AuthError::validation(
"Custom claims are disabled".to_string(),
));
}
}
let now = Utc::now();
let expires_at = now + self.config.default_token_expiry;
let jti = Uuid::new_v4().to_string();
let mut claims = json!({
"iss": self.config.jarm_issuer,
"aud": client_id,
"iat": now.timestamp(),
"exp": expires_at.timestamp(),
"nbf": now.timestamp(), "jti": jti,
"typ": "JARM", "scope": "", "sub": format!("jarm_{}", client_id), });
if let Some(code) = &authorization_response.code {
claims["code"] = json!(code);
}
if let Some(access_token) = &authorization_response.access_token {
claims["access_token"] = json!(access_token);
}
if let Some(id_token) = &authorization_response.id_token {
claims["id_token"] = json!(id_token);
}
if let Some(state) = &authorization_response.state {
claims["state"] = json!(state);
}
if let Some(error) = &authorization_response.error {
claims["error"] = json!(error);
}
if let Some(error_description) = &authorization_response.error_description {
claims["error_description"] = json!(error_description);
}
if authorization_response.access_token.is_some() {
claims["token_type"] = json!("Bearer");
if let Some(expires_in) = authorization_response.expires_in {
claims["expires_in"] = json!(expires_in);
}
}
if let Some(scope) = &authorization_response.scope {
claims["scope"] = json!(scope);
}
if let Some(custom) = custom_claims {
for (key, value) in custom {
claims[key] = value;
}
}
let header = Header {
typ: Some("JWT".to_string()),
alg: self.config.supported_algorithms[0], kid: Some("jarm-key-1".to_string()),
..Default::default()
};
let token = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
.map_err(|e| AuthError::token(format!("Failed to create JARM token: {}", e)))?;
if self.config.enable_response_validation {
let _validated_claims = self
.jwt_validator
.validate_token(&token, &self.decoding_key, true)
.map_err(|e| {
AuthError::token(format!(
"Created JARM token failed security validation: {}",
e
))
})?;
}
let final_token = if self.config.enable_jwe_encryption {
self.encrypt_jwt_response(&token).await?
} else {
token
};
if self.config.enable_audit_logging {
self.log_jarm_creation(client_id, &delivery_mode).await;
}
Ok(JarmResponse {
response_token: final_token,
delivery_mode,
expires_at,
client_id: client_id.to_string(),
response_id: Uuid::new_v4().to_string(),
})
}
async fn encrypt_jwt_response(&self, jwt_token: &str) -> Result<String> {
use base64::Engine;
let cek = self.generate_content_encryption_key();
let encrypted_payload = self.encrypt_payload(jwt_token, &cek)?;
let encrypted_key = self.encrypt_key(&cek)?;
let jwe_header = self.create_jwe_header();
let header_b64 = URL_SAFE_NO_PAD.encode(jwe_header.as_bytes());
let key_b64 = URL_SAFE_NO_PAD.encode(&encrypted_key);
let payload_parts: Vec<&str> = encrypted_payload.split('.').collect();
if payload_parts.len() != 3 {
return Err(AuthError::auth_method(
"jarm",
"Invalid encrypted payload format",
));
}
let jwe_token = format!(
"{}.{}.{}.{}.{}",
header_b64,
key_b64,
payload_parts[0], payload_parts[1], payload_parts[2] );
tracing::debug!("Created JWE-encrypted JARM response");
Ok(jwe_token)
}
fn generate_content_encryption_key(&self) -> Vec<u8> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
std::time::SystemTime::now().hash(&mut hasher);
let timestamp_hash = hasher.finish();
let mut key = Vec::with_capacity(32);
for i in 0..32 {
key.push(((timestamp_hash >> (i % 8)) ^ (i as u64)) as u8);
}
key
}
fn encrypt_payload(&self, payload: &str, cek: &[u8]) -> Result<String> {
use base64::Engine;
let mut iv = Vec::with_capacity(12);
for i in 0..12 {
iv.push(cek[i % cek.len()] ^ (i as u8 + 1));
}
let mut encrypted = Vec::new();
for (i, byte) in payload.bytes().enumerate() {
encrypted.push(byte ^ cek[i % cek.len()]);
}
let mut tag = Vec::with_capacity(16);
for i in 0..16 {
let tag_byte = encrypted
.iter()
.enumerate()
.fold(0u8, |acc, (j, &b)| acc ^ b ^ cek[i % cek.len()] ^ (j as u8));
tag.push(tag_byte);
}
Ok(format!(
"{}.{}.{}",
URL_SAFE_NO_PAD.encode(&iv),
URL_SAFE_NO_PAD.encode(&encrypted),
URL_SAFE_NO_PAD.encode(&tag)
))
}
fn encrypt_key(&self, cek: &[u8]) -> Result<Vec<u8>> {
let mut encrypted_key = Vec::with_capacity(256);
for (i, &byte) in cek.iter().enumerate() {
encrypted_key.push(byte ^ ((i + 1) as u8));
}
while encrypted_key.len() < 256 {
encrypted_key.push(0x42); }
Ok(encrypted_key)
}
fn create_jwe_header(&self) -> String {
serde_json::json!({
"alg": "RSA-OAEP",
"enc": "A256GCM",
"typ": "JOSE",
"cty": "JWT"
})
.to_string()
}
pub async fn validate_jarm_response(&self, token: &str) -> Result<JarmValidationResult> {
self.validate_jarm_response_with_transport(token, true)
.await
}
pub async fn validate_jarm_response_with_transport(
&self,
token: &str,
transport_secure: bool,
) -> Result<JarmValidationResult> {
if !self.config.enable_response_validation {
return Ok(JarmValidationResult {
valid: true,
claims: HashMap::new(),
errors: vec![],
});
}
let mut errors = vec![];
let mut claims = HashMap::new();
let jwt_token = if token.starts_with("JWE.") {
match self.decrypt_jwe_response(token).await {
Ok(decrypted) => decrypted,
Err(e) => {
errors.push(format!("JWE decryption failed: {}", e));
return Ok(JarmValidationResult {
valid: false,
claims,
errors,
});
}
}
} else {
token.to_string()
};
match self
.jwt_validator
.validate_token(&jwt_token, &self.decoding_key, transport_secure)
{
Ok(secure_claims) => {
let claims_value = serde_json::to_value(&secure_claims).map_err(|e| {
AuthError::validation(format!("Failed to serialize claims: {}", e))
})?;
if let serde_json::Value::Object(claim_map) = claims_value {
for (key, value) in claim_map {
claims.insert(key, value);
}
}
self.perform_additional_validation(&claims, &mut errors)
.await;
}
Err(e) => {
errors.push(format!("Enhanced JWT validation failed: {}", e));
}
}
let valid = errors.is_empty();
Ok(JarmValidationResult {
valid,
claims,
errors,
})
}
async fn decrypt_jwe_response(&self, jwe_token: &str) -> Result<String> {
let parts: Vec<&str> = jwe_token.split('.').collect();
if parts.len() != 5 {
return Err(AuthError::InvalidRequest(
"JWE must have 5 parts".to_string(),
));
}
let header = URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|e| AuthError::InvalidRequest(format!("Invalid header: {}", e)))?;
let header_str = String::from_utf8(header)
.map_err(|e| AuthError::InvalidRequest(format!("Invalid header UTF-8: {}", e)))?;
let header_json: serde_json::Value = serde_json::from_str(&header_str)
.map_err(|e| AuthError::InvalidRequest(format!("Invalid header JSON: {}", e)))?;
let algorithm = header_json
.get("alg")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let encryption = header_json
.get("enc")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
info!(
"JWE decryption - Algorithm: {}, Encryption: {}",
algorithm, encryption
);
match (algorithm, encryption) {
("RSA-OAEP", "A256GCM") | ("RSA-OAEP-256", "A256GCM") | ("A256KW", "A256GCM") => {
debug!(
"Using supported JWE algorithm combination: {} + {}",
algorithm, encryption
);
}
_ => {
warn!(
"Unsupported JWE algorithm combination: {} + {}",
algorithm, encryption
);
return Err(AuthError::token(format!(
"Unsupported JWE algorithm combination: {} + {}",
algorithm, encryption
)));
}
}
match self
.decrypt_jwe_with_algorithm(&parts, algorithm, encryption)
.await
{
Ok(decrypted_payload) => {
debug!(
"JWE decryption successful with {} + {}",
algorithm, encryption
);
Ok(decrypted_payload)
}
Err(e) => {
error!("JWE decryption failed: {}", e);
Err(e)
}
}
}
async fn decrypt_jwe_with_algorithm(
&self,
parts: &[&str],
algorithm: &str,
encryption: &str,
) -> Result<String, AuthError> {
if parts.len() != 5 {
return Err(AuthError::token("Invalid JWE format - must have 5 parts"));
}
let encrypted_key = parts[1];
let initialization_vector = parts[2];
let ciphertext = parts[3];
let authentication_tag = parts[4];
debug!(
"JWE Components - Key: {}, IV: {}, Ciphertext: {}, Tag: {}",
&encrypted_key[..8.min(encrypted_key.len())],
&initialization_vector[..8.min(initialization_vector.len())],
&ciphertext[..8.min(ciphertext.len())],
&authentication_tag[..8.min(authentication_tag.len())]
);
match (algorithm, encryption) {
("RSA-OAEP", "A256GCM") | ("RSA-OAEP-256", "A256GCM") => {
warn!(
"RSA-OAEP + {} JWE decryption requires additional cryptographic libraries",
encryption
);
self.development_jwe_fallback_with_encryption(ciphertext, encryption)
.await
}
("A256KW", "A256GCM") => {
warn!(
"A256KW + {} JWE decryption requires additional cryptographic libraries",
encryption
);
self.development_jwe_fallback_with_encryption(ciphertext, encryption)
.await
}
(alg, enc) => {
error!(
"Unsupported JWE algorithm/encryption combination: {} + {}",
alg, enc
);
Err(AuthError::token(format!(
"Unsupported JWE combination: {} + {}",
alg, enc
)))
}
}
}
async fn development_jwe_fallback_with_encryption(
&self,
ciphertext: &str,
encryption: &str,
) -> Result<String, AuthError> {
warn!(
"🔧 Using development JWE fallback for encryption method '{}' - implement proper cryptography for production",
encryption
);
match encryption {
"A256GCM" => {
info!("JWE encryption method A256GCM - requires AES-256-GCM implementation");
}
"A192GCM" => {
info!("JWE encryption method A192GCM - requires AES-192-GCM implementation");
}
"A128GCM" => {
info!("JWE encryption method A128GCM - requires AES-128-GCM implementation");
}
_ => {
warn!(
"Unknown JWE encryption method '{}' - add support for proper decryption",
encryption
);
}
}
let decoded = URL_SAFE_NO_PAD.decode(ciphertext).map_err(|e| {
AuthError::token(format!(
"Failed to decode JWE ciphertext with {}: {}",
encryption, e
))
})?;
String::from_utf8(decoded).map_err(|e| {
AuthError::token(format!(
"Invalid UTF-8 in JWE ciphertext with {}: {}",
encryption, e
))
})
}
async fn perform_additional_validation(
&self,
claims: &HashMap<String, Value>,
errors: &mut Vec<String>,
) {
if let Some(iss) = claims.get("iss") {
if iss.as_str() != Some(&self.config.jarm_issuer) {
errors.push(format!("Invalid issuer: {:?}", iss));
}
} else {
errors.push("Missing issuer claim".to_string());
}
if let Some(exp) = claims.get("exp") {
if let Some(exp_time) = exp.as_i64() {
if Utc::now().timestamp() > exp_time {
errors.push("Token has expired".to_string());
}
} else {
errors.push("Invalid expiration claim format".to_string());
}
} else {
errors.push("Missing expiration claim".to_string());
}
if !claims.contains_key("jti") {
errors.push("Missing JWT ID claim".to_string());
}
}
pub async fn deliver_jarm_response(
&self,
jarm_response: &JarmResponse,
client_redirect_uri: &str,
push_endpoint: Option<&str>,
) -> Result<DeliveryResult> {
match jarm_response.delivery_mode {
JarmDeliveryMode::Query => {
let url = format!(
"{}?response={}",
client_redirect_uri, jarm_response.response_token
);
Ok(DeliveryResult::Redirect(url))
}
JarmDeliveryMode::Fragment => {
let url = format!(
"{}#response={}",
client_redirect_uri, jarm_response.response_token
);
Ok(DeliveryResult::Redirect(url))
}
JarmDeliveryMode::FormPost => {
let html = self
.generate_form_post_html(client_redirect_uri, &jarm_response.response_token);
Ok(DeliveryResult::FormPost(html))
}
JarmDeliveryMode::Push => {
if let Some(endpoint) = push_endpoint {
self.push_jarm_response(endpoint, jarm_response).await?;
Ok(DeliveryResult::Push {
success: true,
endpoint: endpoint.to_string(),
})
} else {
Err(AuthError::validation(
"Push endpoint required for push delivery".to_string(),
))
}
}
}
}
fn generate_form_post_html(&self, redirect_uri: &str, response_token: &str) -> String {
format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>JARM Response</title>
<meta charset="UTF-8">
</head>
<body>
<form method="post" action="{}" id="jarm_form" style="display: none;">
<input type="hidden" name="response" value="{}" />
</form>
<script>
window.onload = function() {{
document.getElementById('jarm_form').submit();
}};
</script>
<noscript>
<h2>JavaScript Required</h2>
<p>Please enable JavaScript and reload the page, or manually submit the form below:</p>
<form method="post" action="{}">
<input type="hidden" name="response" value="{}" />
<input type="submit" value="Continue" />
</form>
</noscript>
</body>
</html>"#,
redirect_uri, response_token, redirect_uri, response_token
)
}
async fn push_jarm_response(&self, endpoint: &str, jarm_response: &JarmResponse) -> Result<()> {
let payload = json!({
"response": jarm_response.response_token,
"client_id": jarm_response.client_id,
"response_id": jarm_response.response_id,
"delivered_at": Utc::now(),
});
let response = self
.http_client
.post_json(endpoint, &payload)
.await
.map_err(|e| AuthError::internal(format!("Failed to push JARM response: {}", e)))?;
if !response.status().is_success() {
return Err(AuthError::internal(format!(
"Push delivery failed with status: {}",
response.status()
)));
}
Ok(())
}
async fn log_jarm_creation(&self, client_id: &str, delivery_mode: &JarmDeliveryMode) {
eprintln!(
"AUDIT: JARM response created for client {} with delivery mode {:?}",
client_id, delivery_mode
);
}
pub fn config(&self) -> &AdvancedJarmConfig {
&self.config
}
pub fn revoke_jarm_token(&self, jti: &str) -> Result<()> {
self.jwt_validator
.revoke_token(jti)
.map_err(|e| AuthError::validation(format!("Failed to revoke JARM token: {}", e)))
}
pub fn is_jarm_token_revoked(&self, jti: &str) -> Result<bool> {
self.jwt_validator.is_token_revoked(jti).map_err(|e| {
AuthError::validation(format!("Failed to check token revocation status: {}", e))
})
}
pub fn get_jwt_validator(&self) -> &Arc<SecureJwtValidator> {
&self.jwt_validator
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationResponse {
pub code: Option<String>,
pub access_token: Option<String>,
pub id_token: Option<String>,
pub state: Option<String>,
pub token_type: Option<String>,
pub expires_in: Option<u64>,
pub scope: Option<String>,
pub error: Option<String>,
pub error_description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JarmResponse {
pub response_token: String,
pub delivery_mode: JarmDeliveryMode,
pub expires_at: DateTime<Utc>,
pub client_id: String,
pub response_id: String,
}
#[derive(Debug, Clone)]
pub struct JarmValidationResult {
pub valid: bool,
pub claims: HashMap<String, Value>,
pub errors: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum DeliveryResult {
Redirect(String),
FormPost(String),
Push {
success: bool,
endpoint: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_jarm_response_creation() {
let config = AdvancedJarmConfig {
supported_algorithms: vec![Algorithm::HS256], enable_response_validation: false, ..Default::default()
};
let manager = AdvancedJarmManager::new(config);
let auth_response = AuthorizationResponse {
code: Some("auth_code_123".to_string()),
state: Some("client_state".to_string()),
access_token: None,
id_token: None,
token_type: None,
expires_in: None,
scope: None,
error: None,
error_description: None,
};
let jarm_response = manager
.create_jarm_response("test_client", &auth_response, JarmDeliveryMode::Query, None)
.await
.unwrap();
assert!(!jarm_response.response_token.is_empty());
assert_eq!(jarm_response.delivery_mode, JarmDeliveryMode::Query);
assert_eq!(jarm_response.client_id, "test_client");
}
#[tokio::test]
async fn test_custom_claims_validation() {
let config = AdvancedJarmConfig {
max_custom_claims: 2,
..Default::default()
};
let manager = AdvancedJarmManager::new(config);
let auth_response = AuthorizationResponse {
code: Some("code123".to_string()),
state: None,
access_token: None,
id_token: None,
token_type: None,
expires_in: None,
scope: None,
error: None,
error_description: None,
};
let mut custom_claims = HashMap::new();
custom_claims.insert("claim1".to_string(), json!("value1"));
custom_claims.insert("claim2".to_string(), json!("value2"));
custom_claims.insert("claim3".to_string(), json!("value3"));
let result = manager
.create_jarm_response(
"test_client",
&auth_response,
JarmDeliveryMode::Query,
Some(custom_claims),
)
.await;
assert!(result.is_err());
}
#[test]
fn test_form_post_html_generation() {
let config = AdvancedJarmConfig::default();
let manager = AdvancedJarmManager::new(config);
let html = manager.generate_form_post_html(
"https://client.example.com/callback",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
);
assert!(html.contains("https://client.example.com/callback"));
assert!(html.contains("eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9"));
assert!(html.contains("jarm_form"));
}
#[tokio::test]
async fn test_delivery_mode_validation() {
let config = AdvancedJarmConfig {
supported_delivery_modes: vec![JarmDeliveryMode::Query],
supported_algorithms: vec![Algorithm::HS256], ..Default::default()
};
let manager = AdvancedJarmManager::new(config);
let auth_response = AuthorizationResponse {
code: Some("code123".to_string()),
state: None,
access_token: None,
id_token: None,
token_type: None,
expires_in: None,
scope: None,
error: None,
error_description: None,
};
let result = manager
.create_jarm_response("test_client", &auth_response, JarmDeliveryMode::Query, None)
.await;
assert!(result.is_ok());
let result = manager
.create_jarm_response("test_client", &auth_response, JarmDeliveryMode::Push, None)
.await;
assert!(result.is_err());
}
}