1use rustauth_core::context::AuthContext;
2use rustauth_core::crypto::random::generate_random_string;
3use rustauth_core::db::{Create, DbAdapter, DbValue, FindOne, User, Where};
4use rustauth_core::error::RustAuthError;
5use rustauth_oauth::oauth2::OAuth2Tokens;
6use rustauth_plugins::organization::{
7 organization_options_from_context, provision_organization_member,
8 ProvisionOrganizationMemberInput,
9};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use time::OffsetDateTime;
13
14use crate::options::{
15 DomainVerificationOptions, OrganizationProvisioningOptions, OrganizationRoleInput,
16 ProvisionUserInput, SsoOptions,
17};
18use crate::store::{SsoProviderRecord, SsoProviderStore};
19
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22pub struct NormalizedSsoProfile {
24 pub provider_type: String,
26 pub provider_id: String,
28 pub account_id: String,
30 pub email: String,
32 pub email_verified: bool,
34 pub name: Option<String>,
36 pub image: Option<String>,
38 pub raw_attributes: Option<Value>,
40 pub token_data: Option<OAuth2Tokens>,
42}
43
44pub fn provider_matches_email_domain(provider: &SsoProviderRecord, email: &str) -> bool {
45 let Some((_, email_domain)) = email.rsplit_once('@') else {
46 return false;
47 };
48 let email_domain = normalize_domain(email_domain);
49 if email_domain.is_empty() {
50 return false;
51 }
52 provider.domain.split(',').any(|domain| {
53 let domain = normalize_domain(domain);
54 if domain.is_empty() || is_public_suffix(&domain) {
55 return false;
56 }
57 email_domain == domain || email_domain.ends_with(&format!(".{domain}"))
58 })
59}
60
61pub fn validate_provider_domains(domains: &str) -> bool {
62 let mut has_domain = false;
63 for domain in domains.split(',') {
64 let domain = normalize_domain(domain);
65 if domain.is_empty() || is_public_suffix(&domain) {
66 return false;
67 }
68 has_domain = true;
69 }
70 has_domain
71}
72
73pub async fn assign_organization_from_provider(
74 context: &AuthContext,
75 adapter: &dyn DbAdapter,
76 provisioning_options: &OrganizationProvisioningOptions,
77 user: &User,
78 profile: &NormalizedSsoProfile,
79 provider: &SsoProviderRecord,
80 token: Option<OAuth2Tokens>,
81) -> Result<(), RustAuthError> {
82 let Some(organization_id) = provider.organization_id.as_deref() else {
83 return Ok(());
84 };
85 if provisioning_options.disabled || !context.has_plugin("organization") {
86 return Ok(());
87 }
88 if organization_member(adapter, organization_id, &user.id)
89 .await?
90 .is_some()
91 {
92 return Ok(());
93 }
94
95 let role = provisioning_options
96 .resolve_role(OrganizationRoleInput {
97 user: user.clone(),
98 profile: profile.clone(),
99 provider: provider.clone(),
100 token,
101 })
102 .await?;
103 if let Some(options) = organization_options_from_context(context) {
104 provision_organization_member(
105 adapter,
106 &options,
107 ProvisionOrganizationMemberInput {
108 organization_id,
109 user,
110 role: &role,
111 },
112 )
113 .await?;
114 } else {
115 create_org_membership_direct(adapter, organization_id, &user.id, &role).await?;
116 }
117 Ok(())
118}
119
120pub async fn provision_sso_user(
121 options: &SsoOptions,
122 user: &User,
123 profile: &NormalizedSsoProfile,
124 provider: &SsoProviderRecord,
125 token: Option<OAuth2Tokens>,
126 is_register: bool,
127) -> Result<(), RustAuthError> {
128 let Some(provision_user) = &options.provision_user else {
129 return Ok(());
130 };
131 if !is_register && !options.provision_user_on_every_login {
132 return Ok(());
133 }
134 provision_user
135 .resolve(ProvisionUserInput {
136 user: user.clone(),
137 profile: profile.clone(),
138 provider: provider.clone(),
139 token,
140 is_register,
141 })
142 .await
143}
144
145pub async fn assign_organization_by_domain(
146 context: &AuthContext,
147 adapter: &dyn DbAdapter,
148 provisioning_options: &OrganizationProvisioningOptions,
149 domain_verification: &DomainVerificationOptions,
150 user: &User,
151) -> Result<(), RustAuthError> {
152 assign_organization_by_domain_with_model(
153 context,
154 adapter,
155 crate::schema::SSO_PROVIDER_MODEL,
156 provisioning_options,
157 domain_verification,
158 user,
159 )
160 .await
161}
162
163pub(crate) async fn assign_organization_by_domain_with_model(
164 context: &AuthContext,
165 adapter: &dyn DbAdapter,
166 model_name: &str,
167 provisioning_options: &OrganizationProvisioningOptions,
168 domain_verification: &DomainVerificationOptions,
169 user: &User,
170) -> Result<(), RustAuthError> {
171 if provisioning_options.disabled || !context.has_plugin("organization") {
172 return Ok(());
173 }
174
175 let Some((_, email_domain)) = user.email.rsplit_once('@') else {
176 return Ok(());
177 };
178 let email_domain = normalize_domain(email_domain);
179 if email_domain.is_empty() {
180 return Ok(());
181 }
182
183 let providers = SsoProviderStore::new_with_model_and_domain_verification(
184 adapter,
185 model_name,
186 domain_verification.enabled,
187 )
188 .list()
189 .await?;
190 let provider = providers.into_iter().find(|provider| {
191 provider.organization_id.is_some()
192 && provider_matches_email_domain(provider, &user.email)
193 && (!domain_verification.enabled || provider.domain_verified.unwrap_or(false))
194 });
195 let Some(provider) = provider else {
196 return Ok(());
197 };
198
199 let provider_type = if provider.saml_config.is_some() {
200 "saml"
201 } else {
202 "oidc"
203 };
204 assign_organization_from_provider(
205 context,
206 adapter,
207 provisioning_options,
208 user,
209 &NormalizedSsoProfile {
210 provider_type: provider_type.to_owned(),
211 provider_id: provider.provider_id.clone(),
212 account_id: user.id.clone(),
213 email: user.email.clone(),
214 email_verified: user.email_verified,
215 name: Some(user.name.clone()),
216 image: user.image.clone(),
217 raw_attributes: None,
218 token_data: None,
219 },
220 &provider,
221 None,
222 )
223 .await
224}
225
226async fn organization_member(
227 adapter: &dyn DbAdapter,
228 organization_id: &str,
229 user_id: &str,
230) -> Result<Option<rustauth_core::db::DbRecord>, RustAuthError> {
231 adapter
232 .find_one(
233 FindOne::new("member")
234 .where_clause(Where::new(
235 "organization_id",
236 DbValue::String(organization_id.to_owned()),
237 ))
238 .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned()))),
239 )
240 .await
241}
242
243async fn create_org_membership_direct(
244 adapter: &dyn DbAdapter,
245 organization_id: &str,
246 user_id: &str,
247 role: &str,
248) -> Result<(), RustAuthError> {
249 adapter
250 .create(
251 Create::new("member")
252 .data("id", DbValue::String(generate_random_string(32)))
253 .data(
254 "organization_id",
255 DbValue::String(organization_id.to_owned()),
256 )
257 .data("user_id", DbValue::String(user_id.to_owned()))
258 .data("role", DbValue::String(role.to_owned()))
259 .data("created_at", DbValue::Timestamp(OffsetDateTime::now_utc()))
260 .force_allow_id(),
261 )
262 .await?;
263 Ok(())
264}
265
266fn normalize_domain(value: &str) -> String {
267 let trimmed = value
268 .trim()
269 .trim_start_matches("http://")
270 .trim_start_matches("https://")
271 .trim_end_matches('.');
272 trimmed
273 .split('/')
274 .next()
275 .unwrap_or_default()
276 .to_ascii_lowercase()
277}
278
279fn is_public_suffix(domain: &str) -> bool {
280 publicsuffix2::List::global()
281 .tld(domain, publicsuffix2::MatchOpts::default())
282 .is_some_and(|suffix| suffix == domain)
283}