use super::{AuthError, AuthenticatedPrincipal};
use jsonwebtoken::jwk::{Jwk, JwkSet};
use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
use serde_json::Value;
use solo_core::TenantId;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct OidcConfig {
pub discovery_url: String,
pub audience: String,
pub tenant_claim_name: String,
}
struct CachedJwks {
keys: HashMap<String, KeyEntry>,
fetched_at: Instant,
ttl: Duration,
}
struct KeyEntry {
key: DecodingKey,
algorithm: Algorithm,
}
#[derive(Clone)]
pub struct OidcValidator {
config: OidcConfig,
http_client: reqwest::Client,
jwks_cache: Arc<RwLock<Option<CachedJwks>>>,
}
impl std::fmt::Debug for OidcValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OidcValidator")
.field("config", &self.config)
.field("jwks_cache", &"<RwLock>")
.finish()
}
}
impl OidcValidator {
pub fn new(config: OidcConfig) -> Self {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("reqwest client");
Self {
config,
http_client,
jwks_cache: Arc::new(RwLock::new(None)),
}
}
#[cfg(test)]
pub fn with_http_client(config: OidcConfig, http_client: reqwest::Client) -> Self {
Self {
config,
http_client,
jwks_cache: Arc::new(RwLock::new(None)),
}
}
pub async fn validate(
&self,
header: Option<&str>,
) -> Result<AuthenticatedPrincipal, AuthError> {
let header = header.ok_or(AuthError::MissingAuthHeader)?;
let token = header
.strip_prefix("Bearer ")
.ok_or(AuthError::MalformedAuthHeader)?;
let jwt_header = decode_header(token).map_err(|e| AuthError::InvalidOidcToken {
reason: format!("decode header: {e}"),
})?;
let kid = jwt_header
.kid
.clone()
.ok_or_else(|| AuthError::InvalidOidcToken {
reason: "missing kid in token header".to_string(),
})?;
let entry = self.get_key(&kid).await?;
if entry.algorithm != jwt_header.alg {
return Err(AuthError::InvalidOidcToken {
reason: format!(
"token alg {:?} does not match JWK alg {:?}",
jwt_header.alg, entry.algorithm
),
});
}
let mut validation = Validation::new(entry.algorithm);
validation.set_audience(&[&self.config.audience]);
let token_data: TokenData<Value> =
decode(token, &entry.key, &validation).map_err(|e| AuthError::InvalidOidcToken {
reason: format!("{e}"),
})?;
let subject = token_data
.claims
.get("sub")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let tenant_claim_str = token_data
.claims
.get(&self.config.tenant_claim_name)
.and_then(|v| v.as_str())
.ok_or_else(|| AuthError::MissingTenantClaim {
claim_name: self.config.tenant_claim_name.clone(),
})?;
let tenant_claim = TenantId::new(tenant_claim_str.to_string())?;
let scopes = token_data
.claims
.get("scope")
.and_then(|v| v.as_str())
.map(|s| s.split_whitespace().map(String::from).collect())
.unwrap_or_default();
Ok(AuthenticatedPrincipal {
subject,
tenant_claim: Some(tenant_claim),
scopes,
claims: token_data.claims,
})
}
async fn get_key(&self, kid: &str) -> Result<KeyEntry, AuthError> {
{
let cache = self.jwks_cache.read().await;
if let Some(c) = cache.as_ref()
&& c.fetched_at.elapsed() < c.ttl
&& let Some(entry) = c.keys.get(kid)
{
return Ok(KeyEntry {
key: entry.key.clone(),
algorithm: entry.algorithm,
});
}
}
self.refresh_cache().await?;
let cache = self.jwks_cache.read().await;
cache
.as_ref()
.and_then(|c| c.keys.get(kid))
.map(|entry| KeyEntry {
key: entry.key.clone(),
algorithm: entry.algorithm,
})
.ok_or_else(|| AuthError::Jwks(format!("kid '{kid}' not found in JWKS")))
}
async fn refresh_cache(&self) -> Result<(), AuthError> {
let discovery_resp = self
.http_client
.get(&self.config.discovery_url)
.send()
.await
.map_err(|e| AuthError::Discovery(format!("{e}")))?
.error_for_status()
.map_err(|e| AuthError::Discovery(format!("{e}")))?;
let ttl = parse_max_age(
discovery_resp
.headers()
.get("cache-control")
.and_then(|h| h.to_str().ok()),
)
.unwrap_or(Duration::from_secs(3600));
let body: Value = discovery_resp
.json()
.await
.map_err(|e| AuthError::Discovery(format!("{e}")))?;
let jwks_uri = body
.get("jwks_uri")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthError::Discovery("discovery missing jwks_uri".to_string()))?;
let jwks: JwkSet = self
.http_client
.get(jwks_uri)
.send()
.await
.map_err(|e| AuthError::Jwks(format!("{e}")))?
.error_for_status()
.map_err(|e| AuthError::Jwks(format!("{e}")))?
.json()
.await
.map_err(|e| AuthError::Jwks(format!("{e}")))?;
let mut keys = HashMap::new();
for jwk in jwks.keys.iter() {
let Some(kid) = jwk.common.key_id.as_deref() else {
continue;
};
let Some(algorithm) = jwk_algorithm(jwk) else {
continue;
};
let key = match DecodingKey::from_jwk(jwk) {
Ok(k) => k,
Err(_) => continue,
};
keys.insert(kid.to_string(), KeyEntry { key, algorithm });
}
let mut cache = self.jwks_cache.write().await;
*cache = Some(CachedJwks {
keys,
fetched_at: Instant::now(),
ttl,
});
Ok(())
}
}
fn jwk_algorithm(jwk: &Jwk) -> Option<Algorithm> {
use jsonwebtoken::jwk::KeyAlgorithm;
match jwk.common.key_algorithm? {
KeyAlgorithm::HS256 => Some(Algorithm::HS256),
KeyAlgorithm::HS384 => Some(Algorithm::HS384),
KeyAlgorithm::HS512 => Some(Algorithm::HS512),
KeyAlgorithm::RS256 => Some(Algorithm::RS256),
KeyAlgorithm::RS384 => Some(Algorithm::RS384),
KeyAlgorithm::RS512 => Some(Algorithm::RS512),
KeyAlgorithm::PS256 => Some(Algorithm::PS256),
KeyAlgorithm::PS384 => Some(Algorithm::PS384),
KeyAlgorithm::PS512 => Some(Algorithm::PS512),
KeyAlgorithm::ES256 => Some(Algorithm::ES256),
KeyAlgorithm::ES384 => Some(Algorithm::ES384),
KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
_ => None,
}
}
fn parse_max_age(header: Option<&str>) -> Option<Duration> {
let h = header?;
for part in h.split(',').map(str::trim) {
if let Some(rest) = part.strip_prefix("max-age=")
&& let Ok(n) = rest.parse::<u64>()
{
return Some(Duration::from_secs(n));
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header};
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
struct FakeIdp {
server: MockServer,
signing_secret: Vec<u8>,
signing_kid: String,
}
impl FakeIdp {
async fn start(signing_kid: &str, cache_max_age_secs: Option<u64>) -> Self {
let server = MockServer::start().await;
let secret = b"fixture-secret-bytes-for-hmac-tests".to_vec();
let kid = signing_kid.to_string();
let discovery_body = json!({
"issuer": server.uri(),
"jwks_uri": format!("{}/jwks", server.uri()),
});
let mut discovery_resp = ResponseTemplate::new(200).set_body_json(discovery_body);
if let Some(secs) = cache_max_age_secs {
discovery_resp = discovery_resp
.insert_header("cache-control", format!("max-age={secs}").as_str());
}
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(discovery_resp)
.mount(&server)
.await;
let jwks_body = json!({
"keys": [
{
"kty": "oct",
"kid": &kid,
"alg": "HS256",
"k": base64_url(&secret),
}
]
});
Mock::given(method("GET"))
.and(path("/jwks"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
.mount(&server)
.await;
Self {
server,
signing_secret: secret,
signing_kid: kid,
}
}
async fn rotate_to(&mut self, new_kid: &str, new_secret: &[u8]) {
let jwks_body = json!({
"keys": [
{
"kty": "oct",
"kid": new_kid,
"alg": "HS256",
"k": base64_url(new_secret),
}
]
});
self.server.reset().await;
let discovery_body = json!({
"issuer": self.server.uri(),
"jwks_uri": format!("{}/jwks", self.server.uri()),
});
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(ResponseTemplate::new(200).set_body_json(discovery_body))
.mount(&self.server)
.await;
Mock::given(method("GET"))
.and(path("/jwks"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
.mount(&self.server)
.await;
self.signing_secret = new_secret.to_vec();
self.signing_kid = new_kid.to_string();
}
fn mint(&self, claims_override: Value) -> String {
self.mint_with_kid(&self.signing_kid, &self.signing_secret, claims_override)
}
fn mint_with_kid(&self, kid: &str, secret: &[u8], claims_override: Value) -> String {
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(kid.to_string());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let mut claims = json!({
"iss": self.server.uri(),
"sub": "test-subject",
"aud": "test-audience",
"exp": now + 600,
"iat": now,
"solo_tenant": "default",
});
if let (Value::Object(c), Value::Object(o)) = (&mut claims, claims_override) {
for (k, v) in o {
c.insert(k, v);
}
}
jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(secret))
.expect("encode")
}
}
fn base64_url(bytes: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
fn make_validator(server_uri: &str) -> OidcValidator {
OidcValidator::with_http_client(
OidcConfig {
discovery_url: format!("{server_uri}/.well-known/openid-configuration"),
audience: "test-audience".to_string(),
tenant_claim_name: "solo_tenant".to_string(),
},
reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.unwrap(),
)
}
#[tokio::test]
async fn oidc_happy_path() {
let idp = FakeIdp::start("test-kid-1", None).await;
let validator = make_validator(&idp.server.uri());
let token = idp.mint(json!({ "solo_tenant": "tenant-a" }));
let principal = validator
.validate(Some(&format!("Bearer {token}")))
.await
.expect("validate");
assert_eq!(principal.subject, "test-subject");
assert_eq!(
principal.tenant_claim,
Some(TenantId::new("tenant-a").unwrap())
);
}
#[tokio::test]
async fn oidc_key_rotation() {
let mut idp = FakeIdp::start("old-kid", None).await;
let validator = make_validator(&idp.server.uri());
let warmup_token = idp.mint(json!({}));
let _ = validator
.validate(Some(&format!("Bearer {warmup_token}")))
.await
.expect("warmup");
let new_secret = b"new-rotated-secret-32-bytes--here".to_vec();
idp.rotate_to("new-kid", &new_secret).await;
let token = idp.mint(json!({}));
let principal = validator
.validate(Some(&format!("Bearer {token}")))
.await
.expect("post-rotation");
assert_eq!(principal.subject, "test-subject");
}
#[tokio::test]
async fn oidc_invalid_audience() {
let idp = FakeIdp::start("kid-aud", None).await;
let validator = make_validator(&idp.server.uri());
let token = idp.mint(json!({ "aud": "wrong-audience" }));
let err = validator
.validate(Some(&format!("Bearer {token}")))
.await
.unwrap_err();
assert!(
matches!(err, AuthError::InvalidOidcToken { .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn oidc_expired_token() {
let idp = FakeIdp::start("kid-exp", None).await;
let validator = make_validator(&idp.server.uri());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let token = idp.mint(json!({ "exp": now - 300, "iat": now - 600 }));
let err = validator
.validate(Some(&format!("Bearer {token}")))
.await
.unwrap_err();
assert!(
matches!(err, AuthError::InvalidOidcToken { .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn oidc_missing_tenant_claim() {
let idp = FakeIdp::start("kid-no-tenant", None).await;
let validator = make_validator(&idp.server.uri());
let token = idp.mint(json!({ "solo_tenant": null }));
let err = validator
.validate(Some(&format!("Bearer {token}")))
.await
.unwrap_err();
assert!(
matches!(err, AuthError::MissingTenantClaim { ref claim_name } if claim_name == "solo_tenant"),
"got {err:?}"
);
}
#[tokio::test]
async fn oidc_invalid_tenant_claim_format() {
let idp = FakeIdp::start("kid-bad-tenant", None).await;
let validator = make_validator(&idp.server.uri());
let token = idp.mint(json!({ "solo_tenant": "INVALID UPPERCASE" }));
let err = validator
.validate(Some(&format!("Bearer {token}")))
.await
.unwrap_err();
assert!(
matches!(err, AuthError::InvalidTenantClaim(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn oidc_jwks_cache_within_ttl_no_refetch() {
let server = MockServer::start().await;
let secret = b"counted-secret-32-bytes--padding".to_vec();
let kid = "counted-kid";
let discovery_body = json!({
"issuer": server.uri(),
"jwks_uri": format!("{}/jwks", server.uri()),
});
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(ResponseTemplate::new(200).set_body_json(discovery_body))
.expect(1) .mount(&server)
.await;
let jwks_body = json!({
"keys": [
{
"kty": "oct",
"kid": kid,
"alg": "HS256",
"k": base64_url(&secret),
}
]
});
Mock::given(method("GET"))
.and(path("/jwks"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
.expect(1) .mount(&server)
.await;
let validator = make_validator(&server.uri());
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(kid.to_string());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = json!({
"iss": server.uri(),
"sub": "subj",
"aud": "test-audience",
"exp": now + 600,
"iat": now,
"solo_tenant": "default",
});
let token =
jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(&secret)).unwrap();
let _ = validator
.validate(Some(&format!("Bearer {token}")))
.await
.expect("first");
let _ = validator
.validate(Some(&format!("Bearer {token}")))
.await
.expect("second");
}
#[tokio::test]
async fn oidc_jwks_cache_respects_cache_control_max_age() {
let server = MockServer::start().await;
let secret = b"max-age-secret-bytes-for-tests--".to_vec();
let kid = "max-age-kid";
let discovery_body = json!({
"issuer": server.uri(),
"jwks_uri": format!("{}/jwks", server.uri()),
});
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("cache-control", "max-age=0")
.set_body_json(discovery_body),
)
.expect(2) .mount(&server)
.await;
let jwks_body = json!({
"keys": [
{
"kty": "oct",
"kid": kid,
"alg": "HS256",
"k": base64_url(&secret),
}
]
});
Mock::given(method("GET"))
.and(path("/jwks"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_body))
.expect(2)
.mount(&server)
.await;
let validator = make_validator(&server.uri());
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(kid.to_string());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = json!({
"iss": server.uri(),
"sub": "subj",
"aud": "test-audience",
"exp": now + 600,
"iat": now,
"solo_tenant": "default",
});
let token =
jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(&secret)).unwrap();
let _ = validator
.validate(Some(&format!("Bearer {token}")))
.await
.expect("first");
let _ = validator
.validate(Some(&format!("Bearer {token}")))
.await
.expect("second");
}
#[test]
fn parse_max_age_handles_typical_headers() {
assert_eq!(
parse_max_age(Some("max-age=300")),
Some(Duration::from_secs(300))
);
assert_eq!(
parse_max_age(Some("public, max-age=86400, must-revalidate")),
Some(Duration::from_secs(86400))
);
assert_eq!(parse_max_age(Some("no-cache, no-store")), None);
assert_eq!(parse_max_age(None), None);
}
}