use std::{
collections::HashMap,
fmt,
sync::{Arc, RwLock},
time::Duration,
};
use jsonwebtoken::{
DecodingKey,
jwk::{Jwk, JwkSet},
};
use tokio::sync::{Notify, watch};
use tokio_util::sync::CancellationToken;
use url::Url;
struct KeyEntry {
decoding_key: DecodingKey,
jwk: Jwk,
first_seen: chrono::DateTime<chrono::Utc>,
last_updated: chrono::DateTime<chrono::Utc>,
version: u64,
}
type KeyId = String;
#[derive(Clone)]
pub struct JwksKeyStore {
jwks_url: Url,
client: reqwest::Client,
cache: Arc<RwLock<HashMap<KeyId, KeyEntry>>>,
fetch_notify: Arc<Notify>,
fetch_generation_tx: Arc<watch::Sender<u64>>,
fetch_generation_rx: watch::Receiver<u64>,
}
impl fmt::Debug for JwksKeyStore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let cache = self.cache.read().unwrap();
let mut map = f.debug_map();
for (kid, entry) in cache.iter() {
map.entry(kid, &EntryDebug(entry));
}
map.finish()
}
}
struct EntryDebug<'a>(&'a KeyEntry);
impl fmt::Debug for EntryDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let e = self.0;
f.debug_struct("KeyEntry")
.field("decoding_key", &e.jwk)
.field("first_seen", &e.first_seen)
.field("last_updated", &e.last_updated)
.field("version", &e.version)
.finish()
}
}
impl JwksKeyStore {
pub fn new(
jwks_url: Url,
refresh_interval: Duration,
cancellation_token: CancellationToken,
) -> Self {
let (fetch_generation_tx, fetch_generation_rx) = watch::channel(0u64);
let store = Self {
jwks_url,
client: reqwest::Client::new(),
cache: Arc::new(RwLock::new(HashMap::new())),
fetch_notify: Arc::new(Notify::new()),
fetch_generation_tx: Arc::new(fetch_generation_tx),
fetch_generation_rx,
};
let bg = store.clone();
tokio::spawn(async move {
bg.background_loop(refresh_interval, cancellation_token)
.await;
});
store
}
pub fn get_key(&self, kid: &str) -> Option<DecodingKey> {
self.cache
.read()
.unwrap()
.get(kid)
.map(|e| e.decoding_key.clone())
}
pub async fn await_key(&self, kid: &str) -> Option<DecodingKey> {
let mut rx = self.fetch_generation_rx.clone();
if let Some(key) = self.get_key(kid) {
return Some(key);
}
self.fetch_notify.notify_one();
let _ = rx.changed().await;
self.get_key(kid)
}
async fn background_loop(self, refresh_interval: Duration, ct: CancellationToken) {
loop {
tokio::select! {
biased;
_ = ct.cancelled() => break,
_ = tokio::time::sleep(refresh_interval) => {},
_ = self.fetch_notify.notified() => {},
}
self.do_fetch().await;
self.fetch_generation_tx.send_modify(|g| *g += 1);
}
}
async fn do_fetch(&self) {
let jwks: JwkSet = match self
.client
.get(self.jwks_url.clone())
.send()
.await
.and_then(|r| r.error_for_status())
{
Err(e) => {
tracing::warn!(url = %self.jwks_url, error = %e, "failed to fetch JWKS");
return;
}
Ok(resp) => {
match resp.json().await {
Ok(j) => j,
Err(e) => {
tracing::warn!(
url = %self.jwks_url,
error = %e,
"failed to parse JWKS response"
);
return;
}
}
}
};
let now = chrono::Utc::now();
let mut cache = self.cache.write().unwrap();
for jwk in &jwks.keys {
let Some(kid) = &jwk.common.key_id else {
continue;
};
let key = match DecodingKey::from_jwk(jwk) {
Ok(k) => k,
Err(e) => {
tracing::warn!(
%kid,
error = %e,
"skipping JWKS key: failed to build decoding key"
);
continue;
}
};
match cache.get_mut(kid.as_str()) {
Some(entry) => {
if &entry.jwk != jwk {
tracing::debug!(%kid, version = entry.version + 1, "JWKS key material changed");
entry.decoding_key = key;
entry.jwk = jwk.clone();
entry.last_updated = now;
entry.version += 1;
}
}
None => {
tracing::debug!(%kid, "caching new JWKS key");
cache.insert(
kid.clone(),
KeyEntry {
decoding_key: key,
jwk: jwk.clone(),
first_seen: now,
last_updated: now,
version: 0,
},
);
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
};
use axum::{Json, Router, routing::get};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use scion_sdk_token_validator::validator::insecure_const_ed25519_signing_key;
use tokio_util::sync::CancellationToken;
use url::Url;
use super::JwksKeyStore;
fn test_jwks_json(kid: &str) -> serde_json::Value {
let signing_key = insecure_const_ed25519_signing_key();
let x = URL_SAFE_NO_PAD.encode(signing_key.verifying_key().as_bytes());
serde_json::json!({
"keys": [{
"kid": kid,
"kty": "OKP",
"use": "sig",
"alg": "EdDSA",
"crv": "Ed25519",
"x": x
}]
})
}
async fn start_jwks_server(
body: serde_json::Value,
delay: Option<Duration>,
) -> (Url, Arc<AtomicUsize>) {
let counter = Arc::new(AtomicUsize::new(0));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = {
let counter = counter.clone();
Router::new().route(
"/.well-known/jwks.json",
get(move || {
let body = body.clone();
let counter = counter.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
if let Some(d) = delay {
tokio::time::sleep(d).await;
}
Json(body)
}
}),
)
};
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
let url = format!("http://{}/.well-known/jwks.json", addr)
.parse()
.unwrap();
(url, counter)
}
fn make_store(url: Url) -> (JwksKeyStore, CancellationToken) {
let ct = CancellationToken::new();
let store = JwksKeyStore::new(url, Duration::from_secs(3600), ct.clone());
(store, ct)
}
#[tokio::test]
async fn cache_miss_triggers_fetch() {
let kid = "test-kid-1";
let (url, counter) = start_jwks_server(test_jwks_json(kid), None).await;
let (store, _ct) = make_store(url);
let result = store.await_key(kid).await;
assert!(result.is_some(), "expected key for known kid");
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"expected exactly one fetch"
);
}
#[tokio::test]
async fn cache_hit_avoids_second_fetch() {
scion_sdk_utils::rustls::select_ring_crypto_provider();
let kid = "test-kid-2";
let (url, counter) = start_jwks_server(test_jwks_json(kid), None).await;
let (store, _ct) = make_store(url);
store.await_key(kid).await.unwrap();
let result = store.get_key(kid); assert!(result.is_some(), "second lookup should hit the cache");
assert_eq!(counter.load(Ordering::SeqCst), 1, "expected only one fetch");
}
#[tokio::test]
async fn unknown_kid_returns_none_after_fetch() {
scion_sdk_utils::rustls::select_ring_crypto_provider();
let (url, _) = start_jwks_server(test_jwks_json("other-kid"), None).await;
let (store, _ct) = make_store(url);
let result = store.await_key("not-present").await;
assert!(result.is_none(), "unknown kid should return None");
}
#[tokio::test]
async fn fetch_failure_returns_none() {
scion_sdk_utils::rustls::select_ring_crypto_provider();
let url: Url = "http://127.0.0.1:19999/.well-known/jwks.json"
.parse()
.unwrap();
let (store, _ct) = make_store(url);
let result = store.await_key("any-kid").await;
assert!(result.is_none(), "fetch failure should return None");
}
#[tokio::test]
async fn concurrent_requests_for_same_kid_trigger_single_fetch() {
scion_sdk_utils::rustls::select_ring_crypto_provider();
let kid = "test-kid-concurrent";
let (url, counter) =
start_jwks_server(test_jwks_json(kid), Some(Duration::from_millis(50))).await;
let ct = CancellationToken::new();
let store = Arc::new(JwksKeyStore::new(
url,
Duration::from_secs(3600),
ct.clone(),
));
let handles: Vec<_> = (0..5)
.map(|_| {
let store = store.clone();
let kid = kid.to_string();
tokio::spawn(async move { store.await_key(&kid).await })
})
.collect();
for h in handles {
let result = h.await.unwrap();
assert!(result.is_some(), "all concurrent requests must succeed");
}
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"expected exactly one JWKS fetch despite concurrent requests"
);
}
#[tokio::test]
async fn periodic_refresh_fetches_on_interval() {
scion_sdk_utils::rustls::select_ring_crypto_provider();
let kid = "refresh-kid";
let (url, counter) = start_jwks_server(test_jwks_json(kid), None).await;
let ct = CancellationToken::new();
let _store = JwksKeyStore::new(url, Duration::from_millis(30), ct.clone());
tokio::time::sleep(Duration::from_millis(150)).await;
ct.cancel();
assert!(
counter.load(Ordering::SeqCst) >= 2,
"expected at least two periodic fetches, got {}",
counter.load(Ordering::SeqCst)
);
}
}