use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock};
use crate::types::AuthError;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Jwk {
pub kid: String,
pub kty: String,
pub alg: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#use: Option<String>,
pub n: String,
pub e: String,
}
#[async_trait]
pub trait JwksProvider: Send + Sync {
async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError>;
async fn refresh(&self) -> Result<(), AuthError>;
}
struct CachedKeys {
keys: Vec<Jwk>,
fetched_at: Instant,
ttl: Duration,
}
pub struct RemoteJwksProvider {
jwks_uri: String,
http: reqwest::Client,
cache: RwLock<Option<CachedKeys>>,
in_flight: Mutex<()>,
default_ttl: Duration,
}
impl RemoteJwksProvider {
pub fn new(jwks_uri: String) -> Result<Self, AuthError> {
validate_https_public_uri(&jwks_uri, "JWKS URI")?;
let http = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| AuthError::ConfigError(format!("failed to build HTTP client: {e}")))?;
Ok(Self::with_client(jwks_uri, http))
}
#[cfg(test)]
pub fn new_for_test(jwks_uri: String) -> Self {
Self::with_client(jwks_uri, reqwest::Client::new())
}
fn with_client(jwks_uri: String, http: reqwest::Client) -> Self {
Self {
jwks_uri,
http,
cache: RwLock::new(None),
in_flight: Mutex::new(()),
default_ttl: Duration::from_secs(300),
}
}
async fn fetch_and_store(&self) -> Result<Vec<Jwk>, AuthError> {
let resp = self
.http
.get(&self.jwks_uri)
.send()
.await
.map_err(|e| AuthError::ProviderUnavailable(format!("JWKS fetch failed: {e}")))?;
if !resp.status().is_success() {
return Err(AuthError::ProviderUnavailable(format!(
"JWKS endpoint returned {}",
resp.status()
)));
}
let ttl = resp
.headers()
.get("cache-control")
.and_then(|v| v.to_str().ok())
.and_then(|v| {
v.split(',').find_map(|part| {
let part = part.trim();
part.strip_prefix("max-age=")
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs)
})
})
.unwrap_or(self.default_ttl);
#[derive(Deserialize)]
struct JwksResponse {
keys: Vec<Jwk>,
}
let body: JwksResponse = resp
.json()
.await
.map_err(|e| AuthError::ProviderUnavailable(format!("JWKS parse failed: {e}")))?;
let keys = body.keys;
*self.cache.write().await = Some(CachedKeys {
keys: keys.clone(),
fetched_at: Instant::now(),
ttl,
});
Ok(keys)
}
}
pub fn validate_https_public_uri(uri: &str, label: &str) -> Result<(), AuthError> {
let parsed = uri
.parse::<reqwest::Url>()
.map_err(|e| AuthError::ConfigError(format!("invalid {label} '{uri}': {e}")))?;
if parsed.scheme() != "https" {
return Err(AuthError::ConfigError(format!(
"{label} must use HTTPS (got scheme '{}')",
parsed.scheme()
)));
}
if parsed.host_str().is_some_and(is_private_or_loopback_host) {
return Err(AuthError::ConfigError(format!(
"{label} host '{}' is a private or loopback address (SSRF guard)",
parsed.host_str().unwrap_or("")
)));
}
Ok(())
}
fn is_private_or_loopback_host(host: &str) -> bool {
if matches!(host, "localhost" | "localhost.localdomain" | "0.0.0.0") {
return true;
}
let ip_str = host
.strip_prefix('[')
.and_then(|s| s.strip_suffix(']'))
.unwrap_or(host);
if let Ok(ip) = ip_str.parse::<IpAddr>() {
return ip.is_loopback() || is_private_ip(ip);
}
false
}
fn is_private_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_private() || v4.is_link_local() || v4.is_loopback() || v4.is_unspecified()
}
IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unique_local() || v6.is_unspecified()
}
}
}
#[async_trait]
impl JwksProvider for RemoteJwksProvider {
async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError> {
{
let cache = self.cache.read().await;
if let Some(c) = cache.as_ref().filter(|c| c.fetched_at.elapsed() < c.ttl) {
return Ok(c.keys.clone());
}
}
let _guard = self.in_flight.lock().await;
{
let cache = self.cache.read().await;
if let Some(c) = cache.as_ref().filter(|c| c.fetched_at.elapsed() < c.ttl) {
return Ok(c.keys.clone());
}
}
self.fetch_and_store().await
}
async fn refresh(&self) -> Result<(), AuthError> {
self.fetch_and_store().await.map(|_| ())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jwk_from_json_fields() {
let jwk = Jwk {
kid: "key-1".into(),
kty: "RSA".into(),
alg: Some("RS256".into()),
r#use: None,
n: "modulus-base64url".into(),
e: "AQAB".into(),
};
assert_eq!(jwk.kid, "key-1");
assert_eq!(jwk.e, "AQAB");
}
#[test]
fn https_enforcement_rejects_http() {
let result = RemoteJwksProvider::new(
"http://kc.example.com/realms/test/protocol/openid-connect/certs".into(),
);
assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("HTTPS")));
}
#[test]
fn ssrf_guard_rejects_localhost() {
let result = RemoteJwksProvider::new(
"https://localhost/realms/test/protocol/openid-connect/certs".into(),
);
assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("loopback")));
}
#[test]
fn ssrf_guard_rejects_private_ip() {
let result = RemoteJwksProvider::new(
"https://192.168.1.1/realms/test/protocol/openid-connect/certs".into(),
);
assert!(matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private")));
}
#[test]
fn production_url_accepted() {
let result = RemoteJwksProvider::new(
"https://kc.example.com/realms/test/protocol/openid-connect/certs".into(),
);
assert!(result.is_ok());
}
#[test]
fn ssrf_guard_rejects_link_local_metadata_endpoint() {
let result = RemoteJwksProvider::new("https://169.254.169.254/latest/meta-data".into());
assert!(
matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
);
}
#[test]
fn ssrf_guard_rejects_ipv6_unique_local() {
let result = RemoteJwksProvider::new(
"https://[fc00::1]/realms/test/protocol/openid-connect/certs".into(),
);
assert!(
matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
);
}
#[test]
fn ssrf_guard_rejects_ipv6_loopback() {
let result = RemoteJwksProvider::new(
"https://[::1]/realms/test/protocol/openid-connect/certs".into(),
);
assert!(
matches!(result, Err(AuthError::ConfigError(s)) if s.contains("private") || s.contains("loopback"))
);
}
#[tokio::test]
async fn cache_returns_fresh_keys_without_http() {
let provider = RemoteJwksProvider::new_for_test("http://unreachable:9999/certs".into());
{
let mut cache = provider.cache.write().await;
*cache = Some(CachedKeys {
keys: vec![Jwk {
kid: "cached-key".into(),
kty: "RSA".into(),
alg: Some("RS256".into()),
r#use: None,
n: "n".into(),
e: "AQAB".into(),
}],
fetched_at: Instant::now(),
ttl: Duration::from_secs(300),
});
}
let keys = provider.get_signing_keys().await.unwrap();
assert_eq!(keys.len(), 1);
assert_eq!(keys[0].kid, "cached-key");
}
#[tokio::test]
async fn refresh_via_wiremock() {
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let body =
r#"{"keys":[{"kid":"key-1","kty":"RSA","alg":"RS256","n":"modulus","e":"AQAB"}]}"#;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let jwks_uri = format!(
"{}/realms/test/protocol/openid-connect/certs",
mock_server.uri()
);
let provider = RemoteJwksProvider::new_for_test(jwks_uri);
provider.refresh().await.unwrap();
let keys = provider.get_signing_keys().await.unwrap();
assert_eq!(keys.len(), 1);
assert_eq!(keys[0].kid, "key-1");
}
}