Skip to main content

rustauth_sso/
store.rs

1use rustauth_core::crypto::random::generate_random_string;
2use rustauth_core::db::{
3    Create, DbAdapter, DbRecord, DbValue, Delete, FindMany, FindOne, Update, Where,
4};
5use rustauth_core::error::RustAuthError;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9#[cfg(feature = "oidc")]
10use crate::oidc_impl::flow::oidc_redirect_uri;
11use crate::options::OidcConfig;
12#[cfg(feature = "saml")]
13use crate::options::SamlConfig;
14use crate::schema::SSO_PROVIDER_MODEL;
15#[cfg(feature = "saml")]
16use crate::utils::certificate_metadata;
17use crate::utils::client_id_last_four;
18
19const SSO_PROVIDER_FIELDS: [&str; 9] = [
20    "id",
21    "issuer",
22    "oidc_config",
23    "saml_config",
24    "user_id",
25    "provider_id",
26    "organization_id",
27    "domain",
28    "created_at",
29];
30
31const SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED: [&str; 10] = [
32    "id",
33    "issuer",
34    "oidc_config",
35    "saml_config",
36    "user_id",
37    "provider_id",
38    "organization_id",
39    "domain",
40    "domain_verified",
41    "created_at",
42];
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "camelCase")]
46/// Raw SSO provider record loaded from the adapter.
47pub struct SsoProviderRecord {
48    /// Database id.
49    pub id: String,
50    /// Provider issuer.
51    pub issuer: String,
52    /// Serialized OIDC config JSON.
53    pub oidc_config: Option<String>,
54    /// Serialized SAML config JSON.
55    pub saml_config: Option<String>,
56    /// Owner user id.
57    pub user_id: String,
58    /// Stable provider id.
59    pub provider_id: String,
60    /// Optional organization id assigned to provider users.
61    pub organization_id: Option<String>,
62    /// Comma-separated provider domains.
63    pub domain: String,
64    /// Domain verification state.
65    pub domain_verified: Option<bool>,
66    /// Creation timestamp.
67    pub created_at: Option<OffsetDateTime>,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
71#[serde(rename_all = "camelCase")]
72/// Provider representation returned by public read endpoints.
73pub struct SanitizedSsoProvider {
74    /// Stable provider id.
75    pub provider_id: String,
76    /// Preferred provider protocol label.
77    pub provider_type: String,
78    #[serde(rename = "type")]
79    /// Upstream-compatible provider protocol label.
80    pub upstream_type: String,
81    /// Provider issuer.
82    pub issuer: String,
83    /// Provider domains.
84    pub domain: String,
85    /// Optional organization id.
86    pub organization_id: Option<String>,
87    /// Whether the provider domain has been verified.
88    pub domain_verified: bool,
89    /// Sanitized OIDC config, if configured.
90    pub oidc_config: Option<SanitizedOidcConfig>,
91    /// Sanitized SAML config, if configured.
92    pub saml_config: Option<SanitizedSamlConfig>,
93    #[serde(skip_serializing_if = "Option::is_none", rename = "redirectURI")]
94    /// Shared OIDC redirect URI shown to clients.
95    pub redirect_uri: Option<String>,
96    /// SAML service provider metadata URL.
97    pub sp_metadata_url: String,
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
101#[serde(rename_all = "camelCase")]
102/// OIDC provider config with secret material removed.
103pub struct SanitizedOidcConfig {
104    /// Discovery endpoint URL.
105    pub discovery_endpoint: String,
106    /// Last four characters of the client id.
107    pub client_id_last_four: String,
108    /// Whether PKCE is enabled.
109    pub pkce: bool,
110    /// Authorization endpoint URL.
111    pub authorization_endpoint: Option<String>,
112    /// Token endpoint URL.
113    pub token_endpoint: Option<String>,
114    /// UserInfo endpoint URL.
115    pub user_info_endpoint: Option<String>,
116    /// JWKS endpoint URL.
117    pub jwks_endpoint: Option<String>,
118    /// OAuth token revocation endpoint URL.
119    pub revocation_endpoint: Option<String>,
120    /// OIDC end-session endpoint URL.
121    pub end_session_endpoint: Option<String>,
122    /// OAuth token introspection endpoint URL.
123    pub introspection_endpoint: Option<String>,
124    /// Client authentication method selected for the token endpoint.
125    pub token_endpoint_authentication: Option<crate::options::TokenEndpointAuthentication>,
126    /// Configured default scopes.
127    pub scopes: Option<Vec<String>>,
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
131#[serde(rename_all = "camelCase")]
132/// SAML provider config with private key material removed.
133pub struct SanitizedSamlConfig {
134    /// IdP entry point.
135    pub entry_point: String,
136    /// Callback URL.
137    pub callback_url: String,
138    /// Assertion consumer service URL.
139    pub acs_url: Option<String>,
140    /// Expected audience.
141    pub audience: Option<String>,
142    /// Whether assertion signatures are required.
143    pub want_assertions_signed: bool,
144    /// Whether outbound AuthnRequests are signed.
145    pub authn_requests_signed: bool,
146    /// Requested NameID format.
147    pub identifier_format: Option<String>,
148    /// Signature algorithm.
149    pub signature_algorithm: Option<String>,
150    /// Digest algorithm.
151    pub digest_algorithm: Option<String>,
152    /// SHA-256 fingerprint of the IdP certificate.
153    pub certificate_sha256_fingerprint: String,
154    /// Certificate validity start, when parseable.
155    pub certificate_not_before: Option<String>,
156    /// Certificate validity end, when parseable.
157    pub certificate_not_after: Option<String>,
158    /// Certificate public key algorithm, when parseable.
159    pub certificate_public_key_algorithm: Option<String>,
160    #[serde(skip_serializing_if = "Option::is_none")]
161    /// Certificate parse error, when metadata could not be extracted.
162    pub certificate_error: Option<String>,
163}
164
165#[derive(Clone, Copy)]
166/// Adapter-backed store for SSO provider records.
167pub struct SsoProviderStore<'a> {
168    adapter: &'a dyn DbAdapter,
169    model_name: &'a str,
170    include_domain_verified: bool,
171}
172
173impl<'a> SsoProviderStore<'a> {
174    /// Create a provider store over an RustAuth adapter.
175    pub fn new(adapter: &'a dyn DbAdapter) -> Self {
176        Self::new_with_model(adapter, SSO_PROVIDER_MODEL)
177    }
178
179    /// Create a provider store using a custom logical model name.
180    pub fn new_with_model(adapter: &'a dyn DbAdapter, model_name: &'a str) -> Self {
181        Self {
182            adapter,
183            model_name,
184            include_domain_verified: false,
185        }
186    }
187
188    /// Create a provider store from plugin options.
189    pub fn new_with_options(
190        adapter: &'a dyn DbAdapter,
191        options: &'a crate::options::SsoOptions,
192    ) -> Self {
193        Self::new_with_model_and_domain_verification(
194            adapter,
195            &options.model_name,
196            options.domain_verification.enabled,
197        )
198    }
199
200    /// Create a provider store with explicit model and domain verification field support.
201    pub fn new_with_model_and_domain_verification(
202        adapter: &'a dyn DbAdapter,
203        model_name: &'a str,
204        include_domain_verified: bool,
205    ) -> Self {
206        Self {
207            adapter,
208            model_name,
209            include_domain_verified,
210        }
211    }
212
213    /// List all SSO providers.
214    pub async fn list(&self) -> Result<Vec<SsoProviderRecord>, RustAuthError> {
215        let query = self.select_find_many(FindMany::new(self.model_name));
216        self.adapter
217            .find_many(query)
218            .await?
219            .into_iter()
220            .map(record_from_db)
221            .collect()
222    }
223
224    /// List SSO providers owned by a user.
225    pub async fn list_by_user(
226        &self,
227        user_id: &str,
228    ) -> Result<Vec<SsoProviderRecord>, RustAuthError> {
229        let query = FindMany::new(self.model_name)
230            .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())));
231        self.adapter
232            .find_many(self.select_find_many(query))
233            .await?
234            .into_iter()
235            .map(record_from_db)
236            .collect()
237    }
238
239    /// Find an SSO provider by stable provider id.
240    pub async fn find_by_provider_id(
241        &self,
242        provider_id: &str,
243    ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
244        let query = FindOne::new(self.model_name).where_clause(provider_id_where(provider_id));
245        self.adapter
246            .find_one(self.select_find_one(query))
247            .await?
248            .map(record_from_db)
249            .transpose()
250    }
251
252    /// Find the first SSO provider assigned to an organization.
253    pub async fn find_by_organization_id(
254        &self,
255        organization_id: &str,
256    ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
257        let query = FindOne::new(self.model_name).where_clause(Where::new(
258            "organization_id",
259            DbValue::String(organization_id.to_owned()),
260        ));
261        self.adapter
262            .find_one(self.select_find_one(query))
263            .await?
264            .map(record_from_db)
265            .transpose()
266    }
267
268    /// Create an SSO provider record.
269    pub async fn create(
270        &self,
271        input: CreateSsoProviderInput,
272    ) -> Result<SsoProviderRecord, RustAuthError> {
273        let now = OffsetDateTime::now_utc();
274        let mut query = Create::new(self.model_name)
275            .data("id", DbValue::String(generate_random_string(32)))
276            .data("issuer", DbValue::String(input.issuer))
277            .data("oidc_config", optional_string(input.oidc_config))
278            .data("saml_config", optional_string(input.saml_config))
279            .data("user_id", DbValue::String(input.user_id))
280            .data("provider_id", DbValue::String(input.provider_id))
281            .data("organization_id", optional_string(input.organization_id))
282            .data("domain", DbValue::String(input.domain))
283            .data("created_at", DbValue::Timestamp(now))
284            .data("updated_at", DbValue::Timestamp(now))
285            .force_allow_id();
286        query = self.select_create(query);
287        if let Some(domain_verified) = input.domain_verified {
288            query = query.data("domain_verified", DbValue::Boolean(domain_verified));
289        }
290
291        record_from_db(self.adapter.create(query).await?)
292    }
293
294    /// Update a provider domain verification flag.
295    pub async fn update_domain_verified(
296        &self,
297        provider_id: &str,
298        verified: bool,
299    ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
300        self.adapter
301            .update(
302                Update::new(self.model_name)
303                    .where_clause(provider_id_where(provider_id))
304                    .data("domain_verified", DbValue::Boolean(verified))
305                    .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
306            )
307            .await?
308            .map(record_from_db)
309            .transpose()
310    }
311
312    /// Partially update an SSO provider record.
313    pub async fn update(
314        &self,
315        provider_id: &str,
316        input: UpdateSsoProviderInput,
317    ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
318        let mut query = Update::new(self.model_name)
319            .where_clause(provider_id_where(provider_id))
320            .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
321
322        if let Some(issuer) = input.issuer {
323            query = query.data("issuer", DbValue::String(issuer));
324        }
325        if let Some(domain) = input.domain {
326            query = query.data("domain", DbValue::String(domain));
327        }
328        if let Some(organization_id) = input.organization_id {
329            query = query.data("organization_id", DbValue::String(organization_id));
330        }
331        if let Some(oidc_config) = input.oidc_config {
332            query = query.data("oidc_config", optional_string(oidc_config));
333        }
334        if let Some(saml_config) = input.saml_config {
335            query = query.data("saml_config", optional_string(saml_config));
336        }
337        if let Some(domain_verified) = input.domain_verified {
338            query = query.data("domain_verified", DbValue::Boolean(domain_verified));
339        }
340
341        self.adapter
342            .update(query)
343            .await?
344            .map(record_from_db)
345            .transpose()
346    }
347
348    /// Delete an SSO provider by provider id.
349    pub async fn delete(&self, provider_id: &str) -> Result<(), RustAuthError> {
350        self.adapter
351            .delete(Delete::new(self.model_name).where_clause(provider_id_where(provider_id)))
352            .await
353    }
354
355    fn select_create(&self, query: Create) -> Create {
356        if self.include_domain_verified {
357            query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
358        } else {
359            query.select(SSO_PROVIDER_FIELDS)
360        }
361    }
362
363    fn select_find_one(&self, query: FindOne) -> FindOne {
364        if self.include_domain_verified {
365            query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
366        } else {
367            query.select(SSO_PROVIDER_FIELDS)
368        }
369    }
370
371    fn select_find_many(&self, query: FindMany) -> FindMany {
372        if self.include_domain_verified {
373            query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
374        } else {
375            query.select(SSO_PROVIDER_FIELDS)
376        }
377    }
378}
379
380#[derive(Debug, Clone, PartialEq, Eq)]
381/// Input used to create an SSO provider record.
382pub struct CreateSsoProviderInput {
383    /// Stable provider id.
384    pub provider_id: String,
385    /// Provider issuer.
386    pub issuer: String,
387    /// Provider domains.
388    pub domain: String,
389    /// Owner user id.
390    pub user_id: String,
391    /// Optional organization id.
392    pub organization_id: Option<String>,
393    /// Serialized OIDC configuration.
394    pub oidc_config: Option<String>,
395    /// Serialized SAML configuration.
396    pub saml_config: Option<String>,
397    /// Initial domain verification state.
398    pub domain_verified: Option<bool>,
399}
400
401#[derive(Debug, Clone, Default, PartialEq, Eq)]
402/// Partial provider update input used by route handlers.
403pub struct UpdateSsoProviderInput {
404    /// Updated issuer.
405    pub issuer: Option<String>,
406    /// Updated domains.
407    pub domain: Option<String>,
408    /// Updated organization id.
409    pub organization_id: Option<String>,
410    /// Updated serialized OIDC config; `Some(None)` clears it.
411    pub oidc_config: Option<Option<String>>,
412    /// Updated serialized SAML config; `Some(None)` clears it.
413    pub saml_config: Option<Option<String>>,
414    /// Updated domain verification state.
415    pub domain_verified: Option<bool>,
416}
417
418impl SsoProviderRecord {
419    /// Convert the raw provider record into the public sanitized shape.
420    pub fn sanitized_with_options(
421        &self,
422        base_url: &str,
423        options: Option<&crate::options::SsoOptions>,
424    ) -> SanitizedSsoProvider {
425        let oidc_config = self
426            .oidc_config
427            .as_deref()
428            .and_then(|value| serde_json::from_str::<OidcConfig>(value).ok())
429            .map(|config| SanitizedOidcConfig {
430                discovery_endpoint: config.discovery_endpoint,
431                client_id_last_four: client_id_last_four(&config.client_id),
432                pkce: config.pkce,
433                authorization_endpoint: config.authorization_endpoint,
434                token_endpoint: config.token_endpoint,
435                user_info_endpoint: config.user_info_endpoint,
436                jwks_endpoint: config.jwks_endpoint,
437                revocation_endpoint: config.revocation_endpoint,
438                end_session_endpoint: config.end_session_endpoint,
439                introspection_endpoint: config.introspection_endpoint,
440                token_endpoint_authentication: config.token_endpoint_authentication,
441                scopes: config.scopes,
442            });
443        #[cfg(feature = "saml")]
444        let saml_config = self
445            .saml_config
446            .as_deref()
447            .and_then(|value| serde_json::from_str::<SamlConfig>(value).ok())
448            .map(|config| {
449                let certificate = certificate_metadata(&config.cert);
450                SanitizedSamlConfig {
451                    entry_point: config.entry_point,
452                    callback_url: config.callback_url,
453                    acs_url: config.acs_url,
454                    audience: config.audience,
455                    want_assertions_signed: config.want_assertions_signed,
456                    authn_requests_signed: config.authn_requests_signed,
457                    identifier_format: config.identifier_format,
458                    signature_algorithm: config.signature_algorithm,
459                    digest_algorithm: config.digest_algorithm,
460                    certificate_sha256_fingerprint: certificate.sha256_fingerprint,
461                    certificate_not_before: certificate.not_before,
462                    certificate_not_after: certificate.not_after,
463                    certificate_public_key_algorithm: certificate.public_key_algorithm,
464                    certificate_error: certificate.parse_error,
465                }
466            });
467        #[cfg(not(feature = "saml"))]
468        let saml_config = None;
469        let provider_type = if saml_config.is_some() {
470            "saml"
471        } else {
472            "oidc"
473        }
474        .to_owned();
475        #[cfg(feature = "oidc")]
476        let redirect_uri = oidc_config.as_ref().and_then(|_| {
477            options.map(|options| oidc_redirect_uri(base_url, &self.provider_id, options))
478        });
479        #[cfg(not(feature = "oidc"))]
480        let redirect_uri = None;
481        SanitizedSsoProvider {
482            provider_id: self.provider_id.clone(),
483            provider_type: provider_type.clone(),
484            upstream_type: provider_type,
485            issuer: self.issuer.clone(),
486            domain: self.domain.clone(),
487            organization_id: self.organization_id.clone(),
488            domain_verified: self.domain_verified.unwrap_or(false),
489            oidc_config,
490            saml_config,
491            redirect_uri,
492            sp_metadata_url: format!(
493                "{}/sso/saml2/sp/metadata?providerId={}",
494                base_url.trim_end_matches('/'),
495                url::form_urlencoded::byte_serialize(self.provider_id.as_bytes())
496                    .collect::<String>()
497            ),
498        }
499    }
500
501    /// Convert the raw provider record into the public sanitized shape.
502    pub fn sanitized(&self, base_url: &str) -> SanitizedSsoProvider {
503        self.sanitized_with_options(base_url, None)
504    }
505}
506
507fn provider_id_where(provider_id: &str) -> Where {
508    Where::new("provider_id", DbValue::String(provider_id.to_owned()))
509}
510
511fn optional_string(value: Option<String>) -> DbValue {
512    value.map(DbValue::String).unwrap_or(DbValue::Null)
513}
514
515fn record_from_db(record: DbRecord) -> Result<SsoProviderRecord, RustAuthError> {
516    Ok(SsoProviderRecord {
517        id: required_string(&record, "id")?.to_owned(),
518        issuer: required_string(&record, "issuer")?.to_owned(),
519        oidc_config: optional_string_field(&record, "oidc_config")?,
520        saml_config: optional_string_field(&record, "saml_config")?,
521        user_id: required_string(&record, "user_id")?.to_owned(),
522        provider_id: required_string(&record, "provider_id")?.to_owned(),
523        organization_id: optional_string_field(&record, "organization_id")?,
524        domain: required_string(&record, "domain")?.to_owned(),
525        domain_verified: optional_bool_field(&record, "domain_verified")?,
526        created_at: optional_timestamp_field(&record, "created_at")?,
527    })
528}
529
530fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
531    match record.get(field) {
532        Some(DbValue::String(value)) => Ok(value),
533        Some(_) => Err(invalid_field(field, "string")),
534        None => Err(missing_field(field)),
535    }
536}
537
538fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
539    match record.get(field) {
540        Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
541        Some(DbValue::Json(value)) => serde_json::to_string(value)
542            .map(Some)
543            .map_err(|error| RustAuthError::Adapter(format!("invalid JSON in `{field}`: {error}"))),
544        Some(DbValue::Null) | None => Ok(None),
545        Some(_) => Err(invalid_field(field, "string, JSON, or null")),
546    }
547}
548
549fn optional_bool_field(record: &DbRecord, field: &str) -> Result<Option<bool>, RustAuthError> {
550    match record.get(field) {
551        Some(DbValue::Boolean(value)) => Ok(Some(*value)),
552        Some(DbValue::Null) | None => Ok(None),
553        Some(_) => Err(invalid_field(field, "boolean or null")),
554    }
555}
556
557fn optional_timestamp_field(
558    record: &DbRecord,
559    field: &str,
560) -> Result<Option<OffsetDateTime>, RustAuthError> {
561    match record.get(field) {
562        Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
563        Some(DbValue::Null) | None => Ok(None),
564        Some(_) => Err(invalid_field(field, "timestamp or null")),
565    }
566}
567
568fn missing_field(field: &str) -> RustAuthError {
569    RustAuthError::Adapter(format!("sso provider record is missing `{field}`"))
570}
571
572fn invalid_field(field: &str, expected: &str) -> RustAuthError {
573    RustAuthError::Adapter(format!(
574        "sso provider record field `{field}` must be {expected}"
575    ))
576}