use rustauth_core::context::AuthContext;
use rustauth_core::crypto::random::generate_random_string;
use rustauth_core::db::{Create, DbAdapter, DbValue, FindOne, User, Where};
use rustauth_core::error::RustAuthError;
use rustauth_oauth::oauth2::OAuth2Tokens;
use rustauth_plugins::organization::{
organization_options_from_context, provision_organization_member,
ProvisionOrganizationMemberInput,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use time::OffsetDateTime;
use crate::options::{
DomainVerificationOptions, OrganizationProvisioningOptions, OrganizationRoleInput,
ProvisionUserInput, SsoOptions,
};
use crate::store::{SsoProviderRecord, SsoProviderStore};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NormalizedSsoProfile {
pub provider_type: String,
pub provider_id: String,
pub account_id: String,
pub email: String,
pub email_verified: bool,
pub name: Option<String>,
pub image: Option<String>,
pub raw_attributes: Option<Value>,
pub token_data: Option<OAuth2Tokens>,
}
pub fn provider_matches_email_domain(provider: &SsoProviderRecord, email: &str) -> bool {
let Some((_, email_domain)) = email.rsplit_once('@') else {
return false;
};
let email_domain = normalize_domain(email_domain);
if email_domain.is_empty() {
return false;
}
provider.domain.split(',').any(|domain| {
let domain = normalize_domain(domain);
if domain.is_empty() || is_public_suffix(&domain) {
return false;
}
email_domain == domain || email_domain.ends_with(&format!(".{domain}"))
})
}
pub fn validate_provider_domains(domains: &str) -> bool {
let mut has_domain = false;
for domain in domains.split(',') {
let domain = normalize_domain(domain);
if domain.is_empty() || is_public_suffix(&domain) {
return false;
}
has_domain = true;
}
has_domain
}
pub async fn assign_organization_from_provider(
context: &AuthContext,
adapter: &dyn DbAdapter,
provisioning_options: &OrganizationProvisioningOptions,
user: &User,
profile: &NormalizedSsoProfile,
provider: &SsoProviderRecord,
token: Option<OAuth2Tokens>,
) -> Result<(), RustAuthError> {
let Some(organization_id) = provider.organization_id.as_deref() else {
return Ok(());
};
if provisioning_options.disabled || !context.has_plugin("organization") {
return Ok(());
}
if organization_member(adapter, organization_id, &user.id)
.await?
.is_some()
{
return Ok(());
}
let role = provisioning_options
.resolve_role(OrganizationRoleInput {
user: user.clone(),
profile: profile.clone(),
provider: provider.clone(),
token,
})
.await?;
if let Some(options) = organization_options_from_context(context) {
provision_organization_member(
adapter,
&options,
ProvisionOrganizationMemberInput {
organization_id,
user,
role: &role,
},
)
.await?;
} else {
create_org_membership_direct(adapter, organization_id, &user.id, &role).await?;
}
Ok(())
}
pub async fn provision_sso_user(
options: &SsoOptions,
user: &User,
profile: &NormalizedSsoProfile,
provider: &SsoProviderRecord,
token: Option<OAuth2Tokens>,
is_register: bool,
) -> Result<(), RustAuthError> {
let Some(provision_user) = &options.provision_user else {
return Ok(());
};
if !is_register && !options.provision_user_on_every_login {
return Ok(());
}
provision_user
.resolve(ProvisionUserInput {
user: user.clone(),
profile: profile.clone(),
provider: provider.clone(),
token,
is_register,
})
.await
}
pub async fn assign_organization_by_domain(
context: &AuthContext,
adapter: &dyn DbAdapter,
provisioning_options: &OrganizationProvisioningOptions,
domain_verification: &DomainVerificationOptions,
user: &User,
) -> Result<(), RustAuthError> {
assign_organization_by_domain_with_model(
context,
adapter,
crate::schema::SSO_PROVIDER_MODEL,
provisioning_options,
domain_verification,
user,
)
.await
}
pub(crate) async fn assign_organization_by_domain_with_model(
context: &AuthContext,
adapter: &dyn DbAdapter,
model_name: &str,
provisioning_options: &OrganizationProvisioningOptions,
domain_verification: &DomainVerificationOptions,
user: &User,
) -> Result<(), RustAuthError> {
if provisioning_options.disabled || !context.has_plugin("organization") {
return Ok(());
}
let Some((_, email_domain)) = user.email.rsplit_once('@') else {
return Ok(());
};
let email_domain = normalize_domain(email_domain);
if email_domain.is_empty() {
return Ok(());
}
let providers = SsoProviderStore::new_with_model_and_domain_verification(
adapter,
model_name,
domain_verification.enabled,
)
.list()
.await?;
let provider = providers.into_iter().find(|provider| {
provider.organization_id.is_some()
&& provider_matches_email_domain(provider, &user.email)
&& (!domain_verification.enabled || provider.domain_verified.unwrap_or(false))
});
let Some(provider) = provider else {
return Ok(());
};
let provider_type = if provider.saml_config.is_some() {
"saml"
} else {
"oidc"
};
assign_organization_from_provider(
context,
adapter,
provisioning_options,
user,
&NormalizedSsoProfile {
provider_type: provider_type.to_owned(),
provider_id: provider.provider_id.clone(),
account_id: user.id.clone(),
email: user.email.clone(),
email_verified: user.email_verified,
name: Some(user.name.clone()),
image: user.image.clone(),
raw_attributes: None,
token_data: None,
},
&provider,
None,
)
.await
}
async fn organization_member(
adapter: &dyn DbAdapter,
organization_id: &str,
user_id: &str,
) -> Result<Option<rustauth_core::db::DbRecord>, RustAuthError> {
adapter
.find_one(
FindOne::new("member")
.where_clause(Where::new(
"organization_id",
DbValue::String(organization_id.to_owned()),
))
.where_clause(Where::new("user_id", DbValue::String(user_id.to_owned()))),
)
.await
}
async fn create_org_membership_direct(
adapter: &dyn DbAdapter,
organization_id: &str,
user_id: &str,
role: &str,
) -> Result<(), RustAuthError> {
adapter
.create(
Create::new("member")
.data("id", DbValue::String(generate_random_string(32)))
.data(
"organization_id",
DbValue::String(organization_id.to_owned()),
)
.data("user_id", DbValue::String(user_id.to_owned()))
.data("role", DbValue::String(role.to_owned()))
.data("created_at", DbValue::Timestamp(OffsetDateTime::now_utc()))
.force_allow_id(),
)
.await?;
Ok(())
}
fn normalize_domain(value: &str) -> String {
let trimmed = value
.trim()
.trim_start_matches("http://")
.trim_start_matches("https://")
.trim_end_matches('.');
trimmed
.split('/')
.next()
.unwrap_or_default()
.to_ascii_lowercase()
}
fn is_public_suffix(domain: &str) -> bool {
publicsuffix2::List::global()
.tld(domain, publicsuffix2::MatchOpts::default())
.is_some_and(|suffix| suffix == domain)
}