openauth-sqlx 0.0.2

SQLx database adapters for OpenAuth.
Documentation
use openauth_core::db::{
    Connector, DbField, DbFieldType, DbTable, DbValue, Where, WhereMode, WhereOperator,
};
use openauth_core::error::OpenAuthError;
use sqlx::sqlite::SqliteArguments;
use sqlx::Arguments;
use time::format_description::well_known::Rfc3339;

use super::errors::{argument_error, json_error, time_error};
use super::support::{quote_identifier, resolve_field};

pub(super) fn where_sql(
    table: &DbTable,
    clauses: &[Where],
    args: &mut SqliteArguments<'_>,
) -> Result<String, OpenAuthError> {
    if clauses.is_empty() {
        return Ok(String::new());
    }

    let mut sql = String::from(" WHERE ");
    for (index, clause) in clauses.iter().enumerate() {
        if index > 0 {
            sql.push(' ');
            sql.push_str(match clause.connector {
                Connector::And => "AND",
                Connector::Or => "OR",
            });
            sql.push(' ');
        }
        sql.push_str(&clause_sql(table, clause, args)?);
    }
    Ok(sql)
}

pub(super) fn clause_sql(
    table: &DbTable,
    clause: &Where,
    args: &mut SqliteArguments<'_>,
) -> Result<String, OpenAuthError> {
    let (_, field) = resolve_field(table, &clause.field)?;
    let column = quote_identifier(&field.name)?;
    if clause.value == DbValue::Null {
        return Ok(match clause.operator {
            WhereOperator::Eq => format!("{column} IS NULL"),
            WhereOperator::Ne => format!("{column} IS NOT NULL"),
            _ => {
                return Err(OpenAuthError::Adapter(
                    "null only supports Eq and Ne operators".to_owned(),
                ))
            }
        });
    }

    match clause.operator {
        WhereOperator::Eq
        | WhereOperator::Ne
        | WhereOperator::Lt
        | WhereOperator::Lte
        | WhereOperator::Gt
        | WhereOperator::Gte => {
            let operator = match clause.operator {
                WhereOperator::Eq => "=",
                WhereOperator::Ne => "!=",
                WhereOperator::Lt => "<",
                WhereOperator::Lte => "<=",
                WhereOperator::Gt => ">",
                WhereOperator::Gte => ">=",
                _ => unreachable!("operator matched by outer arm"),
            };
            bind_value(args, field, &clause.value)?;
            Ok(format!("{column} {operator} ?"))
        }
        WhereOperator::In | WhereOperator::NotIn => {
            let placeholders = bind_array_values(args, field, &clause.value)?;
            let operator = if clause.operator == WhereOperator::In {
                "IN"
            } else {
                "NOT IN"
            };
            Ok(format!("{column} {operator} ({})", placeholders.join(", ")))
        }
        WhereOperator::Contains | WhereOperator::StartsWith | WhereOperator::EndsWith => {
            let DbValue::String(value) = &clause.value else {
                return Err(OpenAuthError::Adapter(
                    "string pattern operators require string values".to_owned(),
                ));
            };
            let pattern = match clause.operator {
                WhereOperator::Contains => format!("%{value}%"),
                WhereOperator::StartsWith => format!("{value}%"),
                WhereOperator::EndsWith => format!("%{value}"),
                _ => unreachable!("operator matched by outer arm"),
            };
            args.add(pattern).map_err(argument_error)?;
            if clause.mode == WhereMode::Insensitive {
                Ok(format!("LOWER({column}) LIKE LOWER(?)"))
            } else {
                Ok(format!("{column} LIKE ?"))
            }
        }
    }
}

pub(super) fn bind_array_values(
    args: &mut SqliteArguments<'_>,
    field: &DbField,
    value: &DbValue,
) -> Result<Vec<String>, OpenAuthError> {
    match value {
        DbValue::StringArray(values) => {
            for value in values {
                bind_value(args, field, &DbValue::String(value.clone()))?;
            }
            Ok(vec!["?".to_owned(); values.len()])
        }
        DbValue::NumberArray(values) => {
            for value in values {
                bind_value(args, field, &DbValue::Number(*value))?;
            }
            Ok(vec!["?".to_owned(); values.len()])
        }
        _ => Err(OpenAuthError::Adapter(
            "IN and NOT IN require array values".to_owned(),
        )),
    }
}

pub(super) fn bind_value(
    args: &mut SqliteArguments<'_>,
    field: &DbField,
    value: &DbValue,
) -> Result<(), OpenAuthError> {
    match value {
        DbValue::String(value) => args.add(value.clone()).map_err(argument_error),
        DbValue::Number(value) => args.add(*value).map_err(argument_error),
        DbValue::Boolean(value) => args.add(i64::from(*value)).map_err(argument_error),
        DbValue::Timestamp(value) => args
            .add(value.format(&Rfc3339).map_err(time_error)?)
            .map_err(argument_error),
        DbValue::Json(value) => args.add(value.to_string()).map_err(argument_error),
        DbValue::StringArray(value) => args
            .add(serde_json::to_string(value).map_err(json_error)?)
            .map_err(argument_error),
        DbValue::NumberArray(value) => args
            .add(serde_json::to_string(value).map_err(json_error)?)
            .map_err(argument_error),
        DbValue::Record(_) | DbValue::RecordArray(_) => Err(OpenAuthError::Adapter(
            "joined records cannot be bound as SQL values".to_owned(),
        )),
        DbValue::Null => match field.field_type {
            DbFieldType::String
            | DbFieldType::Timestamp
            | DbFieldType::Json
            | DbFieldType::StringArray
            | DbFieldType::NumberArray => args.add(Option::<String>::None).map_err(argument_error),
            DbFieldType::Number | DbFieldType::Boolean => {
                args.add(Option::<i64>::None).map_err(argument_error)
            }
        },
    }
}