openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
use crate::error::OpenAuthError;
use crate::plugin::PluginInitOutput;

use super::builder::insert_social_provider;
use super::origins::{push_trusted_origin, push_unique};
use super::AuthContext;

pub(super) fn initialize_plugins(context: &mut AuthContext) -> Result<(), OpenAuthError> {
    let plugins = context.plugins.clone();
    for plugin in plugins {
        apply_plugin_output(
            context,
            &plugin.id,
            PluginInitOutput {
                schema: plugin.schema.clone(),
                rate_limit: plugin.rate_limit.clone(),
                error_codes: plugin.error_codes.clone(),
                database_hooks: plugin.database_hooks.clone(),
                migrations: plugin.migrations.clone(),
                social_providers: plugin.social_providers.clone(),
                ..PluginInitOutput::default()
            },
        )?;
        if let Some(init) = &plugin.init {
            let output = init(context)?;
            apply_plugin_output(context, &plugin.id, output)?;
        }
    }
    Ok(())
}

pub(super) fn apply_plugin_output(
    context: &mut AuthContext,
    plugin_id: &str,
    output: PluginInitOutput,
) -> Result<(), OpenAuthError> {
    for origin in output.trusted_origins {
        push_trusted_origin(&mut context.trusted_origins, origin);
    }
    for path in output.disabled_paths {
        push_unique(&mut context.disabled_paths, path);
    }
    for contribution in output.schema {
        contribution.apply(&mut context.db_schema)?;
    }
    context.rate_limit.plugin_rules.extend(output.rate_limit);
    for error_code in output.error_codes {
        error_code.validate()?;
        if let Some(existing) = context.plugin_error_codes.get(&error_code.code) {
            if existing != &error_code {
                return Err(OpenAuthError::InvalidConfig(format!(
                    "plugin `{plugin_id}` tried to register conflicting error code `{}`",
                    error_code.code
                )));
            }
            continue;
        }
        context
            .plugin_error_codes
            .insert(error_code.code.clone(), error_code);
    }
    for hook in output.database_hooks {
        let hook = hook.with_plugin_id(plugin_id);
        if context
            .plugin_database_hooks
            .iter()
            .any(|existing| existing.has_overlapping_phase(&hook))
        {
            return Err(OpenAuthError::InvalidConfig(format!(
                "plugin `{plugin_id}` tried to register duplicate database hook `{}` for {:?}",
                hook.name, hook.operation
            )));
        }
        context.plugin_database_hooks.push(hook);
    }
    for migration in output.migrations {
        if context
            .plugin_migrations
            .iter()
            .any(|existing| existing.name == migration.name)
        {
            return Err(OpenAuthError::InvalidConfig(format!(
                "plugin `{plugin_id}` tried to register duplicate migration `{}`",
                migration.name
            )));
        }
        context.plugin_migrations.push(migration);
    }
    for provider in output.social_providers {
        insert_social_provider(&mut context.social_providers, provider)?;
    }
    for (name, field) in output.user_additional_fields {
        insert_runtime_field(
            plugin_id,
            "user",
            &mut context.options.user.additional_fields,
            name,
            field,
        )?;
    }
    for (name, field) in output.session_additional_fields {
        insert_runtime_field(
            plugin_id,
            "session",
            &mut context.options.session.additional_fields,
            name,
            field,
        )?;
    }
    Ok(())
}

fn insert_runtime_field<T>(
    plugin_id: &str,
    table: &str,
    fields: &mut std::collections::BTreeMap<String, T>,
    name: String,
    field: T,
) -> Result<(), OpenAuthError>
where
    T: PartialEq,
{
    if let Some(existing) = fields.get(&name) {
        if existing == &field {
            return Ok(());
        }
        return Err(OpenAuthError::InvalidConfig(format!(
            "plugin `{plugin_id}` tried to register conflicting additional field `{name}` on `{table}`"
        )));
    }
    fields.insert(name, field);
    Ok(())
}