1use crate::errors::{AuthError, Result};
59use crate::security::secure_jwt::{SecureJwtConfig, SecureJwtValidator};
60use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
61use chrono::{DateTime, Duration, Utc};
62use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header};
63use rsa::pkcs8::{DecodePrivateKey, DecodePublicKey};
64use rsa::{Oaep, RsaPrivateKey, RsaPublicKey};
65use serde::{Deserialize, Serialize};
66use serde_json::{Value, json};
67use sha2::Sha256;
68use std::collections::HashMap;
69use std::sync::Arc;
70use tracing::{debug, error, info, warn};
71use uuid::Uuid;
72
73#[derive(Debug, Clone)]
75pub struct AdvancedJarmConfig {
76 pub supported_algorithms: Vec<Algorithm>,
78 pub default_token_expiry: Duration,
80 pub enable_jwe_encryption: bool,
82 pub supported_delivery_modes: Vec<JarmDeliveryMode>,
84 pub enable_custom_claims: bool,
86 pub max_custom_claims: usize,
88 pub enable_response_validation: bool,
90 pub jarm_issuer: String,
92 pub enable_audit_logging: bool,
94 pub jwe_algorithm: Option<String>,
96 pub jwe_content_encryption: Option<String>,
98 pub rsa_private_key_pem: Option<String>,
104 pub rsa_public_key_pem: Option<String>,
110 pub jwe_recipient_public_key_pem: Option<String>,
116 pub jwe_recipient_private_key_pem: Option<String>,
122}
123
124impl Default for AdvancedJarmConfig {
125 fn default() -> Self {
126 Self {
127 supported_algorithms: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512],
128 default_token_expiry: Duration::minutes(10),
129 enable_jwe_encryption: false,
130 supported_delivery_modes: vec![
131 JarmDeliveryMode::Query,
132 JarmDeliveryMode::Fragment,
133 JarmDeliveryMode::FormPost,
134 JarmDeliveryMode::Push,
135 ],
136 enable_custom_claims: true,
137 max_custom_claims: 20,
138 enable_response_validation: true,
139 jarm_issuer: "https://auth-server.example.com".to_string(),
140 enable_audit_logging: true,
141 jwe_algorithm: Some("RSA-OAEP-256".to_string()),
142 jwe_content_encryption: Some("A256GCM".to_string()),
143 rsa_private_key_pem: None,
144 rsa_public_key_pem: None,
145 jwe_recipient_public_key_pem: None,
146 jwe_recipient_private_key_pem: None,
147 }
148 }
149}
150
151impl AdvancedJarmConfig {
152 pub fn builder() -> AdvancedJarmConfigBuilder {
164 AdvancedJarmConfigBuilder {
165 inner: Self::default(),
166 }
167 }
168}
169
170pub struct AdvancedJarmConfigBuilder {
172 inner: AdvancedJarmConfig,
173}
174
175impl AdvancedJarmConfigBuilder {
176 pub fn supported_algorithms(mut self, algos: Vec<Algorithm>) -> Self {
178 self.inner.supported_algorithms = algos;
179 self
180 }
181
182 pub fn default_token_expiry(mut self, expiry: Duration) -> Self {
184 self.inner.default_token_expiry = expiry;
185 self
186 }
187
188 pub fn enable_jwe_encryption(mut self, enable: bool) -> Self {
190 self.inner.enable_jwe_encryption = enable;
191 self
192 }
193
194 pub fn supported_delivery_modes(mut self, modes: Vec<JarmDeliveryMode>) -> Self {
196 self.inner.supported_delivery_modes = modes;
197 self
198 }
199
200 pub fn enable_custom_claims(mut self, enable: bool) -> Self {
202 self.inner.enable_custom_claims = enable;
203 self
204 }
205
206 pub fn max_custom_claims(mut self, max: usize) -> Self {
208 self.inner.max_custom_claims = max;
209 self
210 }
211
212 pub fn enable_response_validation(mut self, enable: bool) -> Self {
214 self.inner.enable_response_validation = enable;
215 self
216 }
217
218 pub fn jarm_issuer(mut self, issuer: impl Into<String>) -> Self {
220 self.inner.jarm_issuer = issuer.into();
221 self
222 }
223
224 pub fn enable_audit_logging(mut self, enable: bool) -> Self {
226 self.inner.enable_audit_logging = enable;
227 self
228 }
229
230 pub fn jwe_content_encryption(mut self, enc: impl Into<String>) -> Self {
232 self.inner.jwe_content_encryption = Some(enc.into());
233 self
234 }
235
236 pub fn jwe_algorithm(mut self, alg: impl Into<String>) -> Self {
238 self.inner.jwe_algorithm = Some(alg.into());
239 self
240 }
241
242 pub fn rsa_private_key_pem(mut self, pem: impl Into<String>) -> Self {
244 self.inner.rsa_private_key_pem = Some(pem.into());
245 self
246 }
247
248 pub fn rsa_public_key_pem(mut self, pem: impl Into<String>) -> Self {
250 self.inner.rsa_public_key_pem = Some(pem.into());
251 self
252 }
253
254 pub fn jwe_recipient_public_key_pem(mut self, pem: impl Into<String>) -> Self {
256 self.inner.jwe_recipient_public_key_pem = Some(pem.into());
257 self
258 }
259
260 pub fn jwe_recipient_private_key_pem(mut self, pem: impl Into<String>) -> Self {
262 self.inner.jwe_recipient_private_key_pem = Some(pem.into());
263 self
264 }
265
266 pub fn build(self) -> AdvancedJarmConfig {
268 self.inner
269 }
270}
271
272#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
274pub enum JarmDeliveryMode {
275 Query,
277 Fragment,
279 FormPost,
281 Push,
283}
284
285pub struct AdvancedJarmManager {
287 config: AdvancedJarmConfig,
289 jwt_validator: Arc<SecureJwtValidator>,
291 encoding_key: EncodingKey,
293 decoding_key: DecodingKey,
295 http_client: crate::server::core::common_http::HttpClient,
297 jwe_public_key: Option<RsaPublicKey>,
299 jwe_private_key: Option<RsaPrivateKey>,
301}
302
303impl AdvancedJarmManager {
304 pub fn new(config: AdvancedJarmConfig) -> Self {
306 let private_pem = config
311 .rsa_private_key_pem
312 .clone()
313 .or_else(|| std::env::var("JARM_RSA_PRIVATE_KEY_PEM").ok());
314 let public_pem = config
315 .rsa_public_key_pem
316 .clone()
317 .or_else(|| std::env::var("JARM_RSA_PUBLIC_KEY_PEM").ok());
318
319 fn make_validator_secret() -> String {
324 use ring::rand::{SecureRandom, SystemRandom};
325 let rng = SystemRandom::new();
326 let mut bytes = [0u8; 32];
327 rng.fill(&mut bytes)
328 .expect("System CSPRNG unavailable; cannot initialize JARM JWT validator secret");
329 bytes.iter().fold(String::with_capacity(64), |mut s, b| {
330 s.push_str(&format!("{b:02x}"));
331 s
332 })
333 }
334
335 let (encoding_key, decoding_key, validator_jwt_secret, rsa_pub_pem) = match (private_pem, public_pem) {
337 (Some(priv_pem), Some(pub_pem)) => {
338 match (
339 EncodingKey::from_rsa_pem(priv_pem.as_bytes()),
340 DecodingKey::from_rsa_pem(pub_pem.as_bytes()),
341 ) {
342 (Ok(enc), Ok(dec)) => {
343 info!("JARM: loaded RSA signing/verification keys from configuration");
344 let secret = std::env::var("JARM_JWT_SECRET")
347 .unwrap_or_else(|_| make_validator_secret());
348 (enc, dec, secret, Some(pub_pem))
349 }
350 (Err(e), _) | (_, Err(e)) => {
351 warn!(
352 "JARM: failed to parse provided RSA keys ({}). \
353 Falling back to development-only symmetric key — \
354 DO NOT use in production.",
355 e
356 );
357 (
358 EncodingKey::from_secret(b"test_key_for_development_only_123456"),
359 DecodingKey::from_secret(b"test_key_for_development_only_123456"),
360 "test_key_for_development_only_123456".to_string(),
361 None,
362 )
363 }
364 }
365 }
366 _ => {
367 warn!(
372 "SECURITY WARNING: AdvancedJarmManager is using a development-only \
373 symmetric fallback key for JARM signing. This is NOT secure. Supply \
374 an RSA private key via AdvancedJarmConfig::rsa_private_key_pem or \
375 the JARM_RSA_PRIVATE_KEY_PEM environment variable before deploying \
376 to production."
377 );
378 (
379 EncodingKey::from_secret(b"test_key_for_development_only_123456"),
380 DecodingKey::from_secret(b"test_key_for_development_only_123456"),
381 "test_key_for_development_only_123456".to_string(),
382 None,
383 )
384 }
385 };
386
387 let jwe_pub_pem = config
389 .jwe_recipient_public_key_pem
390 .clone()
391 .or_else(|| std::env::var("JARM_JWE_RECIPIENT_PUBLIC_KEY_PEM").ok());
392 let jwe_priv_pem = config
393 .jwe_recipient_private_key_pem
394 .clone()
395 .or_else(|| std::env::var("JARM_JWE_RECIPIENT_PRIVATE_KEY_PEM").ok());
396
397 let jwe_public_key =
398 jwe_pub_pem
399 .as_deref()
400 .and_then(|pem| match RsaPublicKey::from_public_key_pem(pem) {
401 Ok(k) => {
402 info!("JARM JWE: loaded RSA recipient public key");
403 Some(k)
404 }
405 Err(e) => {
406 warn!("JARM JWE: could not parse recipient public key: {e}");
407 None
408 }
409 });
410 let jwe_private_key =
411 jwe_priv_pem
412 .as_deref()
413 .and_then(|pem| match RsaPrivateKey::from_pkcs8_pem(pem) {
414 Ok(k) => {
415 info!("JARM JWE: loaded RSA recipient private key");
416 Some(k)
417 }
418 Err(e) => {
419 warn!("JARM JWE: could not parse recipient private key: {e}");
420 None
421 }
422 });
423
424 let mut required_issuers = std::collections::HashSet::new();
425 required_issuers.insert(config.jarm_issuer.clone());
426
427 let jwt_config = SecureJwtConfig {
428 allowed_algorithms: config.supported_algorithms.clone(),
429 required_issuers,
430 required_audiences: std::collections::HashSet::new(), max_token_lifetime: std::time::Duration::from_secs(
432 config.default_token_expiry.num_seconds() as u64,
433 ),
434 clock_skew: std::time::Duration::from_secs(30),
435 require_jti: true,
436 validate_nbf: true,
437 allowed_token_types: {
438 let mut types = std::collections::HashSet::new();
439 types.insert("JARM".to_string());
440 types
441 },
442 require_secure_transport: true,
443 jwt_secret: validator_jwt_secret,
444 rsa_public_key_pem: rsa_pub_pem,
445 ec_public_key_pem: None,
446 ed_public_key_pem: None,
447 };
448
449 Self {
450 config,
451 jwt_validator: Arc::new(SecureJwtValidator::new(jwt_config).unwrap_or_else(|e| {
452 tracing::error!("Failed to initialize SecureJwtValidator for JARM: {e}");
453 panic!("SecureJwtValidator init failed");
454 })),
455 encoding_key,
456 decoding_key,
457 http_client: {
458 use crate::server::core::common_config::EndpointConfig;
459 let endpoint_config = EndpointConfig::new("https://localhost");
460 crate::server::core::common_http::HttpClient::new(endpoint_config).unwrap()
461 },
462 jwe_public_key,
463 jwe_private_key,
464 }
465 }
466
467 pub async fn create_jarm_response(
469 &self,
470 client_id: &str,
471 authorization_response: &AuthorizationResponse,
472 delivery_mode: JarmDeliveryMode,
473 custom_claims: Option<HashMap<String, Value>>,
474 ) -> Result<JarmResponse> {
475 if !self
477 .config
478 .supported_delivery_modes
479 .contains(&delivery_mode)
480 {
481 return Err(AuthError::validation(format!(
482 "Unsupported delivery mode: {:?}",
483 delivery_mode
484 )));
485 }
486
487 if let Some(ref claims) = custom_claims {
489 if self.config.enable_custom_claims {
490 if claims.len() > self.config.max_custom_claims {
491 return Err(AuthError::validation(format!(
492 "Too many custom claims: {} > {}",
493 claims.len(),
494 self.config.max_custom_claims
495 )));
496 }
497 } else {
498 return Err(AuthError::validation(
499 "Custom claims are disabled".to_string(),
500 ));
501 }
502 }
503
504 let now = Utc::now();
505 let expires_at = now + self.config.default_token_expiry;
506
507 let jti = Uuid::new_v4().to_string();
509 let mut claims = json!({
510 "iss": self.config.jarm_issuer,
511 "aud": client_id,
512 "iat": now.timestamp(),
513 "exp": expires_at.timestamp(),
514 "nbf": now.timestamp(), "jti": jti,
516 "typ": "JARM", "scope": "", "sub": format!("jarm_{}", client_id), });
520
521 if let Some(code) = &authorization_response.code {
523 claims["code"] = json!(code);
524 }
525 if let Some(access_token) = &authorization_response.access_token {
526 claims["access_token"] = json!(access_token);
527 }
528 if let Some(id_token) = &authorization_response.id_token {
529 claims["id_token"] = json!(id_token);
530 }
531 if let Some(state) = &authorization_response.state {
532 claims["state"] = json!(state);
533 }
534 if let Some(error) = &authorization_response.error {
535 claims["error"] = json!(error);
536 }
537 if let Some(error_description) = &authorization_response.error_description {
538 claims["error_description"] = json!(error_description);
539 }
540
541 if authorization_response.access_token.is_some() {
543 claims["token_type"] = json!("Bearer");
544 if let Some(expires_in) = authorization_response.expires_in {
545 claims["expires_in"] = json!(expires_in);
546 }
547 }
548
549 if let Some(scope) = &authorization_response.scope {
551 claims["scope"] = json!(scope);
552 }
553
554 if let Some(custom) = custom_claims {
556 for (key, value) in custom {
557 claims[key] = value;
558 }
559 }
560
561 let header = Header {
563 typ: Some("JWT".to_string()),
564 alg: self.config.supported_algorithms[0], kid: Some("jarm-key-1".to_string()),
566 ..Default::default()
567 };
568
569 let token = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
571 .map_err(|e| AuthError::token(format!("Failed to create JARM token: {}", e)))?;
572
573 if self.config.enable_response_validation {
575 let _validated_claims = self
576 .jwt_validator
577 .validate_token(&token, &self.decoding_key)
578 .map_err(|e| {
579 AuthError::token(format!(
580 "Created JARM token failed security validation: {}",
581 e
582 ))
583 })?;
584 }
585
586 let final_token = if self.config.enable_jwe_encryption {
588 self.encrypt_jwt_response(&token).await?
589 } else {
590 token
591 };
592
593 if self.config.enable_audit_logging {
595 self.log_jarm_creation(client_id, &delivery_mode).await;
596 }
597
598 Ok(JarmResponse {
599 response_token: final_token,
600 delivery_mode,
601 expires_at,
602 client_id: client_id.to_string(),
603 response_id: Uuid::new_v4().to_string(),
604 })
605 }
606
607 async fn encrypt_jwt_response(&self, jwt_token: &str) -> Result<String> {
609 use base64::Engine;
614
615 let cek = self.generate_content_encryption_key();
617
618 let encrypted_payload = self.encrypt_payload(jwt_token, &cek)?;
620
621 let encrypted_key = self.encrypt_key(&cek)?;
623
624 let jwe_header = self.create_jwe_header();
626 let header_b64 = URL_SAFE_NO_PAD.encode(jwe_header.as_bytes());
627 let key_b64 = URL_SAFE_NO_PAD.encode(&encrypted_key);
628 let payload_parts: Vec<&str> = encrypted_payload.split('.').collect();
629
630 if payload_parts.len() != 3 {
631 return Err(AuthError::auth_method(
632 "jarm",
633 "Invalid encrypted payload format",
634 ));
635 }
636
637 let jwe_token = format!(
638 "{}.{}.{}.{}.{}",
639 header_b64,
640 key_b64,
641 payload_parts[0], payload_parts[1], payload_parts[2] );
645
646 tracing::debug!("Created JWE-encrypted JARM response");
647 Ok(jwe_token)
648 }
649
650 fn generate_content_encryption_key(&self) -> Vec<u8> {
652 use rand::Rng;
653 let mut key = vec![0u8; 32];
654 rand::rng().fill_bytes(&mut key);
655 key
656 }
657
658 fn encrypt_payload(&self, payload: &str, cek: &[u8]) -> Result<String> {
663 use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
664 use rand::Rng;
665
666 if cek.len() != 32 {
667 return Err(AuthError::crypto("CEK must be 32 bytes for AES-256-GCM"));
668 }
669
670 let mut nonce_bytes = [0u8; 12];
672 rand::rng().fill_bytes(&mut nonce_bytes);
673 let nonce = Nonce::from_slice(&nonce_bytes);
674
675 let key = Key::<Aes256Gcm>::from_slice(cek);
676 let cipher = Aes256Gcm::new(key);
677
678 let ciphertext_with_tag = cipher
680 .encrypt(nonce, payload.as_bytes())
681 .map_err(|e| AuthError::crypto(format!("AES-256-GCM encryption failed: {}", e)))?;
682
683 let tag_pos = ciphertext_with_tag.len().saturating_sub(16);
684 let ciphertext = &ciphertext_with_tag[..tag_pos];
685 let tag = &ciphertext_with_tag[tag_pos..];
686
687 Ok(format!(
688 "{}.{}.{}",
689 URL_SAFE_NO_PAD.encode(nonce_bytes),
690 URL_SAFE_NO_PAD.encode(ciphertext),
691 URL_SAFE_NO_PAD.encode(tag)
692 ))
693 }
694
695 fn encrypt_key(&self, cek: &[u8]) -> Result<Vec<u8>> {
701 let pub_key = self.jwe_public_key.as_ref().ok_or_else(|| {
702 AuthError::crypto(
703 "JARM JWE requires an RSA recipient public key for CEK wrapping. \
704 Set AdvancedJarmConfig::jwe_recipient_public_key_pem or the \
705 JARM_JWE_RECIPIENT_PUBLIC_KEY_PEM environment variable.",
706 )
707 })?;
708
709 let mut rng = rand_core::OsRng;
711 let padding = Oaep::new::<Sha256>();
712 pub_key
713 .encrypt(&mut rng, padding, cek)
714 .map_err(|e| AuthError::crypto(format!("RSA-OAEP CEK wrap failed: {e}")))
715 }
716
717 fn create_jwe_header(&self) -> String {
719 serde_json::json!({
720 "alg": "RSA-OAEP",
721 "enc": "A256GCM",
722 "typ": "JOSE",
723 "cty": "JWT"
724 })
725 .to_string()
726 }
727
728 pub async fn validate_jarm_response(&self, token: &str) -> Result<JarmValidationResult> {
730 self.validate_jarm_response_with_transport(token, true)
731 .await
732 }
733
734 pub async fn validate_jarm_response_with_transport(
736 &self,
737 token: &str,
738 _transport_secure: bool,
739 ) -> Result<JarmValidationResult> {
740 if !self.config.enable_response_validation {
741 return Ok(JarmValidationResult {
742 valid: true,
743 claims: HashMap::new(),
744 errors: vec![],
745 });
746 }
747
748 let mut errors = vec![];
749 let mut claims = HashMap::new();
750
751 let jwt_token = if token.starts_with("JWE.") {
753 match self.decrypt_jwe_response(token).await {
754 Ok(decrypted) => decrypted,
755 Err(e) => {
756 errors.push(format!("JWE decryption failed: {}", e));
757 return Ok(JarmValidationResult {
758 valid: false,
759 claims,
760 errors,
761 });
762 }
763 }
764 } else {
765 token.to_string()
766 };
767
768 match self
770 .jwt_validator
771 .validate_token(&jwt_token, &self.decoding_key)
772 {
773 Ok(secure_claims) => {
774 let claims_value = serde_json::to_value(&secure_claims).map_err(|e| {
776 AuthError::validation(format!("Failed to serialize claims: {}", e))
777 })?;
778
779 if let serde_json::Value::Object(claim_map) = claims_value {
780 for (key, value) in claim_map {
781 claims.insert(key, value);
782 }
783 }
784
785 self.perform_additional_validation(&claims, &mut errors)
787 .await;
788 }
789 Err(e) => {
790 errors.push(format!("Enhanced JWT validation failed: {}", e));
791 }
792 }
793
794 let valid = errors.is_empty();
795
796 Ok(JarmValidationResult {
797 valid,
798 claims,
799 errors,
800 })
801 }
802
803 async fn decrypt_jwe_response(&self, jwe_token: &str) -> Result<String> {
805 let parts: Vec<&str> = jwe_token.split('.').collect();
807 if parts.len() != 5 {
808 return Err(AuthError::InvalidRequest(
809 "JWE must have 5 parts".to_string(),
810 ));
811 }
812
813 let header = URL_SAFE_NO_PAD
815 .decode(parts[0])
816 .map_err(|e| AuthError::InvalidRequest(format!("Invalid header: {}", e)))?;
817 let header_str = String::from_utf8(header)
818 .map_err(|e| AuthError::InvalidRequest(format!("Invalid header UTF-8: {}", e)))?;
819
820 let header_json: serde_json::Value = serde_json::from_str(&header_str)
822 .map_err(|e| AuthError::InvalidRequest(format!("Invalid header JSON: {}", e)))?;
823
824 let algorithm = header_json
826 .get("alg")
827 .and_then(|v| v.as_str())
828 .unwrap_or("unknown");
829 let encryption = header_json
830 .get("enc")
831 .and_then(|v| v.as_str())
832 .unwrap_or("unknown");
833
834 info!(
835 "JWE decryption - Algorithm: {}, Encryption: {}",
836 algorithm, encryption
837 );
838
839 match (algorithm, encryption) {
841 ("RSA-OAEP", "A256GCM") | ("RSA-OAEP-256", "A256GCM") => {
842 debug!(
844 "Using supported JWE algorithm combination: {} + {}",
845 algorithm, encryption
846 );
847 }
848 _ => {
849 warn!(
850 "Unsupported JWE algorithm combination: {} + {}",
851 algorithm, encryption
852 );
853 return Err(AuthError::token(format!(
854 "Unsupported JWE algorithm combination: {} + {}",
855 algorithm, encryption
856 )));
857 }
858 }
859
860 match self
862 .decrypt_jwe_with_algorithm(&parts, algorithm, encryption)
863 .await
864 {
865 Ok(decrypted_payload) => {
866 debug!(
867 "JWE decryption successful with {} + {}",
868 algorithm, encryption
869 );
870 Ok(decrypted_payload)
871 }
872 Err(e) => {
873 error!("JWE decryption failed: {}", e);
874 Err(e)
875 }
876 }
877 }
878
879 async fn decrypt_jwe_with_algorithm(
881 &self,
882 parts: &[&str],
883 algorithm: &str,
884 encryption: &str,
885 ) -> Result<String, AuthError> {
886 if parts.len() != 5 {
888 return Err(AuthError::token("Invalid JWE format - must have 5 parts"));
889 }
890
891 let encrypted_key = parts[1];
893 let initialization_vector = parts[2];
894 let ciphertext = parts[3];
895 let authentication_tag = parts[4];
896
897 debug!(
898 "JWE Components - Key: {}, IV: {}, Ciphertext: {}, Tag: {}",
899 &encrypted_key[..8.min(encrypted_key.len())],
900 &initialization_vector[..8.min(initialization_vector.len())],
901 &ciphertext[..8.min(ciphertext.len())],
902 &authentication_tag[..8.min(authentication_tag.len())]
903 );
904
905 match (algorithm, encryption) {
907 ("RSA-OAEP", "A256GCM") | ("RSA-OAEP-256", "A256GCM") => self.decrypt_rsa_oaep_a256gcm(
908 encrypted_key,
909 initialization_vector,
910 ciphertext,
911 authentication_tag,
912 ),
913 (alg, enc) => {
914 error!(
915 "Unsupported JWE algorithm/encryption combination: {} + {}",
916 alg, enc
917 );
918 Err(AuthError::token(format!(
919 "Unsupported JWE combination: {} + {}",
920 alg, enc
921 )))
922 }
923 }
924 }
925
926 fn decrypt_rsa_oaep_a256gcm(
928 &self,
929 encrypted_key_b64: &str,
930 iv_b64: &str,
931 ciphertext_b64: &str,
932 tag_b64: &str,
933 ) -> Result<String, AuthError> {
934 use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
935
936 let priv_key = self.jwe_private_key.as_ref().ok_or_else(|| {
937 AuthError::crypto(
938 "JARM JWE decryption requires an RSA private key. \
939 Set AdvancedJarmConfig::jwe_recipient_private_key_pem or the \
940 JARM_JWE_RECIPIENT_PRIVATE_KEY_PEM environment variable.",
941 )
942 })?;
943
944 let encrypted_cek = URL_SAFE_NO_PAD
946 .decode(encrypted_key_b64)
947 .map_err(|e| AuthError::token(format!("Bad encrypted_key encoding: {e}")))?;
948 let padding = Oaep::new::<Sha256>();
949 let cek = priv_key
950 .decrypt(padding, &encrypted_cek)
951 .map_err(|e| AuthError::crypto(format!("RSA-OAEP CEK unwrap failed: {e}")))?;
952
953 if cek.len() != 32 {
954 return Err(AuthError::crypto(format!(
955 "Unwrapped CEK has unexpected length {} (expected 32)",
956 cek.len()
957 )));
958 }
959
960 let nonce_bytes = URL_SAFE_NO_PAD
962 .decode(iv_b64)
963 .map_err(|e| AuthError::token(format!("Bad IV encoding: {e}")))?;
964 let mut ciphertext = URL_SAFE_NO_PAD
965 .decode(ciphertext_b64)
966 .map_err(|e| AuthError::token(format!("Bad ciphertext encoding: {e}")))?;
967 let tag = URL_SAFE_NO_PAD
968 .decode(tag_b64)
969 .map_err(|e| AuthError::token(format!("Bad tag encoding: {e}")))?;
970
971 if nonce_bytes.len() != 12 {
972 return Err(AuthError::crypto("JWE IV must be 12 bytes for AES-256-GCM"));
973 }
974
975 ciphertext.extend_from_slice(&tag);
977 let nonce = Nonce::from_slice(&nonce_bytes);
978 let key = Key::<Aes256Gcm>::from_slice(&cek);
979 let cipher = Aes256Gcm::new(key);
980 let plaintext = cipher
981 .decrypt(nonce, ciphertext.as_slice())
982 .map_err(|e| AuthError::crypto(format!("AES-256-GCM decryption failed: {e}")))?;
983
984 String::from_utf8(plaintext)
985 .map_err(|e| AuthError::token(format!("Decrypted payload is not valid UTF-8: {e}")))
986 }
987
988 async fn perform_additional_validation(
990 &self,
991 claims: &HashMap<String, Value>,
992 errors: &mut Vec<String>,
993 ) {
994 if let Some(iss) = claims.get("iss") {
996 if iss.as_str() != Some(&self.config.jarm_issuer) {
997 errors.push(format!("Invalid issuer: {:?}", iss));
998 }
999 } else {
1000 errors.push("Missing issuer claim".to_string());
1001 }
1002
1003 if let Some(exp) = claims.get("exp") {
1005 if let Some(exp_time) = exp.as_i64() {
1006 if Utc::now().timestamp() > exp_time {
1007 errors.push("Token has expired".to_string());
1008 }
1009 } else {
1010 errors.push("Invalid expiration claim format".to_string());
1011 }
1012 } else {
1013 errors.push("Missing expiration claim".to_string());
1014 }
1015
1016 if !claims.contains_key("jti") {
1018 errors.push("Missing JWT ID claim".to_string());
1019 }
1020 }
1021
1022 pub async fn deliver_jarm_response(
1024 &self,
1025 jarm_response: &JarmResponse,
1026 client_redirect_uri: &str,
1027 push_endpoint: Option<&str>,
1028 ) -> Result<DeliveryResult> {
1029 match jarm_response.delivery_mode {
1030 JarmDeliveryMode::Query => {
1031 let url = format!(
1032 "{}?response={}",
1033 client_redirect_uri, jarm_response.response_token
1034 );
1035 Ok(DeliveryResult::Redirect(url))
1036 }
1037 JarmDeliveryMode::Fragment => {
1038 let url = format!(
1039 "{}#response={}",
1040 client_redirect_uri, jarm_response.response_token
1041 );
1042 Ok(DeliveryResult::Redirect(url))
1043 }
1044 JarmDeliveryMode::FormPost => {
1045 let html = self
1046 .generate_form_post_html(client_redirect_uri, &jarm_response.response_token);
1047 Ok(DeliveryResult::FormPost(html))
1048 }
1049 JarmDeliveryMode::Push => {
1050 if let Some(endpoint) = push_endpoint {
1051 self.push_jarm_response(endpoint, jarm_response).await?;
1052 Ok(DeliveryResult::Push {
1053 success: true,
1054 endpoint: endpoint.to_string(),
1055 })
1056 } else {
1057 Err(AuthError::validation(
1058 "Push endpoint required for push delivery".to_string(),
1059 ))
1060 }
1061 }
1062 }
1063 }
1064
1065 fn generate_form_post_html(&self, redirect_uri: &str, response_token: &str) -> String {
1067 format!(
1068 r#"<!DOCTYPE html>
1069<html>
1070<head>
1071 <title>JARM Response</title>
1072 <meta charset="UTF-8">
1073</head>
1074<body>
1075 <form method="post" action="{}" id="jarm_form" style="display: none;">
1076 <input type="hidden" name="response" value="{}" />
1077 </form>
1078 <script>
1079 window.onload = function() {{
1080 document.getElementById('jarm_form').submit();
1081 }};
1082 </script>
1083 <noscript>
1084 <h2>JavaScript Required</h2>
1085 <p>Please enable JavaScript and reload the page, or manually submit the form below:</p>
1086 <form method="post" action="{}">
1087 <input type="hidden" name="response" value="{}" />
1088 <input type="submit" value="Continue" />
1089 </form>
1090 </noscript>
1091</body>
1092</html>"#,
1093 redirect_uri, response_token, redirect_uri, response_token
1094 )
1095 }
1096
1097 async fn push_jarm_response(&self, endpoint: &str, jarm_response: &JarmResponse) -> Result<()> {
1099 let payload = json!({
1100 "response": jarm_response.response_token,
1101 "client_id": jarm_response.client_id,
1102 "response_id": jarm_response.response_id,
1103 "delivered_at": Utc::now(),
1104 });
1105
1106 let response = self
1107 .http_client
1108 .post_json(endpoint, &payload)
1109 .await
1110 .map_err(|e| AuthError::internal(format!("Failed to push JARM response: {}", e)))?;
1111
1112 if !response.status().is_success() {
1113 return Err(AuthError::internal(format!(
1114 "Push delivery failed with status: {}",
1115 response.status()
1116 )));
1117 }
1118
1119 Ok(())
1120 }
1121
1122 async fn log_jarm_creation(&self, client_id: &str, delivery_mode: &JarmDeliveryMode) {
1124 tracing::info!(
1125 client_id = %client_id,
1126 delivery_mode = ?delivery_mode,
1127 "AUDIT: JARM response created"
1128 );
1129 }
1130
1131 pub fn config(&self) -> &AdvancedJarmConfig {
1133 &self.config
1134 }
1135
1136 pub fn revoke_jarm_token(&self, jti: &str) -> Result<()> {
1138 self.jwt_validator
1139 .revoke_token(jti)
1140 .map_err(|e| AuthError::validation(format!("Failed to revoke JARM token: {}", e)))
1141 }
1142
1143 pub fn is_jarm_token_revoked(&self, jti: &str) -> Result<bool> {
1145 self.jwt_validator.is_token_revoked(jti).map_err(|e| {
1146 AuthError::validation(format!("Failed to check token revocation status: {}", e))
1147 })
1148 }
1149
1150 pub fn get_jwt_validator(&self) -> &Arc<SecureJwtValidator> {
1152 &self.jwt_validator
1153 }
1154}
1155
1156#[derive(Debug, Clone, Serialize, Deserialize)]
1158pub struct AuthorizationResponse {
1159 pub code: Option<String>,
1161 pub access_token: Option<String>,
1163 pub id_token: Option<String>,
1165 pub state: Option<String>,
1167 pub token_type: Option<String>,
1169 pub expires_in: Option<u64>,
1171 pub scope: Option<String>,
1173 pub error: Option<String>,
1175 pub error_description: Option<String>,
1177}
1178
1179#[derive(Debug, Clone, Serialize, Deserialize)]
1181pub struct JarmResponse {
1182 pub response_token: String,
1184 pub delivery_mode: JarmDeliveryMode,
1186 pub expires_at: DateTime<Utc>,
1188 pub client_id: String,
1190 pub response_id: String,
1192}
1193
1194#[derive(Debug, Clone)]
1196pub struct JarmValidationResult {
1197 pub valid: bool,
1199 pub claims: HashMap<String, Value>,
1201 pub errors: Vec<String>,
1203}
1204
1205#[derive(Debug, Clone)]
1207pub enum DeliveryResult {
1208 Redirect(String),
1210 FormPost(String),
1212 Push {
1214 success: bool,
1216 endpoint: String,
1218 },
1219}
1220
1221#[cfg(test)]
1222mod tests {
1223 use super::*;
1224
1225 #[tokio::test]
1226 async fn test_jarm_response_creation() {
1227 let config = AdvancedJarmConfig::builder()
1229 .supported_algorithms(vec![Algorithm::HS256]) .enable_response_validation(false) .build();
1232 let manager = AdvancedJarmManager::new(config);
1233
1234 let auth_response = AuthorizationResponse {
1235 code: Some("auth_code_123".to_string()),
1236 state: Some("client_state".to_string()),
1237 access_token: None,
1238 id_token: None,
1239 token_type: None,
1240 expires_in: None,
1241 scope: None,
1242 error: None,
1243 error_description: None,
1244 };
1245
1246 let jarm_response = manager
1247 .create_jarm_response("test_client", &auth_response, JarmDeliveryMode::Query, None)
1248 .await
1249 .unwrap();
1250
1251 assert!(!jarm_response.response_token.is_empty());
1252 assert_eq!(jarm_response.delivery_mode, JarmDeliveryMode::Query);
1253 assert_eq!(jarm_response.client_id, "test_client");
1254 }
1255
1256 #[tokio::test]
1257 async fn test_custom_claims_validation() {
1258 let config = AdvancedJarmConfig::builder()
1259 .max_custom_claims(2)
1260 .supported_algorithms(vec![Algorithm::HS256])
1261 .build();
1262 let manager = AdvancedJarmManager::new(config);
1263
1264 let auth_response = AuthorizationResponse {
1265 code: Some("code123".to_string()),
1266 state: None,
1267 access_token: None,
1268 id_token: None,
1269 token_type: None,
1270 expires_in: None,
1271 scope: None,
1272 error: None,
1273 error_description: None,
1274 };
1275
1276 let mut custom_claims = HashMap::new();
1277 custom_claims.insert("claim1".to_string(), json!("value1"));
1278 custom_claims.insert("claim2".to_string(), json!("value2"));
1279 custom_claims.insert("claim3".to_string(), json!("value3")); let result = manager
1282 .create_jarm_response(
1283 "test_client",
1284 &auth_response,
1285 JarmDeliveryMode::Query,
1286 Some(custom_claims),
1287 )
1288 .await;
1289
1290 assert!(result.is_err());
1291 }
1292
1293 #[tokio::test]
1294 async fn test_jwe_encrypt_decrypt_roundtrip() {
1295 use rsa::RsaPrivateKey;
1296 use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey, LineEnding};
1297
1298 let mut rng = rand_core::OsRng;
1300 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("RSA key generation failed");
1301 let public_key = private_key.to_public_key();
1302
1303 let priv_pem = private_key
1304 .to_pkcs8_pem(LineEnding::LF)
1305 .expect("private key PEM serialisation failed")
1306 .to_string();
1307 let pub_pem = public_key
1308 .to_public_key_pem(LineEnding::LF)
1309 .expect("public key PEM serialisation failed");
1310
1311 let config = AdvancedJarmConfig::builder()
1312 .supported_algorithms(vec![Algorithm::HS256])
1313 .enable_jwe_encryption(true)
1314 .enable_response_validation(false)
1315 .jwe_recipient_public_key_pem(pub_pem)
1316 .jwe_recipient_private_key_pem(priv_pem)
1317 .build();
1318 let manager = AdvancedJarmManager::new(config);
1319
1320 assert!(
1321 manager.jwe_public_key.is_some(),
1322 "JWE public key should have been loaded"
1323 );
1324 assert!(
1325 manager.jwe_private_key.is_some(),
1326 "JWE private key should have been loaded"
1327 );
1328
1329 let auth_response = AuthorizationResponse {
1331 code: Some("enc_test_code".to_string()),
1332 state: Some("enc_state".to_string()),
1333 access_token: None,
1334 id_token: None,
1335 token_type: None,
1336 expires_in: None,
1337 scope: None,
1338 error: None,
1339 error_description: None,
1340 };
1341
1342 let jarm = manager
1343 .create_jarm_response("enc_client", &auth_response, JarmDeliveryMode::Query, None)
1344 .await
1345 .expect("create_jarm_response with JWE failed");
1346
1347 let parts: Vec<&str> = jarm.response_token.split('.').collect();
1349 assert_eq!(
1350 parts.len(),
1351 5,
1352 "JWE token should have 5 parts, got: {}",
1353 parts.len()
1354 );
1355
1356 let recovered = manager
1358 .decrypt_jwe_with_algorithm(&parts, "RSA-OAEP-256", "A256GCM")
1359 .await
1360 .expect("JWE decryption failed");
1361
1362 assert_eq!(
1364 recovered.split('.').count(),
1365 3,
1366 "Recovered payload should be a 3-part JWT"
1367 );
1368 }
1369
1370 #[test]
1371 fn test_form_post_html_generation() {
1372 let config = AdvancedJarmConfig::builder()
1373 .supported_algorithms(vec![Algorithm::HS256])
1374 .build();
1375 let manager = AdvancedJarmManager::new(config);
1376
1377 let html = manager.generate_form_post_html(
1378 "https://client.example.com/callback",
1379 "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
1380 );
1381
1382 assert!(html.contains("https://client.example.com/callback"));
1383 assert!(html.contains("eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9"));
1384 assert!(html.contains("jarm_form"));
1385 }
1386
1387 #[tokio::test]
1388 async fn test_delivery_mode_validation() {
1389 let config = AdvancedJarmConfig::builder()
1390 .supported_delivery_modes(vec![JarmDeliveryMode::Query])
1391 .supported_algorithms(vec![Algorithm::HS256]) .build();
1393 let manager = AdvancedJarmManager::new(config);
1394
1395 let auth_response = AuthorizationResponse {
1396 code: Some("code123".to_string()),
1397 state: None,
1398 access_token: None,
1399 id_token: None,
1400 token_type: None,
1401 expires_in: None,
1402 scope: None,
1403 error: None,
1404 error_description: None,
1405 };
1406
1407 let result = manager
1409 .create_jarm_response("test_client", &auth_response, JarmDeliveryMode::Query, None)
1410 .await;
1411 assert!(result.is_ok());
1412
1413 let result = manager
1415 .create_jarm_response("test_client", &auth_response, JarmDeliveryMode::Push, None)
1416 .await;
1417 assert!(result.is_err());
1418 }
1419}