pub mod routes;
mod jws;
use jws::{
extract_groups_claim, extract_groups_claim_from_json,
extract_optional_string_claim, extract_string_claim,
jwks_signature_verifies, parse_compact_jws,
};
mod backchannel;
mod bearer;
use crate::auth::Identity;
use crate::config::OidcConfig;
use crate::metrics::Metrics;
use anyhow::{Context, Result, anyhow, bail};
use arc_swap::ArcSwap;
use openidconnect::core::{
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow,
CoreClaimName, CoreClaimType, CoreClientAuthMethod,
CoreErrorResponseType, CoreGenderClaim, CoreGrantType,
CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
CoreJweKeyManagementAlgorithm, CoreResponseMode, CoreResponseType,
CoreRevocableToken, CoreRevocationErrorResponse,
CoreSubjectIdentifierType, CoreTokenIntrospectionResponse,
CoreTokenType,
};
use openidconnect::{
AccessToken, AdditionalProviderMetadata, AsyncHttpClient,
AuthorizationCode, ClientId,
ClientSecret, CsrfToken, EmptyExtraTokenFields, EndpointMaybeSet,
EndpointNotSet, EndpointSet, IdTokenFields, IssuerUrl, Nonce,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
ProviderMetadata, RedirectUrl, RefreshToken, Scope,
StandardErrorResponse, StandardTokenResponse, TokenResponse,
UserInfoClaims,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Clone, Debug, Deserialize, Serialize)]
struct LogoutMetadata {
#[serde(default)]
end_session_endpoint: Option<url::Url>,
#[serde(default)]
revocation_endpoint: Option<url::Url>,
}
impl AdditionalProviderMetadata for LogoutMetadata {}
type HypershuntProviderMetadata = ProviderMetadata<
LogoutMetadata,
CoreAuthDisplay,
CoreClientAuthMethod,
CoreClaimName,
CoreClaimType,
CoreGrantType,
CoreJweContentEncryptionAlgorithm,
CoreJweKeyManagementAlgorithm,
CoreJsonWebKey,
CoreResponseMode,
CoreResponseType,
CoreSubjectIdentifierType,
>;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct ExtraClaims(
pub(crate) serde_json::Map<String, serde_json::Value>,
);
impl openidconnect::AdditionalClaims for ExtraClaims {}
type HsIdTokenFields = IdTokenFields<
ExtraClaims,
EmptyExtraTokenFields,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
openidconnect::core::CoreJwsSigningAlgorithm,
>;
type HsTokenResponse =
StandardTokenResponse<HsIdTokenFields, CoreTokenType>;
pub(crate) type OidcClient = openidconnect::Client<
ExtraClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
HsTokenResponse,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
EndpointSet, EndpointNotSet, EndpointNotSet, EndpointSet, EndpointMaybeSet, EndpointMaybeSet, >;
type OidcClientFromMetadata = openidconnect::Client<
ExtraClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
HsTokenResponse,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointMaybeSet, EndpointMaybeSet, >;
#[derive(Default, Debug, Clone)]
pub struct IdpHints {
pub login_hint: Option<String>,
pub prompt: Option<String>,
pub max_age: Option<String>,
pub acr_values: Option<String>,
pub ui_locales: Option<String>,
}
impl IdpHints {
fn pairs(&self) -> impl Iterator<Item = (&'static str, &str)> {
[
("login_hint", self.login_hint.as_deref()),
("prompt", self.prompt.as_deref()),
("max_age", self.max_age.as_deref()),
("acr_values", self.acr_values.as_deref()),
("ui_locales", self.ui_locales.as_deref()),
]
.into_iter()
.filter_map(|(k, v)| v.map(|val| (k, val)))
}
}
struct StateEntry {
pkce_verifier: PkceCodeVerifier,
nonce: Nonce,
return_to: String,
created: Instant,
}
struct RefreshEntry {
refresh_token: RefreshToken,
expires_at: Instant,
id_token: String,
subject: String,
idp_sid: Option<String>,
}
pub struct OidcProvider {
client: ArcSwap<Option<Arc<OidcClient>>>,
cfg: OidcConfig,
metrics: Arc<Metrics>,
states: Mutex<HashMap<String, StateEntry>>,
state_ttl: Duration,
refreshes: Mutex<HashMap<String, RefreshEntry>>,
refresh_ttl: Duration,
end_session_url: ArcSwap<Option<url::Url>>,
revocation_url: ArcSwap<Option<url::Url>>,
jwks: ArcSwap<Option<Arc<openidconnect::core::CoreJsonWebKeySet>>>,
seen_jtis: Mutex<HashMap<String, Instant>>,
bearer_cache: Mutex<lru::LruCache<[u8; 32], BearerCacheEntry>>,
http_client: openidconnect::reqwest::Client,
}
#[derive(Clone)]
struct BearerCacheEntry {
identity: Identity,
expires_at: u64,
}
fn build_http_client() -> Result<openidconnect::reqwest::Client> {
openidconnect::reqwest::ClientBuilder::new()
.redirect(openidconnect::reqwest::redirect::Policy::none())
.build()
.context("building OIDC HTTP client")
}
async fn run_discovery<'c, C>(
cfg: &'c OidcConfig,
http_client: &'c C,
) -> Result<(
OidcClient,
Option<url::Url>,
Option<url::Url>,
openidconnect::core::CoreJsonWebKeySet,
)>
where
C: AsyncHttpClient<'c>,
<C as AsyncHttpClient<'c>>::Error: Send + Sync,
{
let issuer_url = IssuerUrl::new(cfg.issuer.clone())
.with_context(|| format!("invalid OIDC issuer URL: {}", cfg.issuer))?;
let metadata = HypershuntProviderMetadata::discover_async(
issuer_url,
http_client,
)
.await
.with_context(|| format!("OIDC discovery failed for {}", cfg.issuer))?;
let end_session_url =
metadata.additional_metadata().end_session_endpoint.clone();
let revocation_url =
metadata.additional_metadata().revocation_endpoint.clone();
let jwks = metadata.jwks().clone();
let redirect = RedirectUrl::new(cfg.redirect_uri.clone())
.with_context(|| {
format!("invalid redirect-uri: {}", cfg.redirect_uri)
})?;
let revocation_for_client = revocation_url
.clone()
.unwrap_or_else(|| metadata.issuer().url().clone());
let client = OidcClientFromMetadata::from_provider_metadata(
metadata,
ClientId::new(cfg.client_id.clone()),
cfg.client_secret.clone().map(ClientSecret::new),
)
.set_redirect_uri(redirect)
.set_revocation_url(openidconnect::RevocationUrl::from_url(
revocation_for_client,
));
Ok((client, end_session_url, revocation_url, jwks))
}
impl OidcProvider {
pub fn new(cfg: OidcConfig, metrics: Arc<Metrics>) -> Arc<Self> {
let http_client = build_http_client().unwrap_or_else(|e| {
tracing::error!(
error = %format!("{e:#}"),
"OIDC HTTP client build failed; using minimal \
redirect-disabled client"
);
openidconnect::reqwest::Client::builder()
.redirect(openidconnect::reqwest::redirect::Policy::none())
.pool_max_idle_per_host(0)
.build()
.expect("redirect-disabled reqwest client must build")
});
let provider = Arc::new(Self {
http_client,
client: ArcSwap::new(Arc::new(None)),
state_ttl: Duration::from_secs(cfg.state_ttl_secs),
refresh_ttl: Duration::from_secs(cfg.refresh_ttl_secs),
metrics,
end_session_url: ArcSwap::new(Arc::new(None)),
revocation_url: ArcSwap::new(Arc::new(None)),
jwks: ArcSwap::new(Arc::new(None)),
seen_jtis: Mutex::new(HashMap::new()),
bearer_cache: Mutex::new(lru::LruCache::new(
NonZeroUsize::new(cfg.bearer_cache_size.max(1))
.expect("bearer_cache_size >= 1"),
)),
states: Mutex::new(HashMap::new()),
refreshes: Mutex::new(HashMap::new()),
cfg,
});
let weak = Arc::downgrade(&provider);
crate::task::spawn_supervised("oidc.discovery", async move {
let mut attempt: u32 = 0;
loop {
let Some(p) = weak.upgrade() else { return };
match run_discovery(&p.cfg, &p.http_client).await {
Ok((client, end_session, revocation, jwks)) => {
p.client.store(Arc::new(Some(Arc::new(client))));
p.end_session_url.store(Arc::new(end_session));
p.revocation_url.store(Arc::new(revocation));
p.jwks.store(Arc::new(Some(Arc::new(jwks))));
p.metrics.oidc_discoveries.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
tracing::info!(
issuer = %p.cfg.issuer,
"discovery succeeded"
);
break;
}
Err(e) => {
p.metrics.oidc_discovery_failures.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
if !p.cfg.discovery_retry {
tracing::error!(
issuer = %p.cfg.issuer,
error = %format!("{e:#}"),
"discovery failed (retry disabled); \
provider will remain unavailable"
);
return;
}
let secs = std::cmp::min(1u64 << attempt.min(8), 300);
tracing::warn!(
issuer = %p.cfg.issuer,
retry_in = secs,
error = %format!("{e:#}"),
"discovery failed; retrying"
);
drop(p);
tokio::time::sleep(Duration::from_secs(secs)).await;
attempt = attempt.saturating_add(1);
}
}
}
let Some(p) = weak.upgrade() else { return };
let interval_secs = p.cfg.discovery_refresh_secs;
drop(p);
if interval_secs == 0 {
return;
}
let mut ticker = tokio::time::interval(
Duration::from_secs(interval_secs),
);
ticker.tick().await;
loop {
ticker.tick().await;
let Some(p) = weak.upgrade() else { return };
match run_discovery(&p.cfg, &p.http_client).await {
Ok((client, end_session, revocation, jwks)) => {
p.client.store(Arc::new(Some(Arc::new(client))));
p.end_session_url.store(Arc::new(end_session));
p.revocation_url.store(Arc::new(revocation));
p.jwks.store(Arc::new(Some(Arc::new(jwks))));
p.metrics.oidc_discoveries.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
tracing::debug!(
issuer = %p.cfg.issuer,
"discovery refreshed"
);
}
Err(e) => {
p.metrics.oidc_discovery_failures.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
tracing::warn!(
issuer = %p.cfg.issuer,
error = %format!("{e:#}"),
"periodic discovery failed; \
keeping previous client"
);
}
}
}
});
let weak = Arc::downgrade(&provider);
let ttl = provider.state_ttl;
crate::task::spawn_supervised("oidc.eviction", async move {
let interval = std::cmp::max(ttl / 10, Duration::from_secs(30));
let mut ticker = tokio::time::interval(interval);
loop {
ticker.tick().await;
let Some(p) = weak.upgrade() else { break };
p.evict_expired();
}
});
provider
}
pub fn client(&self) -> Option<Arc<OidcClient>> {
self.client.load().as_ref().clone()
}
pub fn is_ready(&self) -> bool {
self.client.load().is_some()
}
async fn merge_userinfo(
&self,
client: &OidcClient,
access_token: &AccessToken,
id_token_username: &str,
id_token_groups: Vec<String>,
) -> (String, Vec<String>) {
if !self.cfg.userinfo {
return (id_token_username.to_owned(), id_token_groups);
}
let request = match client
.user_info(access_token.clone(), None)
{
Ok(r) => r,
Err(e) => {
tracing::warn!(
error = %format!("{e:#}"),
"userinfo not configurable for this IdP"
);
return (id_token_username.to_owned(), id_token_groups);
}
};
let info: UserInfoClaims<
ExtraClaims,
openidconnect::core::CoreGenderClaim,
> = match request.request_async(&self.http_client).await {
Ok(c) => c,
Err(e) => {
self.metrics.oidc_userinfo_failures.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
tracing::warn!(
error = %format!("{e:#}"),
"userinfo request failed; falling back \
to ID-token claims"
);
return (id_token_username.to_owned(), id_token_groups);
}
};
let json = match serde_json::to_value(&info) {
Ok(v) => v,
Err(_) => return (id_token_username.to_owned(), id_token_groups),
};
let username = match json
.get(&self.cfg.username_claim)
.and_then(|v| v.as_str())
{
Some(s) if !s.is_empty() => s.to_owned(),
_ => id_token_username.to_owned(),
};
let groups = extract_groups_claim_from_json(
&self.cfg.groups_claim,
&json,
);
let groups = if groups.is_empty() {
id_token_groups
} else {
groups
};
(username, groups)
}
pub fn begin_login(
&self,
return_to: String,
hints: IdpHints,
) -> Option<(url::Url, String)> {
let client = self.client()?;
let (pkce_challenge, pkce_verifier) =
PkceCodeChallenge::new_random_sha256();
let mut req = client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
for scope in &self.cfg.scopes {
req = req.add_scope(Scope::new(scope.clone()));
}
for r in &self.cfg.resources {
req = req.add_extra_param("resource", r.clone());
}
for (name, value) in hints.pairs() {
req = req.add_extra_param(name, value);
}
let (auth_url, csrf, nonce) =
req.set_pkce_challenge(pkce_challenge).url();
let state_id = csrf.secret().clone();
let entry = StateEntry {
pkce_verifier,
nonce,
return_to,
created: Instant::now(),
};
self.states.lock().expect("oidc state mutex").insert(state_id.clone(), entry);
Some((auth_url, state_id))
}
pub fn refresh_enabled(&self) -> bool {
self.cfg.refresh
}
pub fn refresh_cookie_name(&self) -> &str {
&self.cfg.refresh_cookie_name
}
pub fn refresh_ttl_secs(&self) -> u64 {
self.cfg.refresh_ttl_secs
}
pub fn logout_path(&self) -> &str {
&self.cfg.logout_path
}
pub fn post_logout_uri(&self) -> &str {
&self.cfg.post_logout_uri
}
pub fn idp_logout_enabled(&self) -> bool {
self.cfg.idp_logout
}
pub fn end_session_url(&self) -> Option<url::Url> {
(*self.end_session_url.load_full()).clone()
}
pub fn client_id(&self) -> &str {
&self.cfg.client_id
}
pub fn take_logout_session(
&self,
sid: &str,
) -> Option<(String, RefreshToken)> {
self.refreshes
.lock()
.unwrap()
.remove(sid)
.map(|e| (e.id_token, e.refresh_token))
}
pub fn issuer(&self) -> &str {
self.cfg.issuer.trim_end_matches('/')
}
pub fn require_iss(&self) -> bool {
self.cfg.require_iss
}
pub fn revoke_refresh_token(
self: &Arc<Self>,
refresh_token: RefreshToken,
) {
if !self.cfg.revoke_on_logout {
return;
}
if self.revocation_url.load().is_none() {
return;
}
let Some(client) = self.client() else { return };
let metrics = self.metrics.clone();
let http_client = self.http_client.clone();
crate::task::spawn_supervised("oidc.revocation", async move {
let request = match client.revoke_token(refresh_token.into()) {
Ok(r) => r,
Err(e) => {
tracing::debug!(
error = %format!("{e:#}"),
"revocation not configurable on this \
IdP; skipping"
);
return;
}
};
match request.request_async(&http_client).await {
Ok(()) => {
metrics.oidc_revocations.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
tracing::debug!("refresh token revoked");
}
Err(e) => {
metrics.oidc_revocation_failures.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
tracing::warn!(
error = %format!("{e:#}"),
"refresh token revocation failed"
);
}
}
});
}
pub async fn complete_login(
&self,
code: String,
state_id: &str,
) -> Result<(Identity, String, Option<String>)> {
let client = self
.client()
.ok_or_else(|| anyhow!("OIDC provider not ready"))?;
let entry = self
.states
.lock()
.unwrap()
.remove(state_id)
.ok_or_else(|| anyhow!("unknown or expired OIDC state"))?;
if entry.created.elapsed() > self.state_ttl {
bail!("OIDC state expired before callback");
}
let mut exchange = client
.exchange_code(AuthorizationCode::new(code))
.context("OIDC token endpoint not configured")?
.set_pkce_verifier(entry.pkce_verifier);
for r in &self.cfg.resources {
exchange = exchange.add_extra_param("resource", r.clone());
}
let token_response = exchange
.request_async(&self.http_client)
.await
.context("OIDC token exchange failed")?;
let id_token = token_response
.id_token()
.ok_or_else(|| anyhow!("IdP response did not include an id_token"))?;
let id_token_str = id_token.to_string();
let claims = id_token
.claims(&client.id_token_verifier(), &entry.nonce)
.context("ID token validation failed")?;
let claims_json = serde_json::to_value(claims)
.context("serialising ID token claims")?;
let id_username = extract_string_claim(
&self.cfg.username_claim,
&claims_json,
claims.subject().as_str(),
);
let id_groups =
extract_groups_claim(&self.cfg.groups_claim, &claims_json);
let subject = claims.subject().as_str().to_owned();
let idp_sid = extract_optional_string_claim("sid", &claims_json);
let (username, groups) = self
.merge_userinfo(
&client,
token_response.access_token(),
&id_username,
id_groups,
)
.await;
let sid = if self.cfg.refresh {
token_response.refresh_token().map(|rt| {
let id = CsrfToken::new_random().secret().clone();
self.refreshes.lock().expect("oidc refresh mutex").insert(
id.clone(),
RefreshEntry {
refresh_token: rt.clone(),
expires_at: Instant::now() + self.refresh_ttl,
id_token: id_token_str.clone(),
subject: subject.clone(),
idp_sid: idp_sid.clone(),
},
);
id
})
} else {
None
};
Ok((Identity { username, groups }, entry.return_to, sid))
}
pub async fn refresh(
&self,
sid: &str,
) -> Result<(Identity, String)> {
let client = self
.client()
.ok_or_else(|| anyhow!("OIDC provider not ready"))?;
let rt = {
let map = self.refreshes.lock().expect("oidc refresh mutex");
let entry = map.get(sid).ok_or_else(|| {
anyhow!("unknown OIDC refresh session")
})?;
if Instant::now() > entry.expires_at {
drop(map);
self.refreshes.lock().expect("oidc refresh mutex").remove(sid);
bail!("refresh session expired");
}
entry.refresh_token.clone()
};
let mut exchange = client
.exchange_refresh_token(&rt)
.context("OIDC token endpoint not configured")?;
for r in &self.cfg.resources {
exchange = exchange.add_extra_param("resource", r.clone());
}
let token_response = exchange
.request_async(&self.http_client)
.await
.inspect_err(|_| {
self.refreshes.lock().expect("oidc refresh mutex").remove(sid);
})
.context("OIDC refresh exchange failed")?;
let id_token = token_response
.id_token()
.ok_or_else(|| anyhow!("refresh response had no id_token"))?;
let new_id_token_str = id_token.to_string();
let claims = id_token
.claims(&client.id_token_verifier(), |_: Option<&Nonce>| Ok(()))
.context("refreshed ID token validation failed")?;
let claims_json = serde_json::to_value(claims)
.context("serialising refreshed ID token claims")?;
let id_username = extract_string_claim(
&self.cfg.username_claim,
&claims_json,
claims.subject().as_str(),
);
let id_groups =
extract_groups_claim(&self.cfg.groups_claim, &claims_json);
let new_subject = claims.subject().as_str().to_owned();
let new_idp_sid =
extract_optional_string_claim("sid", &claims_json);
let (username, groups) = self
.merge_userinfo(
&client,
token_response.access_token(),
&id_username,
id_groups,
)
.await;
let new_sid = match token_response.refresh_token() {
Some(new_rt) => {
let id = CsrfToken::new_random().secret().clone();
let mut map = self.refreshes.lock().expect("oidc refresh mutex");
map.remove(sid);
map.insert(
id.clone(),
RefreshEntry {
refresh_token: new_rt.clone(),
expires_at: Instant::now() + self.refresh_ttl,
id_token: new_id_token_str,
subject: new_subject,
idp_sid: new_idp_sid,
},
);
id
}
None => {
let mut map = self.refreshes.lock().expect("oidc refresh mutex");
if let Some(e) = map.get_mut(sid) {
e.expires_at = Instant::now() + self.refresh_ttl;
e.id_token = new_id_token_str;
e.subject = new_subject;
e.idp_sid = new_idp_sid;
}
sid.to_owned()
}
};
Ok((Identity { username, groups }, new_sid))
}
pub fn login_path(&self) -> &str {
&self.cfg.login_path
}
pub fn callback_path(&self) -> &str {
&self.cfg.callback_path
}
fn evict_expired(&self) {
let now = Instant::now();
let ttl = self.state_ttl;
self.states
.lock()
.unwrap()
.retain(|_, e| now.duration_since(e.created) <= ttl);
self.refreshes
.lock()
.unwrap()
.retain(|_, e| now <= e.expires_at);
self.seen_jtis
.lock()
.unwrap()
.retain(|_, expires_at| now <= *expires_at);
}
#[cfg(test)]
fn refresh_count(&self) -> usize {
self.refreshes.lock().expect("oidc refresh mutex").len()
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use sha2::Digest;
use std::time::SystemTime;
#[derive(Debug)]
struct FakeHttpError(String);
impl std::fmt::Display for FakeHttpError {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
write!(f, "fake http error: {}", self.0)
}
}
impl std::error::Error for FakeHttpError {}
fn fake_client(
routes: std::collections::HashMap<String, (u16, String)>,
) -> impl for<'c> AsyncHttpClient<'c, Error = FakeHttpError> {
move |req: openidconnect::HttpRequest| {
let path = req.uri().path().to_owned();
let routes = routes.clone();
async move {
let (status, body) = routes
.get(&path)
.cloned()
.unwrap_or((404, "{}".to_owned()));
openidconnect::http::Response::builder()
.status(status)
.header(
openidconnect::http::header::CONTENT_TYPE,
"application/json",
)
.body(body.into_bytes())
.map_err(|e| FakeHttpError(e.to_string()))
}
}
}
fn discovery_doc(issuer: &str, with_logout: bool) -> String {
let mut doc = serde_json::json!({
"issuer": issuer,
"authorization_endpoint": format!("{issuer}/authorize"),
"token_endpoint": format!("{issuer}/token"),
"jwks_uri": format!("{issuer}/jwks"),
"response_types_supported": ["code"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256"],
});
if with_logout {
doc["end_session_endpoint"] =
format!("{issuer}/logout").into();
doc["revocation_endpoint"] =
format!("{issuer}/revoke").into();
}
doc.to_string()
}
fn empty_jwks() -> String {
serde_json::json!({ "keys": [] }).to_string()
}
fn discovery_routes(
issuer: &str,
with_logout: bool,
) -> std::collections::HashMap<String, (u16, String)> {
let mut m = std::collections::HashMap::new();
m.insert(
"/.well-known/openid-configuration".to_owned(),
(200, discovery_doc(issuer, with_logout)),
);
m.insert("/jwks".to_owned(), (200, empty_jwks()));
m
}
#[tokio::test]
async fn run_discovery_parses_logout_extension_endpoints() {
let issuer = "https://idp.example";
let cfg = mock_cfg(issuer);
let client = fake_client(discovery_routes(issuer, true));
let (_oidc, end_session, revocation, _jwks) =
run_discovery(&cfg, &client)
.await
.expect("discovery should succeed");
assert_eq!(
end_session.map(|u| u.to_string()),
Some("https://idp.example/logout".to_owned()),
);
assert_eq!(
revocation.map(|u| u.to_string()),
Some("https://idp.example/revoke".to_owned()),
);
}
#[tokio::test]
async fn run_discovery_missing_logout_endpoints_yields_none() {
let issuer = "https://idp.example";
let cfg = mock_cfg(issuer);
let client = fake_client(discovery_routes(issuer, false));
let (_oidc, end_session, revocation, _jwks) =
run_discovery(&cfg, &client)
.await
.expect("discovery should succeed");
assert!(end_session.is_none());
assert!(revocation.is_none());
}
#[tokio::test]
async fn run_discovery_http_error_is_err() {
let issuer = "https://idp.example";
let cfg = mock_cfg(issuer);
let mut routes = discovery_routes(issuer, true);
routes.insert(
"/.well-known/openid-configuration".to_owned(),
(500, "{}".to_owned()),
);
let client = fake_client(routes);
assert!(run_discovery(&cfg, &client).await.is_err());
}
#[tokio::test]
async fn run_discovery_malformed_json_is_err() {
let issuer = "https://idp.example";
let cfg = mock_cfg(issuer);
let mut routes = discovery_routes(issuer, true);
routes.insert(
"/.well-known/openid-configuration".to_owned(),
(200, "{ this is not json".to_owned()),
);
let client = fake_client(routes);
assert!(run_discovery(&cfg, &client).await.is_err());
}
#[tokio::test]
async fn run_discovery_issuer_mismatch_is_err() {
let issuer = "https://idp.example";
let cfg = mock_cfg(issuer);
let mut routes = discovery_routes(issuer, true);
routes.insert(
"/.well-known/openid-configuration".to_owned(),
(200, discovery_doc("https://evil.example", true)),
);
let client = fake_client(routes);
assert!(run_discovery(&cfg, &client).await.is_err());
}
#[test]
fn missing_groups_claim_returns_empty() {
let claims = serde_json::json!({});
assert!(extract_groups_claim("groups", &claims).is_empty());
}
#[test]
fn missing_username_claim_falls_back_to_default() {
let claims = serde_json::json!({});
let s =
extract_string_claim("preferred_username", &claims, "alice");
assert_eq!(s, "alice");
}
pub(crate) fn provider_for_store_with_end_session(
ttl: Duration,
end_session: url::Url,
) -> Arc<OidcProvider> {
let p = provider_for_store(ttl);
p.end_session_url.store(Arc::new(Some(end_session)));
p
}
pub(crate) fn provider_for_store(ttl: Duration) -> Arc<OidcProvider> {
let cfg = crate::config::OidcConfig {
issuer: "https://idp.example".into(),
client_id: "id".into(),
client_secret: None,
redirect_uri: "https://app.example/cb".into(),
scopes: vec!["openid".into()],
username_claim: "sub".into(),
groups_claim: "groups".into(),
login_path: "/oidc/login".into(),
callback_path: "/oidc/callback".into(),
state_ttl_secs: 60,
refresh: true,
refresh_ttl_secs: ttl.as_secs(),
refresh_cookie_name: "__hypershunt_oidc_refresh".into(),
logout_path: "/oidc/logout".into(),
post_logout_uri: "/".into(),
idp_logout: true,
userinfo: false,
discovery_refresh_secs: 0,
discovery_retry: true,
backchannel_logout_enabled: true,
backchannel_logout_path:
"/oidc/backchannel-logout".into(),
backchannel_max_iat_skew_secs: 120,
backchannel_jti_ttl_secs: 300,
bearer: false,
bearer_audiences: vec![],
bearer_cache_size: 16,
revoke_on_logout: true,
require_iss: false,
resources: vec![],
};
let client = dummy_client(&cfg);
Arc::new(OidcProvider {
client: ArcSwap::new(Arc::new(Some(Arc::new(client)))),
state_ttl: Duration::from_secs(cfg.state_ttl_secs),
refresh_ttl: ttl,
metrics: Arc::new(Metrics::new()),
cfg,
states: Mutex::new(HashMap::new()),
refreshes: Mutex::new(HashMap::new()),
end_session_url: ArcSwap::new(Arc::new(None)),
revocation_url: ArcSwap::new(Arc::new(None)),
jwks: ArcSwap::new(Arc::new(None)),
seen_jtis: Mutex::new(HashMap::new()),
bearer_cache: Mutex::new(lru::LruCache::new(
NonZeroUsize::new(16).unwrap(),
)),
http_client: build_http_client().unwrap(),
})
}
fn dummy_client(cfg: &crate::config::OidcConfig) -> OidcClient {
use openidconnect::core::{
CoreJwsSigningAlgorithm, CoreResponseType,
CoreSubjectIdentifierType,
};
let issuer = IssuerUrl::new(cfg.issuer.clone()).unwrap();
let metadata = HypershuntProviderMetadata::new(
issuer,
openidconnect::AuthUrl::new(
"https://idp.example/authorize".into(),
)
.unwrap(),
openidconnect::JsonWebKeySetUrl::new(
"https://idp.example/jwks".into(),
)
.unwrap(),
vec![openidconnect::ResponseTypes::new(vec![
CoreResponseType::Code,
])],
vec![CoreSubjectIdentifierType::Public],
vec![CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256],
LogoutMetadata {
end_session_endpoint: None,
revocation_endpoint: None,
},
)
.set_token_endpoint(Some(
openidconnect::TokenUrl::new(
"https://idp.example/token".into(),
)
.unwrap(),
))
.set_jwks(openidconnect::JsonWebKeySet::new(vec![]));
OidcClientFromMetadata::from_provider_metadata(
metadata,
ClientId::new(cfg.client_id.clone()),
None,
)
.set_redirect_uri(
RedirectUrl::new(cfg.redirect_uri.clone()).unwrap(),
)
.set_revocation_url(openidconnect::RevocationUrl::new(
"https://idp.example/revoke".into(),
)
.unwrap())
}
pub(crate) struct MockIdpState {
pub(crate) nonce: Option<String>,
pub(crate) rotate_refresh: bool,
pub(crate) revocations: u32,
pub(crate) token_seq: u32,
}
pub(crate) struct MockIdp {
pub(crate) issuer: String,
pub(crate) state: Arc<std::sync::Mutex<MockIdpState>>,
}
impl MockIdp {
pub(crate) async fn spawn() -> MockIdp {
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use rsa::traits::PublicKeyParts as _;
let listener =
tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.unwrap();
let issuer =
format!("http://{}", listener.local_addr().unwrap());
static RSA_KEY: std::sync::OnceLock<rsa::RsaPrivateKey> =
std::sync::OnceLock::new();
let private = RSA_KEY
.get_or_init(|| {
rsa::RsaPrivateKey::new(&mut rand_core::OsRng, 2048)
.unwrap()
})
.clone();
let signing_key = Arc::new(
rsa::pkcs1v15::SigningKey::<sha2::Sha256>::new(
private.clone(),
),
);
let public = private.to_public_key();
let jwks = serde_json::json!({
"keys": [{
"kty": "RSA", "alg": "RS256",
"use": "sig", "kid": "test-key",
"n": URL_SAFE_NO_PAD
.encode(public.n().to_bytes_be()),
"e": URL_SAFE_NO_PAD
.encode(public.e().to_bytes_be()),
}]
})
.to_string();
let state = Arc::new(std::sync::Mutex::new(MockIdpState {
nonce: None,
rotate_refresh: true,
revocations: 0,
token_seq: 0,
}));
let iss = issuer.clone();
let st = state.clone();
tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await
else {
return;
};
let iss = iss.clone();
let st = st.clone();
let jwks = jwks.clone();
let key = signing_key.clone();
tokio::spawn(async move {
let svc = hyper::service::service_fn(
move |req: hyper::Request<
hyper::body::Incoming,
>| {
let iss = iss.clone();
let st = st.clone();
let jwks = jwks.clone();
let key = key.clone();
async move {
let path = req.uri().path().to_owned();
let body = match path.as_str() {
"/.well-known/openid-configuration" => {
serde_json::json!({
"issuer": iss,
"authorization_endpoint":
format!("{iss}/authorize"),
"token_endpoint":
format!("{iss}/token"),
"jwks_uri":
format!("{iss}/jwks"),
"end_session_endpoint":
format!("{iss}/logout"),
"revocation_endpoint":
format!("{iss}/revoke"),
"response_types_supported":
["code"],
"subject_types_supported":
["public"],
"id_token_signing_alg_values_supported":
["RS256"],
})
.to_string()
}
"/jwks" => jwks,
"/token" => {
let (nonce, seq, rotate, is_refresh);
{
use http_body_util::BodyExt as _;
let form = req
.into_body()
.collect()
.await
.unwrap()
.to_bytes();
let form = String::from_utf8_lossy(&form)
.into_owned();
is_refresh = form
.contains("grant_type=refresh_token");
let mut s = st.lock().unwrap();
s.token_seq += 1;
nonce = s.nonce.clone();
seq = s.token_seq;
rotate = s.rotate_refresh;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let mut claims = serde_json::json!({
"iss": iss,
"aud": "client-1",
"sub": "alice",
"iat": now,
"exp": now + 3600,
"preferred_username": "alice-pref",
"groups": ["devs"],
"sid": "idp-sess-1",
});
if !is_refresh
&& let Some(n) = nonce
{
claims["nonce"] =
n.into();
}
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use rsa::signature::{
SignatureEncoding as _,
Signer as _,
};
let header = URL_SAFE_NO_PAD.encode(
br#"{"alg":"RS256","kid":"test-key"}"#,
);
let payload = URL_SAFE_NO_PAD
.encode(claims.to_string());
let signing_input =
format!("{header}.{payload}");
let sig = key
.sign(signing_input.as_bytes());
let id_token = format!(
"{signing_input}.{}",
URL_SAFE_NO_PAD
.encode(sig.to_bytes())
);
let mut resp = serde_json::json!({
"access_token":
format!("at-{seq}"),
"token_type": "Bearer",
"expires_in": 3600,
"id_token": id_token,
});
if !is_refresh || rotate {
resp["refresh_token"] =
format!("rt-{seq}").into();
}
resp.to_string()
}
"/revoke" => {
st.lock().unwrap().revocations += 1;
String::new()
}
_ => String::new(),
};
Ok::<_, std::convert::Infallible>(
hyper::Response::builder()
.header(
"content-type",
"application/json",
)
.body(
http_body_util::Full::new(
bytes::Bytes::from(body),
),
)
.unwrap(),
)
}
},
);
let _ = hyper::server::conn::http1::Builder::new()
.serve_connection(
hyper_util::rt::TokioIo::new(stream),
svc,
)
.await;
});
}
});
MockIdp { issuer, state }
}
}
pub(crate) fn mock_cfg(issuer: &str) -> crate::config::OidcConfig {
let mut cfg = provider_for_store(Duration::from_secs(60))
.cfg
.clone();
cfg.issuer = issuer.to_owned();
cfg.client_id = "client-1".into();
cfg.username_claim = "preferred_username".into();
cfg.refresh = true;
cfg.refresh_ttl_secs = 60;
cfg
}
async fn await_ready(p: &Arc<OidcProvider>) {
for _ in 0..200 {
if p.is_ready() {
return;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
panic!("provider never became ready against the mock IdP");
}
#[tokio::test(flavor = "multi_thread")]
async fn full_login_refresh_revoke_flow_against_mock_idp() {
let idp = MockIdp::spawn().await;
let metrics = Arc::new(Metrics::new());
let p = OidcProvider::new(mock_cfg(&idp.issuer), metrics.clone());
await_ready(&p).await;
assert!(p.end_session_url().is_some());
assert_eq!(
metrics
.oidc_discoveries
.load(std::sync::atomic::Ordering::Relaxed),
1
);
let (auth_url, state_id) = p
.begin_login("/after".into(), IdpHints::default())
.expect("ready provider must build a login URL");
let nonce = auth_url
.query_pairs()
.find(|(k, _)| k == "nonce")
.map(|(_, v)| v.into_owned())
.expect("auth URL must carry a nonce");
idp.state.lock().unwrap().nonce = Some(nonce);
let (ident, return_to, sid) = p
.complete_login("any-code".into(), &state_id)
.await
.expect("token exchange against mock IdP");
assert_eq!(ident.username, "alice-pref");
assert_eq!(ident.groups, vec!["devs".to_string()]);
assert_eq!(return_to, "/after");
let sid = sid.expect("refresh enabled -> sid cookie value");
let err = p
.complete_login("any-code".into(), &state_id)
.await
.unwrap_err()
.to_string();
assert!(err.contains("unknown or expired"), "got: {err}");
let (ident2, sid2) = p.refresh(&sid).await.unwrap();
assert_eq!(ident2.username, "alice-pref");
assert_ne!(sid2, sid, "rotation must re-key the session");
assert_eq!(p.refresh_count(), 1, "old sid replaced, not added");
idp.state.lock().unwrap().rotate_refresh = false;
let (_, sid3) = p.refresh(&sid2).await.unwrap();
assert_eq!(sid3, sid2, "no rotation -> sid unchanged");
assert!(p.refresh("no-such-sid").await.is_err());
p.revoke_refresh_token(RefreshToken::new("rt-1".into()));
tokio::time::sleep(Duration::from_millis(150)).await;
assert_eq!(idp.state.lock().unwrap().revocations, 0);
assert_eq!(
metrics
.oidc_revocation_failures
.load(std::sync::atomic::Ordering::Relaxed),
0,
"skip must not be recorded as a failure"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn expired_state_is_rejected_at_callback() {
let idp = MockIdp::spawn().await;
let mut cfg = mock_cfg(&idp.issuer);
cfg.state_ttl_secs = 0; let p = OidcProvider::new(cfg, Arc::new(Metrics::new()));
await_ready(&p).await;
let (_, state_id) = p
.begin_login("/".into(), IdpHints::default())
.unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
let err = p
.complete_login("code".into(), &state_id)
.await
.unwrap_err()
.to_string();
assert!(err.contains("state expired"), "got: {err}");
}
#[test]
fn evict_expired_drops_stale_states_and_refreshes() {
let p = provider_for_store(Duration::from_secs(60));
let (_, live_state) = {
p.begin_login("/x".into(), IdpHints::default()).unwrap()
};
p.states.lock().unwrap().insert(
"stale".into(),
StateEntry {
pkce_verifier: openidconnect::PkceCodeVerifier::new(
"v".repeat(43),
),
nonce: Nonce::new("n".into()),
return_to: "/".into(),
created: Instant::now() - Duration::from_secs(3600),
},
);
p.refreshes.lock().unwrap().insert(
"live".into(),
RefreshEntry {
refresh_token: RefreshToken::new("rt".into()),
expires_at: Instant::now() + Duration::from_secs(60),
id_token: String::new(),
subject: "alice".into(),
idp_sid: None,
},
);
p.refreshes.lock().unwrap().insert(
"dead".into(),
RefreshEntry {
refresh_token: RefreshToken::new("rt".into()),
expires_at: Instant::now() - Duration::from_secs(1),
id_token: String::new(),
subject: "alice".into(),
idp_sid: None,
},
);
p.evict_expired();
let states = p.states.lock().unwrap();
assert!(states.contains_key(&live_state));
assert!(!states.contains_key("stale"));
drop(states);
let refreshes = p.refreshes.lock().unwrap();
assert!(refreshes.contains_key("live"));
assert!(!refreshes.contains_key("dead"));
}
#[test]
fn provider_new_starts_in_not_ready_state() {
let cfg = crate::config::OidcConfig {
issuer: "https://127.0.0.1:1/".into(),
client_id: "id".into(),
client_secret: None,
redirect_uri: "https://app.example/cb".into(),
scopes: vec!["openid".into()],
username_claim: "sub".into(),
groups_claim: "groups".into(),
login_path: "/oidc/login".into(),
callback_path: "/oidc/callback".into(),
state_ttl_secs: 60,
refresh: false,
refresh_ttl_secs: 60,
refresh_cookie_name: "__hypershunt_oidc_refresh".into(),
logout_path: "/oidc/logout".into(),
post_logout_uri: "/".into(),
idp_logout: false,
userinfo: false,
discovery_refresh_secs: 0,
discovery_retry: false,
backchannel_logout_enabled: false,
backchannel_logout_path:
"/oidc/backchannel-logout".into(),
backchannel_max_iat_skew_secs: 120,
backchannel_jti_ttl_secs: 300,
bearer: false,
bearer_audiences: vec![],
bearer_cache_size: 16,
revoke_on_logout: true,
require_iss: false,
resources: vec![],
};
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let p = OidcProvider::new(cfg, Arc::new(Metrics::new()));
assert!(!p.is_ready());
assert!(p.client().is_none());
});
}
#[test]
fn userinfo_merge_disabled_returns_id_token_values() {
let p = provider_for_store(Duration::from_secs(60));
let client = p.client().expect("test provider has a client");
let access = openidconnect::AccessToken::new("at".into());
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let (user, groups) = rt.block_on(p.merge_userinfo(
&client,
&access,
"alice",
vec!["devs".into()],
));
assert_eq!(user, "alice");
assert_eq!(groups, vec!["devs".to_string()]);
}
#[test]
fn extract_groups_claim_from_json_array_and_string() {
let v = serde_json::json!({"groups": ["admins", "devs"]});
assert_eq!(
extract_groups_claim_from_json("groups", &v),
vec!["admins", "devs"],
);
let v = serde_json::json!({"groups": "admins devs"});
assert_eq!(
extract_groups_claim_from_json("groups", &v),
vec!["admins", "devs"],
);
let v = serde_json::json!({});
assert!(extract_groups_claim_from_json("groups", &v).is_empty());
}
#[test]
fn refresh_store_evicts_expired_entries() {
let p = provider_for_store(Duration::from_millis(1));
p.refreshes.lock().expect("oidc refresh mutex").insert(
"sid".into(),
RefreshEntry {
refresh_token: RefreshToken::new("rt".into()),
expires_at: Instant::now() - Duration::from_secs(1),
id_token: "test".into(),
subject: "alice".into(),
idp_sid: None,
},
);
assert_eq!(p.refresh_count(), 1);
p.evict_expired();
assert_eq!(p.refresh_count(), 0);
}
#[test]
fn take_logout_session_returns_stored_id_token() {
let p = provider_for_store(Duration::from_secs(60));
p.refreshes.lock().expect("oidc refresh mutex").insert(
"sid".into(),
RefreshEntry {
refresh_token: RefreshToken::new("rt".into()),
expires_at: Instant::now() + Duration::from_secs(60),
id_token: "the-id-token".into(),
subject: "alice".into(),
idp_sid: None,
},
);
let (id_tok, refresh_tok) =
p.take_logout_session("sid").expect("first call");
assert_eq!(id_tok, "the-id-token");
assert_eq!(refresh_tok.secret(), "rt");
assert!(p.take_logout_session("sid").is_none());
assert_eq!(p.refresh_count(), 0);
}
#[test]
fn bearer_cache_returns_stored_identity() {
let p = provider_for_store(Duration::from_secs(60));
let token = "anything";
let key: [u8; 32] = sha2::Sha256::digest(token.as_bytes()).into();
let id = Identity {
username: "alice".into(),
groups: vec!["devs".into()],
};
let future_exp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
p.bearer_cache.lock().expect("oidc bearer cache mutex").put(
key,
BearerCacheEntry {
identity: id.clone(),
expires_at: future_exp,
},
);
let got = p.validate_bearer_token(token).expect("cache hit");
assert_eq!(got.username, id.username);
assert_eq!(got.groups, id.groups);
}
#[test]
fn bearer_cache_evicts_expired_entry_on_lookup() {
let p = provider_for_store(Duration::from_secs(60));
let token = "anything";
let key: [u8; 32] = sha2::Sha256::digest(token.as_bytes()).into();
p.bearer_cache.lock().expect("oidc bearer cache mutex").put(
key,
BearerCacheEntry {
identity: Identity {
username: "alice".into(),
groups: vec![],
},
expires_at: 0,
},
);
assert!(p.validate_bearer_token(token).is_err());
assert!(p.bearer_cache.lock().expect("oidc bearer cache mutex").peek(&key).is_none());
}
#[test]
fn revoke_no_op_when_disabled_in_config() {
let p = provider_for_store(Duration::from_secs(60));
let mut cfg_disabled = p.cfg.clone();
cfg_disabled.revoke_on_logout = false;
let p_off = Arc::new(OidcProvider {
client: ArcSwap::new(Arc::new(p.client.load_full().as_ref().clone())),
state_ttl: Duration::from_secs(60),
refresh_ttl: Duration::from_secs(60),
metrics: Arc::new(crate::metrics::Metrics::new()),
cfg: cfg_disabled,
states: Mutex::new(HashMap::new()),
refreshes: Mutex::new(HashMap::new()),
end_session_url: ArcSwap::new(Arc::new(None)),
revocation_url: ArcSwap::new(Arc::new(None)),
jwks: ArcSwap::new(Arc::new(None)),
seen_jtis: Mutex::new(HashMap::new()),
bearer_cache: Mutex::new(lru::LruCache::new(
NonZeroUsize::new(16).unwrap(),
)),
http_client: build_http_client().unwrap(),
});
p_off.revoke_refresh_token(RefreshToken::new("rt".into()));
assert_eq!(
p_off
.metrics
.oidc_revocations
.load(std::sync::atomic::Ordering::Relaxed),
0
);
assert_eq!(
p_off
.metrics
.oidc_revocation_failures
.load(std::sync::atomic::Ordering::Relaxed),
0
);
}
#[test]
fn issuer_strips_trailing_slash() {
let mut p = provider_for_store(Duration::from_secs(60));
Arc::get_mut(&mut p).unwrap().cfg.issuer =
"https://idp.example/".into();
assert_eq!(p.issuer(), "https://idp.example");
}
#[test]
fn record_jti_rejects_replay() {
let p = provider_for_store(Duration::from_secs(60));
assert!(p.record_jti("jti-1"));
assert!(!p.record_jti("jti-1"));
assert!(p.record_jti("jti-2"));
}
#[test]
fn idp_hints_pairs_filters_none_and_preserves_order() {
let h = IdpHints {
login_hint: Some("alice@example".into()),
prompt: None,
max_age: Some("0".into()),
acr_values: None,
ui_locales: Some("fr".into()),
};
let pairs: Vec<_> = h.pairs().collect();
assert_eq!(
pairs,
vec![
("login_hint", "alice@example"),
("max_age", "0"),
("ui_locales", "fr"),
],
);
}
#[test]
fn refresh_store_keeps_live_entries() {
let p = provider_for_store(Duration::from_secs(60));
p.refreshes.lock().expect("oidc refresh mutex").insert(
"sid".into(),
RefreshEntry {
refresh_token: RefreshToken::new("rt".into()),
expires_at: Instant::now() + Duration::from_secs(60),
id_token: "test".into(),
subject: "alice".into(),
idp_sid: None,
},
);
p.evict_expired();
assert_eq!(p.refresh_count(), 1);
}
}