1use serde::{Deserialize, Serialize};
2#[cfg(not(feature = "saml"))]
3use std::collections::BTreeMap;
4use std::future::Future;
5use time::Duration;
6
7use rustauth_core::db::User;
8use rustauth_core::error::RustAuthError;
9use rustauth_core::options::RateLimitRule;
10#[cfg(not(feature = "saml"))]
11use rustauth_core::secret::SecretString;
12
13pub use rustauth_oidc::{OidcConfig, OidcMapping, TokenEndpointAuthentication};
14
15#[cfg(feature = "saml")]
16pub use rustauth_saml::{
17 DeprecatedAlgorithmBehavior, SamlConfig, SamlIdpMetadata, SamlMapping, SamlService,
18 SamlSpMetadata,
19};
20
21#[path = "options/audit.rs"]
22mod audit;
23#[path = "options/callbacks.rs"]
24mod callbacks;
25
26pub use audit::*;
27pub use callbacks::*;
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "camelCase")]
31pub struct OrganizationProvisioningOptions {
33 pub disabled: bool,
35 pub default_role: String,
37 #[serde(skip)]
38 pub get_role: Option<OrganizationRoleResolver>,
40}
41
42impl Default for OrganizationProvisioningOptions {
43 fn default() -> Self {
44 Self {
45 disabled: false,
46 default_role: "member".to_owned(),
47 get_role: None,
48 }
49 }
50}
51
52impl OrganizationProvisioningOptions {
53 #[must_use]
54 pub fn disabled(mut self, disabled: bool) -> Self {
56 self.disabled = disabled;
57 self
58 }
59
60 #[must_use]
61 pub fn default_role(mut self, role: impl Into<String>) -> Self {
63 self.default_role = role.into();
64 self
65 }
66
67 #[must_use]
68 pub fn get_role<F, Fut>(mut self, resolver: F) -> Self
70 where
71 F: Fn(OrganizationRoleInput) -> Fut + Send + Sync + 'static,
72 Fut: Future<Output = Result<String, RustAuthError>> + Send + 'static,
73 {
74 self.get_role = Some(OrganizationRoleResolver::new(resolver));
75 self
76 }
77
78 pub async fn resolve_role(
80 &self,
81 input: OrganizationRoleInput,
82 ) -> Result<String, RustAuthError> {
83 match &self.get_role {
84 Some(resolver) => resolver.resolve(input).await,
85 None => Ok(self.default_role.clone()),
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct SsoOptions {
94 pub model_name: String,
96 pub provider_table: String,
98 pub providers_limit: usize,
100 #[serde(skip)]
101 pub providers_limit_callback: Option<ProvidersLimitResolver>,
103 pub domain_verification: DomainVerificationOptions,
105 pub redirect_uri: Option<String>,
107 pub disable_implicit_sign_up: bool,
109 pub trust_email_verified: bool,
111 pub default_override_user_info: bool,
113 #[serde(default)]
115 pub oidc: OidcOptions,
116 #[serde(skip)]
117 pub provision_user: Option<ProvisionUserResolver>,
119 pub provision_user_on_every_login: bool,
121 pub organization_provisioning: OrganizationProvisioningOptions,
123 pub saml: SamlOptions,
125 #[serde(skip)]
126 pub rate_limit: SsoRateLimitOptions,
128 #[serde(skip)]
129 pub audit_event: Option<SsoAuditEventResolver>,
131 pub default_sso: Vec<SsoProvider>,
133}
134
135impl Default for SsoOptions {
136 fn default() -> Self {
137 Self {
138 model_name: "sso_provider".to_owned(),
139 provider_table: "sso_providers".to_owned(),
140 providers_limit: 10,
141 providers_limit_callback: None,
142 domain_verification: DomainVerificationOptions::default(),
143 redirect_uri: None,
144 disable_implicit_sign_up: false,
145 trust_email_verified: false,
146 default_override_user_info: false,
147 oidc: OidcOptions::default(),
148 provision_user: None,
149 provision_user_on_every_login: false,
150 organization_provisioning: OrganizationProvisioningOptions::default(),
151 saml: SamlOptions::default(),
152 rate_limit: SsoRateLimitOptions::default(),
153 audit_event: None,
154 default_sso: Vec::new(),
155 }
156 }
157}
158
159impl SsoOptions {
160 pub fn new() -> Self {
162 Self::default()
163 }
164
165 #[must_use]
166 pub fn provider_table(mut self, table: impl Into<String>) -> Self {
168 self.provider_table = table.into();
169 self
170 }
171
172 #[must_use]
173 pub fn providers_limit(mut self, limit: usize) -> Self {
175 self.providers_limit = limit;
176 self
177 }
178
179 #[must_use]
180 pub fn providers_limit_callback<F, Fut>(mut self, resolver: F) -> Self
182 where
183 F: Fn(User) -> Fut + Send + Sync + 'static,
184 Fut: Future<Output = Result<usize, RustAuthError>> + Send + 'static,
185 {
186 self.providers_limit_callback = Some(ProvidersLimitResolver::new(resolver));
187 self
188 }
189
190 pub async fn resolve_providers_limit(&self, user: User) -> Result<usize, RustAuthError> {
192 match &self.providers_limit_callback {
193 Some(resolver) => resolver.resolve(user).await,
194 None => Ok(self.providers_limit),
195 }
196 }
197
198 #[must_use]
199 pub fn domain_verification_enabled(mut self, enabled: bool) -> Self {
201 self.domain_verification.enabled = enabled;
202 self
203 }
204
205 #[must_use]
206 pub fn domain_txt_resolver<F, Fut>(mut self, resolver: F) -> Self
208 where
209 F: Fn(String) -> Fut + Send + Sync + 'static,
210 Fut: Future<Output = Result<Vec<String>, RustAuthError>> + Send + 'static,
211 {
212 self.domain_verification.txt_resolver = Some(DnsTxtResolver::new(resolver));
213 self
214 }
215
216 #[must_use]
217 pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
219 self.redirect_uri = Some(redirect_uri.into());
220 self
221 }
222
223 #[must_use]
224 pub fn organization_provisioning(
226 mut self,
227 provisioning: OrganizationProvisioningOptions,
228 ) -> Self {
229 self.organization_provisioning = provisioning;
230 self
231 }
232
233 #[must_use]
234 pub fn provision_user<F, Fut>(mut self, resolver: F) -> Self
236 where
237 F: Fn(ProvisionUserInput) -> Fut + Send + Sync + 'static,
238 Fut: Future<Output = Result<(), RustAuthError>> + Send + 'static,
239 {
240 self.provision_user = Some(ProvisionUserResolver::new(resolver));
241 self
242 }
243
244 #[must_use]
245 pub fn provision_user_on_every_login(mut self, enabled: bool) -> Self {
247 self.provision_user_on_every_login = enabled;
248 self
249 }
250
251 #[must_use]
252 pub fn rate_limit(mut self, rate_limit: SsoRateLimitOptions) -> Self {
254 self.rate_limit = rate_limit;
255 self
256 }
257
258 #[must_use]
259 pub fn rate_limit_enabled(mut self, enabled: bool) -> Self {
261 self.rate_limit.enabled = enabled;
262 self
263 }
264
265 #[must_use]
266 pub fn audit_event<F, Fut>(mut self, resolver: F) -> Self
268 where
269 F: Fn(SsoAuditEvent) -> Fut + Send + Sync + 'static,
270 Fut: Future<Output = ()> + Send + 'static,
271 {
272 self.audit_event = Some(SsoAuditEventResolver::new(resolver));
273 self
274 }
275
276 #[must_use]
277 pub fn strict_oidc_manual_endpoint_origins(mut self, enabled: bool) -> Self {
279 self.oidc.strict_manual_endpoint_origins = enabled;
280 self
281 }
282
283 #[must_use]
284 pub fn allow_private_endpoint_ips(mut self, enabled: bool) -> Self {
289 self.oidc.allow_private_endpoint_ips = enabled;
290 self
291 }
292}
293
294#[cfg(feature = "oidc")]
295impl rustauth_oidc::OidcFlowOptions for SsoOptions {
296 fn redirect_uri(&self) -> Option<&str> {
297 self.redirect_uri.as_deref()
298 }
299}
300
301#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
302#[serde(rename_all = "camelCase")]
303pub struct OidcOptions {
305 pub strict_manual_endpoint_origins: bool,
310 #[serde(default)]
319 pub allow_private_endpoint_ips: bool,
320}
321
322#[derive(Debug, Clone, PartialEq, Eq)]
323pub struct SsoRateLimitOptions {
325 pub enabled: bool,
327 pub registration: RateLimitRule,
329 pub domain_verification: RateLimitRule,
331 pub oidc_callback: RateLimitRule,
333 pub saml: RateLimitRule,
335}
336
337impl Default for SsoRateLimitOptions {
338 fn default() -> Self {
339 Self {
340 enabled: true,
341 registration: RateLimitRule::new(time::Duration::seconds(60), 10),
342 domain_verification: RateLimitRule::new(time::Duration::seconds(60), 5),
343 oidc_callback: RateLimitRule::new(time::Duration::seconds(60), 30),
344 saml: RateLimitRule::new(time::Duration::seconds(60), 30),
345 }
346 }
347}
348
349#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
350#[serde(rename_all = "camelCase")]
351pub struct DomainVerificationOptions {
353 pub enabled: bool,
355 pub token_prefix: String,
357 pub token_ttl_seconds: u64,
359 #[serde(skip)]
360 pub txt_resolver: Option<DnsTxtResolver>,
362}
363
364impl Default for DomainVerificationOptions {
365 fn default() -> Self {
366 Self {
367 enabled: false,
368 token_prefix: "better-auth-token".to_owned(),
369 token_ttl_seconds: 60 * 60 * 24 * 7,
370 txt_resolver: None,
371 }
372 }
373}
374
375pub const DEFAULT_MAX_SAML_RESPONSE_SIZE: usize = 256 * 1024;
377pub const DEFAULT_MAX_SAML_METADATA_SIZE: usize = 100 * 1024;
379
380#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
381#[serde(rename_all = "camelCase")]
382pub struct SamlOptions {
384 pub enable_in_response_to_validation: bool,
386 pub allow_idp_initiated: bool,
388 pub request_ttl: Duration,
390 pub clock_skew: Duration,
392 pub require_timestamps: bool,
394 pub max_response_size: usize,
396 pub max_metadata_size: usize,
398 pub enable_single_logout: bool,
400 pub logout_request_ttl: Duration,
402 pub want_logout_request_signed: bool,
404 pub want_logout_response_signed: bool,
406 pub algorithms: SamlAlgorithmOptions,
408}
409
410impl Default for SamlOptions {
411 fn default() -> Self {
412 Self {
413 enable_in_response_to_validation: true,
414 allow_idp_initiated: true,
415 request_ttl: Duration::minutes(5),
416 clock_skew: Duration::minutes(5),
417 require_timestamps: false,
418 max_response_size: DEFAULT_MAX_SAML_RESPONSE_SIZE,
419 max_metadata_size: DEFAULT_MAX_SAML_METADATA_SIZE,
420 enable_single_logout: false,
421 logout_request_ttl: Duration::minutes(5),
422 want_logout_request_signed: false,
423 want_logout_response_signed: false,
424 algorithms: SamlAlgorithmOptions::default(),
425 }
426 }
427}
428
429#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
430#[serde(rename_all = "camelCase")]
431pub struct SamlAlgorithmOptions {
433 pub on_deprecated: DeprecatedAlgorithmBehavior,
435 pub allowed_signature_algorithms: Option<Vec<String>>,
437 pub allowed_digest_algorithms: Option<Vec<String>>,
439 pub allowed_key_encryption_algorithms: Option<Vec<String>>,
441 pub allowed_data_encryption_algorithms: Option<Vec<String>>,
443}
444
445#[cfg(not(feature = "saml"))]
446#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
447#[serde(rename_all = "snake_case")]
448pub enum DeprecatedAlgorithmBehavior {
450 Warn,
452 Reject,
454}
455
456impl Default for SamlAlgorithmOptions {
457 fn default() -> Self {
458 Self {
459 on_deprecated: DeprecatedAlgorithmBehavior::Warn,
460 allowed_signature_algorithms: None,
461 allowed_digest_algorithms: None,
462 allowed_key_encryption_algorithms: None,
463 allowed_data_encryption_algorithms: None,
464 }
465 }
466}
467
468#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
469#[serde(rename_all = "camelCase")]
470pub struct SsoProvider {
472 pub provider_id: String,
474 pub issuer: String,
476 pub domain: String,
478 #[serde(skip_serializing_if = "Option::is_none")]
479 pub organization_id: Option<String>,
481 #[serde(skip_serializing_if = "Option::is_none")]
482 pub oidc_config: Option<OidcConfig>,
484 #[serde(skip_serializing_if = "Option::is_none")]
485 pub saml_config: Option<SamlConfig>,
487}
488
489#[cfg(feature = "saml")]
490#[allow(dead_code)]
491pub type SamlProviderConfig = SamlConfig;
492
493#[cfg(not(feature = "saml"))]
494#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
495#[serde(rename_all = "camelCase")]
496pub struct SamlConfig {
498 pub issuer: String,
500 #[serde(default)]
501 pub entry_point: String,
503 pub cert: String,
505 pub callback_url: String,
507 #[serde(skip_serializing_if = "Option::is_none")]
508 pub acs_url: Option<String>,
510 #[serde(skip_serializing_if = "Option::is_none")]
511 pub audience: Option<String>,
513 #[serde(skip_serializing_if = "Option::is_none")]
514 pub idp_metadata: Option<SamlIdpMetadata>,
516 pub sp_metadata: SamlSpMetadata,
518 #[serde(skip_serializing_if = "Option::is_none")]
519 pub mapping: Option<SamlMapping>,
521 #[serde(default = "default_want_assertions_signed")]
523 pub want_assertions_signed: bool,
524 pub authn_requests_signed: bool,
526 #[serde(skip_serializing_if = "Option::is_none")]
527 pub signature_algorithm: Option<String>,
529 #[serde(skip_serializing_if = "Option::is_none")]
530 pub digest_algorithm: Option<String>,
532 #[serde(skip_serializing_if = "Option::is_none")]
533 pub identifier_format: Option<String>,
535 #[serde(skip_serializing_if = "Option::is_none")]
536 pub private_key: Option<SecretString>,
538 #[serde(skip_serializing_if = "Option::is_none")]
539 pub decryption_pvk: Option<SecretString>,
541 #[serde(skip_serializing_if = "Option::is_none")]
542 pub additional_params: Option<BTreeMap<String, serde_json::Value>>,
544}
545
546#[cfg(not(feature = "saml"))]
547const fn default_want_assertions_signed() -> bool {
548 true
549}
550
551#[cfg(not(feature = "saml"))]
552#[allow(dead_code)]
553pub type SamlProviderConfig = SamlConfig;
554
555#[cfg(not(feature = "saml"))]
556#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
557#[serde(rename_all = "camelCase")]
558pub struct SamlIdpMetadata {
560 pub metadata: Option<String>,
562 #[serde(rename = "entityID", alias = "entityId")]
563 pub entity_id: Option<String>,
565 #[serde(rename = "entityURL", alias = "entityUrl")]
566 pub entity_url: Option<String>,
568 #[serde(rename = "redirectURL", alias = "redirectUrl")]
569 pub redirect_url: Option<String>,
571 pub cert: Option<String>,
573 pub private_key: Option<SecretString>,
575 pub private_key_pass: Option<SecretString>,
577 pub is_assertion_encrypted: Option<bool>,
579 pub enc_private_key: Option<SecretString>,
581 pub enc_private_key_pass: Option<SecretString>,
583 pub single_sign_on_service: Option<Vec<SamlService>>,
585 pub single_logout_service: Option<Vec<SamlService>>,
587}
588
589#[cfg(not(feature = "saml"))]
590#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
591pub struct SamlService {
593 #[serde(rename = "Binding")]
594 pub binding: String,
596 #[serde(rename = "Location")]
597 pub location: String,
599}
600
601#[cfg(not(feature = "saml"))]
602#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
603#[serde(rename_all = "camelCase")]
604pub struct SamlSpMetadata {
606 pub metadata: Option<String>,
608 #[serde(rename = "entityID", alias = "entityId")]
609 pub entity_id: Option<String>,
611 pub binding: Option<String>,
613 pub private_key: Option<SecretString>,
615 pub private_key_pass: Option<SecretString>,
617 pub is_assertion_encrypted: Option<bool>,
619 pub enc_private_key: Option<SecretString>,
621 pub enc_private_key_pass: Option<SecretString>,
623}
624
625#[cfg(not(feature = "saml"))]
626#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
627#[serde(rename_all = "camelCase")]
628pub struct SamlMapping {
630 pub id: Option<String>,
632 pub email: Option<String>,
634 pub email_verified: Option<String>,
636 pub name: Option<String>,
638 pub first_name: Option<String>,
640 pub last_name: Option<String>,
642 pub extra_fields: Option<BTreeMap<String, String>>,
644}
645
646#[cfg(all(test, not(feature = "saml")))]
647mod fallback_saml_tests {
648 use super::*;
649
650 #[test]
651 fn fallback_saml_config_uses_upstream_acronym_wire_names_and_accepts_legacy_aliases(
652 ) -> Result<(), Box<dyn std::error::Error>> {
653 let config: SamlConfig = serde_json::from_value(serde_json::json!({
654 "issuer": "https://sp.example.com/metadata",
655 "entryPoint": "https://idp.example.com/sso",
656 "cert": "CERTIFICATE",
657 "callbackUrl": "https://sp.example.com/acs",
658 "spMetadata": {
659 "entityId": "https://sp.example.com/legacy"
660 },
661 "idpMetadata": {
662 "entityId": "https://idp.example.com/legacy",
663 "entityUrl": "https://idp.example.com/legacy-metadata",
664 "redirectUrl": "https://idp.example.com/legacy-redirect"
665 },
666 "wantAssertionsSigned": false,
667 "authnRequestsSigned": false
668 }))?;
669
670 let serialized = serde_json::to_value(&config)?;
671
672 assert_eq!(
673 serialized["spMetadata"]["entityID"],
674 "https://sp.example.com/legacy"
675 );
676 assert_eq!(
677 serialized["idpMetadata"]["entityID"],
678 "https://idp.example.com/legacy"
679 );
680 assert_eq!(
681 serialized["idpMetadata"]["entityURL"],
682 "https://idp.example.com/legacy-metadata"
683 );
684 assert_eq!(
685 serialized["idpMetadata"]["redirectURL"],
686 "https://idp.example.com/legacy-redirect"
687 );
688 Ok(())
689 }
690}