use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{debug, info, warn};
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum SsoError {
#[error("Authentication failed: {0}")]
AuthenticationFailed(String),
#[error("Token validation failed: {0}")]
TokenValidationFailed(String),
#[error("Provider not found: {0}")]
ProviderNotFound(String),
#[error("Configuration error: {0}")]
ConfigurationError(String),
#[error("Session expired for user: {0}")]
SessionExpired(String),
#[error("Attribute mapping error: {0}")]
AttributeMappingError(String),
#[error("Federation error: {0}")]
FederationError(String),
#[error("SAML assertion error: {0}")]
SamlAssertionError(String),
#[error("OIDC error: {0}")]
OidcError(String),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
pub type SsoResult<T> = Result<T, SsoError>;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SsoProtocol {
Saml2,
Oidc,
OAuth2,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Saml2Config {
pub sp_entity_id: String,
pub sp_acs_url: String,
pub idp_sso_url: String,
pub idp_entity_id: String,
pub idp_certificate: String,
pub sp_private_key: Option<String>,
pub sp_certificate: Option<String>,
pub sign_requests: bool,
pub require_encrypted_assertions: bool,
pub email_attribute: String,
pub display_name_attribute: String,
pub groups_attribute: Option<String>,
pub name_id_format: NameIdFormat,
pub clock_skew_seconds: i64,
}
impl Default for Saml2Config {
fn default() -> Self {
Self {
sp_entity_id: "https://oxirs.example.com".to_string(),
sp_acs_url: "https://oxirs.example.com/auth/saml/acs".to_string(),
idp_sso_url: String::new(),
idp_entity_id: String::new(),
idp_certificate: String::new(),
sp_private_key: None,
sp_certificate: None,
sign_requests: true,
require_encrypted_assertions: false,
email_attribute: "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"
.to_string(),
display_name_attribute:
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/displayname".to_string(),
groups_attribute: Some(
"http://schemas.microsoft.com/ws/2008/06/identity/claims/groups".to_string(),
),
name_id_format: NameIdFormat::EmailAddress,
clock_skew_seconds: 300,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum NameIdFormat {
EmailAddress,
Transient,
Persistent,
Unspecified,
}
impl NameIdFormat {
pub fn as_urn(&self) -> &'static str {
match self {
Self::EmailAddress => "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
Self::Transient => "urn:oasis:names:tc:SAML:2.0:nameid-format:transient",
Self::Persistent => "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent",
Self::Unspecified => "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcConfig {
pub client_id: String,
pub client_secret: String,
pub discovery_url: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub use_pkce: bool,
pub max_age_seconds: Option<u64>,
pub audience: Vec<String>,
pub email_claim: String,
pub name_claim: String,
pub groups_claim: Option<String>,
pub validate_nonce: bool,
pub response_type: String,
}
impl Default for OidcConfig {
fn default() -> Self {
Self {
client_id: String::new(),
client_secret: String::new(),
discovery_url: String::new(),
redirect_uri: String::new(),
scopes: vec![
"openid".to_string(),
"email".to_string(),
"profile".to_string(),
],
use_pkce: true,
max_age_seconds: Some(3600),
audience: Vec::new(),
email_claim: "email".to_string(),
name_claim: "name".to_string(),
groups_claim: Some("groups".to_string()),
validate_nonce: true,
response_type: "code".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IdentityProvider {
pub id: String,
pub name: String,
pub protocol: SsoProtocol,
pub saml_config: Option<Saml2Config>,
pub oidc_config: Option<OidcConfig>,
pub enabled: bool,
pub priority: i32,
pub domains: Vec<String>,
pub default_role: String,
pub role_mapping: HashMap<String, String>,
pub auto_provision: bool,
pub sync_attributes: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl IdentityProvider {
pub fn new_saml(id: String, name: String, saml_config: Saml2Config) -> Self {
let now = Utc::now();
Self {
id,
name,
protocol: SsoProtocol::Saml2,
saml_config: Some(saml_config),
oidc_config: None,
enabled: true,
priority: 0,
domains: Vec::new(),
default_role: "user".to_string(),
role_mapping: HashMap::new(),
auto_provision: true,
sync_attributes: true,
created_at: now,
updated_at: now,
}
}
pub fn new_oidc(id: String, name: String, oidc_config: OidcConfig) -> Self {
let now = Utc::now();
Self {
id,
name,
protocol: SsoProtocol::Oidc,
saml_config: None,
oidc_config: Some(oidc_config),
enabled: true,
priority: 0,
domains: Vec::new(),
default_role: "user".to_string(),
role_mapping: HashMap::new(),
auto_provision: true,
sync_attributes: true,
created_at: now,
updated_at: now,
}
}
pub fn map_role(&self, group: &str) -> String {
self.role_mapping
.get(group)
.cloned()
.unwrap_or_else(|| self.default_role.clone())
}
pub fn handles_domain(&self, email: &str) -> bool {
if self.domains.is_empty() {
return true; }
if let Some(at_pos) = email.rfind('@') {
let domain = &email[at_pos + 1..];
self.domains.iter().any(|d| d == domain)
} else {
false
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAssertion {
pub id: String,
pub name_id: String,
pub name_id_format: String,
pub issue_instant: DateTime<Utc>,
pub not_before: Option<DateTime<Utc>>,
pub not_on_or_after: Option<DateTime<Utc>>,
pub session_index: Option<String>,
pub issuer: String,
pub attributes: HashMap<String, Vec<String>>,
}
impl SamlAssertion {
pub fn is_valid(&self, clock_skew_seconds: i64) -> bool {
let now = Utc::now();
let skew = Duration::seconds(clock_skew_seconds);
if let Some(not_before) = self.not_before {
if now + skew < not_before {
return false;
}
}
if let Some(not_on_or_after) = self.not_on_or_after {
if now - skew >= not_on_or_after {
return false;
}
}
true
}
pub fn get_attribute(&self, name: &str) -> Option<&str> {
self.attributes
.get(name)
.and_then(|vals| vals.first())
.map(|s| s.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcIdToken {
pub sub: String,
pub iss: String,
pub aud: Vec<String>,
pub exp: DateTime<Utc>,
pub iat: DateTime<Utc>,
pub nonce: Option<String>,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub name: Option<String>,
pub claims: HashMap<String, serde_json::Value>,
}
impl OidcIdToken {
pub fn is_expired(&self) -> bool {
Utc::now() >= self.exp
}
pub fn get_claim_str(&self, claim: &str) -> Option<String> {
self.claims.get(claim).and_then(|v| match v {
serde_json::Value::String(s) => Some(s.clone()),
serde_json::Value::Number(n) => Some(n.to_string()),
serde_json::Value::Bool(b) => Some(b.to_string()),
_ => None,
})
}
pub fn get_claim_list(&self, claim: &str) -> Vec<String> {
match self.claims.get(claim) {
Some(serde_json::Value::Array(arr)) => arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect(),
Some(serde_json::Value::String(s)) => vec![s.clone()],
_ => Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoUserProfile {
pub subject_id: String,
pub provider_id: String,
pub email: String,
pub email_verified: bool,
pub display_name: String,
pub idp_groups: Vec<String>,
pub roles: Vec<String>,
pub attributes: HashMap<String, Vec<String>>,
pub authenticated_at: DateTime<Utc>,
pub session_expires_at: Option<DateTime<Utc>>,
}
impl SsoUserProfile {
pub fn has_role(&self, role: &str) -> bool {
self.roles.iter().any(|r| r == role)
}
pub fn has_group(&self, group: &str) -> bool {
self.idp_groups.iter().any(|g| g == group)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoSession {
pub id: String,
pub user_profile: SsoUserProfile,
pub saml_session_index: Option<String>,
pub oidc_refresh_token: Option<String>,
pub oidc_access_token: Option<String>,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub last_accessed_at: DateTime<Utc>,
pub revoked: bool,
}
impl SsoSession {
pub fn new(user_profile: SsoUserProfile, ttl_seconds: i64) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
user_profile,
saml_session_index: None,
oidc_refresh_token: None,
oidc_access_token: None,
created_at: now,
expires_at: now + Duration::seconds(ttl_seconds),
last_accessed_at: now,
revoked: false,
}
}
pub fn is_valid(&self) -> bool {
!self.revoked && Utc::now() < self.expires_at
}
pub fn touch(&mut self) {
self.last_accessed_at = Utc::now();
}
pub fn revoke(&mut self) {
self.revoked = true;
debug!("SSO session {} revoked", self.id);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PkceState {
pub code_verifier: String,
pub code_challenge: String,
pub state: String,
pub nonce: String,
pub expires_at: DateTime<Utc>,
pub provider_id: String,
}
impl PkceState {
pub fn generate(provider_id: String) -> Self {
let code_verifier = format!(
"{}-{}-{}",
Uuid::new_v4().simple(),
Uuid::new_v4().simple(),
Uuid::new_v4().simple()
);
let code_challenge = format!("challenge_{}", Uuid::new_v4().simple());
Self {
code_verifier,
code_challenge,
state: Uuid::new_v4().to_string(),
nonce: Uuid::new_v4().to_string(),
expires_at: Utc::now() + Duration::minutes(10),
provider_id,
}
}
pub fn is_expired(&self) -> bool {
Utc::now() >= self.expires_at
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAuthRequest {
pub id: String,
pub issue_instant: DateTime<Utc>,
pub assertion_consumer_service_url: String,
pub destination: String,
pub issuer: String,
pub name_id_policy_format: String,
pub allow_create: bool,
pub relay_state: Option<String>,
}
impl SamlAuthRequest {
pub fn new(provider: &IdentityProvider) -> SsoResult<Self> {
let saml = provider.saml_config.as_ref().ok_or_else(|| {
SsoError::ConfigurationError(format!(
"Provider {} does not have SAML configuration",
provider.id
))
})?;
Ok(Self {
id: format!("_{}", Uuid::new_v4().simple()),
issue_instant: Utc::now(),
assertion_consumer_service_url: saml.sp_acs_url.clone(),
destination: saml.idp_sso_url.clone(),
issuer: saml.sp_entity_id.clone(),
name_id_policy_format: saml.name_id_format.as_urn().to_string(),
allow_create: true,
relay_state: None,
})
}
pub fn to_xml(&self) -> String {
format!(
r#"<samlp:AuthnRequest
xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="{id}"
Version="2.0"
IssueInstant="{issue_instant}"
Destination="{destination}"
AssertionConsumerServiceURL="{acs_url}">
<saml:Issuer>{issuer}</saml:Issuer>
<samlp:NameIDPolicy
Format="{name_id_format}"
AllowCreate="{allow_create}"/>
</samlp:AuthnRequest>"#,
id = self.id,
issue_instant = self.issue_instant.to_rfc3339(),
destination = self.destination,
acs_url = self.assertion_consumer_service_url,
issuer = self.issuer,
name_id_format = self.name_id_policy_format,
allow_create = self.allow_create
)
}
pub fn to_base64(&self) -> String {
let xml = self.to_xml();
use std::fmt::Write as _;
let mut encoded = String::new();
for chunk in xml.as_bytes().chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let c0 = (n >> 18) & 0x3F;
let c1 = (n >> 12) & 0x3F;
let c2 = (n >> 6) & 0x3F;
let c3 = n & 0x3F;
let char_bytes: Vec<char> = chars.chars().collect();
let _ = write!(encoded, "{}", char_bytes[c0 as usize]);
let _ = write!(encoded, "{}", char_bytes[c1 as usize]);
if chunk.len() > 1 {
let _ = write!(encoded, "{}", char_bytes[c2 as usize]);
} else {
encoded.push('=');
}
if chunk.len() > 2 {
let _ = write!(encoded, "{}", char_bytes[c3 as usize]);
} else {
encoded.push('=');
}
}
encoded
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SsoAuditEventKind {
LoginSuccess,
LoginFailure,
LogoutSuccess,
SessionExpired,
SessionRevoked,
ProviderRegistered,
ProviderUpdated,
ProviderDisabled,
TokenRefreshed,
AttributesSynced,
UnknownProvider,
InvalidAssertion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoAuditEvent {
pub id: String,
pub kind: SsoAuditEventKind,
pub timestamp: DateTime<Utc>,
pub user_email: Option<String>,
pub provider_id: Option<String>,
pub session_id: Option<String>,
pub details: HashMap<String, String>,
pub source_ip: Option<String>,
}
impl SsoAuditEvent {
pub fn new(kind: SsoAuditEventKind) -> Self {
Self {
id: Uuid::new_v4().to_string(),
kind,
timestamp: Utc::now(),
user_email: None,
provider_id: None,
session_id: None,
details: HashMap::new(),
source_ip: None,
}
}
pub fn with_user(mut self, email: String) -> Self {
self.user_email = Some(email);
self
}
pub fn with_provider(mut self, provider_id: String) -> Self {
self.provider_id = Some(provider_id);
self
}
pub fn with_session(mut self, session_id: String) -> Self {
self.session_id = Some(session_id);
self
}
pub fn with_detail(mut self, key: String, value: String) -> Self {
self.details.insert(key, value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoConfig {
pub enabled: bool,
pub session_ttl_seconds: i64,
pub allow_local_login: bool,
pub enforce_sso: bool,
pub max_sessions_per_user: usize,
pub enable_audit_log: bool,
pub audit_log_buffer_size: usize,
pub auto_provision_users: bool,
pub default_role: String,
}
impl Default for SsoConfig {
fn default() -> Self {
Self {
enabled: false,
session_ttl_seconds: 28800, allow_local_login: true,
enforce_sso: false,
max_sessions_per_user: 5,
enable_audit_log: true,
audit_log_buffer_size: 10_000,
auto_provision_users: true,
default_role: "user".to_string(),
}
}
}
pub struct SsoAuthManager {
config: SsoConfig,
providers: HashMap<String, IdentityProvider>,
sessions: HashMap<String, SsoSession>,
pending_pkce: HashMap<String, PkceState>,
pending_saml: HashMap<String, SamlAuthRequest>,
audit_log: Vec<SsoAuditEvent>,
}
impl SsoAuthManager {
pub fn new(config: SsoConfig) -> Self {
info!(
"Initializing SSO authentication manager (enabled: {})",
config.enabled
);
Self {
config,
providers: HashMap::new(),
sessions: HashMap::new(),
pending_pkce: HashMap::new(),
pending_saml: HashMap::new(),
audit_log: Vec::new(),
}
}
pub fn register_provider(&mut self, provider: IdentityProvider) -> SsoResult<()> {
if !provider.enabled {
warn!("Registering disabled provider: {}", provider.id);
}
self.audit(
SsoAuditEvent::new(SsoAuditEventKind::ProviderRegistered)
.with_provider(provider.id.clone()),
);
info!(
"Registered SSO provider: {} ({:?})",
provider.name, provider.protocol
);
self.providers.insert(provider.id.clone(), provider);
Ok(())
}
pub fn get_provider(&self, provider_id: &str) -> SsoResult<&IdentityProvider> {
self.providers
.get(provider_id)
.ok_or_else(|| SsoError::ProviderNotFound(provider_id.to_string()))
}
pub fn find_provider_for_email(&self, email: &str) -> Option<&IdentityProvider> {
let mut best: Option<&IdentityProvider> = None;
for provider in self.providers.values() {
if !provider.enabled {
continue;
}
if provider.handles_domain(email) {
match best {
None => best = Some(provider),
Some(current) if provider.priority > current.priority => {
best = Some(provider);
}
_ => {}
}
}
}
best
}
pub fn list_providers(&self) -> Vec<&IdentityProvider> {
let mut providers: Vec<&IdentityProvider> = self.providers.values().collect();
providers.sort_by_key(|item| std::cmp::Reverse(item.priority));
providers
}
pub fn initiate_saml_flow(&mut self, provider_id: &str) -> SsoResult<SamlAuthRequest> {
let provider = self.get_provider(provider_id)?;
if !provider.enabled {
return Err(SsoError::AuthenticationFailed(format!(
"Provider {} is disabled",
provider_id
)));
}
let request = SamlAuthRequest::new(provider)?;
debug!("Created SAML auth request: {}", request.id);
self.pending_saml
.insert(request.id.clone(), request.clone());
Ok(request)
}
pub fn initiate_oidc_flow(&mut self, provider_id: &str) -> SsoResult<OidcAuthRequest> {
let provider = self.get_provider(provider_id)?;
if !provider.enabled {
return Err(SsoError::AuthenticationFailed(format!(
"Provider {} is disabled",
provider_id
)));
}
let oidc = provider.oidc_config.as_ref().ok_or_else(|| {
SsoError::ConfigurationError(format!(
"Provider {} does not have OIDC configuration",
provider_id
))
})?;
let pkce = PkceState::generate(provider_id.to_string());
let state = pkce.state.clone();
let nonce = pkce.nonce.clone();
let mut params = vec![
("response_type".to_string(), oidc.response_type.clone()),
("client_id".to_string(), oidc.client_id.clone()),
("redirect_uri".to_string(), oidc.redirect_uri.clone()),
("scope".to_string(), oidc.scopes.join(" ")),
("state".to_string(), state.clone()),
];
if oidc.validate_nonce {
params.push(("nonce".to_string(), nonce.clone()));
}
if oidc.use_pkce {
params.push(("code_challenge".to_string(), pkce.code_challenge.clone()));
params.push(("code_challenge_method".to_string(), "S256".to_string()));
}
let query_string: String = params
.iter()
.map(|(k, v)| format!("{}={}", k, urlencoded(v)))
.collect::<Vec<_>>()
.join("&");
let authorization_url = format!(
"{}/authorize?{}",
oidc.discovery_url
.trim_end_matches("/.well-known/openid-configuration"),
query_string
);
let auth_request = OidcAuthRequest {
state: state.clone(),
nonce,
authorization_url,
provider_id: provider_id.to_string(),
created_at: Utc::now(),
};
self.pending_pkce.insert(state, pkce);
Ok(auth_request)
}
pub fn process_saml_response(
&mut self,
provider_id: &str,
assertion: SamlAssertion,
) -> SsoResult<SsoSession> {
let provider = self.get_provider(provider_id)?;
let saml = provider.saml_config.as_ref().ok_or_else(|| {
SsoError::ConfigurationError(format!("Provider {} missing SAML config", provider_id))
})?;
if !assertion.is_valid(saml.clock_skew_seconds) {
self.audit(
SsoAuditEvent::new(SsoAuditEventKind::InvalidAssertion)
.with_provider(provider_id.to_string())
.with_detail(
"reason".to_string(),
"Assertion outside validity window".to_string(),
),
);
return Err(SsoError::SamlAssertionError(
"Assertion is outside validity window".to_string(),
));
}
let email = assertion
.get_attribute(&saml.email_attribute)
.ok_or_else(|| {
SsoError::AttributeMappingError(format!(
"Email attribute '{}' not found in assertion",
saml.email_attribute
))
})?
.to_string();
let display_name = assertion
.get_attribute(&saml.display_name_attribute)
.unwrap_or(&assertion.name_id)
.to_string();
let idp_groups: Vec<String> = saml
.groups_attribute
.as_ref()
.and_then(|attr| assertion.attributes.get(attr))
.cloned()
.unwrap_or_default();
let roles: Vec<String> = if idp_groups.is_empty() {
vec![provider.default_role.clone()]
} else {
idp_groups
.iter()
.map(|g| provider.map_role(g))
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect()
};
let session_expires = assertion.not_on_or_after;
let user_profile = SsoUserProfile {
subject_id: assertion.name_id.clone(),
provider_id: provider_id.to_string(),
email: email.clone(),
email_verified: true, display_name,
idp_groups,
roles,
attributes: assertion.attributes.clone(),
authenticated_at: Utc::now(),
session_expires_at: session_expires,
};
let mut session = SsoSession::new(user_profile, self.config.session_ttl_seconds);
session.saml_session_index = assertion.session_index.clone();
info!("SAML SSO login successful for user: {}", email);
self.audit(
SsoAuditEvent::new(SsoAuditEventKind::LoginSuccess)
.with_user(email)
.with_provider(provider_id.to_string())
.with_session(session.id.clone()),
);
self.sessions.insert(session.id.clone(), session.clone());
Ok(session)
}
pub fn process_oidc_callback(
&mut self,
state: &str,
id_token: OidcIdToken,
) -> SsoResult<SsoSession> {
let pkce = self
.pending_pkce
.remove(state)
.ok_or_else(|| SsoError::OidcError(format!("Unknown state: {}", state)))?;
if pkce.is_expired() {
return Err(SsoError::OidcError(
"Authorization request has expired".to_string(),
));
}
let provider_id = pkce.provider_id.clone();
let provider = self.get_provider(&provider_id)?;
let oidc = provider.oidc_config.as_ref().ok_or_else(|| {
SsoError::ConfigurationError("Provider missing OIDC config".to_string())
})?;
if oidc.validate_nonce {
let token_nonce = id_token.nonce.as_deref().unwrap_or("");
if token_nonce != pkce.nonce {
return Err(SsoError::OidcError("Nonce mismatch".to_string()));
}
}
if id_token.is_expired() {
return Err(SsoError::TokenValidationFailed(
"ID token has expired".to_string(),
));
}
let email = id_token.email.clone().unwrap_or_else(|| {
id_token
.get_claim_str(&oidc.email_claim)
.unwrap_or_else(|| id_token.sub.clone())
});
let display_name = id_token.name.clone().unwrap_or_else(|| {
id_token
.get_claim_str(&oidc.name_claim)
.unwrap_or_else(|| email.clone())
});
let idp_groups: Vec<String> = oidc
.groups_claim
.as_ref()
.map(|c| id_token.get_claim_list(c))
.unwrap_or_default();
let roles: Vec<String> = if idp_groups.is_empty() {
vec![provider.default_role.clone()]
} else {
idp_groups
.iter()
.map(|g| provider.map_role(g))
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect()
};
let mut attributes: HashMap<String, Vec<String>> = HashMap::new();
for (k, v) in &id_token.claims {
let val_str = match v {
serde_json::Value::String(s) => Some(vec![s.clone()]),
serde_json::Value::Array(arr) => Some(
arr.iter()
.filter_map(|x| x.as_str().map(|s| s.to_string()))
.collect(),
),
_ => None,
};
if let Some(vs) = val_str {
attributes.insert(k.clone(), vs);
}
}
let user_profile = SsoUserProfile {
subject_id: id_token.sub.clone(),
provider_id: provider_id.clone(),
email: email.clone(),
email_verified: id_token.email_verified.unwrap_or(false),
display_name,
idp_groups,
roles,
attributes,
authenticated_at: Utc::now(),
session_expires_at: Some(id_token.exp),
};
let session = SsoSession::new(user_profile, self.config.session_ttl_seconds);
info!("OIDC SSO login successful for user: {}", email);
self.audit(
SsoAuditEvent::new(SsoAuditEventKind::LoginSuccess)
.with_user(email)
.with_provider(provider_id)
.with_session(session.id.clone()),
);
self.sessions.insert(session.id.clone(), session.clone());
Ok(session)
}
pub fn get_session(&mut self, session_id: &str) -> SsoResult<&SsoSession> {
let session = self
.sessions
.get_mut(session_id)
.ok_or_else(|| SsoError::AuthenticationFailed("Session not found".to_string()))?;
if !session.is_valid() {
let email = session.user_profile.email.clone();
return Err(SsoError::SessionExpired(email));
}
session.touch();
Ok(self
.sessions
.get(session_id)
.expect("session was just inserted"))
}
pub fn revoke_session(&mut self, session_id: &str) -> SsoResult<()> {
let session = self
.sessions
.get_mut(session_id)
.ok_or_else(|| SsoError::AuthenticationFailed("Session not found".to_string()))?;
let email = session.user_profile.email.clone();
let provider_id = session.user_profile.provider_id.clone();
session.revoke();
self.audit(
SsoAuditEvent::new(SsoAuditEventKind::SessionRevoked)
.with_user(email)
.with_provider(provider_id)
.with_session(session_id.to_string()),
);
info!("SSO session {} revoked", session_id);
Ok(())
}
pub fn cleanup_expired_sessions(&mut self) -> usize {
let expired: Vec<String> = self
.sessions
.iter()
.filter(|(_, s)| !s.is_valid())
.map(|(id, _)| id.clone())
.collect();
let count = expired.len();
for id in &expired {
self.sessions.remove(id);
}
if count > 0 {
debug!("Cleaned up {} expired SSO sessions", count);
}
count
}
pub fn active_session_count(&self) -> usize {
self.sessions.values().filter(|s| s.is_valid()).count()
}
pub fn recent_audit_events(&self, limit: usize) -> Vec<&SsoAuditEvent> {
let start = if self.audit_log.len() > limit {
self.audit_log.len() - limit
} else {
0
};
self.audit_log[start..].iter().collect()
}
pub fn statistics(&self) -> SsoStatistics {
let total_sessions = self.sessions.len();
let active_sessions = self.active_session_count();
let revoked_sessions = self.sessions.values().filter(|s| s.revoked).count();
let login_successes = self
.audit_log
.iter()
.filter(|e| matches!(e.kind, SsoAuditEventKind::LoginSuccess))
.count();
let login_failures = self
.audit_log
.iter()
.filter(|e| matches!(e.kind, SsoAuditEventKind::LoginFailure))
.count();
SsoStatistics {
total_providers: self.providers.len(),
enabled_providers: self.providers.values().filter(|p| p.enabled).count(),
total_sessions,
active_sessions,
expired_sessions: total_sessions - active_sessions - revoked_sessions,
revoked_sessions,
total_audit_events: self.audit_log.len(),
login_successes,
login_failures,
}
}
fn audit(&mut self, event: SsoAuditEvent) {
if self.config.enable_audit_log {
if self.audit_log.len() >= self.config.audit_log_buffer_size {
let remove_count = self.config.audit_log_buffer_size / 10;
self.audit_log.drain(0..remove_count);
}
self.audit_log.push(event);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoStatistics {
pub total_providers: usize,
pub enabled_providers: usize,
pub total_sessions: usize,
pub active_sessions: usize,
pub expired_sessions: usize,
pub revoked_sessions: usize,
pub total_audit_events: usize,
pub login_successes: usize,
pub login_failures: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcAuthRequest {
pub state: String,
pub nonce: String,
pub authorization_url: String,
pub provider_id: String,
pub created_at: DateTime<Utc>,
}
fn urlencoded(s: &str) -> String {
let mut encoded = String::with_capacity(s.len());
for byte in s.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
encoded.push(byte as char);
}
b' ' => encoded.push('+'),
b => {
use std::fmt::Write as _;
let _ = write!(encoded, "%{:02X}", b);
}
}
}
encoded
}
#[cfg(test)]
mod tests;
#[cfg(test)]
mod extended_tests;