use std::sync::Arc;
use std::time::Duration as StdDuration;
use async_trait::async_trait;
use chrono::Duration as ChronoDuration;
use jsonwebtoken::Algorithm;
use serde_json::Value;
use tokio::sync::OnceCell;
use crate::social::core::{
IdToken, OAuth2Client, OAuthProvider, ProviderConfig, SocialAuthError, StandardClaims,
TokenResponse,
};
use crate::social::flow::pkce::{CodeChallenge, CodeVerifier};
use crate::social::flow::{AuthorizationFlow, RefreshFlow, TokenExchangeFlow};
use crate::social::oidc::id_token::ValidationConfig;
use crate::social::oidc::{
DiscoveryClient, IdTokenValidator, JwksCache, OIDCDiscovery, UserInfoClient,
};
use crate::social::url_validation::validate_endpoint_url;
const DEFAULT_CACHE_TTL_SECS: i64 = 3600;
const SUPPORTED_ASYMMETRIC_ALGORITHMS: &[Algorithm] = &[
Algorithm::RS256,
Algorithm::RS384,
Algorithm::RS512,
Algorithm::PS256,
Algorithm::PS384,
Algorithm::PS512,
Algorithm::ES256,
Algorithm::ES384,
];
pub type UserInfoMapper =
Box<dyn Fn(&Value) -> Result<StandardClaims, SocialAuthError> + Send + Sync>;
#[derive(Clone)]
pub struct GenericOidcConfig {
pub name: String,
pub discovery_url: String,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub discovery_ttl: Option<StdDuration>,
pub jwks_ttl: Option<StdDuration>,
pub extra_token_params: Option<Vec<(String, String)>>,
}
impl std::fmt::Debug for GenericOidcConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GenericOidcConfig")
.field("name", &self.name)
.field("discovery_url", &self.discovery_url)
.field("client_id", &self.client_id)
.field("client_secret", &"<redacted>")
.field("redirect_uri", &self.redirect_uri)
.field("scopes", &self.scopes)
.field("discovery_ttl", &self.discovery_ttl)
.field("jwks_ttl", &self.jwks_ttl)
.field(
"extra_token_params",
&self
.extra_token_params
.as_ref()
.map(|v| v.iter().map(|(k, _)| k.as_str()).collect::<Vec<_>>()),
)
.finish()
}
}
pub struct GenericOidcProvider {
name: String,
config: GenericOidcConfig,
client: OAuth2Client,
auth_flow: AuthorizationFlow,
token_exchange: TokenExchangeFlow,
refresh_flow: RefreshFlow,
userinfo_client: UserInfoClient,
discovery_client: DiscoveryClient,
jwks_cache: Arc<JwksCache>,
id_token_validator: OnceCell<IdTokenValidator>,
userinfo_mapper: Option<UserInfoMapper>,
}
impl GenericOidcProvider {
pub async fn new(config: GenericOidcConfig) -> Result<Self, SocialAuthError> {
Self::validate_config(&config)?;
let provider_config = build_provider_config(&config);
let client = OAuth2Client::new();
let auth_flow = AuthorizationFlow::new(provider_config.clone());
let token_exchange = TokenExchangeFlow::new(client.clone(), provider_config.clone());
let refresh_flow = RefreshFlow::new(client.clone(), provider_config.clone());
let userinfo_client = UserInfoClient::new(client.clone());
let discovery_ttl = chrono_duration_or_default(config.discovery_ttl);
let discovery_client = DiscoveryClient::with_ttl(client.clone(), discovery_ttl);
let jwks_ttl = chrono_duration_or_default(config.jwks_ttl);
let jwks_cache = Arc::new(JwksCache::with_ttl(client.clone(), jwks_ttl));
let _ = provider_config;
Ok(Self {
name: config.name.clone(),
config,
client,
auth_flow,
token_exchange,
refresh_flow,
userinfo_client,
discovery_client,
jwks_cache,
id_token_validator: OnceCell::new(),
userinfo_mapper: None,
})
}
pub fn with_userinfo_mapper<F>(mut self, mapper: F) -> Self
where
F: Fn(&Value) -> Result<StandardClaims, SocialAuthError> + Send + Sync + 'static,
{
self.userinfo_mapper = Some(Box::new(mapper));
self
}
pub fn config(&self) -> &GenericOidcConfig {
&self.config
}
fn validate_config(config: &GenericOidcConfig) -> Result<(), SocialAuthError> {
if config.name.trim().is_empty() {
return Err(SocialAuthError::InvalidConfiguration(
"GenericOidcProvider requires a non-empty name".into(),
));
}
if config.discovery_url.trim().is_empty() {
return Err(SocialAuthError::InvalidConfiguration(
"GenericOidcProvider requires a non-empty discovery_url".into(),
));
}
if config.client_id.trim().is_empty() {
return Err(SocialAuthError::InvalidConfiguration(
"GenericOidcProvider requires a non-empty client_id".into(),
));
}
if config.client_secret.is_empty() {
return Err(SocialAuthError::InvalidConfiguration(
"GenericOidcProvider requires a non-empty client_secret".into(),
));
}
if config.redirect_uri.trim().is_empty() {
return Err(SocialAuthError::InvalidConfiguration(
"GenericOidcProvider requires a non-empty redirect_uri".into(),
));
}
validate_endpoint_url(&config.discovery_url)?;
Ok(())
}
async fn discover(&self) -> Result<OIDCDiscovery, SocialAuthError> {
let issuer = issuer_from_discovery_url(&self.config.discovery_url);
self.discovery_client.discover(&issuer).await
}
async fn id_token_validator(&self) -> Result<&IdTokenValidator, SocialAuthError> {
self.id_token_validator
.get_or_try_init(|| async {
let discovery = self.discover().await?;
let allowed_algorithms = compute_allowed_algorithms(
discovery.id_token_signing_alg_values_supported.as_deref(),
);
if allowed_algorithms.is_empty() {
return Err(SocialAuthError::InvalidConfiguration(format!(
"Provider '{}' advertises no asymmetric ID token signing algorithms supported by reinhardt-auth",
self.name
)));
}
let validation_config =
ValidationConfig::new(discovery.issuer.clone(), self.config.client_id.clone())
.with_allowed_algorithms(allowed_algorithms);
Ok(IdTokenValidator::new(
self.jwks_cache.clone(),
validation_config,
))
})
.await
}
async fn exchange_with_extras(
&self,
token_endpoint: &str,
code: &str,
code_verifier: Option<&CodeVerifier>,
) -> Result<TokenResponse, SocialAuthError> {
let extras = match self.config.extra_token_params.as_ref() {
Some(extras) if !extras.is_empty() => extras,
_ => {
return self
.token_exchange
.exchange(token_endpoint, code, code_verifier)
.await;
}
};
validate_endpoint_url(token_endpoint)?;
let mut params: Vec<(String, String)> = vec![
("grant_type".to_string(), "authorization_code".to_string()),
("code".to_string(), code.to_string()),
("redirect_uri".to_string(), self.config.redirect_uri.clone()),
("client_id".to_string(), self.config.client_id.clone()),
(
"client_secret".to_string(),
self.config.client_secret.clone(),
),
];
if let Some(verifier) = code_verifier {
params.push(("code_verifier".to_string(), verifier.as_str().to_string()));
}
for (k, v) in extras {
params.push((k.clone(), v.clone()));
}
let response = self
.client
.client()
.post(token_endpoint)
.header("Accept", "application/json")
.form(¶ms)
.send()
.await
.map_err(|e| SocialAuthError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(SocialAuthError::TokenExchangeError(format!(
"Token exchange failed ({}): {}",
status, error_body
)));
}
let token: TokenResponse = response
.json()
.await
.map_err(|e| SocialAuthError::TokenExchangeError(e.to_string()))?;
Ok(token)
}
#[cfg(test)]
fn default_map_userinfo(raw: &Value) -> Result<StandardClaims, SocialAuthError> {
serde_json::from_value(raw.clone())
.map_err(|e| SocialAuthError::UserInfoError(format!("Failed to map UserInfo: {}", e)))
}
}
#[async_trait]
impl OAuthProvider for GenericOidcProvider {
fn name(&self) -> &str {
&self.name
}
fn is_oidc(&self) -> bool {
true
}
async fn authorization_url(
&self,
state: &str,
nonce: Option<&str>,
code_challenge: Option<&str>,
) -> Result<String, SocialAuthError> {
let discovery = self.discover().await?;
let challenge = code_challenge.map(|c| CodeChallenge::from_raw(c.to_string()));
self.auth_flow.build_url(
&discovery.authorization_endpoint,
state,
nonce,
challenge.as_ref(),
)
}
async fn exchange_code(
&self,
code: &str,
code_verifier: Option<&str>,
) -> Result<TokenResponse, SocialAuthError> {
let discovery = self.discover().await?;
let verifier = code_verifier.map(|v| CodeVerifier::from_raw(v.to_string()));
self.exchange_with_extras(&discovery.token_endpoint, code, verifier.as_ref())
.await
}
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, SocialAuthError> {
let discovery = self.discover().await?;
self.refresh_flow
.refresh(&discovery.token_endpoint, refresh_token)
.await
}
async fn validate_id_token(
&self,
id_token: &str,
nonce: Option<&str>,
) -> Result<IdToken, SocialAuthError> {
let discovery = self.discover().await?;
let validator = self.id_token_validator().await?;
validator
.validate(id_token, &discovery.jwks_uri, nonce)
.await
}
async fn get_user_info(&self, access_token: &str) -> Result<StandardClaims, SocialAuthError> {
let discovery = self.discover().await?;
let userinfo_endpoint = discovery.userinfo_endpoint.as_ref().ok_or_else(|| {
SocialAuthError::InvalidConfiguration(
"Provider's discovery document does not advertise a userinfo_endpoint".into(),
)
})?;
match self.userinfo_mapper.as_ref() {
None => {
self.userinfo_client
.get_user_info(userinfo_endpoint, access_token)
.await
}
Some(mapper) => {
validate_endpoint_url(userinfo_endpoint)?;
let response = self
.client
.client()
.get(userinfo_endpoint)
.header("User-Agent", "reinhardt-auth")
.bearer_auth(access_token)
.send()
.await
.map_err(|e| SocialAuthError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(SocialAuthError::UserInfoError(format!(
"UserInfo request failed ({}): {}",
status, error_body
)));
}
let raw: Value = response
.json()
.await
.map_err(|e| SocialAuthError::UserInfoError(e.to_string()))?;
mapper(&raw)
}
}
}
}
fn issuer_from_discovery_url(discovery_url: &str) -> String {
discovery_url
.trim_end_matches("/.well-known/openid-configuration")
.to_string()
}
fn chrono_duration_or_default(ttl: Option<StdDuration>) -> ChronoDuration {
match ttl {
Some(d) if d.as_secs() > 0 => ChronoDuration::seconds(d.as_secs() as i64),
_ => ChronoDuration::seconds(DEFAULT_CACHE_TTL_SECS),
}
}
fn compute_allowed_algorithms(advertised: Option<&[String]>) -> Vec<Algorithm> {
let advertised = match advertised {
Some(a) if !a.is_empty() => a,
_ => return SUPPORTED_ASYMMETRIC_ALGORITHMS.to_vec(),
};
advertised
.iter()
.filter_map(|alg| match alg.as_str() {
"RS256" => Some(Algorithm::RS256),
"RS384" => Some(Algorithm::RS384),
"RS512" => Some(Algorithm::RS512),
"PS256" => Some(Algorithm::PS256),
"PS384" => Some(Algorithm::PS384),
"PS512" => Some(Algorithm::PS512),
"ES256" => Some(Algorithm::ES256),
"ES384" => Some(Algorithm::ES384),
_ => None,
})
.collect()
}
fn build_provider_config(config: &GenericOidcConfig) -> ProviderConfig {
use crate::social::core::config::OIDCConfig;
ProviderConfig {
name: config.name.clone(),
client_id: config.client_id.clone(),
client_secret: config.client_secret.clone(),
redirect_uri: config.redirect_uri.clone(),
scopes: config.scopes.clone(),
oidc: Some(OIDCConfig {
discovery_url: config.discovery_url.clone(),
use_nonce: true,
}),
oauth2: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::collections::HashMap;
use std::time::Duration as StdDuration;
fn valid_config() -> GenericOidcConfig {
GenericOidcConfig {
name: "gitlab".to_string(),
discovery_url: "https://gitlab.example.com/.well-known/openid-configuration".into(),
client_id: "client-id".into(),
client_secret: "client-secret".into(),
redirect_uri: "https://app.example.com/auth/callback".into(),
scopes: vec!["openid".into(), "email".into(), "profile".into()],
discovery_ttl: None,
jwks_ttl: None,
extra_token_params: None,
}
}
#[rstest]
#[tokio::test]
async fn new_succeeds_with_valid_config() {
let config = valid_config();
let provider = GenericOidcProvider::new(config).await;
let provider = provider.expect("provider should construct");
assert_eq!(provider.name(), "gitlab");
assert!(provider.is_oidc());
}
#[rstest]
#[case::empty_name("name")]
#[case::empty_discovery_url("discovery_url")]
#[case::empty_client_id("client_id")]
#[case::empty_client_secret("client_secret")]
#[case::empty_redirect_uri("redirect_uri")]
#[tokio::test]
async fn new_rejects_missing_required_field(#[case] field: &str) {
let mut config = valid_config();
match field {
"name" => config.name.clear(),
"discovery_url" => config.discovery_url.clear(),
"client_id" => config.client_id.clear(),
"client_secret" => config.client_secret.clear(),
"redirect_uri" => config.redirect_uri.clear(),
other => panic!("unhandled field {}", other),
}
let result = GenericOidcProvider::new(config).await;
let err = result.err().expect("missing field must reject config");
assert!(matches!(err, SocialAuthError::InvalidConfiguration(_)));
}
#[rstest]
#[case::http_non_loopback("http://gitlab.example.com/.well-known/openid-configuration")]
#[case::ftp_scheme("ftp://gitlab.example.com/.well-known/openid-configuration")]
#[tokio::test]
async fn new_rejects_insecure_discovery_url(#[case] url: &str) {
let mut config = valid_config();
config.discovery_url = url.to_string();
let result = GenericOidcProvider::new(config).await;
let err = result.err().expect("insecure URL must be rejected");
assert!(matches!(err, SocialAuthError::InsecureEndpoint(_)));
}
#[rstest]
#[tokio::test]
async fn new_accepts_localhost_http_discovery_url_for_dev() {
let mut config = valid_config();
config.discovery_url = "http://localhost:8080/.well-known/openid-configuration".to_string();
let result = GenericOidcProvider::new(config).await;
assert!(
result.is_ok(),
"loopback HTTP must be allowed for local development"
);
}
#[rstest]
#[case::standard(
"https://gitlab.com/.well-known/openid-configuration",
"https://gitlab.com"
)]
#[case::with_subpath(
"https://example.com/auth/.well-known/openid-configuration",
"https://example.com/auth"
)]
#[case::no_well_known("https://example.com/issuer", "https://example.com/issuer")]
fn issuer_from_discovery_url_strips_well_known_suffix(
#[case] discovery_url: &str,
#[case] expected: &str,
) {
let issuer = issuer_from_discovery_url(discovery_url);
assert_eq!(issuer, expected);
}
#[rstest]
fn chrono_duration_or_default_uses_default_when_none() {
let result = chrono_duration_or_default(None);
assert_eq!(result.num_seconds(), DEFAULT_CACHE_TTL_SECS);
}
#[rstest]
fn chrono_duration_or_default_uses_default_when_zero() {
let result = chrono_duration_or_default(Some(StdDuration::from_secs(0)));
assert_eq!(result.num_seconds(), DEFAULT_CACHE_TTL_SECS);
}
#[rstest]
fn chrono_duration_or_default_honors_explicit_value() {
let result = chrono_duration_or_default(Some(StdDuration::from_secs(900)));
assert_eq!(result.num_seconds(), 900);
}
#[rstest]
fn compute_allowed_algorithms_falls_back_to_supported_set_when_none() {
let allowed = compute_allowed_algorithms(None);
assert_eq!(allowed, SUPPORTED_ASYMMETRIC_ALGORITHMS.to_vec());
}
#[rstest]
fn compute_allowed_algorithms_falls_back_when_empty() {
let advertised: Vec<String> = vec![];
let allowed = compute_allowed_algorithms(Some(&advertised));
assert_eq!(allowed, SUPPORTED_ASYMMETRIC_ALGORITHMS.to_vec());
}
#[rstest]
fn compute_allowed_algorithms_intersects_with_advertised_set() {
let advertised = vec![
"RS256".to_string(),
"PS256".to_string(),
"ES256".to_string(),
"ES384".to_string(),
"ES512".to_string(),
"none".to_string(),
"HS256".to_string(),
];
let allowed = compute_allowed_algorithms(Some(&advertised));
assert_eq!(
allowed,
vec![
Algorithm::RS256,
Algorithm::PS256,
Algorithm::ES256,
Algorithm::ES384,
]
);
assert!(
!allowed.contains(&Algorithm::HS256),
"HS* must never be allowed for OIDC ID tokens"
);
}
#[rstest]
fn compute_allowed_algorithms_rejects_only_unsafe_algs() {
let advertised = vec!["HS256".to_string(), "none".to_string()];
let allowed = compute_allowed_algorithms(Some(&advertised));
assert!(
allowed.is_empty(),
"all symmetric / none algorithms must be rejected"
);
}
#[rstest]
fn build_provider_config_propagates_oidc_settings() {
let cfg = valid_config();
let provider_config = build_provider_config(&cfg);
assert_eq!(provider_config.name, cfg.name);
assert_eq!(provider_config.client_id, cfg.client_id);
assert_eq!(provider_config.redirect_uri, cfg.redirect_uri);
assert!(provider_config.oidc.is_some(), "OIDC config required");
assert!(provider_config.oauth2.is_none());
let oidc = provider_config.oidc.unwrap();
assert_eq!(oidc.discovery_url, cfg.discovery_url);
assert!(oidc.use_nonce);
}
#[rstest]
fn debug_redacts_client_secret() {
let config = valid_config();
let formatted = format!("{:?}", config);
assert!(
!formatted.contains("client-secret"),
"client_secret must not appear in Debug output"
);
assert!(formatted.contains("<redacted>"));
}
#[rstest]
#[tokio::test]
async fn default_userinfo_mapping_handles_standard_claims() {
let raw = serde_json::json!({
"sub": "user-1",
"email": "user@example.com",
"email_verified": true,
"name": "User One",
"given_name": "User",
"family_name": "One",
"picture": "https://example.com/u1.png",
"locale": "en-US",
"groups": ["admins"],
});
let claims =
GenericOidcProvider::default_map_userinfo(&raw).expect("default mapper succeeds");
assert_eq!(claims.sub, "user-1");
assert_eq!(claims.email.as_deref(), Some("user@example.com"));
assert_eq!(claims.email_verified, Some(true));
assert_eq!(claims.name.as_deref(), Some("User One"));
assert_eq!(claims.given_name.as_deref(), Some("User"));
assert_eq!(claims.family_name.as_deref(), Some("One"));
assert_eq!(
claims.picture.as_deref(),
Some("https://example.com/u1.png")
);
assert_eq!(claims.locale.as_deref(), Some("en-US"));
assert!(claims.additional_claims.contains_key("groups"));
}
#[rstest]
#[tokio::test]
async fn default_userinfo_mapping_rejects_missing_subject() {
let raw = serde_json::json!({
"email": "user@example.com",
});
let result = GenericOidcProvider::default_map_userinfo(&raw);
let err = result.err().expect("missing sub must error");
assert!(matches!(err, SocialAuthError::UserInfoError(_)));
}
#[rstest]
#[tokio::test]
async fn with_userinfo_mapper_overrides_default_logic() {
let provider = GenericOidcProvider::new(valid_config())
.await
.expect("provider construction must succeed")
.with_userinfo_mapper(|raw: &Value| {
let sub = raw
.get("uid")
.and_then(|v| v.as_str())
.ok_or_else(|| SocialAuthError::UserMapping("missing uid".into()))?
.to_string();
let mut additional = HashMap::new();
additional.insert("source".to_string(), Value::String("custom".into()));
Ok(StandardClaims {
sub,
email: raw
.get("mail")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
email_verified: Some(true),
name: None,
given_name: None,
family_name: None,
picture: None,
locale: None,
additional_claims: additional,
})
});
assert!(provider.userinfo_mapper.is_some(), "mapper must be set");
let raw = serde_json::json!({ "uid": "u-42", "mail": "u42@example.com" });
let claims = (provider.userinfo_mapper.as_ref().expect("mapper installed"))(&raw)
.expect("custom mapper succeeds");
assert_eq!(claims.sub, "u-42");
assert_eq!(claims.email.as_deref(), Some("u42@example.com"));
assert_eq!(claims.email_verified, Some(true));
assert_eq!(
claims.additional_claims.get("source"),
Some(&Value::String("custom".into()))
);
}
#[rstest]
#[tokio::test]
async fn with_userinfo_mapper_can_propagate_errors() {
let provider = GenericOidcProvider::new(valid_config())
.await
.unwrap()
.with_userinfo_mapper(|_raw: &Value| {
Err(SocialAuthError::UserMapping("forced failure".into()))
});
let raw = serde_json::json!({});
let result = (provider.userinfo_mapper.as_ref().unwrap())(&raw);
let err = result.err().expect("mapper must propagate errors");
assert!(matches!(err, SocialAuthError::UserMapping(_)));
}
}