use std::sync::Arc;
use std::time::Duration;
use ppoppo_clock::ArcClock;
use ppoppo_clock::native::WallClock;
use time::OffsetDateTime;
use tokio::sync::RwLock;
use ppoppo_token::{Jwks, KeySet as EngineKeySet};
use super::VerifyError;
const DEFAULT_TTL: Duration = Duration::from_secs(300);
#[derive(Clone)]
pub struct JwksCache {
clock: ArcClock,
inner: Arc<JwksCacheInner>,
}
impl std::fmt::Debug for JwksCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwksCache")
.field("url", &self.inner.url)
.finish_non_exhaustive()
}
}
impl JwksCache {
#[must_use]
pub fn with_clock(mut self, clock: ArcClock) -> Self {
self.clock = clock;
self
}
}
struct JwksCacheInner {
url: String,
http: reqwest::Client,
state: RwLock<JwksCacheState>,
}
struct JwksCacheState {
keyset: Arc<EngineKeySet>,
fetched_at: OffsetDateTime,
ttl: Duration,
}
#[derive(Debug, thiserror::Error)]
#[error("JWKS fetch failed (status={status:?}): {detail}")]
struct FetchFailure {
status: Option<u16>,
detail: String,
}
impl JwksCache {
pub async fn fetch(url: impl Into<String>) -> Result<Self, VerifyError> {
let clock: ArcClock = Arc::new(WallClock);
let url = url.into();
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(5))
.build()
.map_err(|_| VerifyError::KeysetUnavailable)?;
let (jwks, ttl) = fetch_jwks(&http, &url)
.await
.map_err(|_| VerifyError::KeysetUnavailable)?;
let keyset = jwks.into_key_set().map_err(|_| VerifyError::KeysetUnavailable)?;
Ok(Self {
inner: Arc::new(JwksCacheInner {
url,
http,
state: RwLock::new(JwksCacheState {
keyset: Arc::new(keyset),
fetched_at: clock.now_utc(),
ttl,
}),
}),
clock,
})
}
#[cfg(any(test, feature = "test-support"))]
pub fn for_test_empty() -> Self {
let clock: ArcClock = Arc::new(WallClock);
Self {
inner: Arc::new(JwksCacheInner {
url: String::from("test://empty"),
http: reqwest::Client::new(),
state: RwLock::new(JwksCacheState {
keyset: Arc::new(EngineKeySet::default()),
fetched_at: clock.now_utc(),
ttl: Duration::from_secs(86_400),
}),
}),
clock,
}
}
pub async fn snapshot(&self) -> Arc<EngineKeySet> {
self.refresh_if_stale().await;
self.inner.state.read().await.keyset.clone()
}
async fn refresh_if_stale(&self) {
let needs_refresh = {
let guard = self.inner.state.read().await;
let elapsed = self.clock.now_utc() - guard.fetched_at;
let ttl = time::Duration::try_from(guard.ttl).unwrap_or(time::Duration::seconds(300));
elapsed >= ttl
};
if !needs_refresh {
return;
}
match fetch_jwks(&self.inner.http, &self.inner.url).await {
Ok((jwks, ttl)) => {
if let Ok(keyset) = jwks.into_key_set() {
let mut guard = self.inner.state.write().await;
guard.keyset = Arc::new(keyset);
guard.fetched_at = self.clock.now_utc();
guard.ttl = ttl;
}
}
Err(_e) => {
#[cfg(feature = "axum")]
tracing::warn!(error = %_e, "JwksCache refresh failed; serving stale snapshot");
}
}
}
}
async fn fetch_jwks(http: &reqwest::Client, url: &str) -> Result<(Jwks, Duration), FetchFailure> {
let response = http.get(url).send().await.map_err(|e| FetchFailure {
status: None,
detail: format!("send: {e}"),
})?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
return Err(FetchFailure {
status: Some(status),
detail: body,
});
}
let ttl = parse_max_age(&response).unwrap_or(DEFAULT_TTL);
let jwks = response.json::<Jwks>().await.map_err(|e| FetchFailure {
status: None,
detail: format!("parse failed: {e}"),
})?;
Ok((jwks, ttl))
}
fn parse_max_age(response: &reqwest::Response) -> Option<Duration> {
let value = response.headers().get(reqwest::header::CACHE_CONTROL)?;
let s = value.to_str().ok()?;
for part in s.split(',') {
let part = part.trim();
if let Some(rest) = part.strip_prefix("max-age=") {
if let Ok(secs) = rest.trim().parse::<u64>() {
return Some(Duration::from_secs(secs));
}
}
}
None
}