Skip to main content

rustauth_sso/
linking.rs

1use rustauth_core::context::AuthContext;
2use rustauth_core::crypto::random::generate_random_string;
3use rustauth_core::db::{Create, DbAdapter, DbValue, FindOne, User, Where};
4use rustauth_core::error::RustAuthError;
5use rustauth_oauth::oauth2::OAuth2Tokens;
6use rustauth_plugins::organization::{
7    organization_options_from_context, provision_organization_member,
8    ProvisionOrganizationMemberInput,
9};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use time::OffsetDateTime;
13
14use crate::options::{
15    DomainVerificationOptions, OrganizationProvisioningOptions, OrganizationRoleInput,
16    ProvisionUserInput, SsoOptions,
17};
18use crate::store::{SsoProviderRecord, SsoProviderStore};
19
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22/// Normalized identity profile produced by an OIDC or SAML SSO login.
23pub struct NormalizedSsoProfile {
24    /// Provider protocol, such as `oidc` or `saml`.
25    pub provider_type: String,
26    /// Stable RustAuth SSO provider id.
27    pub provider_id: String,
28    /// External account id from the identity provider.
29    pub account_id: String,
30    /// Normalized email address.
31    pub email: String,
32    /// Whether the identity provider marked the email as verified.
33    pub email_verified: bool,
34    /// Display name, when available.
35    pub name: Option<String>,
36    /// Avatar URL, when available.
37    pub image: Option<String>,
38    /// Extra mapped claims or attributes requested by provider mapping.
39    pub raw_attributes: Option<Value>,
40    /// OIDC token data; `None` for SAML.
41    pub token_data: Option<OAuth2Tokens>,
42}
43
44pub fn provider_matches_email_domain(provider: &SsoProviderRecord, email: &str) -> bool {
45    let Some((_, email_domain)) = email.rsplit_once('@') else {
46        return false;
47    };
48    let email_domain = normalize_domain(email_domain);
49    if email_domain.is_empty() {
50        return false;
51    }
52    provider.domain.split(',').any(|domain| {
53        let domain = normalize_domain(domain);
54        if domain.is_empty() || is_public_suffix(&domain) {
55            return false;
56        }
57        email_domain == domain || email_domain.ends_with(&format!(".{domain}"))
58    })
59}
60
61pub fn validate_provider_domains(domains: &str) -> bool {
62    let mut has_domain = false;
63    for domain in domains.split(',') {
64        let domain = normalize_domain(domain);
65        if domain.is_empty() || is_public_suffix(&domain) {
66            return false;
67        }
68        has_domain = true;
69    }
70    has_domain
71}
72
73pub async fn assign_organization_from_provider(
74    context: &AuthContext,
75    adapter: &dyn DbAdapter,
76    provisioning_options: &OrganizationProvisioningOptions,
77    user: &User,
78    profile: &NormalizedSsoProfile,
79    provider: &SsoProviderRecord,
80    token: Option<OAuth2Tokens>,
81) -> Result<(), RustAuthError> {
82    let Some(organization_id) = provider.organization_id.as_deref() else {
83        return Ok(());
84    };
85    if provisioning_options.disabled || !context.has_plugin("organization") {
86        return Ok(());
87    }
88    if organization_member(adapter, organization_id, &user.id)
89        .await?
90        .is_some()
91    {
92        return Ok(());
93    }
94
95    let role = provisioning_options
96        .resolve_role(OrganizationRoleInput {
97            user: user.clone(),
98            profile: profile.clone(),
99            provider: provider.clone(),
100            token,
101        })
102        .await?;
103    if let Some(options) = organization_options_from_context(context) {
104        provision_organization_member(
105            adapter,
106            &options,
107            ProvisionOrganizationMemberInput {
108                organization_id,
109                user,
110                role: &role,
111            },
112        )
113        .await?;
114    } else {
115        create_org_membership_direct(adapter, organization_id, &user.id, &role).await?;
116    }
117    Ok(())
118}
119
120pub async fn provision_sso_user(
121    options: &SsoOptions,
122    user: &User,
123    profile: &NormalizedSsoProfile,
124    provider: &SsoProviderRecord,
125    token: Option<OAuth2Tokens>,
126    is_register: bool,
127) -> Result<(), RustAuthError> {
128    let Some(provision_user) = &options.provision_user else {
129        return Ok(());
130    };
131    if !is_register && !options.provision_user_on_every_login {
132        return Ok(());
133    }
134    provision_user
135        .resolve(ProvisionUserInput {
136            user: user.clone(),
137            profile: profile.clone(),
138            provider: provider.clone(),
139            token,
140            is_register,
141        })
142        .await
143}
144
145pub async fn assign_organization_by_domain(
146    context: &AuthContext,
147    adapter: &dyn DbAdapter,
148    provisioning_options: &OrganizationProvisioningOptions,
149    domain_verification: &DomainVerificationOptions,
150    user: &User,
151) -> Result<(), RustAuthError> {
152    assign_organization_by_domain_with_model(
153        context,
154        adapter,
155        crate::schema::SSO_PROVIDER_MODEL,
156        provisioning_options,
157        domain_verification,
158        user,
159    )
160    .await
161}
162
163pub(crate) async fn assign_organization_by_domain_with_model(
164    context: &AuthContext,
165    adapter: &dyn DbAdapter,
166    model_name: &str,
167    provisioning_options: &OrganizationProvisioningOptions,
168    domain_verification: &DomainVerificationOptions,
169    user: &User,
170) -> Result<(), RustAuthError> {
171    if provisioning_options.disabled || !context.has_plugin("organization") {
172        return Ok(());
173    }
174
175    let Some((_, email_domain)) = user.email.rsplit_once('@') else {
176        return Ok(());
177    };
178    let email_domain = normalize_domain(email_domain);
179    if email_domain.is_empty() {
180        return Ok(());
181    }
182
183    let providers = SsoProviderStore::new_with_model_and_domain_verification(
184        adapter,
185        model_name,
186        domain_verification.enabled,
187    )
188    .list()
189    .await?;
190    let provider = providers.into_iter().find(|provider| {
191        provider.organization_id.is_some()
192            && provider_matches_email_domain(provider, &user.email)
193            && (!domain_verification.enabled || provider.domain_verified.unwrap_or(false))
194    });
195    let Some(provider) = provider else {
196        return Ok(());
197    };
198
199    let provider_type = if provider.saml_config.is_some() {
200        "saml"
201    } else {
202        "oidc"
203    };
204    assign_organization_from_provider(
205        context,
206        adapter,
207        provisioning_options,
208        user,
209        &NormalizedSsoProfile {
210            provider_type: provider_type.to_owned(),
211            provider_id: provider.provider_id.clone(),
212            account_id: user.id.clone(),
213            email: user.email.clone(),
214            email_verified: user.email_verified,
215            name: Some(user.name.clone()),
216            image: user.image.clone(),
217            raw_attributes: None,
218            token_data: None,
219        },
220        &provider,
221        None,
222    )
223    .await
224}
225
226async fn organization_member(
227    adapter: &dyn DbAdapter,
228    organization_id: &str,
229    user_id: &str,
230) -> Result<Option<rustauth_core::db::DbRecord>, RustAuthError> {
231    adapter
232        .find_one(
233            FindOne::new("member")
234                .where_clause(Where::new(
235                    "organization_id",
236                    DbValue::String(organization_id.to_owned()),
237                ))
238                .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned()))),
239        )
240        .await
241}
242
243async fn create_org_membership_direct(
244    adapter: &dyn DbAdapter,
245    organization_id: &str,
246    user_id: &str,
247    role: &str,
248) -> Result<(), RustAuthError> {
249    adapter
250        .create(
251            Create::new("member")
252                .data("id", DbValue::String(generate_random_string(32)))
253                .data(
254                    "organization_id",
255                    DbValue::String(organization_id.to_owned()),
256                )
257                .data("user_id", DbValue::String(user_id.to_owned()))
258                .data("role", DbValue::String(role.to_owned()))
259                .data("created_at", DbValue::Timestamp(OffsetDateTime::now_utc()))
260                .force_allow_id(),
261        )
262        .await?;
263    Ok(())
264}
265
266fn normalize_domain(value: &str) -> String {
267    let trimmed = value
268        .trim()
269        .trim_start_matches("http://")
270        .trim_start_matches("https://")
271        .trim_end_matches('.');
272    trimmed
273        .split('/')
274        .next()
275        .unwrap_or_default()
276        .to_ascii_lowercase()
277}
278
279fn is_public_suffix(domain: &str) -> bool {
280    publicsuffix2::List::global()
281        .tld(domain, publicsuffix2::MatchOpts::default())
282        .is_some_and(|suffix| suffix == domain)
283}