1mod audit;
25mod errors;
26mod hooks;
27#[path = "linking.rs"]
28mod linking_impl;
29mod openapi;
30mod options;
31mod org;
32mod routes;
33mod schema;
34mod secrets;
35mod state;
36mod store;
37mod utils;
38
39#[cfg(feature = "oidc")]
40pub(crate) use rustauth_oidc as oidc_impl;
41#[cfg(feature = "saml")]
42pub(crate) use rustauth_saml as saml_impl;
43
44pub mod linking {
46 pub use crate::linking_impl::{
47 assign_organization_by_domain, assign_organization_from_provider,
48 provider_matches_email_domain, validate_provider_domains, NormalizedSsoProfile,
49 };
50}
51
52#[cfg(feature = "oidc")]
53pub use rustauth_oidc as oidc;
54pub use rustauth_oidc::{OidcProfileMapping, OidcProviderConfig};
55#[cfg(feature = "saml")]
56pub use rustauth_saml as saml;
57
58pub use errors::{sso_error_category, sso_error_descriptors, SsoErrorCategory, SsoErrorDescriptor};
59pub use linking::NormalizedSsoProfile;
60#[cfg(not(feature = "saml"))]
61pub use options::DeprecatedAlgorithmBehavior;
62pub use options::{
63 DnsTxtResolver, DomainVerificationOptions, OidcConfig, OidcMapping, OidcOptions,
64 OrganizationProvisioningOptions, OrganizationRoleInput, OrganizationRoleResolver,
65 ProvidersLimitResolver, ProvisionUserInput, ProvisionUserResolver, SamlAlgorithmOptions,
66 SamlConfig, SamlIdpMetadata, SamlMapping, SamlOptions, SamlService, SamlSpMetadata,
67 SsoAuditEvent, SsoAuditEventKind, SsoAuditEventResolver, SsoAuditSeverity, SsoOptions,
68 SsoProvider, SsoRateLimitOptions, TokenEndpointAuthentication, DEFAULT_MAX_SAML_METADATA_SIZE,
69 DEFAULT_MAX_SAML_RESPONSE_SIZE,
70};
71#[cfg(feature = "saml")]
72pub use saml::DeprecatedAlgorithmBehavior;
73pub use secrets::SecretString;
74pub use store::{
75 CreateSsoProviderInput, SanitizedSsoProvider, SsoProviderRecord, SsoProviderStore,
76};
77
78use rustauth_core::plugin::{AuthPlugin, PluginRateLimitRule};
79use std::sync::Arc;
80
81pub const UPSTREAM_PLUGIN_ID: &str = "sso";
83
84pub const VERSION: &str = env!("CARGO_PKG_VERSION");
86
87#[must_use]
93pub fn sso(options: SsoOptions) -> AuthPlugin {
94 let options = Arc::new(options);
95 let mut plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID).with_version(VERSION);
96
97 for contribution in schema::contributions(&options) {
98 plugin = plugin.with_schema(contribution);
99 }
100 for code in errors::plugin_error_codes() {
101 plugin = plugin.with_error_code(code);
102 }
103 for endpoint in routes::endpoints(Arc::clone(&options)) {
104 plugin = plugin.with_endpoint(endpoint);
105 }
106 for rule in rate_limit_rules(&options.rate_limit) {
107 plugin = plugin.with_rate_limit(rule);
108 }
109
110 #[cfg(feature = "saml")]
111 {
112 plugin = plugin
113 .with_async_before_hook("/sign-out", |context, request| {
114 Box::pin(hooks::capture_sign_out_session(context, request))
115 })
116 .with_async_after_hook("/sign-out", |context, request, response| {
117 Box::pin(hooks::cleanup_sign_out_session(context, request, response))
118 });
119 }
120
121 for path in [
122 "/sign-up/email",
123 "/sign-in/email",
124 "/sign-in/social",
125 "/sign-in/oauth2",
126 "/callback/:id",
127 ] {
128 let hook_options = Arc::clone(&options);
129 plugin = plugin.with_async_after_hook(path, move |context, request, response| {
130 Box::pin(hooks::assign_domain_organization_after_auth(
131 context,
132 request,
133 response,
134 Arc::clone(&hook_options),
135 ))
136 });
137 }
138
139 plugin
140}
141
142fn rate_limit_rules(options: &SsoRateLimitOptions) -> Vec<PluginRateLimitRule> {
143 if !options.enabled {
144 return Vec::new();
145 }
146 let mut rules = vec![
147 PluginRateLimitRule::new("/sso/register", options.registration.clone()),
148 PluginRateLimitRule::new(
149 "/sso/request-domain-verification",
150 options.domain_verification.clone(),
151 ),
152 PluginRateLimitRule::new("/sso/verify-domain", options.domain_verification.clone()),
153 ];
154 #[cfg(feature = "oidc")]
155 {
156 rules.push(PluginRateLimitRule::new(
157 "/sso/callback",
158 options.oidc_callback.clone(),
159 ));
160 rules.push(PluginRateLimitRule::new(
161 "/sso/callback/:providerId",
162 options.oidc_callback.clone(),
163 ));
164 }
165 #[cfg(feature = "saml")]
166 {
167 rules.push(PluginRateLimitRule::new(
168 "/sso/saml2/callback/:providerId",
169 options.saml.clone(),
170 ));
171 rules.push(PluginRateLimitRule::new(
172 "/sso/saml2/sp/acs/:providerId",
173 options.saml.clone(),
174 ));
175 rules.push(PluginRateLimitRule::new(
176 "/sso/saml2/sp/slo/:providerId",
177 options.saml.clone(),
178 ));
179 rules.push(PluginRateLimitRule::new(
180 "/sso/saml2/logout/:providerId",
181 options.saml.clone(),
182 ));
183 }
184 rules
185}