use std::net::{SocketAddr, ToSocketAddrs};
use tracing::{debug, info, warn};
use super::cache::JwksCache;
use super::key::{JwksResponse, VerificationKey, parse_jwk};
use super::url::{JwksPolicy, UrlValidationError};
const MAX_REDIRECTS: usize = 3;
pub async fn fetch_and_cache(
provider_name: &str,
jwks_url: &str,
cache: &JwksCache,
policy: &JwksPolicy,
) -> usize {
match fetch_jwks(jwks_url, policy).await {
Ok(keys) => {
let count = keys.len();
if count > 0 {
info!(
provider = %provider_name,
url = %jwks_url,
keys = count,
"JWKS fetched successfully"
);
cache.update_provider(provider_name, keys);
} else {
warn!(
provider = %provider_name,
url = %jwks_url,
"JWKS response contained no usable signature keys"
);
}
count
}
Err(e) => {
warn!(
provider = %provider_name,
url = %jwks_url,
error = e.category(),
"JWKS fetch failed — using cached keys if available"
);
0
}
}
}
async fn fetch_jwks(
url: &str,
policy: &JwksPolicy,
) -> Result<Vec<VerificationKey>, JwksFetchError> {
debug!(url = %url, "fetching JWKS");
policy
.check_url(url)
.map_err(JwksFetchError::UrlValidation)?;
let parsed = reqwest::Url::parse(url)
.map_err(|_| JwksFetchError::UrlValidation(UrlValidationError::Malformed))?;
let host = parsed
.host_str()
.ok_or(JwksFetchError::UrlValidation(UrlValidationError::NoHost))?
.to_owned();
let port = parsed.port_or_known_default().unwrap_or(443);
let addrs = resolve_host(&host, port).await?;
let ips: Vec<_> = addrs.iter().map(|a| a.ip()).collect();
policy
.check_resolved(&host, &ips)
.map_err(JwksFetchError::UrlValidation)?;
let pinned = addrs[0];
let policy_for_redirect = policy.clone();
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.resolve(&host, pinned)
.redirect(reqwest::redirect::Policy::custom(move |attempt| {
if attempt.previous().len() >= MAX_REDIRECTS {
return attempt.error(redirect_limit_err());
}
if policy_for_redirect.check_parsed(attempt.url()).is_err() {
return attempt.stop();
}
attempt.follow()
}))
.build()
.map_err(|_| JwksFetchError::HttpClient)?;
let response = client
.get(url)
.header("accept", "application/json")
.send()
.await
.map_err(|e| {
if format!("{e}").contains(REDIRECT_LIMIT_SENTINEL) {
JwksFetchError::RedirectLimit
} else {
JwksFetchError::HttpRequest
}
})?;
let status = response.status();
if !status.is_success() {
return Err(JwksFetchError::HttpStatus(status.as_u16()));
}
const MAX_BODY: usize = 1 << 20;
let body = response
.bytes()
.await
.map_err(|_| JwksFetchError::HttpBody)?;
if body.len() > MAX_BODY {
return Err(JwksFetchError::BodyTooLarge);
}
let jwks: JwksResponse = sonic_rs::from_slice(&body).map_err(|_| JwksFetchError::JsonParse)?;
let keys: Vec<VerificationKey> = jwks.keys.iter().filter_map(parse_jwk).collect();
Ok(keys)
}
async fn resolve_host(host: &str, port: u16) -> Result<Vec<SocketAddr>, JwksFetchError> {
let h = host.to_owned();
tokio::task::spawn_blocking(move || (h.as_str(), port).to_socket_addrs().map(|i| i.collect()))
.await
.map_err(|_| JwksFetchError::DnsResolution)?
.map_err(|_| JwksFetchError::DnsResolution)
}
const REDIRECT_LIMIT_SENTINEL: &str = "nodedb-jwks-redirect-limit";
fn redirect_limit_err() -> Box<dyn std::error::Error + Send + Sync> {
Box::<dyn std::error::Error + Send + Sync>::from(REDIRECT_LIMIT_SENTINEL)
}
#[derive(Debug, thiserror::Error)]
pub enum JwksFetchError {
#[error("JWKS URL failed validation: {0}")]
UrlValidation(#[from] UrlValidationError),
#[error("DNS resolution failed")]
DnsResolution,
#[error("HTTP client construction failed")]
HttpClient,
#[error("HTTP request failed")]
HttpRequest,
#[error("JWKS endpoint returned HTTP {0}")]
HttpStatus(u16),
#[error("failed to read response body")]
HttpBody,
#[error("JWKS response body exceeded size cap")]
BodyTooLarge,
#[error("JWKS JSON parse failed")]
JsonParse,
#[error("JWKS fetch exceeded redirect limit")]
RedirectLimit,
}
impl JwksFetchError {
pub fn category(&self) -> &'static str {
match self {
Self::UrlValidation(_) => "url_validation",
Self::DnsResolution => "dns_resolution",
Self::HttpClient => "http_client",
Self::HttpRequest => "http_request",
Self::HttpStatus(_) => "http_status",
Self::HttpBody => "http_body",
Self::BodyTooLarge => "body_too_large",
Self::JsonParse => "json_parse",
Self::RedirectLimit => "redirect_limit",
}
}
}
pub fn spawn_refresh_task(
providers: Vec<(String, String)>, cache: std::sync::Arc<JwksCache>,
refresh_interval_secs: u64,
policy: std::sync::Arc<JwksPolicy>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_secs(refresh_interval_secs));
interval.tick().await;
loop {
interval.tick().await;
for (name, url) in &providers {
fetch_and_cache(name, url, &cache, &policy).await;
}
}
})
}