use crate::internal::domain::{ErrorCode, GatewayError};
use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, time::Duration as StdDuration};
use time::{Duration, OffsetDateTime};
use url::Url;
const DEFAULT_TIMEOUT: StdDuration = StdDuration::from_secs(5);
const DEFAULT_MAX_BODY_BYTES: usize = 64 * 1024;
const DEFAULT_CACHE_TTL: Duration = Duration::minutes(10);
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct Jwk {
pub kid: Option<String>,
pub kty: String,
pub alg: Option<String>,
pub k: Option<String>,
pub n: Option<String>,
pub e: Option<String>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct Jwks {
pub keys: Vec<Jwk>,
}
impl Jwks {
#[must_use]
pub fn select_key(&self, kid: Option<&str>) -> Option<&Jwk> {
match kid {
Some(kid) => self.keys.iter().find(|key| key.kid.as_deref() == Some(kid)),
None if self.keys.len() == 1 => self.keys.first(),
None => None,
}
}
}
#[derive(Clone, Debug)]
pub struct JwksHttpClient {
http: reqwest::Client,
max_body_bytes: usize,
}
impl JwksHttpClient {
pub fn new(timeout: StdDuration, max_body_bytes: usize) -> Result<Self, GatewayError> {
let http = reqwest::Client::builder()
.timeout(timeout)
.connect_timeout(timeout)
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|_| jwks_error("Unable to initialize JWKS HTTP client"))?;
Ok(Self {
http,
max_body_bytes,
})
}
pub async fn fetch(&self, url: &Url) -> Result<Jwks, GatewayError> {
if url.scheme() != "https" {
return Err(jwks_error("JWKS URL must use the https scheme"));
}
let response = self
.http
.get(url.clone())
.send()
.await
.map_err(|_| jwks_error("Unable to fetch JWKS"))?
.error_for_status()
.map_err(|_| jwks_error("JWKS endpoint returned an unsuccessful status"))?;
let bytes = response
.bytes()
.await
.map_err(|_| jwks_error("Unable to read JWKS response body"))?;
if bytes.len() > self.max_body_bytes {
return Err(jwks_error("JWKS response body is too large"));
}
serde_json::from_slice::<Jwks>(&bytes).map_err(|_| jwks_error("Unable to parse JWKS"))
}
}
impl Default for JwksHttpClient {
fn default() -> Self {
Self::new(DEFAULT_TIMEOUT, DEFAULT_MAX_BODY_BYTES).unwrap_or_else(|_| Self {
http: reqwest::Client::new(),
max_body_bytes: DEFAULT_MAX_BODY_BYTES,
})
}
}
#[derive(Clone, Debug)]
struct CachedJwks {
jwks: Jwks,
fetched_at: OffsetDateTime,
}
#[derive(Clone, Debug)]
pub struct JwksCache {
ttl: Duration,
entries: BTreeMap<String, CachedJwks>,
}
impl JwksCache {
#[must_use]
pub fn with_ttl(ttl: Duration) -> Self {
Self {
ttl,
entries: BTreeMap::new(),
}
}
pub fn insert(&mut self, url: &Url, jwks: Jwks, now: OffsetDateTime) {
self.entries.insert(
url.to_string(),
CachedJwks {
jwks,
fetched_at: now,
},
);
}
#[must_use]
pub fn get_cached(&self, url: &Url, now: OffsetDateTime) -> Option<&Jwks> {
let cached = self.entries.get(url.as_str())?;
if now - cached.fetched_at <= self.ttl {
Some(&cached.jwks)
} else {
None
}
}
pub async fn get_or_fetch(
&mut self,
client: &JwksHttpClient,
url: &Url,
now: OffsetDateTime,
) -> Result<Jwks, GatewayError> {
if let Some(jwks) = self.get_cached(url, now) {
return Ok(jwks.clone());
}
let jwks = client.fetch(url).await?;
self.insert(url, jwks.clone(), now);
Ok(jwks)
}
}
impl Default for JwksCache {
fn default() -> Self {
Self::with_ttl(DEFAULT_CACHE_TTL)
}
}
pub async fn fetch_jwks(url: &Url) -> Result<Jwks, GatewayError> {
JwksHttpClient::default().fetch(url).await
}
fn jwks_error(message: &str) -> GatewayError {
GatewayError::new(
ErrorCode::AuthTokenInvalid,
message,
true,
Some("Verify the configured JWKS URL and authorization server response".to_string()),
)
}
#[cfg(test)]
mod tests {
use super::JwksHttpClient;
use crate::internal::domain::ErrorCode;
use std::time::Duration as StdDuration;
use url::Url;
#[tokio::test]
async fn rejects_non_https_jwks_url() {
let Ok(client) = JwksHttpClient::new(StdDuration::from_secs(1), 1024) else {
unreachable!("client should build");
};
let Ok(url) = Url::parse("http://issuer.example.com/jwks.json") else {
unreachable!("static URL should parse");
};
let Err(error) = client.fetch(&url).await else {
unreachable!("http scheme must be rejected without a network call");
};
assert_eq!(error.code, ErrorCode::AuthTokenInvalid);
}
#[tokio::test]
async fn rejects_file_scheme_jwks_url() {
let Ok(client) = JwksHttpClient::new(StdDuration::from_secs(1), 1024) else {
unreachable!("client should build");
};
let Ok(url) = Url::parse("file:///etc/passwd") else {
unreachable!("static URL should parse");
};
let Err(error) = client.fetch(&url).await else {
unreachable!("file scheme must be rejected");
};
assert_eq!(error.code, ErrorCode::AuthTokenInvalid);
}
}