use axum::{
extract::{Query, Request, State},
http::{header, HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
routing::get,
Json, Router,
};
use base64::{
engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD},
Engine as _,
};
use chrono::{DateTime, Duration, Utc};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{info, warn};
use uuid::Uuid;
use argentor_core::{ArgentorError, ArgentorResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoConfig {
pub provider: SsoProvider,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub issuer_url: String,
pub allowed_domains: Vec<String>,
pub scopes: Vec<String>,
pub session_ttl_hours: u32,
}
impl Default for SsoConfig {
fn default() -> Self {
Self {
provider: SsoProvider::Oidc,
client_id: String::new(),
client_secret: String::new(),
redirect_uri: String::new(),
issuer_url: String::new(),
allowed_domains: Vec::new(),
scopes: vec!["openid".into(), "profile".into(), "email".into()],
session_ttl_hours: 24,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum SsoProvider {
Oidc,
Saml,
ApiKey,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct OidcEndpoints {
pub authorization_endpoint: String,
pub token_endpoint: String,
pub userinfo_endpoint: Option<String>,
pub issuer: String,
}
async fn discover_oidc_endpoints(issuer_url: &str) -> ArgentorResult<OidcEndpoints> {
let url = format!(
"{}/.well-known/openid-configuration",
issuer_url.trim_end_matches('/')
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ArgentorError::Http(format!("OIDC discovery request failed: {e}")))?;
if !resp.status().is_success() {
return Err(ArgentorError::Http(format!(
"OIDC discovery returned HTTP {}",
resp.status()
)));
}
let config: serde_json::Value = resp
.json()
.await
.map_err(|e| ArgentorError::Http(format!("OIDC discovery response not valid JSON: {e}")))?;
let authorization_endpoint = config["authorization_endpoint"]
.as_str()
.ok_or_else(|| {
ArgentorError::Http("OIDC discovery: missing authorization_endpoint".into())
})?
.to_string();
let token_endpoint = config["token_endpoint"]
.as_str()
.ok_or_else(|| ArgentorError::Http("OIDC discovery: missing token_endpoint".into()))?
.to_string();
let userinfo_endpoint = config["userinfo_endpoint"].as_str().map(String::from);
let issuer = config["issuer"]
.as_str()
.ok_or_else(|| ArgentorError::Http("OIDC discovery: missing issuer".into()))?
.to_string();
Ok(OidcEndpoints {
authorization_endpoint,
token_endpoint,
userinfo_endpoint,
issuer,
})
}
fn decode_jwt_payload(token: &str) -> ArgentorResult<serde_json::Value> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(ArgentorError::Security(format!(
"Invalid JWT format: expected 3 parts separated by '.', got {}",
parts.len()
)));
}
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).map_err(|e| {
ArgentorError::Security(format!("JWT payload base64url decode failed: {e}"))
})?;
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
.map_err(|e| ArgentorError::Security(format!("JWT payload is not valid JSON: {e}")))?;
Ok(payload)
}
#[derive(Debug, Clone)]
struct SamlClaims {
name_id: String,
email: String,
name: Option<String>,
roles: Vec<String>,
}
fn extract_xml_element(xml: &str, tag: &str) -> Option<String> {
let pattern = format!(
r"<(?:[a-zA-Z0-9_]+:)?{tag}(?:\s[^>]*)?>([^<]+)</(?:[a-zA-Z0-9_]+:)?{tag}>",
tag = regex::escape(tag)
);
let re = Regex::new(&pattern).ok()?;
re.captures(xml)
.and_then(|caps| caps.get(1))
.map(|m| m.as_str().trim().to_string())
}
fn extract_xml_attr(xml: &str, element: &str, attr: &str) -> Option<String> {
let pattern = format!(
r#"<(?:[a-zA-Z0-9_]+:)?{element}[^>]*?\s{attr}\s*=\s*"([^"]+)""#,
element = regex::escape(element),
attr = regex::escape(attr),
);
let re = Regex::new(&pattern).ok()?;
re.captures(xml)
.and_then(|caps| caps.get(1))
.map(|m| m.as_str().to_string())
}
fn extract_saml_attribute(xml: &str, attr_name: &str) -> Option<String> {
let pattern = format!(
r#"<(?:[a-zA-Z0-9_]+:)?Attribute\s[^>]*Name\s*=\s*"{name}"[^>]*>.*?<(?:[a-zA-Z0-9_]+:)?AttributeValue[^>]*>([^<]+)</(?:[a-zA-Z0-9_]+:)?AttributeValue>"#,
name = regex::escape(attr_name),
);
let re = Regex::new(&pattern).ok()?;
re.captures(xml)
.and_then(|caps| caps.get(1))
.map(|m| m.as_str().trim().to_string())
}
fn extract_saml_attribute_values(xml: &str, attr_name: &str) -> Option<Vec<String>> {
let block_pattern = format!(
r#"<(?:[a-zA-Z0-9_]+:)?Attribute\s[^>]*Name\s*=\s*"{name}"[^>]*>([\s\S]*?)</(?:[a-zA-Z0-9_]+:)?Attribute>"#,
name = regex::escape(attr_name),
);
let block_re = Regex::new(&block_pattern).ok()?;
let block = block_re.captures(xml)?.get(1)?.as_str();
let value_re = Regex::new(
r"<(?:[a-zA-Z0-9_]+:)?AttributeValue[^>]*>([^<]+)</(?:[a-zA-Z0-9_]+:)?AttributeValue>",
)
.ok()?;
let values: Vec<String> = value_re
.captures_iter(block)
.filter_map(|caps| caps.get(1).map(|m| m.as_str().trim().to_string()))
.collect();
if values.is_empty() {
None
} else {
Some(values)
}
}
fn parse_saml_response(saml_response: &str) -> ArgentorResult<SamlClaims> {
let decoded_bytes = STANDARD
.decode(saml_response)
.map_err(|e| ArgentorError::Security(format!("SAML response base64 decode failed: {e}")))?;
let xml = String::from_utf8(decoded_bytes)
.map_err(|e| ArgentorError::Security(format!("SAML response is not valid UTF-8: {e}")))?;
let name_id = extract_xml_element(&xml, "NameID")
.ok_or_else(|| ArgentorError::Security("SAML response missing NameID element".into()))?;
let name = extract_saml_attribute(&xml, "name")
.or_else(|| extract_saml_attribute(&xml, "displayName"))
.or_else(|| {
extract_saml_attribute(
&xml,
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name",
)
});
let email = extract_saml_attribute(&xml, "email")
.or_else(|| {
extract_saml_attribute(
&xml,
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
)
})
.unwrap_or_else(|| name_id.clone());
let roles: Vec<String> = extract_saml_attribute_values(&xml, "role")
.or_else(|| {
extract_saml_attribute_values(
&xml,
"http://schemas.microsoft.com/ws/2008/06/identity/claims/role",
)
})
.unwrap_or_default();
if let Some(status_value) = extract_xml_attr(&xml, "StatusCode", "Value") {
if !status_value.contains("Success") {
return Err(ArgentorError::Security(format!(
"SAML authentication failed with status: {status_value}"
)));
}
}
Ok(SamlClaims {
name_id,
email,
name,
roles,
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserIdentity {
pub id: String,
pub email: String,
pub name: Option<String>,
pub domain: String,
pub roles: Vec<String>,
pub tenant_id: Option<String>,
pub authenticated_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl UserIdentity {
pub fn is_expired(&self) -> bool {
Utc::now() >= self.expires_at
}
}
pub struct SsoManager {
config: SsoConfig,
sessions: RwLock<HashMap<String, UserIdentity>>,
pending_states: RwLock<HashMap<String, DateTime<Utc>>>,
}
impl SsoManager {
pub fn new(config: SsoConfig) -> Self {
Self {
config,
sessions: RwLock::new(HashMap::new()),
pending_states: RwLock::new(HashMap::new()),
}
}
pub fn config(&self) -> &SsoConfig {
&self.config
}
pub fn login_url(&self, state: &str) -> String {
if let Ok(mut states) = self.pending_states.write() {
states.insert(state.to_string(), Utc::now());
}
match self.config.provider {
SsoProvider::Oidc => {
let scopes = self.config.scopes.join("+");
format!(
"{issuer}/authorize?client_id={client_id}&redirect_uri={redirect_uri}&response_type=code&scope={scopes}&state={state}",
issuer = self.config.issuer_url.trim_end_matches('/'),
client_id = self.config.client_id,
redirect_uri = self.config.redirect_uri,
)
}
SsoProvider::Saml => {
format!(
"{issuer}?SAMLRequest=placeholder&RelayState={state}",
issuer = self.config.issuer_url.trim_end_matches('/'),
)
}
SsoProvider::ApiKey => {
format!(
"{redirect_uri}?mode=api_key&state={state}",
redirect_uri = self.config.redirect_uri,
)
}
}
}
pub async fn handle_callback(
&self,
code: &str,
state: &str,
) -> ArgentorResult<(String, UserIdentity)> {
let state_valid = self
.pending_states
.read()
.map_err(|e| ArgentorError::Gateway(format!("lock poisoned: {e}")))?
.contains_key(state);
if !state_valid {
return Err(ArgentorError::Security(
"Invalid or expired SSO state parameter — possible CSRF attack".into(),
));
}
if let Ok(mut states) = self.pending_states.write() {
states.remove(state);
}
if code.is_empty() {
return Err(ArgentorError::Gateway(
"Empty authorization code in SSO callback".into(),
));
}
match self.config.provider {
SsoProvider::Oidc => {
let endpoints = discover_oidc_endpoints(&self.config.issuer_url).await?;
let client = reqwest::Client::new();
let token_response = client
.post(&endpoints.token_endpoint)
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", self.config.redirect_uri.as_str()),
("client_id", self.config.client_id.as_str()),
("client_secret", self.config.client_secret.as_str()),
])
.send()
.await
.map_err(|e| {
ArgentorError::Http(format!("OIDC token exchange request failed: {e}"))
})?;
if !token_response.status().is_success() {
let status = token_response.status();
let body = token_response
.text()
.await
.unwrap_or_else(|_| "<unreadable>".into());
return Err(ArgentorError::Http(format!(
"OIDC token endpoint returned HTTP {status}: {body}"
)));
}
let token_json: serde_json::Value = token_response.json().await.map_err(|e| {
ArgentorError::Http(format!("OIDC token response not valid JSON: {e}"))
})?;
let id_token_str = token_json["id_token"].as_str().ok_or_else(|| {
ArgentorError::Security("OIDC token response missing id_token field".into())
})?;
let payload = decode_jwt_payload(id_token_str)?;
let discovered_issuer = endpoints.issuer.trim_end_matches('/');
let configured_issuer = self.config.issuer_url.trim_end_matches('/');
if discovered_issuer != configured_issuer {
return Err(ArgentorError::Security(format!(
"OIDC discovered issuer mismatch: configured '{configured_issuer}', \
discovered '{discovered_issuer}'"
)));
}
let iss = payload["iss"].as_str().unwrap_or("");
if iss.trim_end_matches('/') != discovered_issuer {
return Err(ArgentorError::Security(format!(
"OIDC id_token issuer mismatch: expected '{discovered_issuer}', got '{iss}'"
)));
}
let email_verified = payload["email_verified"].as_bool().unwrap_or(false);
if !email_verified {
return Err(ArgentorError::Security(
"OIDC email not verified by the identity provider".into(),
));
}
let email = payload["email"].as_str().ok_or_else(|| {
ArgentorError::Security("OIDC id_token missing email claim".into())
})?;
let name = payload["name"].as_str().map(String::from);
let sub = payload["sub"].as_str().unwrap_or("unknown");
if !self.is_domain_allowed(email) {
return Err(ArgentorError::Security(format!(
"Email domain not in allowed list for '{email}'"
)));
}
let identity = self.build_identity(
sub,
email,
name.as_deref(),
vec!["oidc-user".into()],
None,
);
let session_token = self.create_session(identity.clone());
info!(
email = email,
sub = sub,
"OIDC authentication successful, session created"
);
Ok((session_token, identity))
}
SsoProvider::Saml => {
let claims = parse_saml_response(code)?;
if !self.is_domain_allowed(&claims.email) {
return Err(ArgentorError::Security(format!(
"Email domain not in allowed list for '{}'",
claims.email
)));
}
let mut roles = claims.roles.clone();
if roles.is_empty() {
roles.push("saml-user".into());
}
let identity = self.build_identity(
&claims.name_id,
&claims.email,
claims.name.as_deref(),
roles,
None,
);
let session_token = self.create_session(identity.clone());
info!(
email = %claims.email,
name_id = %claims.name_id,
"SAML authentication successful, session created"
);
Ok((session_token, identity))
}
SsoProvider::ApiKey => {
Err(ArgentorError::Agent(
"API key authentication does not use the SSO callback flow \u{2014} \
use POST /auth/api-key instead"
.into(),
))
}
}
}
pub fn create_session(&self, identity: UserIdentity) -> String {
let token = Uuid::new_v4().to_string();
if let Ok(mut sessions) = self.sessions.write() {
sessions.insert(token.clone(), identity);
}
token
}
pub fn validate_session(&self, token: &str) -> Option<UserIdentity> {
let identity = {
let sessions = self.sessions.read().ok()?;
sessions.get(token).cloned()
};
match identity {
Some(id) if id.is_expired() => {
self.revoke_session(token);
None
}
other => other,
}
}
pub fn revoke_session(&self, token: &str) -> bool {
if let Ok(mut sessions) = self.sessions.write() {
sessions.remove(token).is_some()
} else {
false
}
}
pub fn is_domain_allowed(&self, email: &str) -> bool {
if self.config.allowed_domains.is_empty() {
return true;
}
let domain = match email.rsplit_once('@') {
Some((_, d)) => d.to_lowercase(),
None => return false, };
self.config
.allowed_domains
.iter()
.any(|d| d.to_lowercase() == domain)
}
pub fn active_sessions(&self) -> Vec<(String, UserIdentity)> {
let sessions = match self.sessions.read() {
Ok(s) => s,
Err(_) => return Vec::new(),
};
sessions
.iter()
.filter(|(_, id)| !id.is_expired())
.map(|(token, id)| (token.clone(), id.clone()))
.collect()
}
pub fn cleanup_expired(&self) -> usize {
let mut sessions = match self.sessions.write() {
Ok(s) => s,
Err(_) => return 0,
};
let before = sessions.len();
sessions.retain(|_, id| !id.is_expired());
before - sessions.len()
}
pub fn build_identity(
&self,
id: &str,
email: &str,
name: Option<&str>,
roles: Vec<String>,
tenant_id: Option<String>,
) -> UserIdentity {
let domain = email
.rsplit_once('@')
.map(|(_, d)| d.to_string())
.unwrap_or_default();
let now = Utc::now();
let ttl = Duration::hours(i64::from(self.config.session_ttl_hours));
UserIdentity {
id: id.to_string(),
email: email.to_string(),
name: name.map(ToString::to_string),
domain,
roles,
tenant_id,
authenticated_at: now,
expires_at: now + ttl,
}
}
}
#[derive(Clone)]
pub struct SsoState {
pub manager: Arc<SsoManager>,
}
#[derive(Deserialize)]
pub struct LoginQuery {
pub redirect: Option<String>,
}
#[derive(Deserialize)]
pub struct CallbackQuery {
pub code: Option<String>,
pub state: Option<String>,
}
#[derive(Deserialize)]
pub struct ApiKeyAuthRequest {
pub api_key: String,
pub email: Option<String>,
}
#[derive(Serialize)]
pub struct AuthResponse {
pub token: String,
pub identity: UserIdentity,
pub expires_at: DateTime<Utc>,
}
async fn sso_login_handler(
State(state): State<Arc<SsoState>>,
Query(query): Query<LoginQuery>,
) -> impl IntoResponse {
let csrf_state = Uuid::new_v4().to_string();
let login_url = state.manager.login_url(&csrf_state);
let _redirect = query.redirect.unwrap_or_else(|| "/".to_string());
(
StatusCode::TEMPORARY_REDIRECT,
[(header::LOCATION, login_url)],
"",
)
.into_response()
}
async fn sso_callback_handler(
State(state): State<Arc<SsoState>>,
Query(query): Query<CallbackQuery>,
) -> impl IntoResponse {
let code = match query.code {
Some(c) => c,
None => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "missing_code",
"message": "Authorization code is required in the callback"
})),
)
.into_response();
}
};
let callback_state = query.state.unwrap_or_default();
match state.manager.handle_callback(&code, &callback_state).await {
Ok((token, identity)) => {
let response = AuthResponse {
expires_at: identity.expires_at,
token,
identity,
};
(StatusCode::OK, Json(response)).into_response()
}
Err(e) => (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "sso_callback_failed",
"message": e.to_string()
})),
)
.into_response(),
}
}
async fn sso_logout_handler(
State(state): State<Arc<SsoState>>,
headers: HeaderMap,
) -> impl IntoResponse {
let token = extract_session_token(&headers);
match token {
Some(t) => {
let revoked = state.manager.revoke_session(&t);
if revoked {
(
StatusCode::OK,
Json(serde_json::json!({
"status": "logged_out",
"message": "Session revoked successfully"
})),
)
.into_response()
} else {
(
StatusCode::OK,
Json(serde_json::json!({
"status": "no_session",
"message": "No active session found for the given token"
})),
)
.into_response()
}
}
None => (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "missing_token",
"message": "No session token found in Authorization header or cookie"
})),
)
.into_response(),
}
}
async fn sso_me_handler(
State(state): State<Arc<SsoState>>,
headers: HeaderMap,
) -> impl IntoResponse {
let token = extract_session_token(&headers);
match token.and_then(|t| state.manager.validate_session(&t)) {
Some(identity) => (StatusCode::OK, Json(identity)).into_response(),
None => (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "not_authenticated",
"message": "No valid session — please log in via /auth/login"
})),
)
.into_response(),
}
}
async fn sso_api_key_handler(
State(state): State<Arc<SsoState>>,
Json(body): Json<ApiKeyAuthRequest>,
) -> impl IntoResponse {
if body.api_key.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "missing_api_key",
"message": "API key is required"
})),
)
.into_response();
}
let email = body.email.unwrap_or_else(|| "api-key-user@local".into());
if !state.manager.is_domain_allowed(&email) {
return (
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": "domain_not_allowed",
"message": format!(
"Email domain is not in the allowed list: {:?}",
state.manager.config().allowed_domains
)
})),
)
.into_response();
}
let identity = state.manager.build_identity(
&format!("apikey-{}", &body.api_key[..8.min(body.api_key.len())]),
&email,
Some("API Key User"),
vec!["api-user".into()],
None,
);
let token = state.manager.create_session(identity.clone());
let response = AuthResponse {
expires_at: identity.expires_at,
token,
identity,
};
(StatusCode::OK, Json(response)).into_response()
}
fn extract_session_token(headers: &HeaderMap) -> Option<String> {
if let Some(auth) = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
{
return Some(auth.to_string());
}
if let Some(cookie_header) = headers.get("cookie").and_then(|v| v.to_str().ok()) {
for cookie in cookie_header.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix("argentor_session=") {
return Some(value.to_string());
}
}
}
None
}
#[derive(Clone)]
pub struct SsoMiddlewareState {
pub manager: Arc<SsoManager>,
}
pub async fn sso_auth_middleware(
State(state): State<Arc<SsoMiddlewareState>>,
headers: HeaderMap,
mut request: Request,
next: Next,
) -> Response {
let token = extract_session_token(&headers);
match token {
Some(t) => match state.manager.validate_session(&t) {
Some(identity) => {
request.extensions_mut().insert(identity);
next.run(request).await
}
None => {
warn!("SSO session token invalid or expired");
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "session_expired",
"message": "Session token is invalid or expired — please log in again"
})),
)
.into_response()
}
},
None => {
warn!("No SSO session token in request");
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "not_authenticated",
"message": "Authentication required — include a Bearer token or argentor_session cookie"
})),
)
.into_response()
}
}
}
pub fn sso_router(manager: Arc<SsoManager>) -> Router {
let state = Arc::new(SsoState { manager });
Router::new()
.route("/auth/login", get(sso_login_handler))
.route("/auth/callback", get(sso_callback_handler))
.route("/auth/logout", get(sso_logout_handler))
.route("/auth/me", get(sso_me_handler))
.route("/auth/api-key", axum::routing::post(sso_api_key_handler))
.with_state(state)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request as HttpRequest;
use tower::ServiceExt;
fn test_config() -> SsoConfig {
SsoConfig {
provider: SsoProvider::Oidc,
client_id: "test-client-id".into(),
client_secret: "test-client-secret".into(),
redirect_uri: "https://app.example.com/auth/callback".into(),
issuer_url: "https://accounts.example.com".into(),
allowed_domains: vec!["example.com".into(), "corp.example.com".into()],
scopes: vec!["openid".into(), "profile".into(), "email".into()],
session_ttl_hours: 24,
}
}
fn test_manager() -> SsoManager {
SsoManager::new(test_config())
}
fn test_identity(email: &str) -> UserIdentity {
let domain = email
.rsplit_once('@')
.map(|(_, d)| d.to_string())
.unwrap_or_default();
UserIdentity {
id: Uuid::new_v4().to_string(),
email: email.to_string(),
name: Some("Test User".into()),
domain,
roles: vec!["user".into()],
tenant_id: Some("tenant-1".into()),
authenticated_at: Utc::now(),
expires_at: Utc::now() + Duration::hours(24),
}
}
fn expired_identity(email: &str) -> UserIdentity {
let domain = email
.rsplit_once('@')
.map(|(_, d)| d.to_string())
.unwrap_or_default();
UserIdentity {
id: Uuid::new_v4().to_string(),
email: email.to_string(),
name: Some("Expired User".into()),
domain,
roles: vec!["user".into()],
tenant_id: None,
authenticated_at: Utc::now() - Duration::hours(48),
expires_at: Utc::now() - Duration::hours(1),
}
}
#[test]
fn sso_config_default_values() {
let config = SsoConfig::default();
assert_eq!(config.provider, SsoProvider::Oidc);
assert_eq!(config.session_ttl_hours, 24);
assert!(config.scopes.contains(&"openid".to_string()));
assert!(config.allowed_domains.is_empty());
}
#[test]
fn sso_config_serialization_roundtrip() {
let config = test_config();
let json = serde_json::to_string(&config).unwrap();
let deserialized: SsoConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.client_id, "test-client-id");
assert_eq!(deserialized.allowed_domains.len(), 2);
assert_eq!(deserialized.provider, SsoProvider::Oidc);
}
#[test]
fn sso_provider_variants_serialize() {
let oidc = serde_json::to_string(&SsoProvider::Oidc).unwrap();
let saml = serde_json::to_string(&SsoProvider::Saml).unwrap();
let api = serde_json::to_string(&SsoProvider::ApiKey).unwrap();
assert_eq!(oidc, "\"Oidc\"");
assert_eq!(saml, "\"Saml\"");
assert_eq!(api, "\"ApiKey\"");
}
#[test]
fn login_url_oidc_contains_required_params() {
let manager = test_manager();
let url = manager.login_url("csrf-state-123");
assert!(url.contains("accounts.example.com/authorize"));
assert!(url.contains("client_id=test-client-id"));
assert!(url.contains("redirect_uri=https://app.example.com/auth/callback"));
assert!(url.contains("response_type=code"));
assert!(url.contains("state=csrf-state-123"));
assert!(url.contains("scope=openid+profile+email"));
}
#[test]
fn login_url_saml_contains_relay_state() {
let mut config = test_config();
config.provider = SsoProvider::Saml;
let manager = SsoManager::new(config);
let url = manager.login_url("saml-state-456");
assert!(url.contains("SAMLRequest="));
assert!(url.contains("RelayState=saml-state-456"));
}
#[test]
fn login_url_api_key_uses_redirect_uri() {
let mut config = test_config();
config.provider = SsoProvider::ApiKey;
let manager = SsoManager::new(config);
let url = manager.login_url("api-state");
assert!(url.contains("app.example.com/auth/callback"));
assert!(url.contains("mode=api_key"));
}
#[test]
fn login_url_records_pending_state() {
let manager = test_manager();
let _url = manager.login_url("track-me");
let states = manager.pending_states.read().unwrap();
assert!(states.contains_key("track-me"));
}
#[test]
fn domain_allowed_exact_match() {
let manager = test_manager();
assert!(manager.is_domain_allowed("alice@example.com"));
assert!(manager.is_domain_allowed("bob@corp.example.com"));
}
#[test]
fn domain_blocked_when_not_in_list() {
let manager = test_manager();
assert!(!manager.is_domain_allowed("eve@evil.com"));
assert!(!manager.is_domain_allowed("mallory@other.org"));
}
#[test]
fn domain_case_insensitive() {
let manager = test_manager();
assert!(manager.is_domain_allowed("Alice@EXAMPLE.COM"));
assert!(manager.is_domain_allowed("bob@Corp.Example.Com"));
}
#[test]
fn domain_all_allowed_when_list_empty() {
let mut config = test_config();
config.allowed_domains = vec![];
let manager = SsoManager::new(config);
assert!(manager.is_domain_allowed("anyone@anywhere.com"));
}
#[test]
fn domain_invalid_email_rejected() {
let manager = test_manager();
assert!(!manager.is_domain_allowed("not-an-email"));
assert!(!manager.is_domain_allowed(""));
}
#[test]
fn session_create_and_validate() {
let manager = test_manager();
let identity = test_identity("alice@example.com");
let token = manager.create_session(identity.clone());
let retrieved = manager.validate_session(&token);
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.email, "alice@example.com");
assert_eq!(retrieved.name, Some("Test User".into()));
}
#[test]
fn session_validate_unknown_token_returns_none() {
let manager = test_manager();
assert!(manager.validate_session("nonexistent-token").is_none());
}
#[test]
fn session_revoke_removes_session() {
let manager = test_manager();
let identity = test_identity("bob@example.com");
let token = manager.create_session(identity);
assert!(manager.validate_session(&token).is_some());
assert!(manager.revoke_session(&token));
assert!(manager.validate_session(&token).is_none());
}
#[test]
fn session_revoke_unknown_returns_false() {
let manager = test_manager();
assert!(!manager.revoke_session("no-such-token"));
}
#[test]
fn session_expired_auto_revoked_on_validate() {
let manager = test_manager();
let identity = expired_identity("expired@example.com");
let token = manager.create_session(identity);
assert!(manager.validate_session(&token).is_none());
let sessions = manager.sessions.read().unwrap();
assert!(!sessions.contains_key(&token));
}
#[test]
fn cleanup_removes_expired_sessions() {
let manager = test_manager();
let _valid_token = manager.create_session(test_identity("valid@example.com"));
let _expired1 = manager.create_session(expired_identity("old1@example.com"));
let _expired2 = manager.create_session(expired_identity("old2@example.com"));
let removed = manager.cleanup_expired();
assert_eq!(removed, 2);
let active = manager.active_sessions();
assert_eq!(active.len(), 1);
assert_eq!(active[0].1.email, "valid@example.com");
}
#[test]
fn cleanup_returns_zero_when_none_expired() {
let manager = test_manager();
let _t = manager.create_session(test_identity("fresh@example.com"));
assert_eq!(manager.cleanup_expired(), 0);
}
#[test]
fn active_sessions_lists_only_valid() {
let manager = test_manager();
let _t1 = manager.create_session(test_identity("a@example.com"));
let _t2 = manager.create_session(test_identity("b@example.com"));
let _t3 = manager.create_session(expired_identity("c@example.com"));
let active = manager.active_sessions();
assert_eq!(active.len(), 2);
let emails: Vec<&str> = active.iter().map(|(_, id)| id.email.as_str()).collect();
assert!(emails.contains(&"a@example.com"));
assert!(emails.contains(&"b@example.com"));
}
#[test]
fn build_identity_sets_domain_and_ttl() {
let manager = test_manager();
let identity = manager.build_identity(
"user-1",
"alice@example.com",
Some("Alice"),
vec!["admin".into()],
Some("tenant-x".into()),
);
assert_eq!(identity.id, "user-1");
assert_eq!(identity.email, "alice@example.com");
assert_eq!(identity.domain, "example.com");
assert_eq!(identity.name, Some("Alice".into()));
assert_eq!(identity.roles, vec!["admin"]);
assert_eq!(identity.tenant_id, Some("tenant-x".into()));
assert!(identity.expires_at > Utc::now());
let ttl = identity.expires_at - identity.authenticated_at;
assert_eq!(ttl.num_hours(), 24);
}
#[test]
fn user_identity_is_expired() {
let valid = test_identity("valid@example.com");
assert!(!valid.is_expired());
let expired = expired_identity("old@example.com");
assert!(expired.is_expired());
}
#[tokio::test]
async fn middleware_valid_session_passes_through() {
let manager = Arc::new(test_manager());
let identity = test_identity("auth@example.com");
let token = manager.create_session(identity);
let mw_state = Arc::new(SsoMiddlewareState {
manager: manager.clone(),
});
let app = Router::new()
.route(
"/protected",
get(|req: HttpRequest<Body>| async move {
let id = req.extensions().get::<UserIdentity>();
match id {
Some(u) => (StatusCode::OK, u.email.clone()).into_response(),
None => (StatusCode::INTERNAL_SERVER_ERROR, "no identity").into_response(),
}
}),
)
.layer(axum::middleware::from_fn_with_state(
mw_state,
sso_auth_middleware,
));
let response = app
.oneshot(
HttpRequest::builder()
.uri("/protected")
.header("authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn middleware_missing_session_returns_401() {
let manager = Arc::new(test_manager());
let mw_state = Arc::new(SsoMiddlewareState {
manager: manager.clone(),
});
let app = Router::new()
.route("/protected", get(|| async { "ok" }))
.layer(axum::middleware::from_fn_with_state(
mw_state,
sso_auth_middleware,
));
let response = app
.oneshot(
HttpRequest::builder()
.uri("/protected")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn middleware_expired_session_returns_401() {
let manager = Arc::new(test_manager());
let identity = expired_identity("old@example.com");
let token = "expired-token-123".to_string();
manager
.sessions
.write()
.unwrap()
.insert(token.clone(), identity);
let mw_state = Arc::new(SsoMiddlewareState {
manager: manager.clone(),
});
let app = Router::new()
.route("/protected", get(|| async { "ok" }))
.layer(axum::middleware::from_fn_with_state(
mw_state,
sso_auth_middleware,
));
let response = app
.oneshot(
HttpRequest::builder()
.uri("/protected")
.header("authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn middleware_invalid_token_returns_401() {
let manager = Arc::new(test_manager());
let mw_state = Arc::new(SsoMiddlewareState {
manager: manager.clone(),
});
let app = Router::new()
.route("/protected", get(|| async { "ok" }))
.layer(axum::middleware::from_fn_with_state(
mw_state,
sso_auth_middleware,
));
let response = app
.oneshot(
HttpRequest::builder()
.uri("/protected")
.header("authorization", "Bearer bogus-token")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn extract_token_from_bearer_header() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer my-token-123".parse().unwrap());
assert_eq!(extract_session_token(&headers), Some("my-token-123".into()));
}
#[test]
fn extract_token_from_cookie() {
let mut headers = HeaderMap::new();
headers.insert(
"cookie",
"other=val; argentor_session=cookie-token-456; another=x"
.parse()
.unwrap(),
);
assert_eq!(
extract_session_token(&headers),
Some("cookie-token-456".into())
);
}
#[test]
fn extract_token_bearer_takes_precedence_over_cookie() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer header-token".parse().unwrap());
headers.insert("cookie", "argentor_session=cookie-token".parse().unwrap());
assert_eq!(extract_session_token(&headers), Some("header-token".into()));
}
#[test]
fn extract_token_returns_none_when_absent() {
let headers = HeaderMap::new();
assert_eq!(extract_session_token(&headers), None);
}
#[tokio::test]
async fn callback_rejects_invalid_state() {
let manager = test_manager();
let result = manager.handle_callback("some-code", "unknown-state").await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("CSRF"), "Error should mention CSRF: {err}");
}
#[tokio::test]
async fn callback_rejects_empty_code() {
let manager = test_manager();
let _url = manager.login_url("valid-state");
let result = manager.handle_callback("", "valid-state").await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Empty authorization code"),
"Error should mention empty code: {err}"
);
}
#[tokio::test]
async fn callback_oidc_discovery_fails_on_unreachable_issuer() {
let manager = test_manager();
let _url = manager.login_url("oidc-state");
let result = manager.handle_callback("auth-code-123", "oidc-state").await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("OIDC discovery") || err.contains("HTTP"),
"Should indicate OIDC discovery/HTTP failure: {err}"
);
}
#[tokio::test]
async fn callback_saml_invalid_base64_rejected() {
let mut config = test_config();
config.provider = SsoProvider::Saml;
let manager = SsoManager::new(config);
let _url = manager.login_url("saml-state");
let result = manager.handle_callback("saml-response", "saml-state").await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("base64 decode failed"),
"Should indicate base64 decode failure: {err}"
);
}
fn make_saml_xml(name_id: &str, status: &str, attrs: &[(&str, &[&str])]) -> String {
let mut attr_xml = String::new();
for (name, values) in attrs {
attr_xml.push_str(&format!(r#"<saml:Attribute Name="{name}">"#));
for val in *values {
attr_xml.push_str(&format!("<saml:AttributeValue>{val}</saml:AttributeValue>"));
}
attr_xml.push_str("</saml:Attribute>\n");
}
format!(
r#"<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="_response123" Version="2.0" IssueInstant="2025-01-15T10:00:00Z" Destination="https://app.example.com/auth/callback">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<samlp:Status>
<samlp:StatusCode Value="{status}"/>
</samlp:Status>
<saml:Assertion ID="_assertion456" Version="2.0" IssueInstant="2025-01-15T10:00:00Z">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<saml:Subject>
<saml:NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">{name_id}</saml:NameID>
</saml:Subject>
<saml:Conditions NotBefore="2025-01-15T09:55:00Z" NotOnOrAfter="2025-01-15T10:05:00Z">
<saml:AudienceRestriction>
<saml:Audience>https://app.example.com</saml:Audience>
</saml:AudienceRestriction>
</saml:Conditions>
<saml:AttributeStatement>
{attr_xml}
</saml:AttributeStatement>
</saml:Assertion>
</samlp:Response>"#
)
}
fn encode_saml(xml: &str) -> String {
STANDARD.encode(xml.as_bytes())
}
#[test]
fn saml_parse_valid_response() {
let xml = make_saml_xml(
"alice@example.com",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[
("name", &["Alice Smith"]),
("email", &["alice@example.com"]),
("role", &["admin", "viewer"]),
],
);
let encoded = encode_saml(&xml);
let claims = parse_saml_response(&encoded).unwrap();
assert_eq!(claims.name_id, "alice@example.com");
assert_eq!(claims.email, "alice@example.com");
assert_eq!(claims.name, Some("Alice Smith".into()));
assert_eq!(claims.roles, vec!["admin", "viewer"]);
}
#[test]
fn saml_parse_extracts_name_id() {
let xml = make_saml_xml(
"bob@corp.example.com",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[],
);
let encoded = encode_saml(&xml);
let claims = parse_saml_response(&encoded).unwrap();
assert_eq!(claims.name_id, "bob@corp.example.com");
assert_eq!(claims.email, "bob@corp.example.com");
}
#[test]
fn saml_parse_extracts_attributes() {
let xml = make_saml_xml(
"user123",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[
("name", &["Charlie Brown"]),
("email", &["charlie@example.com"]),
],
);
let encoded = encode_saml(&xml);
let claims = parse_saml_response(&encoded).unwrap();
assert_eq!(claims.name, Some("Charlie Brown".into()));
assert_eq!(claims.email, "charlie@example.com");
}
#[test]
fn saml_parse_missing_name_id_rejected() {
let xml = r#"<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion">
<samlp:Status>
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
</samlp:Status>
<saml:Assertion>
<saml:Subject></saml:Subject>
</saml:Assertion>
</samlp:Response>"#;
let encoded = encode_saml(xml);
let result = parse_saml_response(&encoded);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("missing NameID"),
"Should mention missing NameID: {err}"
);
}
#[test]
fn saml_parse_failed_status_code_rejected() {
let xml = make_saml_xml(
"alice@example.com",
"urn:oasis:names:tc:SAML:2.0:status:Requester",
&[],
);
let encoded = encode_saml(&xml);
let result = parse_saml_response(&encoded);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("SAML authentication failed"),
"Should indicate SAML auth failure: {err}"
);
assert!(
err.contains("Requester"),
"Should include the status code: {err}"
);
}
#[tokio::test]
async fn saml_callback_domain_validation() {
let mut config = test_config();
config.provider = SsoProvider::Saml;
let manager = SsoManager::new(config);
let _url = manager.login_url("saml-state");
let xml = make_saml_xml(
"evil@attacker.com",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[("email", &["evil@attacker.com"])],
);
let encoded = encode_saml(&xml);
let result = manager.handle_callback(&encoded, "saml-state").await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("not in allowed list"),
"Should reject unauthorized domain: {err}"
);
}
#[test]
fn saml_parse_multiple_roles() {
let xml = make_saml_xml(
"admin@example.com",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[("role", &["admin", "editor", "viewer", "auditor"])],
);
let encoded = encode_saml(&xml);
let claims = parse_saml_response(&encoded).unwrap();
assert_eq!(claims.roles.len(), 4);
assert_eq!(claims.roles, vec!["admin", "editor", "viewer", "auditor"]);
}
#[test]
fn saml_parse_azure_ad_uri_attributes() {
let xml = make_saml_xml(
"azure-user-id",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[
(
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name",
&["Azure User"],
),
(
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
&["azure@example.com"],
),
(
"http://schemas.microsoft.com/ws/2008/06/identity/claims/role",
&["GlobalAdmin", "Reader"],
),
],
);
let encoded = encode_saml(&xml);
let claims = parse_saml_response(&encoded).unwrap();
assert_eq!(claims.name_id, "azure-user-id");
assert_eq!(claims.name, Some("Azure User".into()));
assert_eq!(claims.email, "azure@example.com");
assert_eq!(claims.roles, vec!["GlobalAdmin", "Reader"]);
}
#[tokio::test]
async fn saml_callback_full_success_flow() {
let mut config = test_config();
config.provider = SsoProvider::Saml;
let manager = SsoManager::new(config);
let _url = manager.login_url("saml-ok-state");
let xml = make_saml_xml(
"alice@example.com",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[
("name", &["Alice Smith"]),
("email", &["alice@example.com"]),
("role", &["engineer"]),
],
);
let encoded = encode_saml(&xml);
let result = manager.handle_callback(&encoded, "saml-ok-state").await;
assert!(result.is_ok(), "SAML callback should succeed");
let (token, identity) = result.unwrap();
assert!(!token.is_empty());
assert_eq!(identity.email, "alice@example.com");
assert_eq!(identity.name, Some("Alice Smith".into()));
assert_eq!(identity.domain, "example.com");
assert_eq!(identity.roles, vec!["engineer"]);
let session = manager.validate_session(&token);
assert!(session.is_some());
}
#[test]
fn saml_parse_displayname_fallback() {
let xml = make_saml_xml(
"user@example.com",
"urn:oasis:names:tc:SAML:2.0:status:Success",
&[("displayName", &["Display Name User"])],
);
let encoded = encode_saml(&xml);
let claims = parse_saml_response(&encoded).unwrap();
assert_eq!(claims.name, Some("Display Name User".into()));
}
#[test]
fn saml_parse_invalid_base64() {
let result = parse_saml_response("!!!not-base64!!!");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("base64 decode failed"),
"Should mention base64 failure: {err}"
);
}
#[test]
fn extract_xml_element_simple() {
let xml = "<NameID>alice@example.com</NameID>";
assert_eq!(
extract_xml_element(xml, "NameID"),
Some("alice@example.com".into())
);
}
#[test]
fn extract_xml_element_with_namespace() {
let xml = r#"<saml:NameID Format="email">alice@example.com</saml:NameID>"#;
assert_eq!(
extract_xml_element(xml, "NameID"),
Some("alice@example.com".into())
);
}
#[test]
fn extract_xml_element_not_found() {
let xml = "<Issuer>https://idp.example.com</Issuer>";
assert_eq!(extract_xml_element(xml, "NameID"), None);
}
#[test]
fn extract_xml_attr_status_code() {
let xml = r#"<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>"#;
let val = extract_xml_attr(xml, "StatusCode", "Value");
assert_eq!(
val,
Some("urn:oasis:names:tc:SAML:2.0:status:Success".into())
);
}
#[test]
fn extract_saml_attribute_simple() {
let xml = r#"<saml:Attribute Name="email"><saml:AttributeValue>alice@example.com</saml:AttributeValue></saml:Attribute>"#;
assert_eq!(
extract_saml_attribute(xml, "email"),
Some("alice@example.com".into())
);
}
#[test]
fn extract_saml_attribute_values_multiple() {
let xml = r#"<saml:Attribute Name="role"><saml:AttributeValue>admin</saml:AttributeValue><saml:AttributeValue>user</saml:AttributeValue></saml:Attribute>"#;
let values = extract_saml_attribute_values(xml, "role");
assert_eq!(values, Some(vec!["admin".into(), "user".into()]));
}
fn make_test_jwt(payload: &serde_json::Value) -> String {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
let header = serde_json::json!({"alg": "RS256", "typ": "JWT"});
let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
let payload_b64 = URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
let sig_b64 = URL_SAFE_NO_PAD.encode(b"fake-signature");
format!("{header_b64}.{payload_b64}.{sig_b64}")
}
#[test]
fn decode_jwt_payload_valid_token() {
let claims = serde_json::json!({
"sub": "user-123",
"email": "alice@example.com",
"name": "Alice",
"iss": "https://accounts.example.com",
"email_verified": true
});
let token = make_test_jwt(&claims);
let decoded = decode_jwt_payload(&token).unwrap();
assert_eq!(decoded["sub"].as_str().unwrap(), "user-123");
assert_eq!(decoded["email"].as_str().unwrap(), "alice@example.com");
assert_eq!(decoded["name"].as_str().unwrap(), "Alice");
assert!(decoded["email_verified"].as_bool().unwrap());
}
#[test]
fn decode_jwt_payload_invalid_format_no_dots() {
let result = decode_jwt_payload("not-a-jwt");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Invalid JWT format"),
"Should mention invalid format: {err}"
);
}
#[test]
fn decode_jwt_payload_invalid_format_two_parts() {
let result = decode_jwt_payload("part1.part2");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("expected 3 parts"),
"Should mention expected 3 parts: {err}"
);
}
#[test]
fn decode_jwt_payload_invalid_base64() {
let result = decode_jwt_payload("header.!!!invalid!!!.signature");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("base64url decode failed"),
"Should mention base64 decode failure: {err}"
);
}
#[test]
fn decode_jwt_payload_invalid_json() {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
let header = URL_SAFE_NO_PAD.encode(b"{}");
let payload = URL_SAFE_NO_PAD.encode(b"not json at all");
let sig = URL_SAFE_NO_PAD.encode(b"sig");
let token = format!("{header}.{payload}.{sig}");
let result = decode_jwt_payload(&token);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("not valid JSON"),
"Should mention invalid JSON: {err}"
);
}
#[test]
fn decode_jwt_payload_extracts_nested_claims() {
let claims = serde_json::json!({
"sub": "u-456",
"email": "bob@corp.example.com",
"name": "Bob",
"iss": "https://idp.example.com",
"aud": "test-client-id",
"email_verified": true,
"custom_claim": {"org": "acme", "tier": "enterprise"}
});
let token = make_test_jwt(&claims);
let decoded = decode_jwt_payload(&token).unwrap();
assert_eq!(decoded["custom_claim"]["org"].as_str().unwrap(), "acme");
}
#[test]
fn oidc_domain_validation_after_decode() {
let manager = test_manager();
assert!(manager.is_domain_allowed("alice@example.com"));
assert!(!manager.is_domain_allowed("alice@evil.com"));
}
#[test]
fn oidc_endpoints_struct_fields() {
let endpoints = OidcEndpoints {
authorization_endpoint: "https://idp.example.com/authorize".into(),
token_endpoint: "https://idp.example.com/token".into(),
userinfo_endpoint: Some("https://idp.example.com/userinfo".into()),
issuer: "https://idp.example.com".into(),
};
assert_eq!(endpoints.token_endpoint, "https://idp.example.com/token");
assert_eq!(endpoints.issuer, "https://idp.example.com");
assert!(endpoints.userinfo_endpoint.is_some());
}
#[test]
fn oidc_endpoints_optional_userinfo() {
let endpoints = OidcEndpoints {
authorization_endpoint: "https://idp.example.com/authorize".into(),
token_endpoint: "https://idp.example.com/token".into(),
userinfo_endpoint: None,
issuer: "https://idp.example.com".into(),
};
assert!(endpoints.userinfo_endpoint.is_none());
}
#[test]
fn issuer_validation_trailing_slash_normalization() {
let iss_a = "https://accounts.example.com/";
let iss_b = "https://accounts.example.com";
assert_eq!(iss_a.trim_end_matches('/'), iss_b.trim_end_matches('/'));
}
#[test]
fn email_verified_false_is_rejected() {
let payload = serde_json::json!({
"sub": "user-1",
"email": "alice@example.com",
"email_verified": false,
"iss": "https://accounts.example.com"
});
let email_verified = payload["email_verified"].as_bool().unwrap_or(false);
assert!(!email_verified, "email_verified=false should be rejected");
}
#[test]
fn email_verified_missing_is_rejected() {
let payload = serde_json::json!({
"sub": "user-1",
"email": "alice@example.com",
"iss": "https://accounts.example.com"
});
let email_verified = payload["email_verified"].as_bool().unwrap_or(false);
assert!(
!email_verified,
"missing email_verified should default to false"
);
}
#[test]
fn missing_id_token_in_response_detected() {
let token_json = serde_json::json!({
"access_token": "abc123",
"token_type": "Bearer"
});
assert!(
token_json["id_token"].as_str().is_none(),
"Missing id_token should be detected"
);
}
}