openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
use super::*;

pub fn create_statement(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &Create,
) -> Result<SqlStatement, OpenAuthError> {
    let table = resolve_table(schema, &query.model)?;
    let mut columns = Vec::new();
    let mut placeholders = Vec::new();
    let mut params = Vec::new();

    for (field, value) in &query.data {
        let (_, metadata) = resolve_field(table, field)?;
        columns.push(dialect.quote_identifier(&metadata.name)?);
        params.push(SqlParam::new(metadata, value.clone()));
        placeholders.push(dialect.placeholder(params.len()));
    }

    Ok(SqlStatement {
        sql: format!(
            "INSERT INTO {} ({}) VALUES ({})",
            dialect.quote_identifier(&table.name)?,
            columns.join(", "),
            placeholders.join(", ")
        ),
        params,
    })
}

pub fn find_one_statement(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &FindOne,
) -> Result<SqlReadStatement, OpenAuthError> {
    let mut find_many = FindMany::new(query.model.clone());
    find_many.where_clauses = query.where_clauses.clone();
    find_many.limit = Some(1);
    find_many.select = query.select.clone();
    find_many.joins = query.joins.clone();
    find_many_statement(dialect, schema, &find_many)
}

pub fn find_many_statement(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &FindMany,
) -> Result<SqlReadStatement, OpenAuthError> {
    let table = resolve_table(schema, &query.model)?;
    let selection = selected_fields(table, &query.select)?;
    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
    let sql = format!(
        "SELECT {} FROM {}{}{}",
        selection
            .iter()
            .map(|selected| dialect.quote_identifier(&selected.field.name))
            .collect::<Result<Vec<_>, _>>()?
            .join(", "),
        dialect.quote_identifier(&table.name)?,
        where_sql.sql,
        dialect.order_limit_offset(table, query.sort_by.as_ref(), query.limit, query.offset)?
    );

    Ok(SqlReadStatement {
        statement: SqlStatement {
            sql,
            params: where_sql.params,
        },
        selection,
    })
}

pub fn find_many_with_joins_statement<'a>(
    dialect: SqlDialect,
    schema: &'a DbSchema,
    query: &FindMany,
) -> Result<SqlJoinReadStatement<'a>, OpenAuthError> {
    let (base_logical, table) = resolve_table_with_logical(schema, &query.model)?;
    let joins = resolve_native_joins(schema, base_logical, table, &query.joins, 100)?;
    let base_selection = internal_base_selection(table, &query.select, &joins)?;
    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
    let base_columns = base_selection
        .iter()
        .map(|(_, field)| dialect.quote_identifier(&field.name))
        .collect::<Result<Vec<_>, _>>()?;
    let base_sql = format!(
        "SELECT {} FROM {}{}{}",
        base_columns.join(", "),
        dialect.quote_identifier(&table.name)?,
        where_sql.sql,
        dialect.order_limit_offset(table, query.sort_by.as_ref(), query.limit, query.offset)?
    );

    let mut selects = vec![format!(
        "{}.{} AS {}",
        dialect.quote_identifier("base")?,
        dialect.quote_identifier(&resolve_field_from_selection(&base_selection, "id")?.name)?,
        dialect.quote_identifier("__base_id")?
    )];
    for (index, (_, field)) in base_selection.iter().enumerate() {
        selects.push(format!(
            "{}.{} AS {}",
            dialect.quote_identifier("base")?,
            dialect.quote_identifier(&field.name)?,
            dialect.quote_identifier(&base_alias(index))?
        ));
    }
    for (join_index, join) in joins.iter().enumerate() {
        for (field_index, (_, field)) in join.selection.iter().enumerate() {
            selects.push(format!(
                "{}.{} AS {}",
                dialect.quote_identifier(&join_alias(join_index))?,
                dialect.quote_identifier(&field.name)?,
                dialect.quote_identifier(&join_field_alias(join_index, field_index))?
            ));
        }
    }

    let mut sql = format!(
        "SELECT {} FROM ({}) AS {}",
        selects.join(", "),
        base_sql,
        dialect.quote_identifier("base")?
    );
    for (index, join) in joins.iter().enumerate() {
        sql.push_str(" LEFT JOIN ");
        sql.push_str(&dialect.quote_identifier(&join.table.name)?);
        sql.push_str(" AS ");
        sql.push_str(&dialect.quote_identifier(&join_alias(index))?);
        sql.push_str(" ON ");
        sql.push_str(&dialect.quote_identifier(&join_alias(index))?);
        sql.push('.');
        sql.push_str(&dialect.quote_identifier(&join.to)?);
        sql.push_str(" = ");
        sql.push_str(&dialect.quote_identifier("base")?);
        sql.push('.');
        sql.push_str(&dialect.quote_identifier(&join.from)?);
    }

    Ok(SqlJoinReadStatement {
        statement: SqlStatement {
            sql,
            params: where_sql.params,
        },
        base_selection,
        joins,
    })
}

pub fn count_statement(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &Count,
) -> Result<SqlStatement, OpenAuthError> {
    let table = resolve_table(schema, &query.model)?;
    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
    Ok(SqlStatement {
        sql: format!(
            "SELECT COUNT(*) FROM {}{}",
            dialect.quote_identifier(&table.name)?,
            where_sql.sql
        ),
        params: where_sql.params,
    })
}

