use std::collections::BTreeMap;
use std::sync::Arc;
#[cfg(feature = "oauth")]
use rustauth_oauth::oauth2::SocialOAuthProvider;
use time::Duration;
use crate::background::tokio::TokioBackgroundTaskRunner;
use crate::cookies::get_cookies;
use crate::crypto::password::{hash_password, verify_password};
use crate::crypto::{build_secret_config, parse_secrets_env};
use crate::db::RateLimitStorage as DbRateLimitStorage;
use crate::db::{auth_schema, AuthSchemaOptions, DbAdapter, DbField, HookedAdapter};
use crate::env::is_production_posture;
use crate::env::logger::create_logger;
use crate::error::RustAuthError;
use crate::options::hooks::{plugin_after_hooks, plugin_before_hooks};
use crate::options::RateLimitStore;
use crate::options::{
plugin_database_hooks_from_init, BackgroundTaskRunner, ModelSchemaOptions,
RateLimitStorageOption, RustAuthOptions, SessionAdditionalField, UserAdditionalField,
};
use crate::plugin::AuthPlugin;
use crate::rate_limit::{GovernorMemoryRateLimitStore, LegacyRateLimitStorageAdapter};
use super::origins::resolve_trusted_origins;
use super::plugins::initialize_plugins;
use super::secrets::{resolve_legacy_secret, validate_secret, DEFAULT_SECRET};
use super::{
noop_telemetry_publisher, AuthContext, AuthEnvironment, PasswordContext, PasswordPolicy,
RateLimitContext, SecretMaterial, SessionConfig,
};
pub fn create_auth_context(options: RustAuthOptions) -> Result<AuthContext, RustAuthError> {
create_auth_context_with_environment_and_adapter(options, AuthEnvironment::from_process(), None)
}
pub fn create_auth_context_with_adapter(
options: RustAuthOptions,
adapter: Arc<dyn DbAdapter>,
) -> Result<AuthContext, RustAuthError> {
create_auth_context_with_environment_and_adapter(
options,
AuthEnvironment::from_process(),
Some(adapter),
)
}
pub fn create_auth_context_with_environment(
options: RustAuthOptions,
environment: AuthEnvironment,
) -> Result<AuthContext, RustAuthError> {
create_auth_context_with_environment_and_adapter(options, environment, None)
}
pub fn create_auth_context_with_environment_and_adapter(
options: RustAuthOptions,
environment: AuthEnvironment,
adapter: Option<Arc<dyn DbAdapter>>,
) -> Result<AuthContext, RustAuthError> {
let logger = create_logger(options.logger.clone());
let production_posture = is_production_posture(&options);
let env_secrets = parse_secrets_env(environment.rustauth_secrets.as_deref())?;
let secrets = if options.secrets.is_empty() {
env_secrets.unwrap_or_default()
} else {
options.secrets.clone()
};
let legacy_secret = resolve_legacy_secret(&options, &environment);
let (secret, secret_config) = if secrets.is_empty() {
let secret = legacy_secret.unwrap_or_else(|| DEFAULT_SECRET.to_owned());
validate_secret(&secret, &options)?;
(secret.clone(), SecretMaterial::Single(secret))
} else {
let config = build_secret_config(&secrets, legacy_secret.as_deref().unwrap_or(""))?;
let current = config
.keys
.get(&config.current_version)
.cloned()
.ok_or_else(|| {
RustAuthError::InvalidSecretConfig(format!(
"secret version {} not found in keys",
config.current_version
))
})?;
(current, SecretMaterial::Rotating(config))
};
let base_path = options
.base_path
.clone()
.unwrap_or_else(|| "/api/auth".to_owned());
let base_url = options.base_url.clone().unwrap_or_default();
let trusted_origins = resolve_trusted_origins(&base_url, &options, &environment);
let auth_cookies = get_cookies(&options)?;
#[cfg(feature = "oauth")]
let social_providers = resolve_social_providers(&options)?;
let session_config = SessionConfig {
update_age: options.session.update_age.unwrap_or(Duration::hours(24)),
expires_in: options.session.expires_in.unwrap_or(Duration::days(7)),
fresh_age: options.session.fresh_age.unwrap_or(Duration::days(1)),
cookie_refresh_cache: options.session.cookie_cache.refresh_cache,
};
let password = PasswordContext {
config: PasswordPolicy {
min_password_length: options.password.min_password_length,
max_password_length: options.password.max_password_length,
},
hash: options.password.hash_password.unwrap_or(hash_password),
verify: options.password.verify_password.unwrap_or(verify_password),
};
validate_rate_limit_storage(&options)?;
let rate_limit = RateLimitContext {
enabled: options.rate_limit.enabled.unwrap_or(production_posture),
window: options.rate_limit.window,
max: options.rate_limit.max,
storage: options.rate_limit.storage,
custom_rules: options.rate_limit.custom_rules.clone(),
dynamic_rules: options.rate_limit.dynamic_rules.clone(),
plugin_rules: Vec::new(),
custom_store: options.rate_limit.custom_store.clone().or_else(|| {
options.rate_limit.custom_storage.clone().map(|storage| {
Arc::new(LegacyRateLimitStorageAdapter::new(storage)) as Arc<dyn RateLimitStore>
})
}),
hybrid: options.rate_limit.hybrid.clone(),
memory_cleanup_interval: options.rate_limit.memory_cleanup_interval,
memory_store: Arc::new(GovernorMemoryRateLimitStore::with_cleanup_interval(
options.rate_limit.memory_cleanup_interval,
)),
missing_ip_policy: options.rate_limit.missing_ip_policy,
};
let schema_options = schema_options_from_auth_options(&options);
let app_name = options
.app_name
.clone()
.unwrap_or_else(|| "RustAuth".to_owned());
let mut context =
AuthContext {
app_name,
base_url,
base_path,
options: options.clone(),
auth_cookies,
session_config,
secret,
secret_config,
password,
rate_limit,
trusted_origins,
disabled_paths: options.disabled_paths,
plugins: options.plugins,
adapter,
secondary_storage: options.secondary_storage.clone(),
background_tasks: options.advanced.background_tasks.clone().or_else(|| {
Some(Arc::new(TokioBackgroundTaskRunner) as Arc<dyn BackgroundTaskRunner>)
}),
#[cfg(feature = "oauth")]
social_providers,
db_schema: auth_schema(schema_options),
plugin_error_codes: BTreeMap::new(),
plugin_database_hooks: {
let mut hooks = plugin_database_hooks_from_init(&options.init_database_hooks);
hooks.extend(options.database_hooks.clone());
hooks
},
plugin_migrations: Vec::new(),
telemetry_publisher: noop_telemetry_publisher(),
logger,
};
apply_global_hooks(&mut context);
initialize_plugins(&mut context)?;
if !context.plugin_database_hooks.is_empty() {
if let Some(adapter) = context.adapter.clone() {
context.adapter = Some(Arc::new(HookedAdapter::with_logger(
adapter,
context.plugin_database_hooks.clone(),
context.logger.clone(),
)));
}
}
Ok(context)
}
fn apply_global_hooks(context: &mut AuthContext) {
let before = plugin_before_hooks(&context.options.hooks);
let after = plugin_after_hooks(&context.options.hooks);
if before.is_empty() && after.is_empty() {
return;
}
let mut plugin = AuthPlugin::new("__rustauth_global__");
plugin.hooks.before = before;
plugin.hooks.after = after;
context.plugins.insert(0, plugin);
}
#[cfg(feature = "oauth")]
fn resolve_social_providers(
options: &RustAuthOptions,
) -> Result<BTreeMap<String, Arc<dyn SocialOAuthProvider>>, RustAuthError> {
let mut providers = BTreeMap::new();
for provider in &options.social_providers {
insert_social_provider(&mut providers, provider.clone())?;
}
Ok(providers)
}
#[cfg(feature = "oauth")]
pub(super) fn insert_social_provider(
providers: &mut BTreeMap<String, Arc<dyn SocialOAuthProvider>>,
provider: Arc<dyn SocialOAuthProvider>,
) -> Result<(), RustAuthError> {
let id = provider.id().to_owned();
if id.trim().is_empty() {
return Err(RustAuthError::InvalidConfig(
"social provider id cannot be empty".to_owned(),
));
}
if providers.insert(id.clone(), provider).is_some() {
return Err(RustAuthError::InvalidConfig(format!(
"duplicate social provider `{id}`"
)));
}
Ok(())
}
fn validate_rate_limit_storage(options: &RustAuthOptions) -> Result<(), RustAuthError> {
if options.rate_limit.custom_store.is_some() || options.rate_limit.custom_storage.is_some() {
return Ok(());
}
if matches!(
options.rate_limit.storage,
RateLimitStorageOption::Database | RateLimitStorageOption::SecondaryStorage
) {
return Err(RustAuthError::InvalidConfig(
"rate_limit.custom_store or rate_limit.custom_storage is required when using database or secondary-storage rate limiting without a concrete adapter".to_owned(),
));
}
Ok(())
}
fn apply_model_schema(table: &mut crate::db::TableOptions, schema: &ModelSchemaOptions) {
table.name = schema.model_name.clone();
table.field_names = schema
.field_names
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
}
fn schema_options_from_auth_options(options: &RustAuthOptions) -> AuthSchemaOptions {
let mut schema_options = AuthSchemaOptions {
has_secondary_storage: options.secondary_storage.is_some(),
store_session_in_database: options.session.store_session_in_database,
rate_limit_storage: match options.rate_limit.storage {
RateLimitStorageOption::Memory => DbRateLimitStorage::Memory,
RateLimitStorageOption::Database => DbRateLimitStorage::Database,
RateLimitStorageOption::SecondaryStorage => DbRateLimitStorage::SecondaryStorage,
},
..AuthSchemaOptions::default()
};
apply_model_schema(&mut schema_options.user, &options.user.schema);
apply_model_schema(&mut schema_options.session, &options.session.schema);
apply_model_schema(&mut schema_options.account, &options.account.schema);
apply_model_schema(
&mut schema_options.verification,
&options.verification.schema,
);
apply_model_schema(&mut schema_options.rate_limit, &options.rate_limit.schema);
for (name, field) in &options.user.additional_fields {
schema_options
.user
.additional_fields
.insert(name.clone(), user_additional_field_to_db_field(name, field));
}
for (name, field) in &options.session.additional_fields {
schema_options.session.additional_fields.insert(
name.clone(),
session_additional_field_to_db_field(name, field),
);
}
schema_options
}
pub(super) fn user_additional_field_to_db_field(
logical_name: &str,
field: &UserAdditionalField,
) -> DbField {
additional_field_to_db_field(
logical_name,
field.db_name.as_deref(),
field.field_type.clone(),
field.required,
field.input,
field.returned,
)
}
pub(super) fn session_additional_field_to_db_field(
logical_name: &str,
field: &SessionAdditionalField,
) -> DbField {
additional_field_to_db_field(
logical_name,
field.db_name.as_deref(),
field.field_type.clone(),
field.required,
field.input,
field.returned,
)
}
fn additional_field_to_db_field(
logical_name: &str,
db_name: Option<&str>,
field_type: crate::db::DbFieldType,
required: bool,
input: bool,
returned: bool,
) -> DbField {
let mut field = DbField::new(db_name.unwrap_or(logical_name), field_type);
if !required {
field = field.optional();
}
if !input {
field = field.generated();
}
if !returned {
field = field.hidden();
}
field
}