async_oidc_jwt_validator/
lib.rs

1use jsonwebtoken::errors::{Error as JwtError, ErrorKind, Result as JwtResult};
2use jsonwebtoken::jwk::{Jwk, JwkSet};
3use jsonwebtoken::{decode, DecodingKey};
4use serde::Deserialize;
5use std::collections::HashMap;
6
7// Re-export for user convenience
8pub use jsonwebtoken::{Algorithm, Validation};
9
10/// OpenID Connect Discovery document structure
11#[derive(Debug, Deserialize)]
12struct OidcDiscovery {
13    issuer: String,
14    jwks_uri: String,
15}
16
17/// Configuration for OIDC authentication
18#[derive(Debug, Clone)]
19pub struct OidcConfig {
20    pub issuer_url: String,
21    pub client_id: String,
22    pub jwks_uri: String,
23}
24
25/// OIDC JWT validator with JWKS caching
26#[derive(Clone)]
27pub struct OidcValidator {
28    config: OidcConfig,
29    jwks_cache: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Jwk>>>,
30}
31
32impl OidcConfig {
33    /// Creates a new OidcConfig with custom parameters
34    pub fn new(issuer_url: String, client_id: String, jwks_uri: String) -> Self {
35        Self {
36            issuer_url,
37            client_id,
38            jwks_uri,
39        }
40    }
41
42    pub async fn new_with_discovery(issuer_url: String, client_id: String) -> JwtResult<Self> {
43        let jwks_uri = Self::discover_jwks_uri(&issuer_url).await?;
44        Ok(Self {
45            issuer_url,
46            client_id,
47            jwks_uri,
48        })
49    }
50
51    async fn discover_jwks_uri(issuer_url: &str) -> JwtResult<String> {
52        let discovery_url = format!("{}/.well-known/openid-configuration", issuer_url);
53
54        log::debug!("Fetching OpenID Connect Discovery from: {}", discovery_url);
55
56        let response = reqwest::get(&discovery_url).await.map_err(|e| {
57            JwtError::from(ErrorKind::InvalidRsaKey(format!(
58                "Failed to fetch OIDC discovery document: {}",
59                e
60            )))
61        })?;
62
63        let content_type = response
64            .headers()
65            .get("content-type")
66            .and_then(|value| value.to_str().ok())
67            .unwrap_or_default();
68
69        if !content_type.starts_with("application/json") {
70            return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
71                "Unexpected Content-Type: '{}', expected 'application/json'",
72                content_type
73            ))));
74        }
75
76        if !response.status().is_success() {
77            return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
78                "OIDC discovery request failed with status: {}",
79                response.status()
80            ))));
81        }
82
83        let discovery: OidcDiscovery = response.json().await.map_err(|e| {
84            JwtError::from(ErrorKind::InvalidRsaKey(format!(
85                "Failed to parse OIDC discovery response: {}",
86                e
87            )))
88        })?;
89
90        if discovery.issuer != issuer_url {
91            return Err(JwtError::from(ErrorKind::InvalidIssuer));
92        }
93
94        log::debug!("Discovered JWKS URI: {}", discovery.jwks_uri);
95        Ok(discovery.jwks_uri)
96    }
97}
98
99impl OidcValidator {
100    /// Creates a new OidcValidator with the given configuration
101    pub fn new(config: OidcConfig) -> Self {
102        Self {
103            config,
104            jwks_cache: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
105        }
106    }
107
108    async fn fetch_jwks(&self) -> JwtResult<JwkSet> {
109        let jwks_url = self.config.jwks_uri.clone();
110
111        log::debug!("Fetching JWKS from: {}", jwks_url);
112
113        let response = reqwest::get(&jwks_url).await.map_err(|e| {
114            JwtError::from(ErrorKind::InvalidRsaKey(format!(
115                "Failed to fetch JWKS: {}",
116                e
117            )))
118        })?;
119
120        if !response.status().is_success() {
121            return Err(JwtError::from(ErrorKind::InvalidRsaKey(format!(
122                "JWKS request failed with status: {}",
123                response.status()
124            ))));
125        }
126
127        let jwks: JwkSet = response.json().await.map_err(|e| {
128            JwtError::from(ErrorKind::InvalidRsaKey(format!(
129                "Failed to parse JWKS response: {}",
130                e
131            )))
132        })?;
133
134        log::debug!("Fetched {} keys from JWKS", jwks.keys.len());
135        Ok(jwks)
136    }
137
138    async fn get_jwk(&self, kid: &str) -> JwtResult<Jwk> {
139        // Check cache first
140        {
141            let cache = self.jwks_cache.read().await;
142            if let Some(jwk) = cache.get(kid) {
143                return Ok(jwk.clone());
144            }
145        }
146
147        // If not found, refresh cache and try again
148        self.refresh_jwks_cache().await?;
149
150        let cache = self.jwks_cache.read().await;
151        cache
152            .get(kid)
153            .cloned()
154            .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))
155    }
156
157    pub async fn validate_custom<T>(&self, token: &str, validation: &Validation) -> JwtResult<T>
158    where
159        T: for<'de> Deserialize<'de>,
160    {
161        log::debug!("Verifying JWT token");
162
163        // Decode header to get kid
164        let header = jsonwebtoken::decode_header(token)?;
165
166        let kid = header
167            .kid
168            .ok_or_else(|| JwtError::from(ErrorKind::InvalidToken))?;
169        log::debug!("Token kid: {}", kid);
170
171        // Get JWK for this kid (will refresh cache if not found)
172        let jwk = self.get_jwk(&kid).await?;
173
174        log::debug!("Found matching key with kid: {}", kid);
175
176        let decoding_key = DecodingKey::from_jwk(&jwk)
177            .map_err(|_e| JwtError::from(ErrorKind::InvalidKeyFormat))?;
178
179        // Decode and validate token
180        let token_data = decode::<T>(token, &decoding_key, validation)?;
181
182        log::debug!("Token verified successfully");
183        Ok(token_data.claims)
184    }
185
186    pub async fn validate<T>(&self, token: &str) -> JwtResult<T>
187    where
188        T: for<'de> Deserialize<'de>,
189    {
190        log::debug!("Validating JWT token with minimal validation");
191
192        // Create a minimal validation configuration
193        let mut validation = Validation::new(Algorithm::RS256);
194        validation.set_issuer(&[&self.config.issuer_url]);
195        validation.set_audience(&[&self.config.client_id]);
196
197        self.validate_custom(token, &validation).await
198    }
199
200    /// Refreshes the JWKS cache by fetching the latest keys
201    pub async fn refresh_jwks_cache(&self) -> JwtResult<()> {
202        log::info!("Refreshing JWKS cache");
203        let new_jwks = self.fetch_jwks().await?;
204
205        // Check if an update is needed using a read lock
206        let needs_update = {
207            let cache = self.jwks_cache.read().await;
208
209            // Condition 1: The number of keys is different.
210            let lengths_are_different = new_jwks.keys.len() != cache.len();
211
212            // Condition 2: There is at least one new key that wasn't in the old cache.
213            // This only needs to run if the lengths are the same.
214            let has_added_keys = if lengths_are_different {
215                false // No need to run this check if we already know we need an update.
216            } else {
217                new_jwks.keys.iter().any(|jwk| {
218                    if let Some(kid) = &jwk.common.key_id {
219                        !cache.contains_key(kid)
220                    } else {
221                        false // Skip keys without kid
222                    }
223                })
224            };
225            lengths_are_different || has_added_keys
226        }; // Read lock released here
227
228        // Only acquire write lock if there are new keys
229        if needs_update {
230            log::info!("New keys detected, replacing entire cache");
231
232            // Build new HashMap from fetched JWKS
233            let mut new_cache = HashMap::new();
234            for jwk in new_jwks.keys {
235                if let Some(kid) = jwk.common.key_id.clone() {
236                    log::debug!("Adding key to new cache: {}", kid);
237                    new_cache.insert(kid, jwk);
238                }
239            }
240
241            // Replace entire cache
242            let mut cache = self.jwks_cache.write().await;
243            *cache = new_cache;
244
245            log::info!("Successfully replaced JWKS cache with {} keys", cache.len());
246        } else {
247            log::debug!("No new keys found in JWKS, cache unchanged");
248        }
249
250        Ok(())
251    }
252}