Skip to main content

structured_proxy/auth/
jwks.rs

1//! JWKS fetching and key cache.
2//!
3//! Keys are fetched from the configured JWKS URI and cached by `kid`. An unknown
4//! `kid` triggers a refresh (throttled), which is how key rotation is picked up.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use jsonwebtoken::jwk::{AlgorithmParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm};
11use jsonwebtoken::{Algorithm, DecodingKey};
12use tokio::sync::{Mutex, RwLock};
13
14/// A decoding key plus the signature algorithm it is valid for.
15#[derive(Clone)]
16pub struct VerifyingKey {
17    pub key: Arc<DecodingKey>,
18    pub algorithm: Algorithm,
19}
20
21/// Fetches and caches JWKS keys by `kid`.
22pub struct JwksCache {
23    uri: String,
24    client: reqwest::Client,
25    keys: RwLock<HashMap<String, VerifyingKey>>,
26    last_refresh: Mutex<Option<Instant>>,
27}
28
29/// Minimum spacing between refreshes triggered by an unknown `kid`, so a flood
30/// of bogus `kid`s cannot hammer the JWKS endpoint.
31const MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(60);
32
33/// Bound the worst-case latency of a slow/stalled JWKS endpoint.
34const JWKS_HTTP_TIMEOUT: Duration = Duration::from_secs(5);
35
36impl JwksCache {
37    /// Create a cache for `uri` (keys are loaded lazily on first lookup).
38    pub fn new(uri: String) -> Self {
39        let client = reqwest::Client::builder()
40            .timeout(JWKS_HTTP_TIMEOUT)
41            .build()
42            .unwrap_or_default();
43        Self {
44            uri,
45            client,
46            keys: RwLock::new(HashMap::new()),
47            last_refresh: Mutex::new(None),
48        }
49    }
50
51    /// Resolve the verifying key for `kid`, refreshing from the JWKS endpoint
52    /// once (throttled) if it is not already cached.
53    pub async fn key_for(&self, kid: &str) -> Option<VerifyingKey> {
54        if let Some(k) = self.keys.read().await.get(kid).cloned() {
55            return Some(k);
56        }
57        if self.refresh().await.is_err() {
58            return None;
59        }
60        self.keys.read().await.get(kid).cloned()
61    }
62
63    /// Fetch the JWKS and replace the cache. Throttled by [`MIN_REFRESH_INTERVAL`]
64    /// unless the cache is still empty (first load).
65    async fn refresh(&self) -> Result<(), String> {
66        // Claim the refresh slot atomically: hold the lock across the throttle
67        // check and the timestamp update so concurrent callers cannot all pass.
68        {
69            let mut last = self.last_refresh.lock().await;
70            if let Some(t) = *last {
71                let empty = self.keys.read().await.is_empty();
72                if !empty && t.elapsed() < MIN_REFRESH_INTERVAL {
73                    return Err("refresh throttled".to_string());
74                }
75            }
76            *last = Some(Instant::now());
77        }
78
79        let set: JwkSet = self
80            .client
81            .get(&self.uri)
82            .send()
83            .await
84            .map_err(|e| format!("JWKS fetch failed: {e}"))?
85            .json()
86            .await
87            .map_err(|e| format!("JWKS decode failed: {e}"))?;
88
89        let new_keys = parse_jwks(&set);
90        *self.keys.write().await = new_keys;
91        Ok(())
92    }
93}
94
95/// Build the `kid → VerifyingKey` map from a JWK set, skipping keys without a
96/// `kid`, symmetric keys, or those that fail to convert.
97fn parse_jwks(set: &JwkSet) -> HashMap<String, VerifyingKey> {
98    let mut map = HashMap::new();
99    for jwk in &set.keys {
100        let Some(kid) = jwk.common.key_id.clone() else {
101            continue;
102        };
103        let Some(algorithm) = algorithm_for(jwk) else {
104            continue;
105        };
106        if let Ok(key) = DecodingKey::from_jwk(jwk) {
107            map.insert(
108                kid,
109                VerifyingKey {
110                    key: Arc::new(key),
111                    algorithm,
112                },
113            );
114        }
115    }
116    map
117}
118
119/// Pick the signature algorithm for a key.
120///
121/// The JWK's explicit `alg` is authoritative (so ES384 / RS512 / PS256 keys are
122/// not mis-pinned). Without it, fall back to the key type and EC curve.
123/// Symmetric keys (`OctetKey`) and unsupported variants are rejected.
124fn algorithm_for(jwk: &Jwk) -> Option<Algorithm> {
125    if let Some(alg) = jwk.common.key_algorithm.and_then(key_algorithm_to_alg) {
126        return Some(alg);
127    }
128    match &jwk.algorithm {
129        AlgorithmParameters::RSA(_) => Some(Algorithm::RS256),
130        AlgorithmParameters::EllipticCurve(ec) => match ec.curve {
131            EllipticCurve::P256 => Some(Algorithm::ES256),
132            EllipticCurve::P384 => Some(Algorithm::ES384),
133            // P-521 (ES512) is not supported by the verifier.
134            _ => None,
135        },
136        AlgorithmParameters::OctetKeyPair(_) => Some(Algorithm::EdDSA),
137        AlgorithmParameters::OctetKey(_) => None,
138    }
139}
140
141/// Map a JWK signature `alg` to a verifier algorithm, rejecting symmetric and
142/// encryption algorithms (only asymmetric signatures are usable from a JWKS).
143fn key_algorithm_to_alg(ka: KeyAlgorithm) -> Option<Algorithm> {
144    Some(match ka {
145        KeyAlgorithm::ES256 => Algorithm::ES256,
146        KeyAlgorithm::ES384 => Algorithm::ES384,
147        KeyAlgorithm::RS256 => Algorithm::RS256,
148        KeyAlgorithm::RS384 => Algorithm::RS384,
149        KeyAlgorithm::RS512 => Algorithm::RS512,
150        KeyAlgorithm::PS256 => Algorithm::PS256,
151        KeyAlgorithm::PS384 => Algorithm::PS384,
152        KeyAlgorithm::PS512 => Algorithm::PS512,
153        KeyAlgorithm::EdDSA => Algorithm::EdDSA,
154        _ => return None,
155    })
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn parse_jwks_keeps_asymmetric_keys_and_maps_algorithms() {
164        // A minimal RSA JWK with a kid (values are a real test key from the
165        // jsonwebtoken test vectors).
166        let set: JwkSet = serde_json::from_value(serde_json::json!({
167            "keys": [{
168                "kty": "RSA",
169                "kid": "rsa-1",
170                "use": "sig",
171                "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368Qen-JS7-zw04o6sJ9qjp6lFm5_T4nzcCqRfMOgRA_g_S0d7e9k7B0v0vqHr0e1V_o-z0ow5dWpql8-zKj4hQp8sg_Pn8O0R5ZQS4t8hUE-3-r3ftt1YzQ",
172                "e": "AQAB"
173            }]
174        })).unwrap();
175        let keys = parse_jwks(&set);
176        assert!(keys.contains_key("rsa-1"));
177        assert_eq!(keys["rsa-1"].algorithm, Algorithm::RS256);
178    }
179
180    #[test]
181    fn algorithm_prefers_explicit_jwk_alg() {
182        // An EC key that explicitly declares ES384 must not be pinned to ES256.
183        let jwk: Jwk = serde_json::from_value(serde_json::json!({
184            "kty": "EC", "crv": "P-384", "alg": "ES384", "kid": "k",
185            "x": "AAAA", "y": "AAAA"
186        }))
187        .unwrap();
188        assert_eq!(algorithm_for(&jwk), Some(Algorithm::ES384));
189    }
190
191    #[test]
192    fn algorithm_falls_back_to_curve_not_es256() {
193        // No alg field → infer from the curve, not a blanket ES256.
194        let jwk: Jwk = serde_json::from_value(serde_json::json!({
195            "kty": "EC", "crv": "P-384", "kid": "k", "x": "AAAA", "y": "AAAA"
196        }))
197        .unwrap();
198        assert_eq!(algorithm_for(&jwk), Some(Algorithm::ES384));
199    }
200
201    #[test]
202    fn parse_jwks_skips_symmetric_and_keyless() {
203        let set: JwkSet = serde_json::from_value(serde_json::json!({
204            "keys": [
205                { "kty": "oct", "kid": "hmac", "k": "c2VjcmV0" },
206                { "kty": "RSA", "n": "0vx7ag", "e": "AQAB" }
207            ]
208        }))
209        .unwrap();
210        // Symmetric key rejected; RSA without a kid skipped.
211        assert!(parse_jwks(&set).is_empty());
212    }
213}