use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use std::time::Duration;
use crabka_security::{Jwks, JwksHandle};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
#[derive(Debug, thiserror::Error)]
pub(crate) enum FetchError {
#[error("jwks http request failed: {0}")]
Http(#[from] reqwest::Error),
#[error("jwks document was not a valid key set")]
Parse,
}
pub(crate) async fn fetch_jwks(
client: &reqwest::Client,
endpoint: &str,
ignore_key_use: bool,
) -> Result<Jwks, FetchError> {
let body = client
.get(endpoint)
.send()
.await?
.error_for_status()?
.text()
.await?;
Jwks::from_json(&body, ignore_key_use).map_err(|_| FetchError::Parse)
}
pub(crate) struct JwksRefresher {
pub endpoint: String,
pub handle: JwksHandle,
pub interval: Duration,
pub shutdown: CancellationToken,
pub tls_trust: Option<PathBuf>,
pub signal_rx: mpsc::Receiver<()>,
pub min_on_demand_pause: Duration,
pub last_successful_fetch_ms: Arc<AtomicI64>,
pub last_on_demand_refresh_ms: Arc<AtomicI64>,
pub ignore_key_use: bool,
}
impl JwksRefresher {
pub(crate) async fn run(mut self) {
let mut builder = reqwest::Client::builder().timeout(Duration::from_secs(10));
if let Some(path) = &self.tls_trust {
match crabka_security::build_client_config_from_pem(path) {
Ok(cfg) => {
builder = builder.use_preconfigured_tls((*cfg).clone());
}
Err(e) => {
tracing::error!(
error = %e,
path = %path.display(),
"failed to load OAUTHBEARER JWKS TLS trust bundle; refresher will not start",
);
return;
}
}
}
let client = match builder.build() {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "failed to build JWKS HTTP client; OAUTHBEARER signed tokens will not validate");
return;
}
};
let mut tick = tokio::time::interval(self.interval);
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = tick.tick() => {
self.refresh_and_swap(&client).await;
}
Some(()) = self.signal_rx.recv() => {
let now_ms = current_epoch_ms();
let last = self.last_on_demand_refresh_ms.load(Ordering::Relaxed);
let elapsed_ms = now_ms.saturating_sub(last);
let pause_ms = i64::try_from(self.min_on_demand_pause.as_millis())
.unwrap_or(i64::MAX);
if elapsed_ms >= pause_ms {
self.last_on_demand_refresh_ms.store(now_ms, Ordering::Relaxed);
tracing::debug!(
endpoint = %self.endpoint,
elapsed_ms,
"on-demand JWKS refresh triggered by validator signal",
);
self.refresh_and_swap(&client).await;
} else {
tracing::debug!(
endpoint = %self.endpoint,
elapsed_ms,
pause_ms,
"on-demand JWKS refresh rate-limited; signal dropped",
);
}
}
() = self.shutdown.cancelled() => return,
}
}
}
async fn refresh_and_swap(&self, client: &reqwest::Client) {
match fetch_jwks(client, &self.endpoint, self.ignore_key_use).await {
Ok(jwks) => {
tracing::debug!(
endpoint = %self.endpoint,
keys = jwks.len(),
"refreshed OAUTHBEARER JWKS",
);
self.handle.store(jwks);
self.last_successful_fetch_ms
.store(current_epoch_ms(), Ordering::Relaxed);
}
Err(e) => tracing::warn!(
endpoint = %self.endpoint,
error = %e,
"failed to refresh OAUTHBEARER JWKS; keeping previous key set",
),
}
}
}
fn current_epoch_ms() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| i64::try_from(d.as_millis()).unwrap_or(i64::MAX))
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use std::net::SocketAddr;
async fn serve_jwks(body: &'static str) -> (SocketAddr, CancellationToken) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown = CancellationToken::new();
let app =
axum::Router::new().route("/jwks", axum::routing::get(move || async move { body }));
let srv_shutdown = shutdown.clone();
tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move { srv_shutdown.cancelled().await })
.await
.unwrap();
});
(addr, shutdown)
}
const JWKS_BODY: &str = r#"{"keys":[{"kty":"EC","crv":"P-256","kid":"k1","x":"f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU","y":"x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0"}]}"#;
fn test_refresher(
endpoint: String,
handle: JwksHandle,
interval: Duration,
shutdown: CancellationToken,
tls_trust: Option<PathBuf>,
) -> JwksRefresher {
let (_tx, rx) = mpsc::channel::<()>(1);
JwksRefresher {
endpoint,
handle,
interval,
shutdown,
tls_trust,
signal_rx: rx,
min_on_demand_pause: Duration::from_secs(1),
last_successful_fetch_ms: Arc::new(AtomicI64::new(0)),
last_on_demand_refresh_ms: Arc::new(AtomicI64::new(0)),
ignore_key_use: false,
}
}
#[tokio::test]
async fn fetch_jwks_parses_served_keyset() {
let (addr, shutdown) = serve_jwks(JWKS_BODY).await;
let client = reqwest::Client::new();
let jwks = fetch_jwks(&client, &format!("http://{addr}/jwks"), false)
.await
.unwrap();
assert!(jwks.len() == 1);
shutdown.cancel();
}
#[tokio::test]
async fn fetch_jwks_errors_on_dead_endpoint() {
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(500))
.build()
.unwrap();
let err = fetch_jwks(&client, "http://127.0.0.1:1/jwks", false).await;
assert!(err.is_err());
}
#[tokio::test]
async fn refresher_populates_handle_then_stops_on_shutdown() {
let (addr, srv_shutdown) = serve_jwks(JWKS_BODY).await;
let handle = JwksHandle::default();
assert!(handle.load().is_empty());
let shutdown = CancellationToken::new();
let refresher = test_refresher(
format!("http://{addr}/jwks"),
handle.clone(),
Duration::from_millis(50),
shutdown.clone(),
None,
);
let task = tokio::spawn(refresher.run());
for _ in 0..100 {
if !handle.load().is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(handle.load().len() == 1);
shutdown.cancel();
task.await.unwrap();
srv_shutdown.cancel();
}
async fn serve_jwks_https(
body: &'static str,
) -> (std::net::SocketAddr, CancellationToken, std::path::PathBuf) {
use rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
use std::sync::Arc;
use tokio::io::AsyncWriteExt as _;
use tokio_rustls::TlsAcceptor;
let _ = rustls::crypto::ring::default_provider().install_default();
let params = rcgen::CertificateParams::new(vec!["127.0.0.1".to_string()]).unwrap();
let key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let cert = params.self_signed(&key).unwrap();
let dir = Box::leak(Box::new(tempfile::tempdir().unwrap()));
let cert_path = dir.path().join("cert.pem");
std::fs::write(&cert_path, cert.pem()).unwrap();
let key_path = dir.path().join("key.pem");
std::fs::write(&key_path, key.serialize_pem()).unwrap();
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(&cert_path)
.unwrap()
.collect::<Result<_, _>>()
.unwrap();
let priv_key = PrivateKeyDer::from_pem_file(&key_path).unwrap();
let server_cfg = Arc::new(
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, priv_key)
.unwrap(),
);
let acceptor = TlsAcceptor::from(server_cfg);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown = CancellationToken::new();
let srv_shutdown = shutdown.clone();
tokio::spawn(async move {
loop {
tokio::select! {
() = srv_shutdown.cancelled() => break,
Ok((sock, _peer)) = listener.accept() => {
let acceptor = acceptor.clone();
tokio::spawn(async move {
use tokio::io::AsyncReadExt as _;
let Ok(mut tls) = acceptor.accept(sock).await else { return };
let mut buf = [0u8; 1024];
let _ = tls.read(&mut buf).await;
let header = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n",
body.len(),
);
let _ = tls.write_all(header.as_bytes()).await;
let _ = tls.write_all(body.as_bytes()).await;
let _ = tls.shutdown().await;
});
}
}
}
});
(addr, shutdown, cert_path)
}
#[tokio::test]
async fn refresher_fetches_jwks_over_https_with_custom_trust() {
let (addr, srv_shutdown, ca_path) = serve_jwks_https(JWKS_BODY).await;
let handle = JwksHandle::default();
let shutdown = CancellationToken::new();
let refresher = test_refresher(
format!("https://127.0.0.1:{}/jwks", addr.port()),
handle.clone(),
Duration::from_millis(50),
shutdown.clone(),
Some(ca_path),
);
let task = tokio::spawn(refresher.run());
for _ in 0..100 {
if !handle.load().is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(handle.load().len() == 1);
shutdown.cancel();
task.await.unwrap();
srv_shutdown.cancel();
}
#[tokio::test]
async fn refresher_https_fetch_fails_when_custom_trust_doesnt_match_server_cert() {
let (addr, srv_shutdown, _server_cert_path) = serve_jwks_https(JWKS_BODY).await;
let dir = tempfile::tempdir().unwrap();
let params = rcgen::CertificateParams::new(vec!["unrelated.example".to_string()]).unwrap();
let key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let cert = params.self_signed(&key).unwrap();
let bogus_ca = dir.path().join("bogus-ca.pem");
std::fs::write(&bogus_ca, cert.pem()).unwrap();
let handle = JwksHandle::default();
let shutdown = CancellationToken::new();
let refresher = test_refresher(
format!("https://127.0.0.1:{}/jwks", addr.port()),
handle.clone(),
Duration::from_millis(50),
shutdown.clone(),
Some(bogus_ca),
);
let task = tokio::spawn(refresher.run());
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(
handle.load().is_empty(),
"fetch should fail verification and leave handle empty",
);
shutdown.cancel();
task.await.unwrap();
srv_shutdown.cancel();
}
async fn serve_jwks_counting(
body: &'static str,
) -> (
SocketAddr,
CancellationToken,
Arc<std::sync::atomic::AtomicUsize>,
) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown = CancellationToken::new();
let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let counter_cl = counter.clone();
let app = axum::Router::new().route(
"/jwks",
axum::routing::get(move || {
let c = counter_cl.clone();
async move {
c.fetch_add(1, Ordering::Relaxed);
body
}
}),
);
let srv_shutdown = shutdown.clone();
tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move { srv_shutdown.cancelled().await })
.await
.unwrap();
});
(addr, shutdown, counter)
}
#[allow(clippy::type_complexity)]
fn make_signal_refresher(
endpoint: String,
min_on_demand_pause: Duration,
) -> (
JwksRefresher,
mpsc::Sender<()>,
Arc<AtomicI64>, // last_successful_fetch_ms
Arc<AtomicI64>, // last_on_demand_refresh_ms
CancellationToken,
JwksHandle,
) {
let (signal_tx, signal_rx) = mpsc::channel::<()>(1);
let shutdown = CancellationToken::new();
let last_successful = Arc::new(AtomicI64::new(0));
let last_on_demand = Arc::new(AtomicI64::new(0));
let handle = JwksHandle::new_with_refresher_handles(
Jwks::empty(),
last_successful.clone(),
signal_tx.clone(),
);
let refresher = JwksRefresher {
endpoint,
handle: handle.clone(),
interval: Duration::from_hours(1), shutdown: shutdown.clone(),
tls_trust: None,
signal_rx,
min_on_demand_pause,
last_successful_fetch_ms: last_successful.clone(),
last_on_demand_refresh_ms: last_on_demand.clone(),
ignore_key_use: false,
};
(
refresher,
signal_tx,
last_successful,
last_on_demand,
shutdown,
handle,
)
}
#[tokio::test]
async fn refresher_signal_triggers_on_demand_refresh_when_pause_elapsed() {
let (addr, srv_shutdown, count) = serve_jwks_counting(JWKS_BODY).await;
let endpoint = format!("http://{addr}/jwks");
let (refresher, signal_tx, _last_successful, last_on_demand, shutdown, handle) =
make_signal_refresher(endpoint, Duration::from_millis(0));
let task = tokio::spawn(refresher.run());
signal_tx.send(()).await.unwrap();
for _ in 0..100 {
if last_on_demand.load(Ordering::Relaxed) > 0 && !handle.load().is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(
last_on_demand.load(Ordering::Relaxed) > 0,
"on-demand timestamp should have advanced past sentinel 0",
);
assert!(
count.load(Ordering::Relaxed) >= 1,
"server should have served the on-demand request"
);
assert!(
handle.load().len() == 1,
"refresher must store the fetched key set"
);
shutdown.cancel();
let _ = task.await;
srv_shutdown.cancel();
}
#[tokio::test]
async fn refresher_signal_dropped_when_within_min_pause_window() {
let (addr, srv_shutdown, count) = serve_jwks_counting(JWKS_BODY).await;
let endpoint = format!("http://{addr}/jwks");
let (refresher, signal_tx, _last_successful, last_on_demand, shutdown, _handle) =
make_signal_refresher(endpoint, Duration::from_mins(1));
let task = tokio::spawn(refresher.run());
signal_tx.send(()).await.unwrap();
for _ in 0..100 {
if count.load(Ordering::Relaxed) >= 1 {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
let first_ts = last_on_demand.load(Ordering::Relaxed);
assert!(first_ts > 0, "first signal must have fired a refresh");
let count_after_first = count.load(Ordering::Relaxed);
assert!(count_after_first >= 1);
signal_tx.send(()).await.unwrap();
for _ in 0..10 {
tokio::task::yield_now().await;
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(
last_on_demand.load(Ordering::Relaxed) == first_ts,
"second signal within min_pause must not advance timestamp"
);
assert!(
count.load(Ordering::Relaxed) == count_after_first,
"server must not see a second on-demand HTTP request"
);
shutdown.cancel();
let _ = task.await;
srv_shutdown.cancel();
}
#[tokio::test]
async fn refresher_successful_refresh_updates_last_successful_fetch_timestamp() {
let (addr, srv_shutdown, _count) = serve_jwks_counting(JWKS_BODY).await;
let endpoint = format!("http://{addr}/jwks");
let (refresher, signal_tx, last_successful, _last_on_demand, shutdown, _handle) =
make_signal_refresher(endpoint, Duration::from_millis(0));
let task = tokio::spawn(refresher.run());
assert!(last_successful.load(Ordering::Relaxed) == 0);
signal_tx.send(()).await.unwrap();
for _ in 0..100 {
if last_successful.load(Ordering::Relaxed) > 0 {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(
last_successful.load(Ordering::Relaxed) > 0,
"last_successful_fetch_ms must advance after a successful fetch",
);
shutdown.cancel();
let _ = task.await;
srv_shutdown.cancel();
}
#[tokio::test]
async fn refresher_failed_refresh_does_not_advance_last_successful_fetch() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let srv_shutdown = CancellationToken::new();
let srv_token = srv_shutdown.clone();
let app = axum::Router::new().route(
"/jwks",
axum::routing::get(|| async {
(axum::http::StatusCode::INTERNAL_SERVER_ERROR, "boom")
}),
);
tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move { srv_token.cancelled().await })
.await
.unwrap();
});
let endpoint = format!("http://{addr}/jwks");
let (refresher, signal_tx, last_successful, last_on_demand, shutdown, _handle) =
make_signal_refresher(endpoint, Duration::from_millis(0));
let task = tokio::spawn(refresher.run());
signal_tx.send(()).await.unwrap();
for _ in 0..50 {
if last_on_demand.load(Ordering::Relaxed) > 0 {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(
last_on_demand.load(Ordering::Relaxed) > 0,
"on-demand rate-limit timestamp updates even when the fetch itself fails",
);
assert!(
last_successful.load(Ordering::Relaxed) == 0,
"failed fetch must leave last_successful_fetch_ms at sentinel 0"
);
shutdown.cancel();
let _ = task.await;
srv_shutdown.cancel();
}
#[tokio::test]
async fn refresher_passes_ignore_key_use_through_to_jwks_parser() {
const ENC_KEY_BODY: &str =
r#"{"keys":[{"kty":"RSA","kid":"enc-kid","use":"enc","n":"AQAB","e":"AQAB"}]}"#;
let (addr, srv_shutdown, _count) = serve_jwks_counting(ENC_KEY_BODY).await;
let endpoint = format!("http://{addr}/jwks");
let (mut refresher, signal_tx, _last_successful, _last_on_demand, shutdown, handle) =
make_signal_refresher(endpoint, Duration::from_millis(0));
refresher.ignore_key_use = true;
let task = tokio::spawn(refresher.run());
signal_tx.send(()).await.unwrap();
for _ in 0..100 {
if !handle.load().is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(
handle.load().len() == 1,
"ignore_key_use=true must keep the use=enc key in the installed set"
);
shutdown.cancel();
let _ = task.await;
srv_shutdown.cancel();
}
}