Skip to main content

fraiseql_core/security/
oidc.rs

1//! OIDC Discovery and JWKS Support
2//!
3//! This module provides OpenID Connect discovery and JSON Web Key Set (JWKS)
4//! support for validating JWT tokens from any OIDC-compliant provider.
5//!
6//! Supported providers include:
7//! - Auth0
8//! - Keycloak
9//! - Okta
10//! - AWS Cognito
11//! - Microsoft Entra ID (Azure AD)
12//! - Google Identity
13//! - Any OIDC-compliant provider
14//!
15//! # Architecture
16//!
17//! ```text
18//! JWT Token from Client
19//!     ↓
20//! OidcValidator::validate_token()
21//!     ├─ Extract kid (key ID) from JWT header
22//!     ├─ Fetch/cache JWKS from provider
23//!     ├─ Find matching key by kid
24//!     ├─ Verify JWT signature
25//!     └─ Validate claims (iss, aud, exp)
26//!     ↓
27//! AuthenticatedUser (if valid)
28//! ```
29//!
30//! # Example
31//!
32//! ```ignore
33//! use fraiseql_core::security::oidc::{OidcConfig, OidcValidator};
34//!
35//! let config = OidcConfig {
36//!     issuer: "https://your-tenant.auth0.com/".to_string(),
37//!     audience: Some("your-api-identifier".to_string()),
38//!     ..Default::default()
39//! };
40//!
41//! let validator = OidcValidator::new(config).await?;
42//! let user = validator.validate_token("eyJhbG...").await?;
43//! ```
44
45use std::{
46    sync::Arc,
47    time::{Duration, Instant},
48};
49
50use chrono::{DateTime, Utc};
51use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
52use parking_lot::RwLock;
53use serde::{Deserialize, Serialize};
54
55use crate::security::{
56    auth_middleware::AuthenticatedUser,
57    errors::{Result, SecurityError},
58};
59
60// ============================================================================
61// OIDC Configuration
62// ============================================================================
63
64/// OIDC authentication configuration.
65///
66/// Configure this with your identity provider's issuer URL.
67/// The validator will automatically discover JWKS endpoint.
68///
69/// **SECURITY CRITICAL**: You MUST configure the `audience` field to prevent
70/// token confusion attacks. See the `audience` field documentation for details.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct OidcConfig {
73    /// Issuer URL (e.g., `https://your-tenant.auth0.com/`)
74    ///
75    /// Must match the `iss` claim in tokens exactly.
76    /// Should include trailing slash if provider expects it.
77    pub issuer: String,
78
79    /// Expected audience claim (REQUIRED for security).
80    ///
81    /// **SECURITY CRITICAL**: This field is mandatory. Tokens must have this value in their `aud`
82    /// claim. This prevents token confusion attacks where tokens from one service can be used
83    /// in another.
84    ///
85    /// For Auth0, this is typically your API identifier (e.g., `https://api.example.com`).
86    /// For other providers, use a unique identifier that represents your application.
87    ///
88    /// Set at least one of:
89    /// - `audience` (primary audience)
90    /// - `additional_audiences` (secondary audiences)
91    #[serde(default)]
92    pub audience: Option<String>,
93
94    /// Additional allowed audiences (optional).
95    ///
96    /// Some tokens may have multiple audiences. Add extras here.
97    #[serde(default)]
98    pub additional_audiences: Vec<String>,
99
100    /// JWKS cache TTL in seconds.
101    ///
102    /// How long to cache the JWKS before refetching.
103    /// Default: 3600 (1 hour)
104    #[serde(default = "default_jwks_cache_ttl")]
105    pub jwks_cache_ttl_secs: u64,
106
107    /// Allowed token algorithms.
108    ///
109    /// Default: RS256 (most common for OIDC providers)
110    #[serde(default = "default_algorithms")]
111    pub allowed_algorithms: Vec<String>,
112
113    /// Clock skew tolerance in seconds.
114    ///
115    /// Allow this many seconds of clock difference when
116    /// validating exp/nbf/iat claims.
117    /// Default: 60 seconds
118    #[serde(default = "default_clock_skew")]
119    pub clock_skew_secs: u64,
120
121    /// Custom JWKS URI (optional).
122    ///
123    /// If set, skip OIDC discovery and use this URI directly.
124    /// Useful for providers that don't support standard discovery.
125    #[serde(default)]
126    pub jwks_uri: Option<String>,
127
128    /// Require authentication for all requests.
129    ///
130    /// If false, requests without tokens are allowed (anonymous access).
131    /// Default: true
132    #[serde(default = "default_required")]
133    pub required: bool,
134
135    /// Scope claim name.
136    ///
137    /// The claim containing user scopes/permissions.
138    /// Default: "scope" (space-separated string)
139    /// Some providers use "scp" or "permissions" (array)
140    #[serde(default = "default_scope_claim")]
141    pub scope_claim: String,
142}
143
144fn default_jwks_cache_ttl() -> u64 {
145    // SECURITY: Reduced from 3600s (1 hour) to 300s (5 minutes)
146    // Prevents token cache poisoning by limiting revoked token window
147    300
148}
149
150fn default_algorithms() -> Vec<String> {
151    vec!["RS256".to_string()]
152}
153
154fn default_clock_skew() -> u64 {
155    60
156}
157
158fn default_required() -> bool {
159    true
160}
161
162fn default_scope_claim() -> String {
163    "scope".to_string()
164}
165
166impl Default for OidcConfig {
167    fn default() -> Self {
168        Self {
169            issuer:               String::new(),
170            audience:             None,
171            additional_audiences: Vec::new(),
172            jwks_cache_ttl_secs:  default_jwks_cache_ttl(),
173            allowed_algorithms:   default_algorithms(),
174            clock_skew_secs:      default_clock_skew(),
175            jwks_uri:             None,
176            required:             default_required(),
177            scope_claim:          default_scope_claim(),
178        }
179    }
180}
181
182impl OidcConfig {
183    /// Create config for Auth0.
184    ///
185    /// # Arguments
186    ///
187    /// * `domain` - Your Auth0 domain (e.g., "your-tenant.auth0.com")
188    /// * `audience` - Your API identifier
189    #[must_use]
190    pub fn auth0(domain: &str, audience: &str) -> Self {
191        Self {
192            issuer: format!("https://{domain}/"),
193            audience: Some(audience.to_string()),
194            ..Default::default()
195        }
196    }
197
198    /// Create config for Keycloak.
199    ///
200    /// # Arguments
201    ///
202    /// * `base_url` - Keycloak server URL (e.g., `https://keycloak.example.com`)
203    /// * `realm` - Realm name
204    /// * `client_id` - Client ID (used as audience)
205    #[must_use]
206    pub fn keycloak(base_url: &str, realm: &str, client_id: &str) -> Self {
207        Self {
208            issuer: format!("{base_url}/realms/{realm}"),
209            audience: Some(client_id.to_string()),
210            ..Default::default()
211        }
212    }
213
214    /// Create config for Okta.
215    ///
216    /// # Arguments
217    ///
218    /// * `domain` - Your Okta domain (e.g., "your-org.okta.com")
219    /// * `audience` - Your API audience (often "api://default")
220    #[must_use]
221    pub fn okta(domain: &str, audience: &str) -> Self {
222        Self {
223            issuer: format!("https://{domain}"),
224            audience: Some(audience.to_string()),
225            ..Default::default()
226        }
227    }
228
229    /// Create config for AWS Cognito.
230    ///
231    /// # Arguments
232    ///
233    /// * `region` - AWS region (e.g., "us-east-1")
234    /// * `user_pool_id` - Cognito User Pool ID
235    /// * `client_id` - App client ID (used as audience)
236    #[must_use]
237    pub fn cognito(region: &str, user_pool_id: &str, client_id: &str) -> Self {
238        Self {
239            issuer: format!("https://cognito-idp.{region}.amazonaws.com/{user_pool_id}"),
240            audience: Some(client_id.to_string()),
241            ..Default::default()
242        }
243    }
244
245    /// Create config for Microsoft Entra ID (Azure AD).
246    ///
247    /// # Arguments
248    ///
249    /// * `tenant_id` - Azure AD tenant ID
250    /// * `client_id` - Application (client) ID
251    #[must_use]
252    pub fn azure_ad(tenant_id: &str, client_id: &str) -> Self {
253        Self {
254            issuer: format!("https://login.microsoftonline.com/{tenant_id}/v2.0"),
255            audience: Some(client_id.to_string()),
256            ..Default::default()
257        }
258    }
259
260    /// Create config for Google Identity.
261    ///
262    /// # Arguments
263    ///
264    /// * `client_id` - Google OAuth client ID
265    #[must_use]
266    pub fn google(client_id: &str) -> Self {
267        Self {
268            issuer: "https://accounts.google.com".to_string(),
269            audience: Some(client_id.to_string()),
270            ..Default::default()
271        }
272    }
273
274    /// Validate the configuration.
275    pub fn validate(&self) -> Result<()> {
276        if self.issuer.is_empty() {
277            return Err(SecurityError::SecurityConfigError(
278                "OIDC issuer URL is required".to_string(),
279            ));
280        }
281
282        if !self.issuer.starts_with("https://") && !self.issuer.starts_with("http://localhost") {
283            return Err(SecurityError::SecurityConfigError(
284                "OIDC issuer must use HTTPS (except localhost for development)".to_string(),
285            ));
286        }
287
288        // CRITICAL SECURITY FIX: Audience validation is now mandatory
289        // This prevents token confusion attacks where tokens intended for service A
290        // can be used for service B.
291        if self.audience.is_none() && self.additional_audiences.is_empty() {
292            return Err(SecurityError::SecurityConfigError(
293                "OIDC audience is REQUIRED for security. Set 'audience' in auth config to your API identifier. \
294                 This prevents token confusion attacks where tokens from one service can be used in another. \
295                 Example: audience = \"https://api.example.com\" or audience = \"my-api-id\"".to_string(),
296            ));
297        }
298
299        if self.allowed_algorithms.is_empty() {
300            return Err(SecurityError::SecurityConfigError(
301                "At least one algorithm must be allowed".to_string(),
302            ));
303        }
304
305        Ok(())
306    }
307}
308
309// ============================================================================
310// OIDC Discovery Response
311// ============================================================================
312
313/// OIDC Discovery document (partial).
314///
315/// Contains the fields we need from `/.well-known/openid-configuration`.
316#[derive(Debug, Clone, Deserialize)]
317pub struct OidcDiscoveryDocument {
318    /// Issuer identifier
319    pub issuer: String,
320
321    /// JWKS URI for fetching public keys
322    pub jwks_uri: String,
323
324    /// Supported signing algorithms
325    #[serde(default)]
326    pub id_token_signing_alg_values_supported: Vec<String>,
327
328    /// Authorization endpoint (for reference)
329    #[serde(default)]
330    pub authorization_endpoint: Option<String>,
331
332    /// Token endpoint (for reference)
333    #[serde(default)]
334    pub token_endpoint: Option<String>,
335}
336
337// ============================================================================
338// JWKS Types
339// ============================================================================
340
341/// JSON Web Key Set.
342#[derive(Debug, Clone, Deserialize)]
343pub struct Jwks {
344    /// Array of JSON Web Keys
345    pub keys: Vec<Jwk>,
346}
347
348/// JSON Web Key.
349#[derive(Debug, Clone, Deserialize)]
350pub struct Jwk {
351    /// Key type (e.g., "RSA")
352    pub kty: String,
353
354    /// Key ID (used to match with JWT header)
355    pub kid: Option<String>,
356
357    /// Algorithm (e.g., "RS256")
358    #[serde(default)]
359    pub alg: Option<String>,
360
361    /// Intended use (e.g., "sig" for signature)
362    #[serde(rename = "use")]
363    pub key_use: Option<String>,
364
365    /// RSA modulus (base64url encoded)
366    pub n: Option<String>,
367
368    /// RSA exponent (base64url encoded)
369    pub e: Option<String>,
370
371    /// X.509 certificate chain
372    #[serde(default)]
373    pub x5c: Vec<String>,
374}
375
376/// Cached JWKS with expiration.
377#[derive(Debug)]
378struct CachedJwks {
379    jwks:       Jwks,
380    fetched_at: Instant,
381    ttl:        Duration,
382}
383
384impl CachedJwks {
385    fn is_expired(&self) -> bool {
386        self.fetched_at.elapsed() > self.ttl
387    }
388}
389
390// ============================================================================
391// JWT Claims
392// ============================================================================
393
394/// Standard JWT claims for validation.
395#[derive(Debug, Clone, Deserialize)]
396pub struct JwtClaims {
397    /// Subject (user ID)
398    pub sub: Option<String>,
399
400    /// Issuer
401    pub iss: Option<String>,
402
403    /// Audience (can be string or array)
404    #[serde(default)]
405    pub aud: Audience,
406
407    /// Expiration time (Unix timestamp)
408    pub exp: Option<i64>,
409
410    /// Issued at (Unix timestamp)
411    pub iat: Option<i64>,
412
413    /// Not before (Unix timestamp)
414    pub nbf: Option<i64>,
415
416    /// Scope (space-separated string, common in Auth0/Okta)
417    pub scope: Option<String>,
418
419    /// Scopes (array, common in some providers)
420    pub scp: Option<Vec<String>>,
421
422    /// Permissions (array, common in Auth0)
423    pub permissions: Option<Vec<String>>,
424
425    /// Email claim
426    pub email: Option<String>,
427
428    /// Email verified
429    pub email_verified: Option<bool>,
430
431    /// Name claim
432    pub name: Option<String>,
433}
434
435/// Audience can be a single string or array of strings.
436#[derive(Debug, Clone, Default, Deserialize)]
437#[serde(untagged)]
438pub enum Audience {
439    /// No audience specified.
440    #[default]
441    None,
442    /// Single audience string.
443    Single(String),
444    /// Multiple audiences as an array.
445    Multiple(Vec<String>),
446}
447
448impl Audience {
449    /// Check if the audience contains a specific value.
450    pub fn contains(&self, value: &str) -> bool {
451        match self {
452            Self::None => false,
453            Self::Single(s) => s == value,
454            Self::Multiple(v) => v.iter().any(|s| s == value),
455        }
456    }
457
458    /// Get all audience values as a vector.
459    pub fn to_vec(&self) -> Vec<String> {
460        match self {
461            Self::None => Vec::new(),
462            Self::Single(s) => vec![s.clone()],
463            Self::Multiple(v) => v.clone(),
464        }
465    }
466}
467
468// ============================================================================
469// OIDC Validator
470// ============================================================================
471
472/// OIDC token validator with JWKS caching.
473///
474/// Validates JWT tokens against an OIDC provider's public keys.
475/// Automatically fetches and caches the JWKS for efficiency.
476pub struct OidcValidator {
477    config:      OidcConfig,
478    http_client: reqwest::Client,
479    jwks_cache:  Arc<RwLock<Option<CachedJwks>>>,
480    jwks_uri:    String,
481}
482
483impl OidcValidator {
484    /// Create a new OIDC validator.
485    ///
486    /// This will perform OIDC discovery to find the JWKS URI
487    /// unless `jwks_uri` is explicitly set in config.
488    ///
489    /// # Errors
490    ///
491    /// Returns error if:
492    /// - Config validation fails
493    /// - OIDC discovery fails
494    /// - JWKS endpoint cannot be determined
495    pub async fn new(config: OidcConfig) -> Result<Self> {
496        config.validate()?;
497
498        let http_client = reqwest::Client::builder()
499            .timeout(Duration::from_secs(30))
500            .build()
501            .map_err(|e| SecurityError::SecurityConfigError(format!("HTTP client error: {e}")))?;
502
503        // Determine JWKS URI
504        let jwks_uri = if let Some(ref uri) = config.jwks_uri {
505            uri.clone()
506        } else {
507            // Perform OIDC discovery
508            let discovery_url =
509                format!("{}/.well-known/openid-configuration", config.issuer.trim_end_matches('/'));
510
511            tracing::debug!(url = %discovery_url, "Performing OIDC discovery");
512
513            let response = http_client.get(&discovery_url).send().await.map_err(|e| {
514                SecurityError::SecurityConfigError(format!("OIDC discovery failed: {e}"))
515            })?;
516
517            if !response.status().is_success() {
518                return Err(SecurityError::SecurityConfigError(format!(
519                    "OIDC discovery failed with status: {}",
520                    response.status()
521                )));
522            }
523
524            let discovery: OidcDiscoveryDocument = response.json().await.map_err(|e| {
525                SecurityError::SecurityConfigError(format!("Invalid OIDC discovery response: {e}"))
526            })?;
527
528            tracing::info!(
529                issuer = %discovery.issuer,
530                jwks_uri = %discovery.jwks_uri,
531                "OIDC discovery successful"
532            );
533
534            discovery.jwks_uri
535        };
536
537        Ok(Self {
538            config,
539            http_client,
540            jwks_cache: Arc::new(RwLock::new(None)),
541            jwks_uri,
542        })
543    }
544
545    /// Create a validator without performing discovery.
546    ///
547    /// Use this for testing or when you have the JWKS URI directly.
548    #[must_use]
549    pub fn with_jwks_uri(config: OidcConfig, jwks_uri: String) -> Self {
550        Self {
551            config,
552            http_client: reqwest::Client::new(),
553            jwks_cache: Arc::new(RwLock::new(None)),
554            jwks_uri,
555        }
556    }
557
558    /// Validate a JWT token and extract user information.
559    ///
560    /// # Arguments
561    ///
562    /// * `token` - The JWT token string (without "Bearer " prefix)
563    ///
564    /// # Returns
565    ///
566    /// `AuthenticatedUser` if token is valid, error otherwise.
567    ///
568    /// # Errors
569    ///
570    /// Returns error if:
571    /// - Token is malformed
572    /// - Signature verification fails
573    /// - Required claims are missing
574    /// - Token is expired
575    /// - Issuer/audience don't match
576    pub async fn validate_token(&self, token: &str) -> Result<AuthenticatedUser> {
577        // Decode header to get kid
578        let header = decode_header(token).map_err(|e| {
579            tracing::debug!(error = %e, "Failed to decode JWT header");
580            SecurityError::InvalidToken
581        })?;
582
583        let kid = header.kid.as_ref().ok_or_else(|| {
584            tracing::debug!("JWT missing kid (key ID) in header");
585            SecurityError::InvalidToken
586        })?;
587
588        // Get the signing key
589        let decoding_key = self.get_decoding_key(kid).await?;
590
591        // Build validation
592        let mut validation = Validation::new(self.get_algorithm(&header)?);
593        validation.set_issuer(&[&self.config.issuer]);
594
595        // Set audience validation
596        if let Some(ref aud) = self.config.audience {
597            let mut audiences = vec![aud.clone()];
598            audiences.extend(self.config.additional_audiences.clone());
599            validation.set_audience(&audiences);
600        } else {
601            validation.validate_aud = false;
602        }
603
604        // Set clock skew tolerance
605        validation.leeway = self.config.clock_skew_secs;
606
607        // Decode and validate token
608        let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
609            tracing::debug!(error = %e, "JWT validation failed");
610            match e.kind() {
611                jsonwebtoken::errors::ErrorKind::ExpiredSignature => SecurityError::TokenExpired {
612                    expired_at: Utc::now(), // Approximate
613                },
614                jsonwebtoken::errors::ErrorKind::InvalidIssuer => SecurityError::InvalidToken,
615                jsonwebtoken::errors::ErrorKind::InvalidAudience => SecurityError::InvalidToken,
616                jsonwebtoken::errors::ErrorKind::InvalidSignature => SecurityError::InvalidToken,
617                _ => SecurityError::InvalidToken,
618            }
619        })?;
620
621        let claims = token_data.claims;
622
623        // Extract scopes first (before moving claims.sub)
624        let scopes = self.extract_scopes(&claims);
625
626        // Extract user ID (required)
627        let user_id = claims.sub.ok_or(SecurityError::TokenMissingClaim {
628            claim: "sub".to_string(),
629        })?;
630
631        // Extract expiration (required)
632        let exp = claims.exp.ok_or(SecurityError::TokenMissingClaim {
633            claim: "exp".to_string(),
634        })?;
635
636        let expires_at =
637            DateTime::<Utc>::from_timestamp(exp, 0).ok_or(SecurityError::InvalidToken)?;
638
639        tracing::debug!(
640            user_id = %user_id,
641            scopes = ?scopes,
642            expires_at = %expires_at,
643            "Token validated successfully"
644        );
645
646        Ok(AuthenticatedUser {
647            user_id,
648            scopes,
649            expires_at,
650        })
651    }
652
653    /// Get the decoding key for a specific key ID.
654    async fn get_decoding_key(&self, kid: &str) -> Result<DecodingKey> {
655        // Check cache first
656        {
657            let cache = self.jwks_cache.read();
658            if let Some(ref cached) = *cache {
659                if !cached.is_expired() {
660                    if let Some(key) = self.find_key(&cached.jwks, kid) {
661                        return self.jwk_to_decoding_key(key);
662                    }
663                }
664            }
665        }
666
667        // Fetch fresh JWKS
668        let jwks = self.fetch_jwks().await?;
669
670        // SECURITY: Detect key rotation for audit purposes
671        if self.detect_key_rotation(&jwks) {
672            tracing::warn!(
673                "OIDC key rotation detected: some previously cached keys no longer available"
674            );
675        }
676
677        // Find the key index first, then we can clone the key
678        let key_index =
679            jwks.keys.iter().position(|k| k.kid.as_deref() == Some(kid)).ok_or_else(|| {
680                tracing::debug!(kid = %kid, "Key not found in JWKS");
681                SecurityError::InvalidToken
682            })?;
683
684        // Clone the key before caching (keys are small, cloning is fine)
685        let key = jwks.keys[key_index].clone();
686
687        // Cache the JWKS
688        {
689            let mut cache = self.jwks_cache.write();
690            *cache = Some(CachedJwks {
691                jwks,
692                fetched_at: Instant::now(),
693                ttl: Duration::from_secs(self.config.jwks_cache_ttl_secs),
694            });
695        }
696
697        self.jwk_to_decoding_key(&key)
698    }
699
700    /// Fetch JWKS from the provider.
701    async fn fetch_jwks(&self) -> Result<Jwks> {
702        tracing::debug!(uri = %self.jwks_uri, "Fetching JWKS");
703
704        let response = self.http_client.get(&self.jwks_uri).send().await.map_err(|e| {
705            tracing::error!(error = %e, "Failed to fetch JWKS");
706            SecurityError::SecurityConfigError(format!("Failed to fetch JWKS: {e}"))
707        })?;
708
709        if !response.status().is_success() {
710            return Err(SecurityError::SecurityConfigError(format!(
711                "JWKS fetch failed with status: {}",
712                response.status()
713            )));
714        }
715
716        let jwks: Jwks = response.json().await.map_err(|e| {
717            SecurityError::SecurityConfigError(format!("Invalid JWKS response: {e}"))
718        })?;
719
720        tracing::debug!(key_count = jwks.keys.len(), "JWKS fetched successfully");
721
722        Ok(jwks)
723    }
724
725    /// Find a key in the JWKS by key ID.
726    fn find_key<'a>(&self, jwks: &'a Jwks, kid: &str) -> Option<&'a Jwk> {
727        jwks.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
728    }
729
730    /// Detect if JWKS keys have been rotated (old keys removed).
731    ///
732    /// Compares current cached keys with newly fetched keys.
733    /// Returns true if any previously cached keys are missing from the new JWKS.
734    fn detect_key_rotation(&self, new_jwks: &Jwks) -> bool {
735        let cache = self.jwks_cache.read();
736        if let Some(ref cached) = *cache {
737            // Get set of old key IDs
738            let old_kids: std::collections::HashSet<_> =
739                cached.jwks.keys.iter().filter_map(|k| k.kid.as_deref()).collect();
740
741            // Get set of new key IDs
742            let new_kids: std::collections::HashSet<_> =
743                new_jwks.keys.iter().filter_map(|k| k.kid.as_deref()).collect();
744
745            // Rotation detected if any old keys are missing
746            !old_kids.is_subset(&new_kids)
747        } else {
748            false
749        }
750    }
751
752    /// Convert a JWK to a jsonwebtoken DecodingKey.
753    fn jwk_to_decoding_key(&self, jwk: &Jwk) -> Result<DecodingKey> {
754        match jwk.kty.as_str() {
755            "RSA" => {
756                let n = jwk.n.as_ref().ok_or(SecurityError::InvalidToken)?;
757                let e = jwk.e.as_ref().ok_or(SecurityError::InvalidToken)?;
758
759                DecodingKey::from_rsa_components(n, e).map_err(|e| {
760                    tracing::debug!(error = %e, "Failed to create RSA decoding key");
761                    SecurityError::InvalidToken
762                })
763            },
764            other => {
765                tracing::debug!(key_type = %other, "Unsupported key type");
766                Err(SecurityError::InvalidTokenAlgorithm {
767                    algorithm: other.to_string(),
768                })
769            },
770        }
771    }
772
773    /// Get the algorithm from the JWT header.
774    fn get_algorithm(&self, header: &jsonwebtoken::Header) -> Result<Algorithm> {
775        let alg_str = format!("{:?}", header.alg);
776
777        // Check if algorithm is allowed
778        if !self.config.allowed_algorithms.contains(&alg_str) {
779            return Err(SecurityError::InvalidTokenAlgorithm { algorithm: alg_str });
780        }
781
782        Ok(header.alg)
783    }
784
785    /// Extract scopes from JWT claims.
786    ///
787    /// Handles multiple formats:
788    /// - `scope`: space-separated string (Auth0, Okta)
789    /// - `scp`: array of strings (some providers)
790    /// - `permissions`: array of strings (Auth0 RBAC)
791    fn extract_scopes(&self, claims: &JwtClaims) -> Vec<String> {
792        // Try the configured scope claim first (default: "scope")
793        if self.config.scope_claim == "scope" {
794            if let Some(ref scope) = claims.scope {
795                return scope.split_whitespace().map(String::from).collect();
796            }
797        }
798
799        // Try scp (array format)
800        if let Some(ref scp) = claims.scp {
801            return scp.clone();
802        }
803
804        // Try permissions (Auth0 RBAC)
805        if let Some(ref perms) = claims.permissions {
806            return perms.clone();
807        }
808
809        // Try scope as space-separated string
810        if let Some(ref scope) = claims.scope {
811            return scope.split_whitespace().map(String::from).collect();
812        }
813
814        Vec::new()
815    }
816
817    /// Check if authentication is required.
818    #[must_use]
819    pub fn is_required(&self) -> bool {
820        self.config.required
821    }
822
823    /// Get the configured issuer.
824    #[must_use]
825    pub fn issuer(&self) -> &str {
826        &self.config.issuer
827    }
828
829    /// Clear the JWKS cache.
830    ///
831    /// Call this if you need to force a refresh of the signing keys.
832    pub fn clear_cache(&self) {
833        let mut cache = self.jwks_cache.write();
834        *cache = None;
835    }
836}
837
838// ============================================================================
839// Tests
840// ============================================================================
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_oidc_config_default() {
848        let config = OidcConfig::default();
849        assert!(config.issuer.is_empty());
850        assert!(config.audience.is_none());
851        // SECURITY: Cache TTL reduced to 5 minutes to prevent token cache poisoning
852        assert_eq!(config.jwks_cache_ttl_secs, 300);
853        assert_eq!(config.allowed_algorithms, vec!["RS256"]);
854        assert_eq!(config.clock_skew_secs, 60);
855        assert!(config.required);
856    }
857
858    #[test]
859    fn test_oidc_config_auth0() {
860        let config = OidcConfig::auth0("my-tenant.auth0.com", "my-api");
861        assert_eq!(config.issuer, "https://my-tenant.auth0.com/");
862        assert_eq!(config.audience, Some("my-api".to_string()));
863    }
864
865    #[test]
866    fn test_oidc_config_keycloak() {
867        let config = OidcConfig::keycloak("https://keycloak.example.com", "myrealm", "myclient");
868        assert_eq!(config.issuer, "https://keycloak.example.com/realms/myrealm");
869        assert_eq!(config.audience, Some("myclient".to_string()));
870    }
871
872    #[test]
873    fn test_oidc_config_okta() {
874        let config = OidcConfig::okta("myorg.okta.com", "api://default");
875        assert_eq!(config.issuer, "https://myorg.okta.com");
876        assert_eq!(config.audience, Some("api://default".to_string()));
877    }
878
879    #[test]
880    fn test_oidc_config_cognito() {
881        let config = OidcConfig::cognito("us-east-1", "us-east-1_abc123", "client123");
882        assert_eq!(config.issuer, "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_abc123");
883        assert_eq!(config.audience, Some("client123".to_string()));
884    }
885
886    #[test]
887    fn test_oidc_config_azure_ad() {
888        let config = OidcConfig::azure_ad("tenant-id-123", "client-id-456");
889        assert_eq!(config.issuer, "https://login.microsoftonline.com/tenant-id-123/v2.0");
890        assert_eq!(config.audience, Some("client-id-456".to_string()));
891    }
892
893    #[test]
894    fn test_oidc_config_google() {
895        let config = OidcConfig::google("123456.apps.googleusercontent.com");
896        assert_eq!(config.issuer, "https://accounts.google.com");
897        assert_eq!(config.audience, Some("123456.apps.googleusercontent.com".to_string()));
898    }
899
900    #[test]
901    fn test_oidc_config_validate_empty_issuer() {
902        let config = OidcConfig::default();
903        let result = config.validate();
904        assert!(result.is_err());
905        assert!(matches!(result, Err(SecurityError::SecurityConfigError(_))));
906    }
907
908    #[test]
909    fn test_oidc_config_validate_http_issuer() {
910        let config = OidcConfig {
911            issuer: "http://insecure.example.com".to_string(),
912            ..Default::default()
913        };
914        let result = config.validate();
915        assert!(result.is_err());
916    }
917
918    #[test]
919    fn test_oidc_config_validate_localhost_allowed() {
920        let config = OidcConfig {
921            issuer: "http://localhost:8080".to_string(),
922            audience: Some("my-api".to_string()),
923            ..Default::default()
924        };
925        let result = config.validate();
926        assert!(result.is_ok());
927    }
928
929    #[test]
930    fn test_oidc_config_validate_https_required() {
931        let config = OidcConfig {
932            issuer: "https://secure.example.com".to_string(),
933            audience: Some("https://api.example.com".to_string()),
934            ..Default::default()
935        };
936        let result = config.validate();
937        assert!(result.is_ok());
938    }
939
940    #[test]
941    fn test_audience_none() {
942        let aud = Audience::None;
943        assert!(!aud.contains("test"));
944        assert!(aud.to_vec().is_empty());
945    }
946
947    #[test]
948    fn test_audience_single() {
949        let aud = Audience::Single("my-api".to_string());
950        assert!(aud.contains("my-api"));
951        assert!(!aud.contains("other"));
952        assert_eq!(aud.to_vec(), vec!["my-api"]);
953    }
954
955    #[test]
956    fn test_audience_multiple() {
957        let aud = Audience::Multiple(vec!["api1".to_string(), "api2".to_string()]);
958        assert!(aud.contains("api1"));
959        assert!(aud.contains("api2"));
960        assert!(!aud.contains("api3"));
961        assert_eq!(aud.to_vec(), vec!["api1", "api2"]);
962    }
963
964    #[test]
965    fn test_jwk_deserialization() {
966        let jwk_json = r#"{
967            "kty": "RSA",
968            "kid": "test-key-id",
969            "alg": "RS256",
970            "use": "sig",
971            "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
972            "e": "AQAB"
973        }"#;
974
975        let jwk: Jwk = serde_json::from_str(jwk_json).unwrap();
976        assert_eq!(jwk.kty, "RSA");
977        assert_eq!(jwk.kid, Some("test-key-id".to_string()));
978        assert_eq!(jwk.alg, Some("RS256".to_string()));
979        assert!(jwk.n.is_some());
980        assert!(jwk.e.is_some());
981    }
982
983    #[test]
984    fn test_jwks_deserialization() {
985        let jwks_json = r#"{
986            "keys": [
987                {
988                    "kty": "RSA",
989                    "kid": "key1",
990                    "n": "test_n",
991                    "e": "AQAB"
992                },
993                {
994                    "kty": "RSA",
995                    "kid": "key2",
996                    "n": "test_n2",
997                    "e": "AQAB"
998                }
999            ]
1000        }"#;
1001
1002        let jwks: Jwks = serde_json::from_str(jwks_json).unwrap();
1003        assert_eq!(jwks.keys.len(), 2);
1004        assert_eq!(jwks.keys[0].kid, Some("key1".to_string()));
1005        assert_eq!(jwks.keys[1].kid, Some("key2".to_string()));
1006    }
1007
1008    #[test]
1009    fn test_jwt_claims_deserialization() {
1010        let claims_json = r#"{
1011            "sub": "user123",
1012            "iss": "https://issuer.example.com",
1013            "aud": "my-api",
1014            "exp": 1735689600,
1015            "iat": 1735686000,
1016            "scope": "read write",
1017            "email": "user@example.com"
1018        }"#;
1019
1020        let claims: JwtClaims = serde_json::from_str(claims_json).unwrap();
1021        assert_eq!(claims.sub, Some("user123".to_string()));
1022        assert_eq!(claims.iss, Some("https://issuer.example.com".to_string()));
1023        assert!(claims.aud.contains("my-api"));
1024        assert_eq!(claims.exp, Some(1_735_689_600));
1025        assert_eq!(claims.scope, Some("read write".to_string()));
1026    }
1027
1028    #[test]
1029    fn test_jwt_claims_array_audience() {
1030        let claims_json = r#"{
1031            "sub": "user123",
1032            "aud": ["api1", "api2"],
1033            "exp": 1735689600
1034        }"#;
1035
1036        let claims: JwtClaims = serde_json::from_str(claims_json).unwrap();
1037        assert!(claims.aud.contains("api1"));
1038        assert!(claims.aud.contains("api2"));
1039    }
1040
1041    #[test]
1042    fn test_oidc_discovery_document_deserialization() {
1043        let doc_json = r#"{
1044            "issuer": "https://issuer.example.com",
1045            "jwks_uri": "https://issuer.example.com/.well-known/jwks.json",
1046            "authorization_endpoint": "https://issuer.example.com/authorize",
1047            "token_endpoint": "https://issuer.example.com/oauth/token",
1048            "id_token_signing_alg_values_supported": ["RS256", "RS384", "RS512"]
1049        }"#;
1050
1051        let doc: OidcDiscoveryDocument = serde_json::from_str(doc_json).unwrap();
1052        assert_eq!(doc.issuer, "https://issuer.example.com");
1053        assert_eq!(doc.jwks_uri, "https://issuer.example.com/.well-known/jwks.json");
1054        assert_eq!(doc.id_token_signing_alg_values_supported.len(), 3);
1055    }
1056
1057    #[test]
1058    fn test_jwks_cache_ttl_reduced_for_security() {
1059        // SECURITY: Verify cache TTL is reduced to 5 minutes (300 seconds)
1060        // to prevent token cache poisoning attacks
1061        assert_eq!(default_jwks_cache_ttl(), 300, "Cache TTL should be 5 minutes (300 seconds)");
1062    }
1063
1064    #[test]
1065    fn test_cached_jwks_expiration() {
1066        // Test that CachedJwks correctly determines expiration
1067        let jwks = Jwks { keys: vec![] };
1068        let cached = CachedJwks {
1069            jwks,
1070            fetched_at: Instant::now(),
1071            ttl: Duration::from_secs(1),
1072        };
1073
1074        // Should not be expired immediately
1075        assert!(!cached.is_expired());
1076
1077        // After sleep, should be expired
1078        std::thread::sleep(Duration::from_millis(1100));
1079        assert!(cached.is_expired());
1080    }
1081
1082    #[test]
1083    fn test_detect_key_rotation_when_no_cache() {
1084        // Test that key rotation detection returns false when no cache exists
1085        let config = OidcConfig {
1086            issuer: "http://localhost:8080".to_string(),
1087            ..Default::default()
1088        };
1089
1090        let validator = OidcValidator {
1091            config,
1092            http_client: reqwest::Client::new(),
1093            jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1094            jwks_cache: Arc::new(RwLock::new(None)),
1095        };
1096
1097        let new_jwks = Jwks {
1098            keys: vec![Jwk {
1099                kty:     "RSA".to_string(),
1100                kid:     Some("key1".to_string()),
1101                alg:     None,
1102                key_use: None,
1103                n:       None,
1104                e:       None,
1105                x5c:     vec![],
1106            }],
1107        };
1108
1109        // Should not detect rotation when cache is empty
1110        assert!(!validator.detect_key_rotation(&new_jwks));
1111    }
1112
1113    #[test]
1114    fn test_detect_key_rotation_when_keys_removed() {
1115        // Test that key rotation is detected when old keys disappear
1116        let config = OidcConfig {
1117            issuer: "http://localhost:8080".to_string(),
1118            ..Default::default()
1119        };
1120
1121        let validator = OidcValidator {
1122            config,
1123            http_client: reqwest::Client::new(),
1124            jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1125            jwks_cache: Arc::new(RwLock::new(None)),
1126        };
1127
1128        // Cache with 2 keys
1129        let old_jwks = Jwks {
1130            keys: vec![
1131                Jwk {
1132                    kty:     "RSA".to_string(),
1133                    kid:     Some("old_key_1".to_string()),
1134                    alg:     None,
1135                    key_use: None,
1136                    n:       None,
1137                    e:       None,
1138                    x5c:     vec![],
1139                },
1140                Jwk {
1141                    kty:     "RSA".to_string(),
1142                    kid:     Some("old_key_2".to_string()),
1143                    alg:     None,
1144                    key_use: None,
1145                    n:       None,
1146                    e:       None,
1147                    x5c:     vec![],
1148                },
1149            ],
1150        };
1151
1152        {
1153            let mut cache = validator.jwks_cache.write();
1154            *cache = Some(CachedJwks {
1155                jwks:       old_jwks,
1156                fetched_at: Instant::now(),
1157                ttl:        Duration::from_secs(300),
1158            });
1159        }
1160
1161        // New JWKS with only 1 of the old keys (old_key_2 removed)
1162        let new_jwks = Jwks {
1163            keys: vec![
1164                Jwk {
1165                    kty:     "RSA".to_string(),
1166                    kid:     Some("old_key_1".to_string()),
1167                    alg:     None,
1168                    key_use: None,
1169                    n:       None,
1170                    e:       None,
1171                    x5c:     vec![],
1172                },
1173                Jwk {
1174                    kty:     "RSA".to_string(),
1175                    kid:     Some("new_key_1".to_string()),
1176                    alg:     None,
1177                    key_use: None,
1178                    n:       None,
1179                    e:       None,
1180                    x5c:     vec![],
1181                },
1182            ],
1183        };
1184
1185        // Should detect rotation because old_key_2 is missing
1186        assert!(validator.detect_key_rotation(&new_jwks));
1187    }
1188
1189    #[test]
1190    fn test_detect_key_rotation_when_no_keys_removed() {
1191        // Test that key rotation is NOT detected when all old keys still exist
1192        let config = OidcConfig {
1193            issuer: "http://localhost:8080".to_string(),
1194            ..Default::default()
1195        };
1196
1197        let validator = OidcValidator {
1198            config,
1199            http_client: reqwest::Client::new(),
1200            jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1201            jwks_cache: Arc::new(RwLock::new(None)),
1202        };
1203
1204        // Cache with 2 keys
1205        let old_jwks = Jwks {
1206            keys: vec![
1207                Jwk {
1208                    kty:     "RSA".to_string(),
1209                    kid:     Some("key_1".to_string()),
1210                    alg:     None,
1211                    key_use: None,
1212                    n:       None,
1213                    e:       None,
1214                    x5c:     vec![],
1215                },
1216                Jwk {
1217                    kty:     "RSA".to_string(),
1218                    kid:     Some("key_2".to_string()),
1219                    alg:     None,
1220                    key_use: None,
1221                    n:       None,
1222                    e:       None,
1223                    x5c:     vec![],
1224                },
1225            ],
1226        };
1227
1228        {
1229            let mut cache = validator.jwks_cache.write();
1230            *cache = Some(CachedJwks {
1231                jwks:       old_jwks,
1232                fetched_at: Instant::now(),
1233                ttl:        Duration::from_secs(300),
1234            });
1235        }
1236
1237        // New JWKS with old keys + new key (no removal)
1238        let new_jwks = Jwks {
1239            keys: vec![
1240                Jwk {
1241                    kty:     "RSA".to_string(),
1242                    kid:     Some("key_1".to_string()),
1243                    alg:     None,
1244                    key_use: None,
1245                    n:       None,
1246                    e:       None,
1247                    x5c:     vec![],
1248                },
1249                Jwk {
1250                    kty:     "RSA".to_string(),
1251                    kid:     Some("key_2".to_string()),
1252                    alg:     None,
1253                    key_use: None,
1254                    n:       None,
1255                    e:       None,
1256                    x5c:     vec![],
1257                },
1258                Jwk {
1259                    kty:     "RSA".to_string(),
1260                    kid:     Some("new_key".to_string()),
1261                    alg:     None,
1262                    key_use: None,
1263                    n:       None,
1264                    e:       None,
1265                    x5c:     vec![],
1266                },
1267            ],
1268        };
1269
1270        // Should NOT detect rotation because all old keys still exist
1271        assert!(!validator.detect_key_rotation(&new_jwks));
1272    }
1273
1274    #[test]
1275    fn test_find_key_by_kid() {
1276        // Test finding a specific key by kid in JWKS
1277        let config = OidcConfig {
1278            issuer: "http://localhost:8080".to_string(),
1279            ..Default::default()
1280        };
1281
1282        let validator = OidcValidator {
1283            config,
1284            http_client: reqwest::Client::new(),
1285            jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1286            jwks_cache: Arc::new(RwLock::new(None)),
1287        };
1288
1289        let jwks = Jwks {
1290            keys: vec![
1291                Jwk {
1292                    kty:     "RSA".to_string(),
1293                    kid:     Some("key1".to_string()),
1294                    alg:     None,
1295                    key_use: None,
1296                    n:       None,
1297                    e:       None,
1298                    x5c:     vec![],
1299                },
1300                Jwk {
1301                    kty:     "RSA".to_string(),
1302                    kid:     Some("key2".to_string()),
1303                    alg:     None,
1304                    key_use: None,
1305                    n:       None,
1306                    e:       None,
1307                    x5c:     vec![],
1308                },
1309            ],
1310        };
1311
1312        // Should find existing key
1313        assert!(validator.find_key(&jwks, "key1").is_some());
1314        assert!(validator.find_key(&jwks, "key2").is_some());
1315
1316        // Should not find non-existent key
1317        assert!(validator.find_key(&jwks, "key3").is_none());
1318    }
1319
1320    #[test]
1321    fn test_find_key_without_kid() {
1322        // Test handling of keys without kid
1323        let config = OidcConfig {
1324            issuer: "http://localhost:8080".to_string(),
1325            ..Default::default()
1326        };
1327
1328        let validator = OidcValidator {
1329            config,
1330            http_client: reqwest::Client::new(),
1331            jwks_uri: "http://localhost:8080/.well-known/jwks.json".to_string(),
1332            jwks_cache: Arc::new(RwLock::new(None)),
1333        };
1334
1335        let jwks = Jwks {
1336            keys: vec![Jwk {
1337                kty:     "RSA".to_string(),
1338                kid:     None, // No kid
1339                alg:     None,
1340                key_use: None,
1341                n:       None,
1342                e:       None,
1343                x5c:     vec![],
1344            }],
1345        };
1346
1347        // Should not find key without kid even if requested
1348        assert!(validator.find_key(&jwks, "any_kid").is_none());
1349    }
1350
1351    #[test]
1352    fn test_oidc_config_with_custom_cache_ttl() {
1353        // Test that custom cache TTL can be configured
1354        let config = OidcConfig {
1355            issuer: "http://localhost:8080".to_string(),
1356            jwks_cache_ttl_secs: 600, // Custom 10-minute TTL
1357            ..Default::default()
1358        };
1359
1360        assert_eq!(config.jwks_cache_ttl_secs, 600);
1361    }
1362
1363    #[test]
1364    fn test_oidc_config_default_cache_ttl_is_short() {
1365        // Test that default cache TTL is short (5 minutes) for security
1366        let config = OidcConfig::default();
1367        assert!(
1368            config.jwks_cache_ttl_secs <= 300,
1369            "Default cache TTL should be short (≤ 300 seconds) to prevent token poisoning"
1370        );
1371    }
1372}