async_oidc_jwt_validator/
validator.rs

1use crate::config::OidcConfig;
2use jsonwebtoken::errors::{Error as JwtError, ErrorKind, Result as JwtResult};
3use jsonwebtoken::jwk::{Jwk, JwkSet};
4use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
5use serde::Deserialize;
6use std::collections::HashMap;
7
8/// OIDC JWT validator with JWKS caching
9#[derive(Clone)]
10pub struct OidcValidator {
11    config: OidcConfig,
12    jwks_cache: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Jwk>>>,
13}
14
15impl OidcValidator {
16    /// Creates a new OidcValidator with the given configuration
17    pub fn new(config: OidcConfig) -> Self {
18        Self {
19            config,
20            jwks_cache: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
21        }
22    }
23
24    async fn fetch_jwks(&self) -> JwtResult<JwkSet> {
25        let jwks_url = self.config.jwks_uri.clone();
26
27        log::debug!("Fetching JWKS from: {}", jwks_url);
28
29        let response = reqwest::get(&jwks_url).await.map_err(|e| {
30            JwtError::from(ErrorKind::InvalidRsaKey(format!(
31                "Failed to fetch JWKS: {}",
32                e
33            )))
34        })?;
35
36        if !response.status().is_success() {
37            return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
38                "JWKS request failed with status: {}",
39                response.status()
40            ))));
41        }
42
43        let jwks: JwkSet = response.json().await.map_err(|e| {
44            JwtError::from(ErrorKind::InvalidRsaKey(format!(
45                "Failed to parse JWKS response: {}",
46                e
47            )))
48        })?;
49
50        log::debug!("Fetched {} keys from JWKS", jwks.keys.len());
51        Ok(jwks)
52    }
53
54    async fn get_jwk(&self, kid: &str) -> JwtResult<Jwk> {
55        // Check cache first
56        {
57            let cache = self.jwks_cache.read().await;
58            if let Some(jwk) = cache.get(kid) {
59                return Ok(jwk.clone());
60            }
61        }
62
63        // If not found, refresh cache and try again
64        self.refresh_jwks_cache().await?;
65
66        let cache = self.jwks_cache.read().await;
67        cache
68            .get(kid)
69            .cloned()
70            .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))
71    }
72
73    pub async fn validate_custom<T>(&self, token: &str, validation: &Validation) -> JwtResult<T>
74    where
75        T: for<'de> Deserialize<'de>,
76    {
77        log::debug!("Verifying JWT token");
78
79        // Decode header to get kid
80        let header = jsonwebtoken::decode_header(token)?;
81
82        let kid = header
83            .kid
84            .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))?;
85        log::debug!("Token kid: {}", kid);
86
87        // Get JWK for this kid (will refresh cache if not found)
88        let jwk = self.get_jwk(&kid).await?;
89
90        log::debug!("Found matching key with kid: {}", kid);
91
92        let decoding_key = DecodingKey::from_jwk(&jwk)
93            .map_err(|_e| JwtError::from(ErrorKind::InvalidKeyFormat))?;
94
95        // Decode and validate token
96        let token_data = decode::<T>(token, &decoding_key, validation)?;
97
98        log::debug!("Token verified successfully");
99        Ok(token_data.claims)
100    }
101
102    pub async fn validate<T>(&self, token: &str) -> JwtResult<T>
103    where
104        T: for<'de> Deserialize<'de>,
105    {
106        log::debug!("Validating JWT token with minimal validation");
107
108        // Create a minimal validation configuration
109        let mut validation = Validation::new(Algorithm::RS256);
110
111        validation.set_issuer(&[&self.config.issuer_url]);
112
113        validation.set_audience(&[&self.config.client_id]);
114
115        self.validate_custom(token, &validation).await
116    }
117
118    /// Refreshes the JWKS cache by fetching the latest keys
119    pub async fn refresh_jwks_cache(&self) -> JwtResult<()> {
120        log::info!("Refreshing JWKS cache");
121        let new_jwks = self.fetch_jwks().await?;
122
123        // Check if an update is needed using a read lock
124        let needs_update = {
125            let cache = self.jwks_cache.read().await;
126
127            // Condition 1: The number of keys is different.
128            let lengths_are_different = new_jwks.keys.len() != cache.len();
129
130            // Condition 2: There is at least one new key that wasn't in the old cache.
131            // This only needs to run if the lengths are the same.
132            let has_added_keys = if lengths_are_different {
133                false // No need to run this check if we already know we need an update.
134            } else {
135                new_jwks.keys.iter().any(|jwk| {
136                    if let Some(kid) = &jwk.common.key_id {
137                        !cache.contains_key(kid)
138                    } else {
139                        false // Skip keys without kid
140                    }
141                })
142            };
143            lengths_are_different || has_added_keys
144        }; // Read lock released here
145
146        // Only acquire write lock if there are new keys
147        if needs_update {
148            log::info!("New keys detected, replacing entire cache");
149
150            // Build new HashMap from fetched JWKS
151            let mut new_cache = HashMap::new();
152            for jwk in new_jwks.keys {
153                if let Some(kid) = jwk.common.key_id.clone() {
154                    log::debug!("Adding key to new cache: {}", kid);
155                    new_cache.insert(kid, jwk);
156                }
157            }
158
159            // Replace entire cache
160            let mut cache = self.jwks_cache.write().await;
161            *cache = new_cache;
162
163            log::info!("Successfully replaced JWKS cache with {} keys", cache.len());
164        } else {
165            log::debug!("No new keys found in JWKS, cache unchanged");
166        }
167
168        Ok(())
169    }
170}