1use rustauth_core::crypto::random::generate_random_string;
2use rustauth_core::db::{
3 Create, DbAdapter, DbRecord, DbValue, Delete, FindMany, FindOne, Update, Where,
4};
5use rustauth_core::error::RustAuthError;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9#[cfg(feature = "oidc")]
10use crate::oidc_impl::flow::oidc_redirect_uri;
11use crate::options::OidcConfig;
12#[cfg(feature = "saml")]
13use crate::options::SamlConfig;
14use crate::schema::SSO_PROVIDER_MODEL;
15#[cfg(feature = "saml")]
16use crate::utils::certificate_metadata;
17use crate::utils::client_id_last_four;
18
19const SSO_PROVIDER_FIELDS: [&str; 9] = [
20 "id",
21 "issuer",
22 "oidc_config",
23 "saml_config",
24 "user_id",
25 "provider_id",
26 "organization_id",
27 "domain",
28 "created_at",
29];
30
31const SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED: [&str; 10] = [
32 "id",
33 "issuer",
34 "oidc_config",
35 "saml_config",
36 "user_id",
37 "provider_id",
38 "organization_id",
39 "domain",
40 "domain_verified",
41 "created_at",
42];
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "camelCase")]
46pub struct SsoProviderRecord {
48 pub id: String,
50 pub issuer: String,
52 pub oidc_config: Option<String>,
54 pub saml_config: Option<String>,
56 pub user_id: String,
58 pub provider_id: String,
60 pub organization_id: Option<String>,
62 pub domain: String,
64 pub domain_verified: Option<bool>,
66 pub created_at: Option<OffsetDateTime>,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
71#[serde(rename_all = "camelCase")]
72pub struct SanitizedSsoProvider {
74 pub provider_id: String,
76 pub provider_type: String,
78 #[serde(rename = "type")]
79 pub upstream_type: String,
81 pub issuer: String,
83 pub domain: String,
85 pub organization_id: Option<String>,
87 pub domain_verified: bool,
89 pub oidc_config: Option<SanitizedOidcConfig>,
91 pub saml_config: Option<SanitizedSamlConfig>,
93 #[serde(skip_serializing_if = "Option::is_none", rename = "redirectURI")]
94 pub redirect_uri: Option<String>,
96 pub sp_metadata_url: String,
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
101#[serde(rename_all = "camelCase")]
102pub struct SanitizedOidcConfig {
104 pub discovery_endpoint: String,
106 pub client_id_last_four: String,
108 pub pkce: bool,
110 pub authorization_endpoint: Option<String>,
112 pub token_endpoint: Option<String>,
114 pub user_info_endpoint: Option<String>,
116 pub jwks_endpoint: Option<String>,
118 pub revocation_endpoint: Option<String>,
120 pub end_session_endpoint: Option<String>,
122 pub introspection_endpoint: Option<String>,
124 pub token_endpoint_authentication: Option<crate::options::TokenEndpointAuthentication>,
126 pub scopes: Option<Vec<String>>,
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
131#[serde(rename_all = "camelCase")]
132pub struct SanitizedSamlConfig {
134 pub entry_point: String,
136 pub callback_url: String,
138 pub acs_url: Option<String>,
140 pub audience: Option<String>,
142 pub want_assertions_signed: bool,
144 pub authn_requests_signed: bool,
146 pub identifier_format: Option<String>,
148 pub signature_algorithm: Option<String>,
150 pub digest_algorithm: Option<String>,
152 pub certificate_sha256_fingerprint: String,
154 pub certificate_not_before: Option<String>,
156 pub certificate_not_after: Option<String>,
158 pub certificate_public_key_algorithm: Option<String>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 pub certificate_error: Option<String>,
163}
164
165#[derive(Clone, Copy)]
166pub struct SsoProviderStore<'a> {
168 adapter: &'a dyn DbAdapter,
169 model_name: &'a str,
170 include_domain_verified: bool,
171}
172
173impl<'a> SsoProviderStore<'a> {
174 pub fn new(adapter: &'a dyn DbAdapter) -> Self {
176 Self::new_with_model(adapter, SSO_PROVIDER_MODEL)
177 }
178
179 pub fn new_with_model(adapter: &'a dyn DbAdapter, model_name: &'a str) -> Self {
181 Self {
182 adapter,
183 model_name,
184 include_domain_verified: false,
185 }
186 }
187
188 pub fn new_with_options(
190 adapter: &'a dyn DbAdapter,
191 options: &'a crate::options::SsoOptions,
192 ) -> Self {
193 Self::new_with_model_and_domain_verification(
194 adapter,
195 &options.model_name,
196 options.domain_verification.enabled,
197 )
198 }
199
200 pub fn new_with_model_and_domain_verification(
202 adapter: &'a dyn DbAdapter,
203 model_name: &'a str,
204 include_domain_verified: bool,
205 ) -> Self {
206 Self {
207 adapter,
208 model_name,
209 include_domain_verified,
210 }
211 }
212
213 pub async fn list(&self) -> Result<Vec<SsoProviderRecord>, RustAuthError> {
215 let query = self.select_find_many(FindMany::new(self.model_name));
216 self.adapter
217 .find_many(query)
218 .await?
219 .into_iter()
220 .map(record_from_db)
221 .collect()
222 }
223
224 pub async fn list_by_user(
226 &self,
227 user_id: &str,
228 ) -> Result<Vec<SsoProviderRecord>, RustAuthError> {
229 let query = FindMany::new(self.model_name)
230 .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())));
231 self.adapter
232 .find_many(self.select_find_many(query))
233 .await?
234 .into_iter()
235 .map(record_from_db)
236 .collect()
237 }
238
239 pub async fn find_by_provider_id(
241 &self,
242 provider_id: &str,
243 ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
244 let query = FindOne::new(self.model_name).where_clause(provider_id_where(provider_id));
245 self.adapter
246 .find_one(self.select_find_one(query))
247 .await?
248 .map(record_from_db)
249 .transpose()
250 }
251
252 pub async fn find_by_organization_id(
254 &self,
255 organization_id: &str,
256 ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
257 let query = FindOne::new(self.model_name).where_clause(Where::new(
258 "organization_id",
259 DbValue::String(organization_id.to_owned()),
260 ));
261 self.adapter
262 .find_one(self.select_find_one(query))
263 .await?
264 .map(record_from_db)
265 .transpose()
266 }
267
268 pub async fn create(
270 &self,
271 input: CreateSsoProviderInput,
272 ) -> Result<SsoProviderRecord, RustAuthError> {
273 let now = OffsetDateTime::now_utc();
274 let mut query = Create::new(self.model_name)
275 .data("id", DbValue::String(generate_random_string(32)))
276 .data("issuer", DbValue::String(input.issuer))
277 .data("oidc_config", optional_string(input.oidc_config))
278 .data("saml_config", optional_string(input.saml_config))
279 .data("user_id", DbValue::String(input.user_id))
280 .data("provider_id", DbValue::String(input.provider_id))
281 .data("organization_id", optional_string(input.organization_id))
282 .data("domain", DbValue::String(input.domain))
283 .data("created_at", DbValue::Timestamp(now))
284 .data("updated_at", DbValue::Timestamp(now))
285 .force_allow_id();
286 query = self.select_create(query);
287 if let Some(domain_verified) = input.domain_verified {
288 query = query.data("domain_verified", DbValue::Boolean(domain_verified));
289 }
290
291 record_from_db(self.adapter.create(query).await?)
292 }
293
294 pub async fn update_domain_verified(
296 &self,
297 provider_id: &str,
298 verified: bool,
299 ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
300 self.adapter
301 .update(
302 Update::new(self.model_name)
303 .where_clause(provider_id_where(provider_id))
304 .data("domain_verified", DbValue::Boolean(verified))
305 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
306 )
307 .await?
308 .map(record_from_db)
309 .transpose()
310 }
311
312 pub async fn update(
314 &self,
315 provider_id: &str,
316 input: UpdateSsoProviderInput,
317 ) -> Result<Option<SsoProviderRecord>, RustAuthError> {
318 let mut query = Update::new(self.model_name)
319 .where_clause(provider_id_where(provider_id))
320 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
321
322 if let Some(issuer) = input.issuer {
323 query = query.data("issuer", DbValue::String(issuer));
324 }
325 if let Some(domain) = input.domain {
326 query = query.data("domain", DbValue::String(domain));
327 }
328 if let Some(organization_id) = input.organization_id {
329 query = query.data("organization_id", DbValue::String(organization_id));
330 }
331 if let Some(oidc_config) = input.oidc_config {
332 query = query.data("oidc_config", optional_string(oidc_config));
333 }
334 if let Some(saml_config) = input.saml_config {
335 query = query.data("saml_config", optional_string(saml_config));
336 }
337 if let Some(domain_verified) = input.domain_verified {
338 query = query.data("domain_verified", DbValue::Boolean(domain_verified));
339 }
340
341 self.adapter
342 .update(query)
343 .await?
344 .map(record_from_db)
345 .transpose()
346 }
347
348 pub async fn delete(&self, provider_id: &str) -> Result<(), RustAuthError> {
350 self.adapter
351 .delete(Delete::new(self.model_name).where_clause(provider_id_where(provider_id)))
352 .await
353 }
354
355 fn select_create(&self, query: Create) -> Create {
356 if self.include_domain_verified {
357 query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
358 } else {
359 query.select(SSO_PROVIDER_FIELDS)
360 }
361 }
362
363 fn select_find_one(&self, query: FindOne) -> FindOne {
364 if self.include_domain_verified {
365 query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
366 } else {
367 query.select(SSO_PROVIDER_FIELDS)
368 }
369 }
370
371 fn select_find_many(&self, query: FindMany) -> FindMany {
372 if self.include_domain_verified {
373 query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
374 } else {
375 query.select(SSO_PROVIDER_FIELDS)
376 }
377 }
378}
379
380#[derive(Debug, Clone, PartialEq, Eq)]
381pub struct CreateSsoProviderInput {
383 pub provider_id: String,
385 pub issuer: String,
387 pub domain: String,
389 pub user_id: String,
391 pub organization_id: Option<String>,
393 pub oidc_config: Option<String>,
395 pub saml_config: Option<String>,
397 pub domain_verified: Option<bool>,
399}
400
401#[derive(Debug, Clone, Default, PartialEq, Eq)]
402pub struct UpdateSsoProviderInput {
404 pub issuer: Option<String>,
406 pub domain: Option<String>,
408 pub organization_id: Option<String>,
410 pub oidc_config: Option<Option<String>>,
412 pub saml_config: Option<Option<String>>,
414 pub domain_verified: Option<bool>,
416}
417
418impl SsoProviderRecord {
419 pub fn sanitized_with_options(
421 &self,
422 base_url: &str,
423 options: Option<&crate::options::SsoOptions>,
424 ) -> SanitizedSsoProvider {
425 let oidc_config = self
426 .oidc_config
427 .as_deref()
428 .and_then(|value| serde_json::from_str::<OidcConfig>(value).ok())
429 .map(|config| SanitizedOidcConfig {
430 discovery_endpoint: config.discovery_endpoint,
431 client_id_last_four: client_id_last_four(&config.client_id),
432 pkce: config.pkce,
433 authorization_endpoint: config.authorization_endpoint,
434 token_endpoint: config.token_endpoint,
435 user_info_endpoint: config.user_info_endpoint,
436 jwks_endpoint: config.jwks_endpoint,
437 revocation_endpoint: config.revocation_endpoint,
438 end_session_endpoint: config.end_session_endpoint,
439 introspection_endpoint: config.introspection_endpoint,
440 token_endpoint_authentication: config.token_endpoint_authentication,
441 scopes: config.scopes,
442 });
443 #[cfg(feature = "saml")]
444 let saml_config = self
445 .saml_config
446 .as_deref()
447 .and_then(|value| serde_json::from_str::<SamlConfig>(value).ok())
448 .map(|config| {
449 let certificate = certificate_metadata(&config.cert);
450 SanitizedSamlConfig {
451 entry_point: config.entry_point,
452 callback_url: config.callback_url,
453 acs_url: config.acs_url,
454 audience: config.audience,
455 want_assertions_signed: config.want_assertions_signed,
456 authn_requests_signed: config.authn_requests_signed,
457 identifier_format: config.identifier_format,
458 signature_algorithm: config.signature_algorithm,
459 digest_algorithm: config.digest_algorithm,
460 certificate_sha256_fingerprint: certificate.sha256_fingerprint,
461 certificate_not_before: certificate.not_before,
462 certificate_not_after: certificate.not_after,
463 certificate_public_key_algorithm: certificate.public_key_algorithm,
464 certificate_error: certificate.parse_error,
465 }
466 });
467 #[cfg(not(feature = "saml"))]
468 let saml_config = None;
469 let provider_type = if saml_config.is_some() {
470 "saml"
471 } else {
472 "oidc"
473 }
474 .to_owned();
475 #[cfg(feature = "oidc")]
476 let redirect_uri = oidc_config.as_ref().and_then(|_| {
477 options.map(|options| oidc_redirect_uri(base_url, &self.provider_id, options))
478 });
479 #[cfg(not(feature = "oidc"))]
480 let redirect_uri = None;
481 SanitizedSsoProvider {
482 provider_id: self.provider_id.clone(),
483 provider_type: provider_type.clone(),
484 upstream_type: provider_type,
485 issuer: self.issuer.clone(),
486 domain: self.domain.clone(),
487 organization_id: self.organization_id.clone(),
488 domain_verified: self.domain_verified.unwrap_or(false),
489 oidc_config,
490 saml_config,
491 redirect_uri,
492 sp_metadata_url: format!(
493 "{}/sso/saml2/sp/metadata?providerId={}",
494 base_url.trim_end_matches('/'),
495 url::form_urlencoded::byte_serialize(self.provider_id.as_bytes())
496 .collect::<String>()
497 ),
498 }
499 }
500
501 pub fn sanitized(&self, base_url: &str) -> SanitizedSsoProvider {
503 self.sanitized_with_options(base_url, None)
504 }
505}
506
507fn provider_id_where(provider_id: &str) -> Where {
508 Where::new("provider_id", DbValue::String(provider_id.to_owned()))
509}
510
511fn optional_string(value: Option<String>) -> DbValue {
512 value.map(DbValue::String).unwrap_or(DbValue::Null)
513}
514
515fn record_from_db(record: DbRecord) -> Result<SsoProviderRecord, RustAuthError> {
516 Ok(SsoProviderRecord {
517 id: required_string(&record, "id")?.to_owned(),
518 issuer: required_string(&record, "issuer")?.to_owned(),
519 oidc_config: optional_string_field(&record, "oidc_config")?,
520 saml_config: optional_string_field(&record, "saml_config")?,
521 user_id: required_string(&record, "user_id")?.to_owned(),
522 provider_id: required_string(&record, "provider_id")?.to_owned(),
523 organization_id: optional_string_field(&record, "organization_id")?,
524 domain: required_string(&record, "domain")?.to_owned(),
525 domain_verified: optional_bool_field(&record, "domain_verified")?,
526 created_at: optional_timestamp_field(&record, "created_at")?,
527 })
528}
529
530fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
531 match record.get(field) {
532 Some(DbValue::String(value)) => Ok(value),
533 Some(_) => Err(invalid_field(field, "string")),
534 None => Err(missing_field(field)),
535 }
536}
537
538fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
539 match record.get(field) {
540 Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
541 Some(DbValue::Json(value)) => serde_json::to_string(value)
542 .map(Some)
543 .map_err(|error| RustAuthError::Adapter(format!("invalid JSON in `{field}`: {error}"))),
544 Some(DbValue::Null) | None => Ok(None),
545 Some(_) => Err(invalid_field(field, "string, JSON, or null")),
546 }
547}
548
549fn optional_bool_field(record: &DbRecord, field: &str) -> Result<Option<bool>, RustAuthError> {
550 match record.get(field) {
551 Some(DbValue::Boolean(value)) => Ok(Some(*value)),
552 Some(DbValue::Null) | None => Ok(None),
553 Some(_) => Err(invalid_field(field, "boolean or null")),
554 }
555}
556
557fn optional_timestamp_field(
558 record: &DbRecord,
559 field: &str,
560) -> Result<Option<OffsetDateTime>, RustAuthError> {
561 match record.get(field) {
562 Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
563 Some(DbValue::Null) | None => Ok(None),
564 Some(_) => Err(invalid_field(field, "timestamp or null")),
565 }
566}
567
568fn missing_field(field: &str) -> RustAuthError {
569 RustAuthError::Adapter(format!("sso provider record is missing `{field}`"))
570}
571
572fn invalid_field(field: &str, expected: &str) -> RustAuthError {
573 RustAuthError::Adapter(format!(
574 "sso provider record field `{field}` must be {expected}"
575 ))
576}