use openidconnect::{
core::{CoreProviderMetadata, CoreResponseType, CoreTokenResponse},
AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndUserEmail,
IssuerUrl, Nonce, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, TokenResponse,
};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use unicode_normalization::UnicodeNormalization;
use uuid::Uuid;
use crate::errors::AppError;
use crate::models::sso::{SsoAuthState, SsoProvider};
use crate::repositories::SsoRepository;
const AUTH_STATE_TTL_SECS: i64 = 300;
const OIDC_HTTP_TIMEOUT_SECS: u64 = 10;
const OIDC_METADATA_CACHE_TTL_SECS: u64 = 3600;
fn validate_email_domain(email: &str, expected_domain: &str) -> Result<(), AppError> {
let email_domain = email
.split('@')
.nth(1)
.ok_or_else(|| AppError::Validation("Invalid email format".into()))?;
let normalized_email_domain: String = email_domain.to_lowercase().nfkc().collect();
let normalized_expected_domain: String = expected_domain.to_lowercase().nfkc().collect();
if normalized_email_domain != normalized_expected_domain {
return Err(AppError::Forbidden(format!(
"Email domain not allowed. Must be @{}",
expected_domain
)));
}
Ok(())
}
pub struct OidcService {
callback_url: String,
http_client: reqwest::Client,
metadata_cache: Arc<RwLock<HashMap<String, CachedProviderMetadata>>>,
}
#[derive(Debug)]
pub struct SsoAuthStart {
pub auth_url: String,
pub state_id: Uuid,
}
#[derive(Debug, Clone)]
pub struct OidcClaims {
pub sub: String,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub name: Option<String>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub picture: Option<String>,
}
#[derive(Debug, Clone)]
struct CachedProviderMetadata {
metadata: CoreProviderMetadata,
expires_at: std::time::Instant,
}
impl OidcService {
pub fn new(callback_url: String) -> Self {
let http_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(Duration::from_secs(OIDC_HTTP_TIMEOUT_SECS))
.build()
.unwrap_or_else(|e| {
tracing::error!(
error = %e,
"Failed to build OIDC HTTP client; falling back to defaults"
);
reqwest::Client::new()
});
Self {
callback_url,
http_client,
metadata_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn get_provider_metadata(&self, issuer: &str) -> Result<CoreProviderMetadata, AppError> {
let now = std::time::Instant::now();
{
let cache = self.metadata_cache.read().await;
if let Some(cached) = cache.get(issuer) {
if now < cached.expires_at {
return Ok(cached.metadata.clone());
}
}
}
let issuer_url = IssuerUrl::new(issuer.to_string())
.map_err(|e| AppError::Config(format!("Invalid issuer URL: {}", e)))?;
let provider_metadata = tokio::time::timeout(
Duration::from_secs(OIDC_HTTP_TIMEOUT_SECS),
CoreProviderMetadata::discover_async(issuer_url, &self.http_client),
)
.await
.map_err(|_| {
AppError::Internal(anyhow::anyhow!(
"OIDC discovery failed: request timed out after {}s",
OIDC_HTTP_TIMEOUT_SECS
))
})?
.map_err(|e| AppError::Internal(anyhow::anyhow!("OIDC discovery failed: {:?}", e)))?;
{
let mut cache = self.metadata_cache.write().await;
cache.insert(
issuer.to_string(),
CachedProviderMetadata {
metadata: provider_metadata.clone(),
expires_at: now + Duration::from_secs(OIDC_METADATA_CACHE_TTL_SECS),
},
);
}
Ok(provider_metadata)
}
pub async fn start_auth(
&self,
provider: &SsoProvider,
client_secret: &str,
redirect_uri: Option<String>,
access_code: Option<String>,
referral: Option<String>,
sso_repo: &Arc<dyn SsoRepository>,
) -> Result<SsoAuthStart, AppError> {
let provider_metadata = self.get_provider_metadata(&provider.issuer_url).await?;
let client_id = ClientId::new(provider.client_id.clone());
let client_secret_obj = ClientSecret::new(client_secret.to_string());
let redirect_url = RedirectUrl::new(self.callback_url.clone())
.map_err(|e| AppError::Config(format!("Invalid redirect URL: {}", e)))?;
let client = openidconnect::core::CoreClient::from_provider_metadata(
provider_metadata,
client_id,
Some(client_secret_obj),
)
.set_redirect_uri(redirect_url);
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let state_id = Uuid::new_v4();
let state_id_string = state_id.to_string();
let state_id_for_closure = state_id_string.clone();
let nonce = Nonce::new_random();
let nonce_secret = nonce.secret().to_string();
let nonce_for_url = nonce.clone();
let mut auth_request = client.authorize_url(
AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
move || CsrfToken::new(state_id_for_closure.clone()),
move || nonce_for_url.clone(),
);
for scope in &provider.scopes {
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
}
let (auth_url, _csrf_token, _nonce) = auth_request.set_pkce_challenge(pkce_challenge).url();
let mut auth_state = SsoAuthState::new(
provider.id,
provider.org_id,
pkce_verifier.secret().to_string(),
nonce_secret,
redirect_uri,
AUTH_STATE_TTL_SECS,
);
auth_state.state_id = state_id;
auth_state.access_code = access_code;
auth_state.referral = referral;
sso_repo.store_auth_state(auth_state).await?;
Ok(SsoAuthStart {
auth_url: auth_url.to_string(),
state_id,
})
}
pub async fn complete_auth(
&self,
code: &str,
state_id: Uuid,
provider: &SsoProvider,
client_secret: &str,
sso_repo: &Arc<dyn SsoRepository>,
) -> Result<OidcClaims, AppError> {
let auth_state = sso_repo
.consume_auth_state(state_id)
.await?
.ok_or_else(|| AppError::Validation("Invalid or expired SSO state".into()))?;
if auth_state.provider_id != provider.id {
return Err(AppError::Validation("Provider mismatch".into()));
}
let provider_metadata = self.get_provider_metadata(&provider.issuer_url).await?;
let client_id = ClientId::new(provider.client_id.clone());
let client_secret_obj = ClientSecret::new(client_secret.to_string());
let redirect_url = RedirectUrl::new(self.callback_url.clone())
.map_err(|e| AppError::Config(format!("Invalid redirect URL: {}", e)))?;
let client = openidconnect::core::CoreClient::from_provider_metadata(
provider_metadata,
client_id,
Some(client_secret_obj),
)
.set_redirect_uri(redirect_url);
let pkce_verifier = PkceCodeVerifier::new(auth_state.pkce_verifier);
let token_request = client
.exchange_code(AuthorizationCode::new(code.to_string()))
.map_err(|e| AppError::Config(format!("OIDC token endpoint error: {:?}", e)))?;
let token_response: CoreTokenResponse = token_request
.set_pkce_verifier(pkce_verifier)
.request_async(&self.http_client)
.await
.map_err(|_| AppError::InvalidCredentials)?;
let id_token = token_response
.id_token()
.ok_or_else(|| AppError::Validation("No ID token in response".into()))?;
let nonce = Nonce::new(auth_state.nonce);
let claims = id_token
.claims(&client.id_token_verifier(), &nonce)
.map_err(|e| AppError::Validation(format!("Token verification failed: {:?}", e)))?;
let email: Option<String> = claims
.email()
.map(|e: &EndUserEmail| e.as_str().to_string());
let email_verified = claims.email_verified();
if let (Some(expected_domain), Some(email_str)) = (&provider.email_domain, &email) {
validate_email_domain(email_str, expected_domain)?;
}
Ok(OidcClaims {
sub: claims.subject().to_string(),
email,
email_verified,
name: None,
given_name: None,
family_name: None,
picture: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::sso::SsoProvider;
use crate::repositories::{InMemorySsoRepository, SsoRepository};
use openidconnect::core::{CoreJwsSigningAlgorithm, CoreSubjectIdentifierType};
use openidconnect::{AuthUrl, JsonWebKeySetUrl, ResponseTypes};
use std::sync::Arc;
use uuid::Uuid;
#[test]
fn test_oidc_service_creation() {
let service = OidcService::new("https://api.example.com/auth/sso/callback".into());
assert_eq!(
service.callback_url,
"https://api.example.com/auth/sso/callback"
);
}
#[tokio::test]
async fn test_provider_metadata_cache_hit() {
let service = OidcService::new("https://api.example.com/auth/sso/callback".into());
let issuer = "https://issuer.example.com";
let metadata = CoreProviderMetadata::new(
IssuerUrl::new(issuer.to_string()).unwrap(),
AuthUrl::new("https://issuer.example.com/auth".to_string()).unwrap(),
JsonWebKeySetUrl::new("https://issuer.example.com/jwks".to_string()).unwrap(),
vec![ResponseTypes::new(vec![CoreResponseType::Code])],
vec![CoreSubjectIdentifierType::Public],
vec![CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256],
Default::default(),
);
{
let mut cache = service.metadata_cache.write().await;
cache.insert(
issuer.to_string(),
CachedProviderMetadata {
metadata: metadata.clone(),
expires_at: std::time::Instant::now() + Duration::from_secs(60),
},
);
}
let fetched = service.get_provider_metadata(issuer).await.unwrap();
let fetched_json = serde_json::to_string(&fetched).unwrap();
let metadata_json = serde_json::to_string(&metadata).unwrap();
assert_eq!(fetched_json, metadata_json);
}
#[tokio::test]
async fn test_start_auth_persists_access_code_and_referral() {
let service = OidcService::new("https://api.example.com/auth/sso/callback".into());
let issuer = "https://issuer.example.com";
let provider = SsoProvider::new(
Uuid::new_v4(),
"Okta".into(),
issuer.to_string(),
"client-id".into(),
"encrypted-secret".into(),
);
let repo: Arc<dyn SsoRepository> = Arc::new(InMemorySsoRepository::new());
let metadata = CoreProviderMetadata::new(
IssuerUrl::new(issuer.to_string()).unwrap(),
AuthUrl::new("https://issuer.example.com/auth".to_string()).unwrap(),
JsonWebKeySetUrl::new("https://issuer.example.com/jwks".to_string()).unwrap(),
vec![ResponseTypes::new(vec![CoreResponseType::Code])],
vec![CoreSubjectIdentifierType::Public],
vec![CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256],
Default::default(),
);
{
let mut cache = service.metadata_cache.write().await;
cache.insert(
issuer.to_string(),
CachedProviderMetadata {
metadata,
expires_at: std::time::Instant::now() + Duration::from_secs(60),
},
);
}
let result = service
.start_auth(
&provider,
"client-secret",
Some("https://app.example.com/after-login".into()),
Some("ACCESS123".into()),
Some("REFERRAL1".into()),
&repo,
)
.await
.unwrap();
let auth_state = repo.get_auth_state(result.state_id).await.unwrap().unwrap();
assert_eq!(auth_state.access_code.as_deref(), Some("ACCESS123"));
assert_eq!(auth_state.referral.as_deref(), Some("REFERRAL1"));
assert_eq!(
auth_state.redirect_uri.as_deref(),
Some("https://app.example.com/after-login")
);
}
#[test]
fn test_email_domain_validation_exact_match() {
assert!(validate_email_domain("user@example.com", "example.com").is_ok());
}
#[test]
fn test_email_domain_validation_case_insensitive() {
assert!(validate_email_domain("user@EXAMPLE.COM", "example.com").is_ok());
assert!(validate_email_domain("user@example.com", "EXAMPLE.COM").is_ok());
}
#[test]
fn test_email_domain_validation_rejects_subdomain_bypass() {
let result = validate_email_domain("user@evil.example.com", "example.com");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not allowed"));
}
#[test]
fn test_email_domain_validation_rejects_different_domain() {
let result = validate_email_domain("user@attacker.com", "example.com");
assert!(result.is_err());
}
#[test]
fn test_email_domain_validation_rejects_unicode_lookalike() {
let result = validate_email_domain("user@ex\u{0430}mple.com", "example.com");
assert!(result.is_err());
}
#[test]
fn test_email_domain_validation_invalid_email_format() {
let result = validate_email_domain("no-at-symbol", "example.com");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid email"));
}
}