Skip to main content

slim_auth/
resolver.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::jwk::KeyAlgorithm;
8use jsonwebtoken::{
9    Algorithm, DecodingKey, Header,
10    jwk::{Jwk, JwkSet},
11};
12use parking_lot::RwLock;
13use reqwest::{Client as ReqwestClient, StatusCode};
14use url::Url;
15
16use crate::errors::AuthError;
17
18/// Cache entry for a JWKS.
19#[derive(Clone, Debug)]
20pub struct JwksCache {
21    pub jwks: JwkSet,
22    pub fetched_at: Instant,
23    pub ttl: Duration,
24}
25
26/// This struct provides methods to resolve JWT decoding keys from various sources.
27///
28/// The `KeyResolver` is responsible for fetching and caching JSON Web Keys (JWK)
29/// from OpenID Connect providers. It supports:
30///
31/// 1. OpenID Connect Discovery via the standard `.well-known/openid-configuration` endpoint
32/// 2. Direct retrieval from the `.well-known/jwks.json` endpoint as a fallback
33/// 3. Caching of retrieved keys to minimize network requests
34///
35/// Example usage:
36///
37/// ```
38/// let resolver = KeyResolver::new()
39///     .with_jwks_ttl(Duration::from_secs(1800));  // 30 minute cache TTL
40///
41/// let jwt = Jwt::builder()
42///     .issuer("https://your-oidc-provider.com")
43///     .key_resolver(resolver)
44///     .build()?;
45/// ```
46#[derive(Debug)]
47pub struct KeyResolver {
48    client: ReqwestClient,
49    jwks_cache: RwLock<HashMap<String, JwksCache>>,
50    default_jwks_ttl: Duration,
51}
52
53impl Default for KeyResolver {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl KeyResolver {
60    const STATIC_JWKS_ENTRY: &'static str = "static_jwks";
61
62    /// Create a new KeyResolver with default settings
63    pub fn new() -> Self {
64        // Create a reqwest client with default TLS configuration
65        let client = ReqwestClient::builder()
66            .user_agent("AGNTCY Slim Auth")
67            .build()
68            .expect("Failed to create reqwest client");
69
70        Self {
71            client,
72            jwks_cache: RwLock::new(HashMap::new()),
73            default_jwks_ttl: Duration::from_secs(3600), // 1 hour default TTL
74        }
75    }
76
77    pub fn with_jwks(jwks: JwkSet) -> Self {
78        // Initialize the cache with the provided JWKS
79        let mut cache = HashMap::new();
80        cache.insert(
81            Self::STATIC_JWKS_ENTRY.to_string(),
82            JwksCache {
83                jwks,
84                fetched_at: Instant::now(),
85                ttl: Duration::from_secs(u64::MAX), // static JWKS, infinite TTL
86            },
87        );
88
89        let client = ReqwestClient::builder()
90            .user_agent("AGNTCY Slim Auth")
91            .build()
92            .expect("Failed to create reqwest client");
93
94        Self {
95            client,
96            jwks_cache: RwLock::new(cache),
97            default_jwks_ttl: Duration::from_secs(3600), // 1 hour default TTL
98        }
99    }
100
101    /// Set the default TTL for cached JWKS
102    pub fn with_jwks_ttl(mut self, ttl: Duration) -> Self {
103        self.default_jwks_ttl = ttl;
104        self
105    }
106
107    /// Resolve a decoding key from various sources
108    ///
109    /// This function will attempt to resolve the key in the following order:
110    /// 1. If a decoding key is already provided, return it
111    /// 2. If a kid (Key ID) is specified in the token header, fetch the key from the JWKS endpoint
112    /// 3. If no kid is specified, use the first suitable key from the JWKS endpoint
113    ///
114    /// # Arguments
115    /// * `issuer` - The token issuer URL
116    /// * `token_header` - The JWT header containing the algorithm and key ID (if available)
117    pub async fn resolve_key(
118        &self,
119        issuer: &str,
120        token_header: &Header,
121    ) -> Result<DecodingKey, AuthError> {
122        // Check if we have a static JWKS entry
123        if let Some(cache_entry) = self.jwks_cache.read().get(Self::STATIC_JWKS_ENTRY) {
124            // If we have a static JWKS, use it directly
125            return self.get_decoded_key_from_jwks(&cache_entry.jwks, token_header);
126        }
127
128        // Try to get cached key if available
129        if let Ok(cached_key) = self.get_cached_key(issuer, token_header) {
130            return Ok(cached_key);
131        }
132
133        // Try to fetch the keys from the well-known JWKS endpoint
134        let jwks = self.fetch_jwks(issuer).await?;
135
136        // Try to decode the key from the JWKS
137        self.get_decoded_key_from_jwks(&jwks, token_header)
138    }
139
140    /// Convert a JWK to a DecodingKey
141    fn jwk_to_decoding_key(&self, jwk: &Jwk) -> Result<DecodingKey, AuthError> {
142        let ret = DecodingKey::from_jwk(jwk)?;
143        Ok(ret)
144    }
145
146    fn key_alg_to_algorithm(&self, alg: &KeyAlgorithm) -> Result<Algorithm, AuthError> {
147        match alg {
148            KeyAlgorithm::HS256 => Ok(Algorithm::HS256),
149            KeyAlgorithm::HS384 => Ok(Algorithm::HS384),
150            KeyAlgorithm::HS512 => Ok(Algorithm::HS512),
151            KeyAlgorithm::ES256 => Ok(Algorithm::ES256),
152            KeyAlgorithm::ES384 => Ok(Algorithm::ES384),
153            KeyAlgorithm::RS256 => Ok(Algorithm::RS256),
154            KeyAlgorithm::RS384 => Ok(Algorithm::RS384),
155            KeyAlgorithm::RS512 => Ok(Algorithm::RS512),
156            KeyAlgorithm::PS256 => Ok(Algorithm::PS256),
157            KeyAlgorithm::PS384 => Ok(Algorithm::PS384),
158            KeyAlgorithm::PS512 => Ok(Algorithm::PS512),
159            KeyAlgorithm::EdDSA => Ok(Algorithm::EdDSA),
160            _ => Err(AuthError::JwtUnsupportedKeyAlgorithm(*alg)),
161        }
162    }
163
164    fn get_decoded_key_from_jwks(
165        &self,
166        jwks: &JwkSet,
167        token_header: &Header,
168    ) -> Result<DecodingKey, AuthError> {
169        // At this point, we have a valid cache entry
170        if let Some(kid) = &token_header.kid {
171            // Look for a key with a matching ID
172            for key in &jwks.keys {
173                if let Some(id) = &key.common.key_id
174                    && id == kid
175                {
176                    return self.jwk_to_decoding_key(key);
177                }
178            }
179        } else {
180            // If no key ID is specified, use the first suitable key
181            for key in &jwks.keys {
182                if let Some(alg) = &key.common.key_algorithm
183                    && let Ok(algorithm) = self.key_alg_to_algorithm(alg)
184                {
185                    // Check if the algorithm matches the token's algorithm
186                    if algorithm == token_header.alg {
187                        return self.jwk_to_decoding_key(key);
188                    }
189                }
190            }
191        }
192
193        // If no suitable key is found, return an error
194        Err(AuthError::JwksNoSuitableKey)
195    }
196
197    /// Check the cache for a JWKS entry
198    pub fn get_cached_key(
199        &self,
200        issuer: &str,
201        token_header: &Header,
202    ) -> Result<DecodingKey, AuthError> {
203        // Check if we have a cached JWKS that's still valid
204        let cache = self.jwks_cache.read();
205
206        // Check static JWKS entry first
207        if let Some(cache_entry) = cache.get(Self::STATIC_JWKS_ENTRY) {
208            // no need to check the elapsed time for static JWKS
209            return self.get_decoded_key_from_jwks(&cache_entry.jwks, token_header);
210        }
211
212        let cache_entry = cache.get(issuer);
213        if cache_entry.is_none() {
214            return Err(AuthError::JwksCacheMiss {
215                issuer: issuer.to_string(),
216            });
217        }
218
219        let cache_entry = cache_entry.unwrap();
220
221        if cache_entry.fetched_at.elapsed() > cache_entry.ttl {
222            return Err(AuthError::JwksCacheExpired {
223                issuer: issuer.to_string(),
224            });
225        }
226
227        // If we have a valid cache entry, try to decode the key
228        self.get_decoded_key_from_jwks(&cache_entry.jwks, token_header)
229    }
230
231    /// Fetch JWKS from the issuer's endpoint
232    ///
233    /// This function will discover the JWKS URI (either via OpenID Connect Discovery
234    /// or the standard well-known endpoint), fetch the JWKS, and cache it for future use.
235    async fn fetch_jwks(&self, issuer: &str) -> Result<JwkSet, AuthError> {
236        // Build the JWKS URI (this now handles both OpenID discovery and fallback)
237        let jwks_uri = self.build_jwks_uri(issuer).await?;
238
239        // Fetch the JWKS
240        let jwks = self.fetch_jwks_from_uri(&jwks_uri).await?;
241
242        // Cache the JWKS
243        self.jwks_cache.write().insert(
244            issuer.to_string(),
245            JwksCache {
246                jwks: jwks.clone(),
247                fetched_at: Instant::now(),
248                ttl: self.default_jwks_ttl,
249            },
250        );
251
252        Ok(jwks)
253    }
254
255    /// Build the JWKS URI from the issuer
256    ///
257    /// This function first tries to discover the JWKS URI via OpenID Connect Discovery
258    /// (.well-known/openid-configuration), and falls back to the standard .well-known/jwks.json
259    /// location if that fails.
260    async fn build_jwks_uri(&self, issuer: &str) -> Result<String, AuthError> {
261        // Parse the issuer URL
262        // Use typed URL parse error propagation
263        let mut issuer_url = Url::parse(issuer)?;
264
265        // First try OpenID Connect Discovery endpoint
266        let mut openid_config_url = issuer_url.clone();
267        let mut openid_path = openid_config_url.path().trim_end_matches('/').to_owned();
268        openid_path.push_str("/.well-known/openid-configuration");
269        openid_config_url.set_path(&openid_path);
270
271        // Try to fetch the OpenID configuration
272        let openid_config_response = self.client.get(openid_config_url.to_string()).send().await;
273
274        // If we successfully got the OpenID configuration, extract the jwks_uri
275        if let Ok(response) = openid_config_response
276            && response.status() == StatusCode::OK
277            && let Ok(config) = response.json::<serde_json::Value>().await
278            && let Some(jwks_uri) = config.get("jwks_uri").and_then(|v| v.as_str())
279        {
280            return Ok(jwks_uri.to_string());
281        }
282
283        // Fallback to standard well-known JWKS location
284        let mut path = issuer_url.path().trim_end_matches('/').to_owned();
285        path.push_str("/.well-known/jwks.json");
286        issuer_url.set_path(&path);
287
288        Ok(issuer_url.to_string())
289    }
290
291    /// Fetch JWKS from the specified URI
292    async fn fetch_jwks_from_uri(&self, uri: &str) -> Result<JwkSet, AuthError> {
293        // Send the GET request using reqwest
294        let response = self.client.get(uri).send().await?;
295
296        // Check the response status
297        if response.status() != StatusCode::OK {
298            return Err(AuthError::JwtFetchJwksFailed(response.status()));
299        }
300
301        // Get the response body as bytes
302        let body = response.bytes().await?;
303
304        // Parse the JWKS
305        let jwks: JwkSet = serde_json::from_slice(&body)?;
306
307        Ok(jwks)
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use serde_json::json;
315    use wiremock::matchers::{method, path};
316    use wiremock::{Mock, MockServer, ResponseTemplate};
317
318    // Helper to create a test resolver with a client that can talk to the mock server
319    async fn create_test_resolver() -> (KeyResolver, MockServer) {
320        let server = MockServer::start().await;
321        let resolver = KeyResolver::new();
322        (resolver, server)
323    }
324
325    #[tokio::test]
326    async fn test_build_jwks_uri_with_openid_discovery() {
327        let (resolver, mock_server) = create_test_resolver().await;
328
329        // Setup mock for OpenID discovery endpoint
330        let jwks_uri = "https://example.com/custom/path/to/jwks.json";
331        Mock::given(method("GET"))
332            .and(path("/.well-known/openid-configuration"))
333            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
334                "issuer": "https://example.com",
335                "jwks_uri": jwks_uri
336            })))
337            .mount(&mock_server)
338            .await;
339
340        // Test that we can discover the JWKS URI from OpenID configuration
341        let uri = resolver.build_jwks_uri(&mock_server.uri()).await.unwrap();
342        assert_eq!(uri, jwks_uri);
343    }
344
345    #[tokio::test]
346    async fn test_build_jwks_uri_fallback() {
347        let (resolver, mock_server) = create_test_resolver().await;
348
349        // Setup mock to return 404 for OpenID discovery endpoint
350        Mock::given(method("GET"))
351            .and(path("/.well-known/openid-configuration"))
352            .respond_with(ResponseTemplate::new(404))
353            .mount(&mock_server)
354            .await;
355
356        // Test that we fall back to the standard JWKS URI
357        let uri = resolver.build_jwks_uri(&mock_server.uri()).await.unwrap();
358        assert_eq!(uri, format!("{}/.well-known/jwks.json", mock_server.uri()));
359    }
360
361    #[tokio::test]
362    async fn test_fetch_jwks_from_uri() {
363        let (resolver, mock_server) = create_test_resolver().await;
364
365        // Setup mock for JWKS endpoint
366        let jwks = json!({
367            "keys": [
368                {
369                    "kty": "RSA",
370                    "kid": "test-key",
371                    "n": "some-modulus",
372                    "e": "AQAB"
373                }
374            ]
375        });
376
377        Mock::given(method("GET"))
378            .and(path("/.well-known/jwks.json"))
379            .respond_with(ResponseTemplate::new(200).set_body_json(jwks))
380            .mount(&mock_server)
381            .await;
382
383        // Fetch the JWKS from the mock server
384        let jwks_uri = format!("{}/.well-known/jwks.json", mock_server.uri());
385        let fetched_jwks = resolver.fetch_jwks_from_uri(&jwks_uri).await.unwrap();
386
387        assert_eq!(fetched_jwks.keys.len(), 1);
388        assert_eq!(
389            fetched_jwks.keys[0].common.key_id.as_deref(),
390            Some("test-key")
391        );
392    }
393}