use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use jsonwebtoken::jwk::JwkSet;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::debug;
use url::Url;
use crate::error::PodError;
use crate::security::ssrf::SsrfPolicy;
const FETCH_TIMEOUT: Duration = Duration::from_secs(10);
pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(900);
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OidcDiscoveryDoc {
pub issuer: String,
pub jwks_uri: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
#[serde(default)]
pub registration_endpoint: Option<String>,
#[serde(default)]
pub scopes_supported: Option<Vec<String>>,
#[serde(flatten, default)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Clone)]
struct CachedConfig {
fetched: Instant,
doc: OidcDiscoveryDoc,
}
#[derive(Clone)]
struct CachedJwks {
fetched: Instant,
set: JwkSet,
}
pub struct OidcConfigCache {
ttl: Duration,
inner: Arc<RwLock<HashMap<String, CachedConfig>>>,
}
impl OidcConfigCache {
pub fn new(ttl: Duration) -> Self {
Self {
ttl,
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get(&self, issuer: &str) -> Option<OidcDiscoveryDoc> {
let guard = self.inner.read().ok()?;
let entry = guard.get(issuer)?;
if entry.fetched.elapsed() <= self.ttl {
Some(entry.doc.clone())
} else {
None
}
}
fn put(&self, issuer: String, doc: OidcDiscoveryDoc) {
if let Ok(mut guard) = self.inner.write() {
guard.insert(
issuer,
CachedConfig {
fetched: Instant::now(),
doc,
},
);
}
}
}
impl Default for OidcConfigCache {
fn default() -> Self {
Self::new(DEFAULT_CACHE_TTL)
}
}
pub struct JwksCache {
ttl: Duration,
inner: Arc<RwLock<HashMap<String, CachedJwks>>>,
}
impl JwksCache {
pub fn new(ttl: Duration) -> Self {
Self {
ttl,
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get(&self, issuer: &str) -> Option<JwkSet> {
let guard = self.inner.read().ok()?;
let entry = guard.get(issuer)?;
if entry.fetched.elapsed() <= self.ttl {
Some(entry.set.clone())
} else {
None
}
}
fn put(&self, issuer: String, set: JwkSet) {
if let Ok(mut guard) = self.inner.write() {
guard.insert(
issuer,
CachedJwks {
fetched: Instant::now(),
set,
},
);
}
}
}
impl Default for JwksCache {
fn default() -> Self {
Self::new(DEFAULT_CACHE_TTL)
}
}
fn canonical_issuer(u: &Url) -> String {
u.as_str().trim_end_matches('/').to_string()
}
fn pinned_client(base: &Client, host: &str, ip: std::net::IpAddr, port: u16) -> reqwest::Result<Client> {
let _ = base; reqwest::Client::builder()
.resolve(host, SocketAddr::new(ip, port))
.timeout(FETCH_TIMEOUT)
.build()
}
pub async fn fetch_oidc_config(
issuer: &Url,
ssrf: &SsrfPolicy,
client: &Client,
) -> Result<OidcDiscoveryDoc, PodError> {
let approved_ip = ssrf
.resolve_and_check(issuer)
.await
.map_err(|e| PodError::Nip98(format!("issuer SSRF: {e}")))?;
let host = issuer
.host_str()
.ok_or_else(|| PodError::Nip98(format!("issuer URL missing host: {issuer}")))?
.to_string();
let port = issuer.port_or_known_default().unwrap_or(match issuer.scheme() {
"https" => 443,
_ => 80,
});
let pinned = pinned_client(client, &host, approved_ip, port)
.map_err(|e| PodError::Backend(format!("reqwest client build failed: {e}")))?;
let discovery_url = {
let base = if issuer.path().ends_with('/') {
issuer.clone()
} else {
let mut u = issuer.clone();
u.set_path(&format!("{}/", u.path()));
u
};
base.join(".well-known/openid-configuration")?
};
let resp = pinned
.get(discovery_url.clone())
.send()
.await
.map_err(|e| PodError::Backend(format!("discovery fetch failed: {e}")))?;
if !resp.status().is_success() {
return Err(PodError::Backend(format!(
"discovery returned HTTP {}",
resp.status()
)));
}
let doc: OidcDiscoveryDoc = resp
.json()
.await
.map_err(|e| PodError::Backend(format!("discovery parse failed: {e}")))?;
let expected = canonical_issuer(issuer);
let claimed = doc.issuer.trim_end_matches('/');
if claimed != expected {
return Err(PodError::Nip98(format!(
"issuer fixation: discovery doc issuer '{claimed}' does not match fetch URL '{expected}'"
)));
}
debug!(issuer = %expected, jwks_uri = %doc.jwks_uri, "fetched OIDC discovery doc");
Ok(doc)
}
pub async fn fetch_jwks(
issuer: &Url,
ssrf: &SsrfPolicy,
client: &Client,
) -> Result<JwkSet, PodError> {
let doc = fetch_oidc_config(issuer, ssrf, client).await?;
let jwks_url = Url::parse(&doc.jwks_uri)
.map_err(|e| PodError::Nip98(format!("malformed jwks_uri '{}': {e}", doc.jwks_uri)))?;
let approved_ip = ssrf
.resolve_and_check(&jwks_url)
.await
.map_err(|e| PodError::Nip98(format!("jwks_uri SSRF: {e}")))?;
let host = jwks_url
.host_str()
.ok_or_else(|| PodError::Nip98(format!("jwks_uri missing host: {jwks_url}")))?
.to_string();
let port = jwks_url
.port_or_known_default()
.unwrap_or(match jwks_url.scheme() {
"https" => 443,
_ => 80,
});
let pinned = pinned_client(client, &host, approved_ip, port)
.map_err(|e| PodError::Backend(format!("reqwest client build failed: {e}")))?;
let resp = pinned
.get(jwks_url.clone())
.send()
.await
.map_err(|e| PodError::Backend(format!("jwks fetch failed: {e}")))?;
if !resp.status().is_success() {
return Err(PodError::Backend(format!(
"jwks returned HTTP {}",
resp.status()
)));
}
let set: JwkSet = resp
.json()
.await
.map_err(|e| PodError::Backend(format!("jwks parse failed: {e}")))?;
debug!(issuer = %canonical_issuer(issuer), keys = set.keys.len(), "fetched JWKS");
Ok(set)
}
pub struct CachedFetcher {
config_cache: OidcConfigCache,
jwks_cache: JwksCache,
ssrf: Arc<SsrfPolicy>,
client: Client,
}
impl CachedFetcher {
pub fn new(
config_cache: OidcConfigCache,
jwks_cache: JwksCache,
ssrf: Arc<SsrfPolicy>,
client: Client,
) -> Self {
Self {
config_cache,
jwks_cache,
ssrf,
client,
}
}
pub fn with_defaults(ssrf: Arc<SsrfPolicy>, client: Client) -> Self {
Self::new(
OidcConfigCache::default(),
JwksCache::default(),
ssrf,
client,
)
}
pub async fn config(&self, issuer: &Url) -> Result<OidcDiscoveryDoc, PodError> {
let key = canonical_issuer(issuer);
if let Some(cached) = self.config_cache.get(&key) {
return Ok(cached);
}
let doc = fetch_oidc_config(issuer, &self.ssrf, &self.client).await?;
self.config_cache.put(key, doc.clone());
Ok(doc)
}
pub async fn jwks(&self, issuer: &Url) -> Result<JwkSet, PodError> {
let key = canonical_issuer(issuer);
if let Some(cached) = self.jwks_cache.get(&key) {
return Ok(cached);
}
let set = fetch_jwks(issuer, &self.ssrf, &self.client).await?;
self.jwks_cache.put(key, set.clone());
Ok(set)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn canonical_issuer_strips_trailing_slash() {
let u = Url::parse("https://op.example/").unwrap();
assert_eq!(canonical_issuer(&u), "https://op.example");
let u2 = Url::parse("https://op.example").unwrap();
assert_eq!(canonical_issuer(&u2), "https://op.example");
}
#[test]
fn config_cache_misses_when_empty() {
let c = OidcConfigCache::new(Duration::from_secs(60));
assert!(c.get("https://op.example").is_none());
}
#[test]
fn config_cache_hits_within_ttl() {
let c = OidcConfigCache::new(Duration::from_secs(60));
let doc = OidcDiscoveryDoc {
issuer: "https://op.example".into(),
jwks_uri: "https://op.example/jwks".into(),
authorization_endpoint: "https://op.example/authorize".into(),
token_endpoint: "https://op.example/token".into(),
registration_endpoint: None,
scopes_supported: None,
extra: HashMap::new(),
};
c.put("https://op.example".into(), doc);
assert!(c.get("https://op.example").is_some());
}
#[test]
fn config_cache_expires_after_ttl() {
let c = OidcConfigCache::new(Duration::from_nanos(1));
let doc = OidcDiscoveryDoc {
issuer: "https://op.example".into(),
jwks_uri: "https://op.example/jwks".into(),
authorization_endpoint: "https://op.example/authorize".into(),
token_endpoint: "https://op.example/token".into(),
registration_endpoint: None,
scopes_supported: None,
extra: HashMap::new(),
};
c.put("https://op.example".into(), doc);
std::thread::sleep(Duration::from_millis(5));
assert!(c.get("https://op.example").is_none());
}
}