openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
use std::collections::BTreeMap;

use serde_json::Value;

use crate::db::{DbAdapter, DbFieldType, DbRecord, DbValue, FindOne, User, Where};
use crate::error::OpenAuthError;
use crate::options::{SessionAdditionalField, UserAdditionalField};

pub trait AdditionalField {
    fn field_type(&self) -> &DbFieldType;
    fn required(&self) -> bool;
    fn input(&self) -> bool;
    fn returned(&self) -> bool;
    fn default_value(&self) -> Option<&DbValue>;
    fn db_name(&self) -> Option<&str>;
}

impl AdditionalField for UserAdditionalField {
    fn field_type(&self) -> &DbFieldType {
        &self.field_type
    }

    fn required(&self) -> bool {
        self.required
    }

    fn input(&self) -> bool {
        self.input
    }

    fn returned(&self) -> bool {
        self.returned
    }

    fn default_value(&self) -> Option<&DbValue> {
        self.default_value.as_ref()
    }

    fn db_name(&self) -> Option<&str> {
        self.db_name.as_deref()
    }
}

impl AdditionalField for SessionAdditionalField {
    fn field_type(&self) -> &DbFieldType {
        &self.field_type
    }

    fn required(&self) -> bool {
        self.required
    }

    fn input(&self) -> bool {
        self.input
    }

    fn returned(&self) -> bool {
        self.returned
    }

    fn default_value(&self) -> Option<&DbValue> {
        self.default_value.as_ref()
    }

    fn db_name(&self) -> Option<&str> {
        self.db_name.as_deref()
    }
}

pub fn create_values<F>(
    fields: &BTreeMap<String, F>,
    body: &serde_json::Map<String, Value>,
) -> Result<DbRecord, AdditionalFieldError>
where
    F: AdditionalField,
{
    let mut values = DbRecord::new();
    for (name, field) in fields {
        match body.get(name) {
            Some(value) => {
                if !field.input() {
                    return Err(AdditionalFieldError::NotInput(name.clone()));
                }
                values.insert(
                    storage_name(name, field),
                    json_to_db_value(name, field.field_type(), value)
                        .map_err(AdditionalFieldError::InvalidType)?,
                );
            }
            None => {
                if let Some(value) = field.default_value() {
                    values.insert(storage_name(name, field), value.clone());
                } else if field.required() {
                    return Err(AdditionalFieldError::MissingRequired(name.clone()));
                } else {
                    values.insert(storage_name(name, field), DbValue::Null);
                }
            }
        }
    }
    Ok(values)
}

pub fn update_values<F>(
    fields: &BTreeMap<String, F>,
    body: &serde_json::Map<String, Value>,
) -> Result<DbRecord, AdditionalFieldError>
where
    F: AdditionalField,
{
    let mut values = DbRecord::new();
    for (name, value) in body {
        let Some(field) = fields.get(name) else {
            continue;
        };
        if !field.input() {
            return Err(AdditionalFieldError::NotInput(name.clone()));
        }
        values.insert(
            storage_name(name, field),
            json_to_db_value(name, field.field_type(), value)
                .map_err(AdditionalFieldError::InvalidType)?,
        );
    }
    Ok(values)
}

pub fn insert_returned_fields<F>(
    object: &mut serde_json::Map<String, Value>,
    fields: &BTreeMap<String, F>,
    record: &DbRecord,
) -> Result<(), OpenAuthError>
where
    F: AdditionalField,
{
    for (name, field) in fields {
        if !field.returned() {
            continue;
        }
        let value = record
            .get(name)
            .or_else(|| field.db_name().and_then(|db_name| record.get(db_name)))
            .or_else(|| field.default_value())
            .unwrap_or(&DbValue::Null);
        object.insert(name.clone(), db_value_to_json(value)?);
    }
    Ok(())
}

fn storage_name<F>(logical_name: &str, field: &F) -> String
where
    F: AdditionalField,
{
    field
        .db_name()
        .map(str::to_owned)
        .unwrap_or_else(|| logical_name.to_owned())
}

