use serde::{Deserialize, Serialize};
#[cfg(not(feature = "saml"))]
use std::collections::BTreeMap;
use std::future::Future;
use time::Duration;
use rustauth_core::db::User;
use rustauth_core::error::RustAuthError;
use rustauth_core::options::RateLimitRule;
#[cfg(not(feature = "saml"))]
use rustauth_core::secret::SecretString;
pub use rustauth_oidc::{OidcConfig, OidcMapping, TokenEndpointAuthentication};
#[cfg(feature = "saml")]
pub use rustauth_saml::{
DeprecatedAlgorithmBehavior, SamlConfig, SamlIdpMetadata, SamlMapping, SamlService,
SamlSpMetadata,
};
#[path = "options/audit.rs"]
mod audit;
#[path = "options/callbacks.rs"]
mod callbacks;
pub use audit::*;
pub use callbacks::*;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OrganizationProvisioningOptions {
pub disabled: bool,
pub default_role: String,
#[serde(skip)]
pub get_role: Option<OrganizationRoleResolver>,
}
impl Default for OrganizationProvisioningOptions {
fn default() -> Self {
Self {
disabled: false,
default_role: "member".to_owned(),
get_role: None,
}
}
}
impl OrganizationProvisioningOptions {
#[must_use]
pub fn disabled(mut self, disabled: bool) -> Self {
self.disabled = disabled;
self
}
#[must_use]
pub fn default_role(mut self, role: impl Into<String>) -> Self {
self.default_role = role.into();
self
}
#[must_use]
pub fn get_role<F, Fut>(mut self, resolver: F) -> Self
where
F: Fn(OrganizationRoleInput) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<String, RustAuthError>> + Send + 'static,
{
self.get_role = Some(OrganizationRoleResolver::new(resolver));
self
}
pub async fn resolve_role(
&self,
input: OrganizationRoleInput,
) -> Result<String, RustAuthError> {
match &self.get_role {
Some(resolver) => resolver.resolve(input).await,
None => Ok(self.default_role.clone()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SsoOptions {
pub model_name: String,
pub provider_table: String,
pub providers_limit: usize,
#[serde(skip)]
pub providers_limit_callback: Option<ProvidersLimitResolver>,
pub domain_verification: DomainVerificationOptions,
pub redirect_uri: Option<String>,
pub disable_implicit_sign_up: bool,
pub trust_email_verified: bool,
pub default_override_user_info: bool,
#[serde(default)]
pub oidc: OidcOptions,
#[serde(skip)]
pub provision_user: Option<ProvisionUserResolver>,
pub provision_user_on_every_login: bool,
pub organization_provisioning: OrganizationProvisioningOptions,
pub saml: SamlOptions,
#[serde(skip)]
pub rate_limit: SsoRateLimitOptions,
#[serde(skip)]
pub audit_event: Option<SsoAuditEventResolver>,
pub default_sso: Vec<SsoProvider>,
}
impl Default for SsoOptions {
fn default() -> Self {
Self {
model_name: "sso_provider".to_owned(),
provider_table: "sso_providers".to_owned(),
providers_limit: 10,
providers_limit_callback: None,
domain_verification: DomainVerificationOptions::default(),
redirect_uri: None,
disable_implicit_sign_up: false,
trust_email_verified: false,
default_override_user_info: false,
oidc: OidcOptions::default(),
provision_user: None,
provision_user_on_every_login: false,
organization_provisioning: OrganizationProvisioningOptions::default(),
saml: SamlOptions::default(),
rate_limit: SsoRateLimitOptions::default(),
audit_event: None,
default_sso: Vec::new(),
}
}
}
impl SsoOptions {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn provider_table(mut self, table: impl Into<String>) -> Self {
self.provider_table = table.into();
self
}
#[must_use]
pub fn providers_limit(mut self, limit: usize) -> Self {
self.providers_limit = limit;
self
}
#[must_use]
pub fn providers_limit_callback<F, Fut>(mut self, resolver: F) -> Self
where
F: Fn(User) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<usize, RustAuthError>> + Send + 'static,
{
self.providers_limit_callback = Some(ProvidersLimitResolver::new(resolver));
self
}
pub async fn resolve_providers_limit(&self, user: User) -> Result<usize, RustAuthError> {
match &self.providers_limit_callback {
Some(resolver) => resolver.resolve(user).await,
None => Ok(self.providers_limit),
}
}
#[must_use]
pub fn domain_verification_enabled(mut self, enabled: bool) -> Self {
self.domain_verification.enabled = enabled;
self
}
#[must_use]
pub fn domain_txt_resolver<F, Fut>(mut self, resolver: F) -> Self
where
F: Fn(String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Vec<String>, RustAuthError>> + Send + 'static,
{
self.domain_verification.txt_resolver = Some(DnsTxtResolver::new(resolver));
self
}
#[must_use]
pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
self.redirect_uri = Some(redirect_uri.into());
self
}
#[must_use]
pub fn organization_provisioning(
mut self,
provisioning: OrganizationProvisioningOptions,
) -> Self {
self.organization_provisioning = provisioning;
self
}
#[must_use]
pub fn provision_user<F, Fut>(mut self, resolver: F) -> Self
where
F: Fn(ProvisionUserInput) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), RustAuthError>> + Send + 'static,
{
self.provision_user = Some(ProvisionUserResolver::new(resolver));
self
}
#[must_use]
pub fn provision_user_on_every_login(mut self, enabled: bool) -> Self {
self.provision_user_on_every_login = enabled;
self
}
#[must_use]
pub fn rate_limit(mut self, rate_limit: SsoRateLimitOptions) -> Self {
self.rate_limit = rate_limit;
self
}
#[must_use]
pub fn rate_limit_enabled(mut self, enabled: bool) -> Self {
self.rate_limit.enabled = enabled;
self
}
#[must_use]
pub fn audit_event<F, Fut>(mut self, resolver: F) -> Self
where
F: Fn(SsoAuditEvent) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.audit_event = Some(SsoAuditEventResolver::new(resolver));
self
}
#[must_use]
pub fn strict_oidc_manual_endpoint_origins(mut self, enabled: bool) -> Self {
self.oidc.strict_manual_endpoint_origins = enabled;
self
}
#[must_use]
pub fn allow_private_endpoint_ips(mut self, enabled: bool) -> Self {
self.oidc.allow_private_endpoint_ips = enabled;
self
}
}
#[cfg(feature = "oidc")]
impl rustauth_oidc::OidcFlowOptions for SsoOptions {
fn redirect_uri(&self) -> Option<&str> {
self.redirect_uri.as_deref()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct OidcOptions {
pub strict_manual_endpoint_origins: bool,
#[serde(default)]
pub allow_private_endpoint_ips: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SsoRateLimitOptions {
pub enabled: bool,
pub registration: RateLimitRule,
pub domain_verification: RateLimitRule,
pub oidc_callback: RateLimitRule,
pub saml: RateLimitRule,
}
impl Default for SsoRateLimitOptions {
fn default() -> Self {
Self {
enabled: true,
registration: RateLimitRule::new(time::Duration::seconds(60), 10),
domain_verification: RateLimitRule::new(time::Duration::seconds(60), 5),
oidc_callback: RateLimitRule::new(time::Duration::seconds(60), 30),
saml: RateLimitRule::new(time::Duration::seconds(60), 30),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DomainVerificationOptions {
pub enabled: bool,
pub token_prefix: String,
pub token_ttl_seconds: u64,
#[serde(skip)]
pub txt_resolver: Option<DnsTxtResolver>,
}
impl Default for DomainVerificationOptions {
fn default() -> Self {
Self {
enabled: false,
token_prefix: "better-auth-token".to_owned(),
token_ttl_seconds: 60 * 60 * 24 * 7,
txt_resolver: None,
}
}
}
pub const DEFAULT_MAX_SAML_RESPONSE_SIZE: usize = 256 * 1024;
pub const DEFAULT_MAX_SAML_METADATA_SIZE: usize = 100 * 1024;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamlOptions {
pub enable_in_response_to_validation: bool,
pub allow_idp_initiated: bool,
pub request_ttl: Duration,
pub clock_skew: Duration,
pub require_timestamps: bool,
pub max_response_size: usize,
pub max_metadata_size: usize,
pub enable_single_logout: bool,
pub logout_request_ttl: Duration,
pub want_logout_request_signed: bool,
pub want_logout_response_signed: bool,
pub algorithms: SamlAlgorithmOptions,
}
impl Default for SamlOptions {
fn default() -> Self {
Self {
enable_in_response_to_validation: true,
allow_idp_initiated: true,
request_ttl: Duration::minutes(5),
clock_skew: Duration::minutes(5),
require_timestamps: false,
max_response_size: DEFAULT_MAX_SAML_RESPONSE_SIZE,
max_metadata_size: DEFAULT_MAX_SAML_METADATA_SIZE,
enable_single_logout: false,
logout_request_ttl: Duration::minutes(5),
want_logout_request_signed: false,
want_logout_response_signed: false,
algorithms: SamlAlgorithmOptions::default(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamlAlgorithmOptions {
pub on_deprecated: DeprecatedAlgorithmBehavior,
pub allowed_signature_algorithms: Option<Vec<String>>,
pub allowed_digest_algorithms: Option<Vec<String>>,
pub allowed_key_encryption_algorithms: Option<Vec<String>>,
pub allowed_data_encryption_algorithms: Option<Vec<String>>,
}
#[cfg(not(feature = "saml"))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DeprecatedAlgorithmBehavior {
Warn,
Reject,
}
impl Default for SamlAlgorithmOptions {
fn default() -> Self {
Self {
on_deprecated: DeprecatedAlgorithmBehavior::Warn,
allowed_signature_algorithms: None,
allowed_digest_algorithms: None,
allowed_key_encryption_algorithms: None,
allowed_data_encryption_algorithms: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SsoProvider {
pub provider_id: String,
pub issuer: String,
pub domain: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub organization_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oidc_config: Option<OidcConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub saml_config: Option<SamlConfig>,
}
#[cfg(feature = "saml")]
#[allow(dead_code)]
pub type SamlProviderConfig = SamlConfig;
#[cfg(not(feature = "saml"))]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamlConfig {
pub issuer: String,
#[serde(default)]
pub entry_point: String,
pub cert: String,
pub callback_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub acs_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audience: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub idp_metadata: Option<SamlIdpMetadata>,
pub sp_metadata: SamlSpMetadata,
#[serde(skip_serializing_if = "Option::is_none")]
pub mapping: Option<SamlMapping>,
#[serde(default = "default_want_assertions_signed")]
pub want_assertions_signed: bool,
pub authn_requests_signed: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature_algorithm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub digest_algorithm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub identifier_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub private_key: Option<SecretString>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decryption_pvk: Option<SecretString>,
#[serde(skip_serializing_if = "Option::is_none")]
pub additional_params: Option<BTreeMap<String, serde_json::Value>>,
}
#[cfg(not(feature = "saml"))]
const fn default_want_assertions_signed() -> bool {
true
}
#[cfg(not(feature = "saml"))]
#[allow(dead_code)]
pub type SamlProviderConfig = SamlConfig;
#[cfg(not(feature = "saml"))]
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamlIdpMetadata {
pub metadata: Option<String>,
#[serde(rename = "entityID", alias = "entityId")]
pub entity_id: Option<String>,
#[serde(rename = "entityURL", alias = "entityUrl")]
pub entity_url: Option<String>,
#[serde(rename = "redirectURL", alias = "redirectUrl")]
pub redirect_url: Option<String>,
pub cert: Option<String>,
pub private_key: Option<SecretString>,
pub private_key_pass: Option<SecretString>,
pub is_assertion_encrypted: Option<bool>,
pub enc_private_key: Option<SecretString>,
pub enc_private_key_pass: Option<SecretString>,
pub single_sign_on_service: Option<Vec<SamlService>>,
pub single_logout_service: Option<Vec<SamlService>>,
}
#[cfg(not(feature = "saml"))]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SamlService {
#[serde(rename = "Binding")]
pub binding: String,
#[serde(rename = "Location")]
pub location: String,
}
#[cfg(not(feature = "saml"))]
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamlSpMetadata {
pub metadata: Option<String>,
#[serde(rename = "entityID", alias = "entityId")]
pub entity_id: Option<String>,
pub binding: Option<String>,
pub private_key: Option<SecretString>,
pub private_key_pass: Option<SecretString>,
pub is_assertion_encrypted: Option<bool>,
pub enc_private_key: Option<SecretString>,
pub enc_private_key_pass: Option<SecretString>,
}
#[cfg(not(feature = "saml"))]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamlMapping {
pub id: Option<String>,
pub email: Option<String>,
pub email_verified: Option<String>,
pub name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub extra_fields: Option<BTreeMap<String, String>>,
}
#[cfg(all(test, not(feature = "saml")))]
mod fallback_saml_tests {
use super::*;
#[test]
fn fallback_saml_config_uses_upstream_acronym_wire_names_and_accepts_legacy_aliases(
) -> Result<(), Box<dyn std::error::Error>> {
let config: SamlConfig = serde_json::from_value(serde_json::json!({
"issuer": "https://sp.example.com/metadata",
"entryPoint": "https://idp.example.com/sso",
"cert": "CERTIFICATE",
"callbackUrl": "https://sp.example.com/acs",
"spMetadata": {
"entityId": "https://sp.example.com/legacy"
},
"idpMetadata": {
"entityId": "https://idp.example.com/legacy",
"entityUrl": "https://idp.example.com/legacy-metadata",
"redirectUrl": "https://idp.example.com/legacy-redirect"
},
"wantAssertionsSigned": false,
"authnRequestsSigned": false
}))?;
let serialized = serde_json::to_value(&config)?;
assert_eq!(
serialized["spMetadata"]["entityID"],
"https://sp.example.com/legacy"
);
assert_eq!(
serialized["idpMetadata"]["entityID"],
"https://idp.example.com/legacy"
);
assert_eq!(
serialized["idpMetadata"]["entityURL"],
"https://idp.example.com/legacy-metadata"
);
assert_eq!(
serialized["idpMetadata"]["redirectURL"],
"https://idp.example.com/legacy-redirect"
);
Ok(())
}
}