pub fn update_one_plan(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &Update,
) -> Result<SqlUpdateOnePlan, OpenAuthError> {
    let table = resolve_table(schema, &query.model)?;
    let selection = selected_fields(table, &[])?;

    match dialect {
        SqlDialect::Postgres | SqlDialect::Sqlite => {
            let assignment = update_assignment(dialect, table, &query.data, 1)?;
            let where_sql =
                dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
            let row_id = match dialect {
                SqlDialect::Postgres => "ctid",
                SqlDialect::Sqlite => "rowid",
                SqlDialect::MySql => unreachable!("mysql handled by outer match"),
            };
            let mut params = assignment.params;
            params.extend(where_sql.params);
            Ok(SqlUpdateOnePlan::Returning(SqlReadStatement {
                statement: SqlStatement {
                    sql: format!(
                        "UPDATE {} SET {} WHERE {row_id} IN (SELECT {row_id} FROM {}{} LIMIT 1) RETURNING {}",
                        dialect.quote_identifier(&table.name)?,
                        assignment.sql.join(", "),
                        dialect.quote_identifier(&table.name)?,
                        where_sql.sql,
                        selection
                            .iter()
                            .map(|selected| dialect.quote_identifier(&selected.field.name))
                            .collect::<Result<Vec<_>, _>>()?
                            .join(", ")
                    ),
                    params,
                },
                selection,
            }))
        }
        SqlDialect::MySql => {
            let mut select_query = FindMany::new(query.model.clone());
            select_query.where_clauses = query.where_clauses.clone();
            select_query.limit = Some(1);
            let select = find_many_statement(dialect, schema, &select_query)?;
            let assignment = update_assignment(dialect, table, &query.data, 1)?;
            let where_sql =
                dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
            let mut params = assignment.params;
            params.extend(where_sql.params);
            Ok(SqlUpdateOnePlan::PreselectThenUpdate {
                select,
                update: SqlStatement {
                    sql: format!(
                        "UPDATE {} SET {}{} LIMIT 1",
                        dialect.quote_identifier(&table.name)?,
                        assignment.sql.join(", "),
                        where_sql.sql
                    ),
                    params,
                },
                data: query.data.clone(),
            })
        }
    }
}

pub fn update_many_statement(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &UpdateMany,
) -> Result<SqlStatement, OpenAuthError> {
    let table = resolve_table(schema, &query.model)?;
    let assignment = update_assignment(dialect, table, &query.data, 1)?;
    let where_sql =
        dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
    let mut params = assignment.params;
    params.extend(where_sql.params);
    Ok(SqlStatement {
        sql: format!(
            "UPDATE {} SET {}{}",
            dialect.quote_identifier(&table.name)?,
            assignment.sql.join(", "),
            where_sql.sql
        ),
        params,
    })
}

pub fn delete_one_statement(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &Delete,
) -> Result<SqlDeleteOnePlan, OpenAuthError> {
    let table = resolve_table(schema, &query.model)?;
    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
    let statement = match dialect {
        SqlDialect::Postgres => SqlStatement {
            sql: format!(
                "DELETE FROM {} WHERE ctid IN (SELECT ctid FROM {}{} LIMIT 1)",
                dialect.quote_identifier(&table.name)?,
                dialect.quote_identifier(&table.name)?,
                where_sql.sql
            ),
            params: where_sql.params,
        },
        SqlDialect::Sqlite => SqlStatement {
            sql: format!(
                "DELETE FROM {} WHERE rowid IN (SELECT rowid FROM {}{} LIMIT 1)",
                dialect.quote_identifier(&table.name)?,
                dialect.quote_identifier(&table.name)?,
                where_sql.sql
            ),
            params: where_sql.params,
        },
        SqlDialect::MySql => SqlStatement {
            sql: format!(
                "DELETE FROM {}{} LIMIT 1",
                dialect.quote_identifier(&table.name)?,
                where_sql.sql
            ),
            params: where_sql.params,
        },
    };
    let strategy = match dialect {
        SqlDialect::Postgres | SqlDialect::Sqlite => DeleteOneStrategy::NestedId,
        SqlDialect::MySql => DeleteOneStrategy::Limit,
    };
    Ok(SqlDeleteOnePlan {
        statement,
        strategy,
    })
}

pub fn delete_many_statement(
    dialect: SqlDialect,
    schema: &DbSchema,
    query: &DeleteMany,
) -> Result<SqlStatement, OpenAuthError> {
    let table = resolve_table(schema, &query.model)?;
    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
    Ok(SqlStatement {
        sql: format!(
            "DELETE FROM {}{}",
            dialect.quote_identifier(&table.name)?,
            where_sql.sql
        ),
        params: where_sql.params,
    })
}

struct UpdateAssignment {
    sql: Vec<String>,
    params: Vec<SqlParam>,
    next: usize,
}

fn update_assignment(
    dialect: SqlDialect,
    table: &DbTable,
    data: &DbRecord,
    first_placeholder: usize,
) -> Result<UpdateAssignment, OpenAuthError> {
    let mut sql = Vec::new();
    let mut params = Vec::new();
    for (field, value) in data {
        let (_, metadata) = resolve_field(table, field)?;
        params.push(SqlParam::new(metadata, value.clone()));
        sql.push(format!(
            "{} = {}",
            dialect.quote_identifier(&metadata.name)?,
            dialect.placeholder(first_placeholder + params.len() - 1)
        ));
    }
    Ok(UpdateAssignment {
        sql,
        next: first_placeholder + params.len(),
        params,
    })
}