1use std::{
46 sync::Arc,
47 time::{Duration, Instant},
48};
49
50use chrono::{DateTime, Utc};
51use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
52use parking_lot::RwLock;
53use serde::{Deserialize, Serialize};
54
55use crate::security::{
56 auth_middleware::AuthenticatedUser,
57 errors::{Result, SecurityError},
58};
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct OidcConfig {
73 pub issuer: String,
78
79 #[serde(default)]
92 pub audience: Option<String>,
93
94 #[serde(default)]
98 pub additional_audiences: Vec<String>,
99
100 #[serde(default = "default_jwks_cache_ttl")]
105 pub jwks_cache_ttl_secs: u64,
106
107 #[serde(default = "default_algorithms")]
111 pub allowed_algorithms: Vec<String>,
112
113 #[serde(default = "default_clock_skew")]
119 pub clock_skew_secs: u64,
120
121 #[serde(default)]
126 pub jwks_uri: Option<String>,
127
128 #[serde(default = "default_required")]
133 pub required: bool,
134
135 #[serde(default = "default_scope_claim")]
141 pub scope_claim: String,
142}
143
144fn default_jwks_cache_ttl() -> u64 {
145 300
148}
149
150fn default_algorithms() -> Vec<String> {
151 vec!["RS256".to_string()]
152}
153
154fn default_clock_skew() -> u64 {
155 60
156}
157
158fn default_required() -> bool {
159 true
160}
161
162fn default_scope_claim() -> String {
163 "scope".to_string()
164}
165
166impl Default for OidcConfig {
167 fn default() -> Self {
168 Self {
169 issuer: String::new(),
170 audience: None,
171 additional_audiences: Vec::new(),
172 jwks_cache_ttl_secs: default_jwks_cache_ttl(),
173 allowed_algorithms: default_algorithms(),
174 clock_skew_secs: default_clock_skew(),
175 jwks_uri: None,
176 required: default_required(),
177 scope_claim: default_scope_claim(),
178 }
179 }
180}
181
182impl OidcConfig {
183 #[must_use]
190 pub fn auth0(domain: &str, audience: &str) -> Self {
191 Self {
192 issuer: format!("https://{domain}/"),
193 audience: Some(audience.to_string()),
194 ..Default::default()
195 }
196 }
197
198 #[must_use]
206 pub fn keycloak(base_url: &str, realm: &str, client_id: &str) -> Self {
207 Self {
208 issuer: format!("{base_url}/realms/{realm}"),
209 audience: Some(client_id.to_string()),
210 ..Default::default()
211 }
212 }
213
214 #[must_use]
221 pub fn okta(domain: &str, audience: &str) -> Self {
222 Self {
223 issuer: format!("https://{domain}"),
224 audience: Some(audience.to_string()),
225 ..Default::default()
226 }
227 }
228
229 #[must_use]
237 pub fn cognito(region: &str, user_pool_id: &str, client_id: &str) -> Self {
238 Self {
239 issuer: format!("https://cognito-idp.{region}.amazonaws.com/{user_pool_id}"),
240 audience: Some(client_id.to_string()),
241 ..Default::default()
242 }
243 }
244
245 #[must_use]
252 pub fn azure_ad(tenant_id: &str, client_id: &str) -> Self {
253 Self {
254 issuer: format!("https://login.microsoftonline.com/{tenant_id}/v2.0"),
255 audience: Some(client_id.to_string()),
256 ..Default::default()
257 }
258 }
259
260 #[must_use]
266 pub fn google(client_id: &str) -> Self {
267 Self {
268 issuer: "https://accounts.google.com".to_string(),
269 audience: Some(client_id.to_string()),
270 ..Default::default()
271 }
272 }
273
274 pub fn validate(&self) -> Result<()> {
276 if self.issuer.is_empty() {
277 return Err(SecurityError::SecurityConfigError(
278 "OIDC issuer URL is required".to_string(),
279 ));
280 }
281
282 if !self.issuer.starts_with("https://") && !self.issuer.starts_with("http://localhost") {
283 return Err(SecurityError::SecurityConfigError(
284 "OIDC issuer must use HTTPS (except localhost for development)".to_string(),
285 ));
286 }
287
288 if self.audience.is_none() && self.additional_audiences.is_empty() {
292 return Err(SecurityError::SecurityConfigError(
293 "OIDC audience is REQUIRED for security. Set 'audience' in auth config to your API identifier. \
294 This prevents token confusion attacks where tokens from one service can be used in another. \
295 Example: audience = \"https://api.example.com\" or audience = \"my-api-id\"".to_string(),
296 ));
297 }
298
299 if self.allowed_algorithms.is_empty() {
300 return Err(SecurityError::SecurityConfigError(
301 "At least one algorithm must be allowed".to_string(),
302 ));
303 }
304
305 Ok(())
306 }
307}
308
309#[derive(Debug, Clone, Deserialize)]
317pub struct OidcDiscoveryDocument {
318 pub issuer: String,
320
321 pub jwks_uri: String,
323
324 #[serde(default)]
326 pub id_token_signing_alg_values_supported: Vec<String>,
327
328 #[serde(default)]
330 pub authorization_endpoint: Option<String>,
331
332 #[serde(default)]
334 pub token_endpoint: Option<String>,
335}
336
337#[derive(Debug, Clone, Deserialize)]
343pub struct Jwks {
344 pub keys: Vec<Jwk>,
346}
347
348#[derive(Debug, Clone, Deserialize)]
350pub struct Jwk {
351 pub kty: String,
353
354 pub kid: Option<String>,
356
357 #[serde(default)]
359 pub alg: Option<String>,
360
361 #[serde(rename = "use")]
363 pub key_use: Option<String>,
364
365 pub n: Option<String>,
367
368 pub e: Option<String>,
370
371 #[serde(default)]
373 pub x5c: Vec<String>,
374}
375
376#[derive(Debug)]
378struct CachedJwks {
379 jwks: Jwks,
380 fetched_at: Instant,
381 ttl: Duration,
382}
383
384impl CachedJwks {
385 fn is_expired(&self) -> bool {
386 self.fetched_at.elapsed() > self.ttl
387 }
388}
389
390#[derive(Debug, Clone, Deserialize)]
396pub struct JwtClaims {
397 pub sub: Option<String>,
399
400 pub iss: Option<String>,
402
403 #[serde(default)]
405 pub aud: Audience,
406
407 pub exp: Option<i64>,
409
410 pub iat: Option<i64>,
412
413 pub nbf: Option<i64>,
415
416 pub scope: Option<String>,
418
419 pub scp: Option<Vec<String>>,
421
422 pub permissions: Option<Vec<String>>,
424
425 pub email: Option<String>,
427
428 pub email_verified: Option<bool>,
430
431 pub name: Option<String>,
433}
434
435#[derive(Debug, Clone, Default, Deserialize)]
437#[serde(untagged)]
438pub enum Audience {
439 #[default]
441 None,
442 Single(String),
444 Multiple(Vec<String>),
446}
447
448impl Audience {
449 pub fn contains(&self, value: &str) -> bool {
451 match self {
452 Self::None => false,
453 Self::Single(s) => s == value,
454 Self::Multiple(v) => v.iter().any(|s| s == value),
455 }
456 }
457
458 pub fn to_vec(&self) -> Vec<String> {
460 match self {
461 Self::None => Vec::new(),
462 Self::Single(s) => vec![s.clone()],
463 Self::Multiple(v) => v.clone(),
464 }
465 }
466}
467
468pub struct OidcValidator {
477 config: OidcConfig,
478 http_client: reqwest::Client,
479 jwks_cache: Arc<RwLock<Option<CachedJwks>>>,
480 jwks_uri: String,
481}
482
483impl OidcValidator {
484 pub async fn new(config: OidcConfig) -> Result<Self> {
496 config.validate()?;
497
498 let http_client = reqwest::Client::builder()
499 .timeout(Duration::from_secs(30))
500 .build()
501 .map_err(|e| SecurityError::SecurityConfigError(format!("HTTP client error: {e}")))?;
502
503 let jwks_uri = if let Some(ref uri) = config.jwks_uri {
505 uri.clone()
506 } else {
507 let discovery_url =
509 format!("{}/.well-known/openid-configuration", config.issuer.trim_end_matches('/'));
510
511 tracing::debug!(url = %discovery_url, "Performing OIDC discovery");
512
513 let response = http_client.get(&discovery_url).send().await.map_err(|e| {
514 SecurityError::SecurityConfigError(format!("OIDC discovery failed: {e}"))
515 })?;
516
517 if !response.status().is_success() {
518 return Err(SecurityError::SecurityConfigError(format!(
519 "OIDC discovery failed with status: {}",
520 response.status()
521 )));
522 }
523
524 let discovery: OidcDiscoveryDocument = response.json().await.map_err(|e| {
525 SecurityError::SecurityConfigError(format!("Invalid OIDC discovery response: {e}"))
526 })?;
527
528 tracing::info!(
529 issuer = %discovery.issuer,
530 jwks_uri = %discovery.jwks_uri,
531 "OIDC discovery successful"
532 );
533
534 discovery.jwks_uri
535 };
536
537 Ok(Self {
538 config,
539 http_client,
540 jwks_cache: Arc::new(RwLock::new(None)),
541 jwks_uri,
542 })
543 }
544
545 #[must_use]
549 pub fn with_jwks_uri(config: OidcConfig, jwks_uri: String) -> Self {
550 Self {
551 config,
552 http_client: reqwest::Client::new(),
553 jwks_cache: Arc::new(RwLock::new(None)),
554 jwks_uri,
555 }
556 }
557
558 pub async fn validate_token(&self, token: &str) -> Result<AuthenticatedUser> {
577 let header = decode_header(token).map_err(|e| {
579 tracing::debug!(error = %e, "Failed to decode JWT header");
580 SecurityError::InvalidToken
581 })?;
582
583 let kid = header.kid.as_ref().ok_or_else(|| {
584 tracing::debug!("JWT missing kid (key ID) in header");
585 SecurityError::InvalidToken
586 })?;
587
588 let decoding_key = self.get_decoding_key(kid).await?;
590
591 let mut validation = Validation::new(self.get_algorithm(&header)?);
593 validation.set_issuer(&[&self.config.issuer]);
594
595 if let Some(ref aud) = self.config.audience {
597 let mut audiences = vec![aud.clone()];
598 audiences.extend(self.config.additional_audiences.clone());
599 validation.set_audience(&audiences);
600 } else {
601 validation.validate_aud = false;
602 }
603
604 validation.leeway = self.config.clock_skew_secs;
606
607 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
609 tracing::debug!(error = %e, "JWT validation failed");
610 match e.kind() {
611 jsonwebtoken::errors::ErrorKind::ExpiredSignature => SecurityError::TokenExpired {
612 expired_at: Utc::now(), },
614 jsonwebtoken::errors::ErrorKind::InvalidIssuer => SecurityError::InvalidToken,
615 jsonwebtoken::errors::ErrorKind::InvalidAudience => SecurityError::InvalidToken,
616 jsonwebtoken::errors::ErrorKind::InvalidSignature => SecurityError::InvalidToken,
617 _ => SecurityError::InvalidToken,
618 }
619 })?;
620
621 let claims = token_data.claims;
622
623 let scopes = self.extract_scopes(&claims);
625
626 let user_id = claims.sub.ok_or(SecurityError::TokenMissingClaim {
628 claim: "sub".to_string(),
629 })?;
630
631 let exp = claims.exp.ok_or(SecurityError::TokenMissingClaim {
633 claim: "exp".to_string(),
634 })?;
635
636 let expires_at =
637 DateTime::<Utc>::from_timestamp(exp, 0).ok_or(SecurityError::InvalidToken)?;
638
639 tracing::debug!(
640 user_id = %user_id,
641 scopes = ?scopes,
642 expires_at = %expires_at,
643 "Token validated successfully"
644 );
645
646 Ok(AuthenticatedUser {
647 user_id,
648 scopes,
649 expires_at,
650 })
651 }
652
653 async fn get_decoding_key(&self, kid: &str) -> Result<DecodingKey> {
655 {
657 let cache = self.jwks_cache.read();
658 if let Some(ref cached) = *cache {
659 if !cached.is_expired() {
660 if let Some(key) = self.find_key(&cached.jwks, kid) {
661 return self.jwk_to_decoding_key(key);
662 }
663 }
664 }
665 }
666
667 let jwks = self.fetch_jwks().await?;
669
670 if self.detect_key_rotation(&jwks) {
672 tracing::warn!(
673 "OIDC key rotation detected: some previously cached keys no longer available"
674 );
675 }
676
677 let key_index =
679 jwks.keys.iter().position(|k| k.kid.as_deref() == Some(kid)).ok_or_else(|| {
680 tracing::debug!(kid = %kid, "Key not found in JWKS");
681 SecurityError::InvalidToken
682 })?;
683
684 let key = jwks.keys[key_index].clone();
686
687 {
689 let mut cache = self.jwks_cache.write();
690 *cache = Some(CachedJwks {
691 jwks,
692 fetched_at: Instant::now(),
693 ttl: Duration::from_secs(self.config.jwks_cache_ttl_secs),
694 });
695 }
696
697 self.jwk_to_decoding_key(&key)
698 }
699
700 async fn fetch_jwks(&self) -> Result<Jwks> {
702 tracing::debug!(uri = %self.jwks_uri, "Fetching JWKS");
703
704 let response = self.http_client.get(&self.jwks_uri).send().await.map_err(|e| {
705 tracing::error!(error = %e, "Failed to fetch JWKS");
706 SecurityError::SecurityConfigError(format!("Failed to fetch JWKS: {e}"))
707 })?;
708
709 if !response.status().is_success() {
710 return Err(SecurityError::SecurityConfigError(format!(
711 "JWKS fetch failed with status: {}",
712 response.status()
713 )));
714 }
715
716 let jwks: Jwks = response.json().await.map_err(|e| {
717 SecurityError::SecurityConfigError(format!("Invalid JWKS response: {e}"))
718 })?;
719
720 tracing::debug!(key_count = jwks.keys.len(), "JWKS fetched successfully");
721
722 Ok(jwks)
723 }
724
725 fn find_key<'a>(&self, jwks: &'a Jwks, kid: &str) -> Option<&'a Jwk> {
727 jwks.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
728 }
729
730 fn detect_key_rotation(&self, new_jwks: &Jwks) -> bool {
735 let cache = self.jwks_cache.read();
736 if let Some(ref cached) = *cache {
737 let old_kids: std::collections::HashSet<_> =
739 cached.jwks.keys.iter().filter_map(|k| k.kid.as_deref()).collect();
740
741 let new_kids: std::collections::HashSet<_> =
743 new_jwks.keys.iter().filter_map(|k| k.kid.as_deref()).collect();
744
745 !old_kids.is_subset(&new_kids)
747 } else {
748 false
749 }
750 }
751
752 fn jwk_to_decoding_key(&self, jwk: &Jwk) -> Result<DecodingKey> {
754 match jwk.kty.as_str() {
755 "RSA" => {
756 let n = jwk.n.as_ref().ok_or(SecurityError::InvalidToken)?;
757 let e = jwk.e.as_ref().ok_or(SecurityError::InvalidToken)?;
758
759 DecodingKey::from_rsa_components(n, e).map_err(|e| {
760 tracing::debug!(error = %e, "Failed to create RSA decoding key");
761 SecurityError::InvalidToken
762 })
763 },
764 other => {
765 tracing::debug!(key_type = %other, "Unsupported key type");
766 Err(SecurityError::InvalidTokenAlgorithm {
767 algorithm: other.to_string(),
768 })
769 },
770 }
771 }
772
773 fn get_algorithm(&self, header: &jsonwebtoken::Header) -> Result<Algorithm> {
775 let alg_str = format!("{:?}", header.alg);
776
777 if !self.config.allowed_algorithms.contains(&alg_str) {
779 return Err(SecurityError::InvalidTokenAlgorithm { algorithm: alg_str });
780 }
781
782 Ok(header.alg)
783 }
784
785 fn extract_scopes(&self, claims: &JwtClaims) -> Vec<String> {
792 if self.config.scope_claim == "scope" {
794 if let Some(ref scope) = claims.scope {
795 return scope.split_whitespace().map(String::from).collect();
796 }
797 }
798
799 if let Some(ref scp) = claims.scp {
801 return scp.clone();
802 }
803
804 if let Some(ref perms) = claims.permissions {
806 return perms.clone();
807 }
808
809 if let Some(ref scope) = claims.scope {
811 return scope.split_whitespace().map(String::from).collect();
812 }
813
814 Vec::new()
815 }
816
817 #[must_use]
819 pub fn is_required(&self) -> bool {
820 self.config.required
821 }
822
823 #[must_use]
825 pub fn issuer(&self) -> &str {
826 &self.config.issuer
827 }
828
829 pub fn clear_cache(&self) {
833 let mut cache = self.jwks_cache.write();
834 *cache = None;
835 }
836}
837
838#[cfg(test)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_oidc_config_default() {
848 let config = OidcConfig::default();
849 assert!(config.issuer.is_empty());
850 assert!(config.audience.is_none());
851 assert_eq!(config.jwks_cache_ttl_secs, 300);
853 assert_eq!(config.allowed_algorithms, vec!["RS256"]);
854 assert_eq!(config.clock_skew_secs, 60);
855 assert!(config.required);
856 }
857
858 #[test]
859 fn test_oidc_config_auth0() {
860 let config = OidcConfig::auth0("my-tenant.auth0.com", "my-api");
861 assert_eq!(config.issuer, "https://my-tenant.auth0.com/");
862 assert_eq!(config.audience, Some("my-api".to_string()));
863 }
864
865 #[test]
866 fn test_oidc_config_keycloak() {
867 let config = OidcConfig::keycloak("https://keycloak.example.com", "myrealm", "myclient");
868 assert_eq!(config.issuer, "https://keycloak.example.com/realms/myrealm");
869 assert_eq!(config.audience, Some("myclient".to_string()));
870 }
871
872 #[test]
873 fn test_oidc_config_okta() {
874 let config = OidcConfig::okta("myorg.okta.com", "api://default");
875 assert_eq!(config.issuer, "https://myorg.okta.com");
876 assert_eq!(config.audience, Some("api://default".to_string()));
877 }
878
879 #[test]
880 fn test_oidc_config_cognito() {
881 let config = OidcConfig::cognito("us-east-1", "us-east-1_abc123", "client123");
882 assert_eq!(config.issuer, "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_abc123");
883 assert_eq!(config.audience, Some("client123".to_string()));
884 }
885
886 #[test]
887 fn test_oidc_config_azure_ad() {
888 let config = OidcConfig::azure_ad("tenant-id-123", "client-id-456");
889 assert_eq!(config.issuer, "https://login.microsoftonline.com/tenant-id-123/v2.0");
890 assert_eq!(config.audience, Some("client-id-456".to_string()));
891 }
892
893 #[test]
894 fn test_oidc_config_google() {
895 let config = OidcConfig::google("123456.apps.googleusercontent.com");
896 assert_eq!(config.issuer, "https://accounts.google.com");
897 assert_eq!(config.audience, Some("123456.apps.googleusercontent.com".to_string()));
898 }
899
900 #[test]
901 fn test_oidc_config_validate_empty_issuer() {
902 let config = OidcConfig::default();
903 let result = config.validate();
904 assert!(result.is_err());
905 assert!(matches!(result, Err(SecurityError::SecurityConfigError(_))));
906 }
907
908 #[test]
909 fn test_oidc_config_validate_http_issuer() {
910 let config = OidcConfig {
911 issuer: "http://insecure.example.com".to_string(),
912 ..Default::default()
913 };
914 let result = config.validate();
915 assert!(result.is_err());
916 }
917
918 #[test]
919 fn test_oidc_config_validate_localhost_allowed() {
920 let config = OidcConfig {
921 issuer: "http://localhost:8080".to_string(),
922 audience: Some("my-api".to_string()),
923 ..Default::default()
924 };
925 let result = config.validate();
926 assert!(result.is_ok());
927 }
928
929 #[test]
930 fn test_oidc_config_validate_https_required() {
931 let config = OidcConfig {
932 issuer: "https://secure.example.com".to_string(),
933 audience: Some("https://api.example.com".to_string()),
934 ..Default::default()
935 };
936 let result = config.validate();
937 assert!(result.is_ok());
938 }
939
940 #[test]
941 fn test_audience_none() {
942 let aud = Audience::None;
943 assert!(!aud.contains("test"));
944 assert!(aud.to_vec().is_empty());
945 }
946
947 #[test]
948 fn test_audience_single() {
949 let aud = Audience::Single("my-api".to_string());
950 assert!(aud.contains("my-api"));
951 assert!(!aud.contains("other"));
952 assert_eq!(aud.to_vec(), vec!["my-api"]);
953 }
954
955 #[test]
956 fn test_audience_multiple() {
957 let aud = Audience::Multiple(vec!["api1".to_string(), "api2".to_string()]);
958 assert!(aud.contains("api1"));
959 assert!(aud.contains("api2"));
960 assert!(!aud.contains("api3"));
961 assert_eq!(aud.to_vec(), vec!["api1", "api2"]);
962 }
963
964 #[test]
965 fn test_jwk_deserialization() {
966 let jwk_json = r#"{
967 "kty": "RSA",
968 "kid": "test-key-id",
969 "alg": "RS256",
970 "use": "sig",
971 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
972 "e": "AQAB"
973 }"#;
974
975 let jwk: Jwk = serde_json::from_str(jwk_json).unwrap();
976 assert_eq!(jwk.kty, "RSA");
977 assert_eq!(jwk.kid, Some("test-key-id".to_string()));
978 assert_eq!(jwk.alg, Some("RS256".to_string()));
979 assert!(jwk.n.is_some());
980 assert!(jwk.e.is_some());
981 }
982
983 #[test]
984 fn test_jwks_deserialization() {
985 let jwks_json = r#"{
986 "keys": [
987 {
988 "kty": "RSA",
989 "kid": "key1",
990 "n": "test_n",
991 "e": "AQAB"
992 },
993 {
994 "kty": "RSA",
995 "kid": "key2",
996 "n": "test_n2",
997 "e": "AQAB"
998 }
999 ]
1000 }"#;
1001
1002 let jwks: Jwks = serde_json::from_str(jwks_json).unwrap();
1003 assert_eq!(jwks.keys.len(), 2);
1004 assert_eq!(jwks.keys[0].kid, Some("key1".to_string()));
1005 assert_eq!(jwks.keys[1].kid, Some("key2".to_string()));
1006 }
1007
1008 #[test]
1009 fn test_jwt_claims_deserialization() {
1010 let claims_json = r#"{
1011 "sub": "user123",
1012 "iss": "https://issuer.example.com",
1013 "aud": "my-api",
1014 "exp": 1735689600,
1015 "iat": 1735686000,
1016 "scope": "read write",
1017 "email": "user@example.com"
1018 }"#;
1019
1020 let claims: JwtClaims = serde_json::from_str(claims_json).unwrap();
1021 assert_eq!(claims.sub, Some("user123".to_string()));
1022 assert_eq!(claims.iss, Some("https://issuer.example.com".to_string()));
1023 assert!(claims.aud.contains("my-api"));
1024 assert_eq!(claims.exp, Some(1_735_689_600));
1025 assert_eq!(claims.scope, Some("read write".to_string()));
1026 }
1027
1028 #[test]
1029 fn test_jwt_claims_array_audience() {
1030 let claims_json = r#"{
1031 "sub": "user123",
1032 "aud": ["api1", "api2"],
1033 "exp": 1735689600
1034 }"#;
1035
1036 let claims: JwtClaims = serde_json::from_str(claims_json).unwrap();
1037 assert!(claims.aud.contains("api1"));
1038 assert!(claims.aud.contains("api2"));
1039 }
1040
1041 #[test]
1042 fn test_oidc_discovery_document_deserialization() {
1043 let doc_json = r#"{
1044 "issuer": "https://issuer.example.com",
1045 "jwks_uri": "https://issuer.example.com/.well-known/jwks.json",
1046 "authorization_endpoint": "https://issuer.example.com/authorize",
1047 "token_endpoint": "https://issuer.example.com/oauth/token",
1048 "id_token_signing_alg_values_supported": ["RS256", "RS384", "RS512"]
1049 }"#;
1050
1051 let doc: OidcDiscoveryDocument = serde_json::from_str(doc_json).unwrap();
1052 assert_eq!(doc.issuer, "https://issuer.example.com");
1053 assert_eq!(doc.jwks_uri, "https://issuer.example.com/.well-known/jwks.json");
1054 assert_eq!(doc.id_token_signing_alg_values_supported.len(), 3);
1055 }
1056
1057 #[test]
1058 fn test_jwks_cache_ttl_reduced_for_security() {
1059 assert_eq!(default_jwks_cache_ttl(), 300, "Cache TTL should be 5 minutes (300 seconds)");
1062 }
1063
1064 #[test]
1065 fn test_cached_jwks_expiration() {
1066 let jwks = Jwks { keys: vec![] };
1068 let cached = CachedJwks {
1069 jwks,
1070 fetched_at: Instant::now(),
1071 ttl: Duration::from_secs(1),
1072 };
1073
1074 assert!(!cached.is_expired());
1076
1077 std::thread::sleep(Duration::from_millis(1100));
1079 assert!(cached.is_expired());
1080 }
1081
1082 #[test]
1083 fn test_detect_key_rotation_when_no_cache() {
1084 let config = OidcConfig {
1086 issuer: "http://localhost:8080".to_string(),
1087 ..Default::default()
1088 };
1089
1090 let validator = OidcValidator {
1091 config,
1092 http_client: reqwest::Client::new(),
1093 jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1094 jwks_cache: Arc::new(RwLock::new(None)),
1095 };
1096
1097 let new_jwks = Jwks {
1098 keys: vec![Jwk {
1099 kty: "RSA".to_string(),
1100 kid: Some("key1".to_string()),
1101 alg: None,
1102 key_use: None,
1103 n: None,
1104 e: None,
1105 x5c: vec![],
1106 }],
1107 };
1108
1109 assert!(!validator.detect_key_rotation(&new_jwks));
1111 }
1112
1113 #[test]
1114 fn test_detect_key_rotation_when_keys_removed() {
1115 let config = OidcConfig {
1117 issuer: "http://localhost:8080".to_string(),
1118 ..Default::default()
1119 };
1120
1121 let validator = OidcValidator {
1122 config,
1123 http_client: reqwest::Client::new(),
1124 jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1125 jwks_cache: Arc::new(RwLock::new(None)),
1126 };
1127
1128 let old_jwks = Jwks {
1130 keys: vec![
1131 Jwk {
1132 kty: "RSA".to_string(),
1133 kid: Some("old_key_1".to_string()),
1134 alg: None,
1135 key_use: None,
1136 n: None,
1137 e: None,
1138 x5c: vec![],
1139 },
1140 Jwk {
1141 kty: "RSA".to_string(),
1142 kid: Some("old_key_2".to_string()),
1143 alg: None,
1144 key_use: None,
1145 n: None,
1146 e: None,
1147 x5c: vec![],
1148 },
1149 ],
1150 };
1151
1152 {
1153 let mut cache = validator.jwks_cache.write();
1154 *cache = Some(CachedJwks {
1155 jwks: old_jwks,
1156 fetched_at: Instant::now(),
1157 ttl: Duration::from_secs(300),
1158 });
1159 }
1160
1161 let new_jwks = Jwks {
1163 keys: vec![
1164 Jwk {
1165 kty: "RSA".to_string(),
1166 kid: Some("old_key_1".to_string()),
1167 alg: None,
1168 key_use: None,
1169 n: None,
1170 e: None,
1171 x5c: vec![],
1172 },
1173 Jwk {
1174 kty: "RSA".to_string(),
1175 kid: Some("new_key_1".to_string()),
1176 alg: None,
1177 key_use: None,
1178 n: None,
1179 e: None,
1180 x5c: vec![],
1181 },
1182 ],
1183 };
1184
1185 assert!(validator.detect_key_rotation(&new_jwks));
1187 }
1188
1189 #[test]
1190 fn test_detect_key_rotation_when_no_keys_removed() {
1191 let config = OidcConfig {
1193 issuer: "http://localhost:8080".to_string(),
1194 ..Default::default()
1195 };
1196
1197 let validator = OidcValidator {
1198 config,
1199 http_client: reqwest::Client::new(),
1200 jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1201 jwks_cache: Arc::new(RwLock::new(None)),
1202 };
1203
1204 let old_jwks = Jwks {
1206 keys: vec![
1207 Jwk {
1208 kty: "RSA".to_string(),
1209 kid: Some("key_1".to_string()),
1210 alg: None,
1211 key_use: None,
1212 n: None,
1213 e: None,
1214 x5c: vec![],
1215 },
1216 Jwk {
1217 kty: "RSA".to_string(),
1218 kid: Some("key_2".to_string()),
1219 alg: None,
1220 key_use: None,
1221 n: None,
1222 e: None,
1223 x5c: vec![],
1224 },
1225 ],
1226 };
1227
1228 {
1229 let mut cache = validator.jwks_cache.write();
1230 *cache = Some(CachedJwks {
1231 jwks: old_jwks,
1232 fetched_at: Instant::now(),
1233 ttl: Duration::from_secs(300),
1234 });
1235 }
1236
1237 let new_jwks = Jwks {
1239 keys: vec![
1240 Jwk {
1241 kty: "RSA".to_string(),
1242 kid: Some("key_1".to_string()),
1243 alg: None,
1244 key_use: None,
1245 n: None,
1246 e: None,
1247 x5c: vec![],
1248 },
1249 Jwk {
1250 kty: "RSA".to_string(),
1251 kid: Some("key_2".to_string()),
1252 alg: None,
1253 key_use: None,
1254 n: None,
1255 e: None,
1256 x5c: vec![],
1257 },
1258 Jwk {
1259 kty: "RSA".to_string(),
1260 kid: Some("new_key".to_string()),
1261 alg: None,
1262 key_use: None,
1263 n: None,
1264 e: None,
1265 x5c: vec![],
1266 },
1267 ],
1268 };
1269
1270 assert!(!validator.detect_key_rotation(&new_jwks));
1272 }
1273
1274 #[test]
1275 fn test_find_key_by_kid() {
1276 let config = OidcConfig {
1278 issuer: "http://localhost:8080".to_string(),
1279 ..Default::default()
1280 };
1281
1282 let validator = OidcValidator {
1283 config,
1284 http_client: reqwest::Client::new(),
1285 jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1286 jwks_cache: Arc::new(RwLock::new(None)),
1287 };
1288
1289 let jwks = Jwks {
1290 keys: vec![
1291 Jwk {
1292 kty: "RSA".to_string(),
1293 kid: Some("key1".to_string()),
1294 alg: None,
1295 key_use: None,
1296 n: None,
1297 e: None,
1298 x5c: vec![],
1299 },
1300 Jwk {
1301 kty: "RSA".to_string(),
1302 kid: Some("key2".to_string()),
1303 alg: None,
1304 key_use: None,
1305 n: None,
1306 e: None,
1307 x5c: vec![],
1308 },
1309 ],
1310 };
1311
1312 assert!(validator.find_key(&jwks, "key1").is_some());
1314 assert!(validator.find_key(&jwks, "key2").is_some());
1315
1316 assert!(validator.find_key(&jwks, "key3").is_none());
1318 }
1319
1320 #[test]
1321 fn test_find_key_without_kid() {
1322 let config = OidcConfig {
1324 issuer: "http://localhost:8080".to_string(),
1325 ..Default::default()
1326 };
1327
1328 let validator = OidcValidator {
1329 config,
1330 http_client: reqwest::Client::new(),
1331 jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1332 jwks_cache: Arc::new(RwLock::new(None)),
1333 };
1334
1335 let jwks = Jwks {
1336 keys: vec![Jwk {
1337 kty: "RSA".to_string(),
1338 kid: None, alg: None,
1340 key_use: None,
1341 n: None,
1342 e: None,
1343 x5c: vec![],
1344 }],
1345 };
1346
1347 assert!(validator.find_key(&jwks, "any_kid").is_none());
1349 }
1350
1351 #[test]
1352 fn test_oidc_config_with_custom_cache_ttl() {
1353 let config = OidcConfig {
1355 issuer: "http://localhost:8080".to_string(),
1356 jwks_cache_ttl_secs: 600, ..Default::default()
1358 };
1359
1360 assert_eq!(config.jwks_cache_ttl_secs, 600);
1361 }
1362
1363 #[test]
1364 fn test_oidc_config_default_cache_ttl_is_short() {
1365 let config = OidcConfig::default();
1367 assert!(
1368 config.jwks_cache_ttl_secs <= 300,
1369 "Default cache TTL should be short (≤ 300 seconds) to prevent token poisoning"
1370 );
1371 }
1372}