saas-rs-sdk 0.6.4

The SaaS RS SDK
use super::{create::find_field, deserialize::deserialize, schema::Table};
use crate::storage::Error;
use crate::storage::config_store::config_store::FindOptions;
use change_case::snake_case;
use futures_util::TryStreamExt;
use pbbson::Model;
use pbbson::bson::Bson;
use sqlx::{Execute, Pool, Postgres, QueryBuilder};

pub(super) async fn perform(pool: &Pool<Postgres>, table: &Table, options: FindOptions) -> Result<Vec<Model>, Error> {
    // SELECT <field_names> FROM <table>
    let field_names = get_field_names(table, &options.projection);
    let mut qb: QueryBuilder<Postgres> = QueryBuilder::new("SELECT ");
    qb.push(field_names);
    qb.push(" FROM ");
    qb.push(&table.name);

    // WHERE ...
    for (index, (k, v)) in options.filter.iter().enumerate() {
        if index == 0 {
            qb.push(" WHERE ");
        } else if index > 1 {
            qb.push(" AND ");
        }
        let field_name = snake_case(k);
        qb.push(&field_name);
        qb.push(" = ");
        match v {
            Bson::Null => {
                qb.push("NULL");
            }
            Bson::String(s) => {
                qb.push_bind(s);
                if let Some(field) = find_field(table, &field_name) {
                    qb.push("::");
                    qb.push(field.data_type.clone());
                }
            }
            _ => {
                qb.push_bind(v.to_string());
                if let Some(field) = find_field(table, &field_name) {
                    qb.push("::");
                    qb.push(field.data_type.clone());
                }
            }
        }
    }

    // OFFSET start
    if let (Some(offset), _) = options.pagination {
        qb.push(format!(" OFFSET {offset}"));
    }

    // LIMIT count
    if let (_, Some(limit)) = options.pagination {
        qb.push(format!(" LIMIT {limit}"));
    }

    let query = qb.build();
    let sql = query.sql();
    log::debug!(sql; "Finding many");

    // Perform query
    let mut rows = query.fetch(pool);

    // Collect results
    let mut models = vec![];
    while let Some(row) = rows.try_next().await? {
        let model = deserialize(row, table)?;
        models.push(model);
    }
    Ok(models)
}

fn get_field_names(table: &Table, projection: &Option<Model>) -> String {
    match projection {
        None => "*".to_string(),
        Some(projection) => match has_subtractive_projections(projection) {
            false => {
                let mut fields = vec![];
                for (k, v) in projection.iter() {
                    if let Some(n) = v.as_i32()
                        && n == 1
                    {
                        let field_name = snake_case(k);
                        fields.push(field_name);
                    }
                }
                fields.join(",")
            }
            true => {
                let mut fields: Vec<_> = table.fields_by_name.keys().cloned().collect();
                for (k, v) in projection.iter() {
                    if let Some(n) = v.as_i32()
                        && n == 0
                    {
                        let field_name = snake_case(k);
                        fields.retain(|f| f != &field_name);
                    }
                }
                fields.join(",")
            }
        },
    }
}

fn has_subtractive_projections(projection: &Model) -> bool {
    for (_k, v) in projection.iter() {
        if let Some(n) = v.as_i32()
            && n == 0
        {
            return true;
        }
    }
    false
}