openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use std::sync::Arc;

use http::{header, StatusCode};
use openauth_core::api::{parse_request_body, ApiRequest};
use openauth_core::db::{Create, DbRecord, DbValue, Update};
use openauth_core::error::OpenAuthError;
use openauth_core::plugin::{
    PluginBeforeHookAction, PluginDatabaseBeforeAction, PluginDatabaseBeforeInput,
    PluginDatabaseHook,
};
use serde_json::Value;

use super::errors;
use super::options::{UsernameOptions, UsernameValidationError};

pub fn normalize_create_user_hook(options: Arc<UsernameOptions>) -> PluginDatabaseHook {
    PluginDatabaseHook::before_create("username-normalize-create", move |_context, query| {
        Ok(PluginDatabaseBeforeAction::Continue(
            PluginDatabaseBeforeInput::Create(normalize_create_query(&options, query)),
        ))
    })
}

pub fn normalize_update_user_hook(options: Arc<UsernameOptions>) -> PluginDatabaseHook {
    PluginDatabaseHook::before_update("username-normalize-update", move |_context, query| {
        Ok(PluginDatabaseBeforeAction::Continue(
            PluginDatabaseBeforeInput::Update(normalize_update_query(&options, query)),
        ))
    })
}

pub fn sign_up_before_hook(
    options: Arc<UsernameOptions>,
) -> impl Fn(
    &openauth_core::context::AuthContext,
    ApiRequest,
) -> Result<PluginBeforeHookAction, OpenAuthError>
       + Send
       + Sync
       + 'static {
    move |_context, request| validate_and_rewrite_body(&options, request, true)
}

pub fn update_user_before_hook(
    options: Arc<UsernameOptions>,
) -> impl Fn(
    &openauth_core::context::AuthContext,
    ApiRequest,
) -> Result<PluginBeforeHookAction, OpenAuthError>
       + Send
       + Sync
       + 'static {
    move |_context, request| validate_and_rewrite_body(&options, request, false)
}

fn normalize_create_query(options: &UsernameOptions, mut query: Create) -> Create {
    if query.model != "user" {
        return query;
    }
    normalize_record(options, &mut query.data);
    query
}

fn normalize_update_query(options: &UsernameOptions, mut query: Update) -> Update {
    if query.model != "user" {
        return query;
    }
    normalize_record(options, &mut query.data);
    query
}

fn normalize_record(options: &UsernameOptions, data: &mut DbRecord) {
    if let Some(DbValue::String(username)) = data.get_mut("username") {
        *username = options.normalize_username(username);
    }
    if let Some(DbValue::String(display_username)) = data.get_mut("display_username") {
        *display_username = options.normalize_display_username(display_username);
    }
}

fn validate_and_rewrite_body(
    options: &UsernameOptions,
    request: ApiRequest,
    apply_sign_up_fallbacks: bool,
) -> Result<PluginBeforeHookAction, OpenAuthError> {
    let mut body: Value = parse_request_body(&request)?;
    let Some(object) = body.as_object_mut() else {
        return Ok(PluginBeforeHookAction::Continue(request));
    };

    if apply_sign_up_fallbacks {
        let username = string_value(object.get("username")).map(str::to_owned);
        let display_username = string_value(object.get("displayUsername"))
            .map(str::to_owned)
            .or_else(|| string_value(object.get("display_username")).map(str::to_owned));
        if username.is_some() && display_username.is_none() {
            object.insert(
                "displayUsername".to_owned(),
                Value::String(username.unwrap_or_default()),
            );
        } else if username.is_none() {
            if let Some(display_username) = display_username {
                object.insert("username".to_owned(), Value::String(display_username));
            }
        }
    }

    if let Some(username) = string_value(object.get("username")) {
        let username_for_validation = options.username_for_validation(username);
        if let Err(error) =
            options.validate_username(&username_for_validation, options.validation_order.username)
        {
            return validation_error(error, StatusCode::BAD_REQUEST)
                .map(PluginBeforeHookAction::Respond);
        }
        object.insert(
            "username".to_owned(),
            Value::String(options.normalize_username(username)),
        );
    }

    let display_username = string_value(object.get("displayUsername"))
        .or_else(|| string_value(object.get("display_username")));
    if let Some(display_username) = display_username {
        let display_username_for_validation =
            options.display_username_for_validation(display_username);
        if let Err(error) = options.validate_display_username(&display_username_for_validation) {
            return validation_error(error, StatusCode::BAD_REQUEST)
                .map(PluginBeforeHookAction::Respond);
        }
        object.insert(
            "displayUsername".to_owned(),
            Value::String(options.normalize_display_username(display_username)),
        );
        object.remove("display_username");
    }

    let (mut parts, _) = request.into_parts();
    parts.headers.insert(
        header::CONTENT_TYPE,
        http::HeaderValue::from_static("application/json"),
    );
    let next_body =
        serde_json::to_vec(&body).map_err(|error| OpenAuthError::Api(error.to_string()))?;
    Ok(PluginBeforeHookAction::Continue(http::Request::from_parts(
        parts, next_body,
    )))
}

pub fn validation_error(
    error: UsernameValidationError,
    status: StatusCode,
) -> Result<openauth_core::api::ApiResponse, OpenAuthError> {
    match error {
        UsernameValidationError::TooShort => {
            errors::error_response(status, errors::USERNAME_TOO_SHORT, "Username is too short")
        }
        UsernameValidationError::TooLong => {
            errors::error_response(status, errors::USERNAME_TOO_LONG, "Username is too long")
        }
        UsernameValidationError::Invalid => {
            errors::error_response(status, errors::INVALID_USERNAME, "Username is invalid")
        }
        UsernameValidationError::InvalidDisplay => errors::error_response(
            status,
            errors::INVALID_DISPLAY_USERNAME,
            "Display username is invalid",
        ),
    }
}

fn string_value(value: Option<&Value>) -> Option<&str> {
    match value {
        Some(Value::String(value)) => Some(value),
        _ => None,
    }
}