use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use jsonwebtoken::jwk::{AlgorithmParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm};
use jsonwebtoken::{Algorithm, DecodingKey};
use tokio::sync::{Mutex, RwLock};
#[derive(Clone)]
pub struct VerifyingKey {
pub key: Arc<DecodingKey>,
pub algorithm: Algorithm,
}
pub struct JwksCache {
uri: String,
client: reqwest::Client,
keys: RwLock<HashMap<String, VerifyingKey>>,
last_refresh: Mutex<Option<Instant>>,
}
const MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(60);
const JWKS_HTTP_TIMEOUT: Duration = Duration::from_secs(5);
impl JwksCache {
pub fn new(uri: String) -> Self {
let client = reqwest::Client::builder()
.timeout(JWKS_HTTP_TIMEOUT)
.build()
.unwrap_or_default();
Self {
uri,
client,
keys: RwLock::new(HashMap::new()),
last_refresh: Mutex::new(None),
}
}
pub async fn key_for(&self, kid: &str) -> Option<VerifyingKey> {
if let Some(k) = self.keys.read().await.get(kid).cloned() {
return Some(k);
}
if self.refresh().await.is_err() {
return None;
}
self.keys.read().await.get(kid).cloned()
}
async fn refresh(&self) -> Result<(), String> {
{
let mut last = self.last_refresh.lock().await;
if let Some(t) = *last {
let empty = self.keys.read().await.is_empty();
if !empty && t.elapsed() < MIN_REFRESH_INTERVAL {
return Err("refresh throttled".to_string());
}
}
*last = Some(Instant::now());
}
let set: JwkSet = self
.client
.get(&self.uri)
.send()
.await
.map_err(|e| format!("JWKS fetch failed: {e}"))?
.json()
.await
.map_err(|e| format!("JWKS decode failed: {e}"))?;
let new_keys = parse_jwks(&set);
*self.keys.write().await = new_keys;
Ok(())
}
}
fn parse_jwks(set: &JwkSet) -> HashMap<String, VerifyingKey> {
let mut map = HashMap::new();
for jwk in &set.keys {
let Some(kid) = jwk.common.key_id.clone() else {
continue;
};
let Some(algorithm) = algorithm_for(jwk) else {
continue;
};
if let Ok(key) = DecodingKey::from_jwk(jwk) {
map.insert(
kid,
VerifyingKey {
key: Arc::new(key),
algorithm,
},
);
}
}
map
}
fn algorithm_for(jwk: &Jwk) -> Option<Algorithm> {
if let Some(alg) = jwk.common.key_algorithm.and_then(key_algorithm_to_alg) {
return Some(alg);
}
match &jwk.algorithm {
AlgorithmParameters::RSA(_) => Some(Algorithm::RS256),
AlgorithmParameters::EllipticCurve(ec) => match ec.curve {
EllipticCurve::P256 => Some(Algorithm::ES256),
EllipticCurve::P384 => Some(Algorithm::ES384),
_ => None,
},
AlgorithmParameters::OctetKeyPair(_) => Some(Algorithm::EdDSA),
AlgorithmParameters::OctetKey(_) => None,
}
}
fn key_algorithm_to_alg(ka: KeyAlgorithm) -> Option<Algorithm> {
Some(match ka {
KeyAlgorithm::ES256 => Algorithm::ES256,
KeyAlgorithm::ES384 => Algorithm::ES384,
KeyAlgorithm::RS256 => Algorithm::RS256,
KeyAlgorithm::RS384 => Algorithm::RS384,
KeyAlgorithm::RS512 => Algorithm::RS512,
KeyAlgorithm::PS256 => Algorithm::PS256,
KeyAlgorithm::PS384 => Algorithm::PS384,
KeyAlgorithm::PS512 => Algorithm::PS512,
KeyAlgorithm::EdDSA => Algorithm::EdDSA,
_ => return None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_jwks_keeps_asymmetric_keys_and_maps_algorithms() {
let set: JwkSet = serde_json::from_value(serde_json::json!({
"keys": [{
"kty": "RSA",
"kid": "rsa-1",
"use": "sig",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368Qen-JS7-zw04o6sJ9qjp6lFm5_T4nzcCqRfMOgRA_g_S0d7e9k7B0v0vqHr0e1V_o-z0ow5dWpql8-zKj4hQp8sg_Pn8O0R5ZQS4t8hUE-3-r3ftt1YzQ",
"e": "AQAB"
}]
})).unwrap();
let keys = parse_jwks(&set);
assert!(keys.contains_key("rsa-1"));
assert_eq!(keys["rsa-1"].algorithm, Algorithm::RS256);
}
#[test]
fn algorithm_prefers_explicit_jwk_alg() {
let jwk: Jwk = serde_json::from_value(serde_json::json!({
"kty": "EC", "crv": "P-384", "alg": "ES384", "kid": "k",
"x": "AAAA", "y": "AAAA"
}))
.unwrap();
assert_eq!(algorithm_for(&jwk), Some(Algorithm::ES384));
}
#[test]
fn algorithm_falls_back_to_curve_not_es256() {
let jwk: Jwk = serde_json::from_value(serde_json::json!({
"kty": "EC", "crv": "P-384", "kid": "k", "x": "AAAA", "y": "AAAA"
}))
.unwrap();
assert_eq!(algorithm_for(&jwk), Some(Algorithm::ES384));
}
#[test]
fn parse_jwks_skips_symmetric_and_keyless() {
let set: JwkSet = serde_json::from_value(serde_json::json!({
"keys": [
{ "kty": "oct", "kid": "hmac", "k": "c2VjcmV0" },
{ "kty": "RSA", "n": "0vx7ag", "e": "AQAB" }
]
}))
.unwrap();
assert!(parse_jwks(&set).is_empty());
}
}