use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use reqwest::Client;
use serde::{Serialize, de::DeserializeOwned};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use crate::Codec as _;
use crate::jwt::jwks::JwksDocument;
use crate::jwt::{JsonWebToken, JsonWebTokenOptions};
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(3);
const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(60);
type SharedVerifier<P> = Arc<RwLock<Option<Arc<JsonWebToken<P>>>>>;
#[derive(Clone, Debug)]
pub struct RemoteJwksVerifierConfig {
pub jwks_url: String,
pub http_timeout: Duration,
pub refresh_interval: Duration,
pub cache_path: Option<PathBuf>,
}
impl RemoteJwksVerifierConfig {
pub fn from_jwks_url(jwks_url: impl Into<String>) -> Self {
Self {
jwks_url: jwks_url.into(),
http_timeout: DEFAULT_HTTP_TIMEOUT,
refresh_interval: DEFAULT_REFRESH_INTERVAL,
cache_path: None,
}
}
#[must_use]
pub fn with_http_timeout(mut self, timeout: Duration) -> Self {
self.http_timeout = timeout;
self
}
#[must_use]
pub fn with_refresh_interval(mut self, refresh_interval: Duration) -> Self {
self.refresh_interval = refresh_interval;
self
}
#[must_use]
pub fn with_cache_path(mut self, cache_path: impl Into<PathBuf>) -> Self {
self.cache_path = Some(cache_path.into());
self
}
}
#[derive(thiserror::Error, Debug)]
pub enum RemoteJwksVerifierError {
#[error("failed to build HTTP client: {0}")]
HttpClientBuild(#[from] reqwest::Error),
#[error("failed to fetch JWKS document from {url}: {message}")]
Fetch {
url: String,
message: String,
},
#[error("failed to parse JWKS response: {0}")]
ParseResponse(String),
#[error("JWKS document did not contain any valid ES384 keys")]
NoValidKeys,
#[error("failed to persist JWKS cache at {path}: {message}")]
CacheWrite {
path: String,
message: String,
},
#[error("failed to read JWKS cache at {path}: {message}")]
CacheRead {
path: String,
message: String,
},
#[error("missing JWT `kid` and refresh did not provide a fallback key")]
MissingKidWithoutFallback,
#[error("JWT key id `{kid}` not found after refresh")]
UnknownKid {
kid: String,
},
#[error("token verification failed: {0}")]
Verify(String),
#[error("startup failed because no live JWKS or cached JWKS was available")]
StartupNoKeys,
}
#[derive(Clone)]
pub struct RemoteJwksVerifier<P>
where
P: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
config: RemoteJwksVerifierConfig,
client: Client,
verifier: SharedVerifier<P>,
refresh_lock: Arc<tokio::sync::Mutex<()>>,
}
impl<P> RemoteJwksVerifier<P>
where
P: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
pub async fn bootstrap(
config: RemoteJwksVerifierConfig,
) -> Result<Self, RemoteJwksVerifierError> {
let client = Client::builder().timeout(config.http_timeout).build()?;
let verifier: SharedVerifier<P> = Arc::new(RwLock::new(None));
let this = Self {
config,
client,
verifier,
refresh_lock: Arc::new(tokio::sync::Mutex::new(())),
};
let mut has_cache = false;
if let Some(cached) = this.load_cached_verifier().await? {
*this.verifier.write().await = Some(cached);
has_cache = true;
tracing::warn!("starting with cached JWKS keys while attempting live refresh");
}
match this.refresh().await {
Ok(()) => {}
Err(error) if has_cache => {
tracing::warn!(error = %error, "live JWKS refresh failed, continuing with cached keys");
}
Err(_) => return Err(RemoteJwksVerifierError::StartupNoKeys),
}
Ok(this)
}
pub fn start_background_refresh(&self) -> JoinHandle<()> {
let refresh_interval = self.config.refresh_interval;
let this = self.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(refresh_interval);
loop {
ticker.tick().await;
if let Err(error) = this.refresh().await {
tracing::warn!(error = %error, "background JWKS refresh failed");
}
}
})
}
pub async fn refresh(&self) -> Result<(), RemoteJwksVerifierError> {
let _lock = self.refresh_lock.lock().await;
let jwks = self.fetch_jwks().await?;
let codec = Arc::new(codec_from_jwks(&jwks)?);
if let Some(cache_path) = &self.config.cache_path {
persist_jwks_cache(cache_path, &jwks).await?;
}
*self.verifier.write().await = Some(codec);
Ok(())
}
pub async fn verify_token(&self, token: &str) -> Result<P, RemoteJwksVerifierError> {
match self.verify_once(token).await {
Ok(claims) => Ok(claims),
Err(RemoteJwksVerifierError::Verify(ref message))
if message.contains("missing `kid`") || message.contains("not configured") =>
{
let message = message.clone();
self.refresh().await?;
match self.verify_once(token).await {
Ok(claims) => Ok(claims),
Err(RemoteJwksVerifierError::Verify(ref refreshed_message))
if refreshed_message.contains("missing `kid`") =>
{
Err(RemoteJwksVerifierError::MissingKidWithoutFallback)
}
Err(RemoteJwksVerifierError::Verify(ref refreshed_message)) => {
if let Some(kid) = kid_from_token_error(refreshed_message) {
Err(RemoteJwksVerifierError::UnknownKid { kid })
} else {
Err(RemoteJwksVerifierError::Verify(refreshed_message.clone()))
}
}
Err(error) => {
let _ = message;
Err(error)
}
}
}
Err(error) => Err(error),
}
}
async fn verify_once(&self, token: &str) -> Result<P, RemoteJwksVerifierError> {
let verifier = self
.verifier
.read()
.await
.clone()
.ok_or(RemoteJwksVerifierError::StartupNoKeys)?;
verifier
.decode(token.as_bytes())
.map_err(|error: crate::Error| RemoteJwksVerifierError::Verify(error.to_string()))
}
async fn load_cached_verifier(
&self,
) -> Result<Option<Arc<JsonWebToken<P>>>, RemoteJwksVerifierError> {
let Some(cache_path) = &self.config.cache_path else {
return Ok(None);
};
if !cache_path.exists() {
return Ok(None);
}
let raw = tokio::fs::read_to_string(cache_path)
.await
.map_err(|error| RemoteJwksVerifierError::CacheRead {
path: cache_path.display().to_string(),
message: error.to_string(),
})?;
let jwks: JwksDocument = serde_json::from_str(&raw)
.map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))?;
let codec = Arc::new(codec_from_jwks(&jwks)?);
Ok(Some(codec))
}
async fn fetch_jwks(&self) -> Result<JwksDocument, RemoteJwksVerifierError> {
let response = self
.client
.get(&self.config.jwks_url)
.send()
.await
.map_err(|error| RemoteJwksVerifierError::Fetch {
url: self.config.jwks_url.clone(),
message: error.to_string(),
})?;
if !response.status().is_success() {
return Err(RemoteJwksVerifierError::Fetch {
url: self.config.jwks_url.clone(),
message: format!("unexpected HTTP status {}", response.status()),
});
}
response
.json::<JwksDocument>()
.await
.map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))
}
}
fn kid_from_token_error(message: &str) -> Option<String> {
let marker = "JWT `kid` `";
let index = message.find(marker)? + marker.len();
let rest = &message[index..];
let end = rest.find('`')?;
Some(rest[..end].to_string())
}
fn codec_from_jwks<P>(document: &JwksDocument) -> Result<JsonWebToken<P>, RemoteJwksVerifierError>
where
P: Serialize + DeserializeOwned + Clone,
{
let keys: Vec<_> = document
.keys
.iter()
.filter(|key| {
key.alg == "ES384" && key.crv == "P-384" && key.kty == "EC" && key.use_ == "sig"
})
.cloned()
.collect();
if keys.is_empty() {
return Err(RemoteJwksVerifierError::NoValidKeys);
}
let options = JsonWebTokenOptions::for_es384_jwks_keys(&keys)
.map_err(|error| RemoteJwksVerifierError::Verify(error.to_string()))?;
Ok(JsonWebToken::new_with_options(options))
}
async fn persist_jwks_cache(
cache_path: &PathBuf,
jwks: &JwksDocument,
) -> Result<(), RemoteJwksVerifierError> {
if let Some(parent) = cache_path.parent()
&& !parent.as_os_str().is_empty()
{
tokio::fs::create_dir_all(parent).await.map_err(|error| {
RemoteJwksVerifierError::CacheWrite {
path: parent.display().to_string(),
message: error.to_string(),
}
})?;
}
let raw = serde_json::to_string_pretty(jwks)
.map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))?;
tokio::fs::write(cache_path, raw)
.await
.map_err(|error| RemoteJwksVerifierError::CacheWrite {
path: cache_path.display().to_string(),
message: error.to_string(),
})
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::jwt::jwks::EcP384Jwk;
#[test]
fn config_defaults_require_only_jwks_url() {
let config = RemoteJwksVerifierConfig::from_jwks_url(
"https://example.invalid/.well-known/jwks.json",
);
assert_eq!(
config.jwks_url,
"https://example.invalid/.well-known/jwks.json"
);
assert_eq!(config.http_timeout, Duration::from_secs(3));
assert_eq!(config.refresh_interval, Duration::from_secs(60));
assert!(config.cache_path.is_none());
}
#[test]
fn jwks_document_rejects_empty_keys() {
let document = JwksDocument { keys: vec![] };
let result = codec_from_jwks::<crate::jwt::JwtClaims<()>>(&document);
assert!(matches!(result, Err(RemoteJwksVerifierError::NoValidKeys)));
}
#[test]
fn kid_parser_extracts_unknown_kid() {
let message = "JWT `kid` `next-key` is not configured for verification";
assert_eq!(kid_from_token_error(message).as_deref(), Some("next-key"));
}
#[test]
fn codec_builds_from_valid_es384_jwks() {
const TEST_ES384_PUBLIC_KEY_PEM: &[u8] = br#"-----BEGIN PUBLIC KEY-----
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEsjQ/XkOUJO2bXkhDzKRMW1SXp0VsMqGx
MSTG+tppqd3gOxbM8vLgWy4/B0Qdest0Gy3E8QgaKJXQV3zRczNd9zrk1dmwVl6u
Yd+JfgNIeIFP6HWeu/C3wIJ60WDBuGY1
-----END PUBLIC KEY-----
"#;
let key = EcP384Jwk::from_public_key_pem("key-a", TEST_ES384_PUBLIC_KEY_PEM)
.expect("jwk generation should succeed");
let document = JwksDocument { keys: vec![key] };
let codec = codec_from_jwks::<crate::jwt::JwtClaims<()>>(&document)
.expect("codec should be created");
assert_eq!(codec.verification_key_count(), 1);
}
}