use std::sync::Arc;
use std::time::Duration;
use time::OffsetDateTime;
use tokio::sync::RwLock;
use crate::error::Error;
use crate::token::{VerifiedClaims, verify_v4_with_keyset};
use crate::well_known::WellKnownPasetoDocument;
#[derive(Clone)]
pub struct KeySet {
inner: Arc<KeySetInner>,
}
struct KeySetInner {
url: String,
http: reqwest::Client,
state: RwLock<KeySetState>,
}
struct KeySetState {
document: WellKnownPasetoDocument,
fetched_at: OffsetDateTime,
}
impl KeySet {
pub async fn fetch(url: impl Into<String>) -> Result<Self, Error> {
let url = url.into();
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(5))
.build()?;
let document = fetch_document(&http, &url).await?;
Ok(Self {
inner: Arc::new(KeySetInner {
url,
http,
state: RwLock::new(KeySetState {
document,
fetched_at: OffsetDateTime::now_utc(),
}),
}),
})
}
#[must_use]
pub fn with_initial(
url: impl Into<String>,
http: reqwest::Client,
document: WellKnownPasetoDocument,
) -> Self {
Self {
inner: Arc::new(KeySetInner {
url: url.into(),
http,
state: RwLock::new(KeySetState {
document,
fetched_at: OffsetDateTime::now_utc(),
}),
}),
}
}
pub async fn verify(
&self,
token: &str,
expected_issuer: &str,
expected_audience: &str,
) -> Result<VerifiedClaims, Error> {
self.refresh_if_stale().await;
let guard = self.inner.state.read().await;
verify_v4_with_keyset(&guard.document, token, expected_issuer, expected_audience)
}
pub async fn refresh_now(&self) -> Result<OffsetDateTime, Error> {
let document = fetch_document(&self.inner.http, &self.inner.url).await?;
let mut guard = self.inner.state.write().await;
guard.document = document;
guard.fetched_at = OffsetDateTime::now_utc();
Ok(guard.fetched_at)
}
pub async fn snapshot(&self) -> WellKnownPasetoDocument {
self.inner.state.read().await.document.clone()
}
async fn refresh_if_stale(&self) {
let needs_refresh = {
let guard = self.inner.state.read().await;
let ttl = Duration::from_secs(guard.document.cache_ttl_seconds);
let elapsed = OffsetDateTime::now_utc() - guard.fetched_at;
elapsed >= time::Duration::try_from(ttl).unwrap_or(time::Duration::seconds(3600))
};
if !needs_refresh {
return;
}
if let Err(e) = self.refresh_now().await {
#[cfg(feature = "axum")]
tracing::warn!(error = %e, "KeySet refresh failed, serving stale cache");
let _ = e;
}
}
}
async fn fetch_document(
http: &reqwest::Client,
url: &str,
) -> Result<WellKnownPasetoDocument, Error> {
let response = http.get(url).send().await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
return Err(Error::OAuth {
operation: "well-known fetch",
status: Some(status),
detail: body,
});
}
response.json::<WellKnownPasetoDocument>().await.map_err(|e| {
Error::OAuth {
operation: "well-known fetch",
status: None,
detail: format!("parse failed: {e}"),
}
})
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::types::KeyId;
use crate::well_known::{WellKnownKeyStatus, WellKnownPasetoKey};
fn doc(ttl_secs: u64) -> WellKnownPasetoDocument {
WellKnownPasetoDocument {
issuer: "accounts.ppoppo.com".into(),
version: "v4.public".into(),
keys: vec![WellKnownPasetoKey {
kid: KeyId("test-1".into()),
public_key_hex: "a".repeat(64),
status: WellKnownKeyStatus::Active,
created_at: OffsetDateTime::now_utc(),
}],
cache_ttl_seconds: ttl_secs,
}
}
#[tokio::test]
async fn snapshot_returns_initial_document() {
let http = reqwest::Client::new();
let keyset = KeySet::with_initial("http://example.invalid", http, doc(3600));
let snap = keyset.snapshot().await;
assert_eq!(snap.cache_ttl_seconds, 3600);
assert_eq!(snap.keys.len(), 1);
}
#[tokio::test]
async fn refresh_failure_preserves_cache() {
let http = reqwest::Client::new();
let keyset = KeySet::with_initial("http://nonexistent.invalid", http, doc(0));
let snap_before = keyset.snapshot().await;
let _ = keyset.refresh_now().await;
let snap_after = keyset.snapshot().await;
assert_eq!(snap_before, snap_after, "cache must survive refresh failure");
}
}