mod audit;
mod errors;
mod hooks;
#[path = "linking.rs"]
mod linking_impl;
mod openapi;
mod options;
mod org;
mod routes;
mod schema;
mod secrets;
mod state;
mod store;
mod utils;
#[cfg(feature = "oidc")]
pub(crate) use rustauth_oidc as oidc_impl;
#[cfg(feature = "saml")]
pub(crate) use rustauth_saml as saml_impl;
pub mod linking {
pub use crate::linking_impl::{
assign_organization_by_domain, assign_organization_from_provider,
provider_matches_email_domain, validate_provider_domains, NormalizedSsoProfile,
};
}
#[cfg(feature = "oidc")]
pub use rustauth_oidc as oidc;
pub use rustauth_oidc::{OidcProfileMapping, OidcProviderConfig};
#[cfg(feature = "saml")]
pub use rustauth_saml as saml;
pub use errors::{sso_error_category, sso_error_descriptors, SsoErrorCategory, SsoErrorDescriptor};
pub use linking::NormalizedSsoProfile;
#[cfg(not(feature = "saml"))]
pub use options::DeprecatedAlgorithmBehavior;
pub use options::{
DnsTxtResolver, DomainVerificationOptions, OidcConfig, OidcMapping, OidcOptions,
OrganizationProvisioningOptions, OrganizationRoleInput, OrganizationRoleResolver,
ProvidersLimitResolver, ProvisionUserInput, ProvisionUserResolver, SamlAlgorithmOptions,
SamlConfig, SamlIdpMetadata, SamlMapping, SamlOptions, SamlService, SamlSpMetadata,
SsoAuditEvent, SsoAuditEventKind, SsoAuditEventResolver, SsoAuditSeverity, SsoOptions,
SsoProvider, SsoRateLimitOptions, TokenEndpointAuthentication, DEFAULT_MAX_SAML_METADATA_SIZE,
DEFAULT_MAX_SAML_RESPONSE_SIZE,
};
#[cfg(feature = "saml")]
pub use saml::DeprecatedAlgorithmBehavior;
pub use secrets::SecretString;
pub use store::{
CreateSsoProviderInput, SanitizedSsoProvider, SsoProviderRecord, SsoProviderStore,
};
use rustauth_core::plugin::{AuthPlugin, PluginRateLimitRule};
use std::sync::Arc;
pub const UPSTREAM_PLUGIN_ID: &str = "sso";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
#[must_use]
pub fn sso(options: SsoOptions) -> AuthPlugin {
let options = Arc::new(options);
let mut plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID).with_version(VERSION);
for contribution in schema::contributions(&options) {
plugin = plugin.with_schema(contribution);
}
for code in errors::plugin_error_codes() {
plugin = plugin.with_error_code(code);
}
for endpoint in routes::endpoints(Arc::clone(&options)) {
plugin = plugin.with_endpoint(endpoint);
}
for rule in rate_limit_rules(&options.rate_limit) {
plugin = plugin.with_rate_limit(rule);
}
#[cfg(feature = "saml")]
{
plugin = plugin
.with_async_before_hook("/sign-out", |context, request| {
Box::pin(hooks::capture_sign_out_session(context, request))
})
.with_async_after_hook("/sign-out", |context, request, response| {
Box::pin(hooks::cleanup_sign_out_session(context, request, response))
});
}
for path in [
"/sign-up/email",
"/sign-in/email",
"/sign-in/social",
"/sign-in/oauth2",
"/callback/:id",
] {
let hook_options = Arc::clone(&options);
plugin = plugin.with_async_after_hook(path, move |context, request, response| {
Box::pin(hooks::assign_domain_organization_after_auth(
context,
request,
response,
Arc::clone(&hook_options),
))
});
}
plugin
}
fn rate_limit_rules(options: &SsoRateLimitOptions) -> Vec<PluginRateLimitRule> {
if !options.enabled {
return Vec::new();
}
let mut rules = vec![
PluginRateLimitRule::new("/sso/register", options.registration.clone()),
PluginRateLimitRule::new(
"/sso/request-domain-verification",
options.domain_verification.clone(),
),
PluginRateLimitRule::new("/sso/verify-domain", options.domain_verification.clone()),
];
#[cfg(feature = "oidc")]
{
rules.push(PluginRateLimitRule::new(
"/sso/callback",
options.oidc_callback.clone(),
));
rules.push(PluginRateLimitRule::new(
"/sso/callback/:providerId",
options.oidc_callback.clone(),
));
}
#[cfg(feature = "saml")]
{
rules.push(PluginRateLimitRule::new(
"/sso/saml2/callback/:providerId",
options.saml.clone(),
));
rules.push(PluginRateLimitRule::new(
"/sso/saml2/sp/acs/:providerId",
options.saml.clone(),
));
rules.push(PluginRateLimitRule::new(
"/sso/saml2/sp/slo/:providerId",
options.saml.clone(),
));
rules.push(PluginRateLimitRule::new(
"/sso/saml2/logout/:providerId",
options.saml.clone(),
));
}
rules
}