pub fn db_value_to_json(value: &DbValue) -> Result<Value, OpenAuthError> {
    match value {
        DbValue::String(value) => Ok(Value::String(value.clone())),
        DbValue::Number(value) => Ok(Value::Number((*value).into())),
        DbValue::Boolean(value) => Ok(Value::Bool(*value)),
        DbValue::Timestamp(value) => {
            serde_json::to_value(value).map_err(|error| OpenAuthError::Api(error.to_string()))
        }
        DbValue::Json(value) => Ok(value.clone()),
        DbValue::StringArray(values) => Ok(Value::Array(
            values.iter().cloned().map(Value::String).collect(),
        )),
        DbValue::NumberArray(values) => Ok(Value::Array(
            values
                .iter()
                .map(|value| Value::Number((*value).into()))
                .collect(),
        )),
        DbValue::Record(record) => db_record_to_json(record),
        DbValue::RecordArray(records) => records
            .iter()
            .map(db_record_to_json)
            .collect::<Result<Vec<_>, _>>()
            .map(Value::Array),
        DbValue::Null => Ok(Value::Null),
    }
}

pub fn json_to_db_value(
    name: &str,
    field_type: &DbFieldType,
    value: &Value,
) -> Result<DbValue, String> {
    if value.is_null() {
        return Ok(DbValue::Null);
    }
    match field_type {
        DbFieldType::String => value
            .as_str()
            .map(|value| DbValue::String(value.to_owned())),
        DbFieldType::Number => value.as_i64().map(DbValue::Number),
        DbFieldType::Boolean => value.as_bool().map(DbValue::Boolean),
        DbFieldType::Json => Some(DbValue::Json(value.clone())),
        DbFieldType::StringArray => value.as_array().and_then(|values| {
            values
                .iter()
                .map(|value| value.as_str().map(str::to_owned))
                .collect::<Option<Vec<_>>>()
                .map(DbValue::StringArray)
        }),
        DbFieldType::NumberArray => value.as_array().and_then(|values| {
            values
                .iter()
                .map(Value::as_i64)
                .collect::<Option<Vec<_>>>()
                .map(DbValue::NumberArray)
        }),
        DbFieldType::Timestamp => None,
    }
    .ok_or_else(|| format!("invalid value for additional field `{name}`"))
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AdditionalFieldError {
    MissingRequired(String),
    NotInput(String),
    InvalidType(String),
}

impl AdditionalFieldError {
    pub fn message(&self) -> String {
        match self {
            Self::MissingRequired(name) => format!("missing required additional field `{name}`"),
            Self::NotInput(name) => format!("additional field `{name}` is not accepted as input"),
            Self::InvalidType(message) => message.clone(),
        }
    }
}

pub async fn user_response_value(
    adapter: &dyn DbAdapter,
    fields: &BTreeMap<String, UserAdditionalField>,
    user: &User,
) -> Result<Value, OpenAuthError> {
    if fields.is_empty() {
        return serde_json::to_value(user).map_err(|error| OpenAuthError::Api(error.to_string()));
    }
    let record = adapter
        .find_one(
            FindOne::new("user").where_clause(Where::new("id", DbValue::String(user.id.clone()))),
        )
        .await?;
    let mut value =
        serde_json::to_value(user).map_err(|error| OpenAuthError::Api(error.to_string()))?;
    let Some(object) = value.as_object_mut() else {
        return Err(OpenAuthError::Api(
            "could not serialize user as an object".to_owned(),
        ));
    };
    if let Some(record) = record {
        insert_returned_fields(object, fields, &record)?;
    }
    Ok(value)
}

fn db_record_to_json(record: &DbRecord) -> Result<Value, OpenAuthError> {
    record
        .iter()
        .map(|(field, value)| db_value_to_json(value).map(|value| (field.clone(), value)))
        .collect::<Result<serde_json::Map<_, _>, _>>()
        .map(Value::Object)
}