saas-rs-sdk 0.6.0

The SaaS RS SDK
use super::{deserialize::deserialize, schema::Table};
use crate::storage::Error;
use crate::storage::config_store::config_store::FindOptions;
use change_case::snake_case;
use futures_util::TryStreamExt;
use pbbson::Model;
use pbbson::bson::Bson;
use sqlx::{Pool, Postgres};

pub(super) async fn perform(pool: &Pool<Postgres>, table: &Table, options: FindOptions) -> Result<Vec<Model>, Error> {
    // SELECT <field_names> FROM <table>
    let table_name = &table.name;
    let field_names = get_field_names(table, &options.projection);
    let mut sql = format!("SELECT {field_names} FROM {table_name}");

    // WHERE ...
    for (index, (k, v)) in options.filter.iter().enumerate() {
        if index == 0 {
            sql = format!("{sql} WHERE ");
        } else if index > 1 {
            sql = format!("{sql} AND ");
        }
        let field_name = snake_case(k);
        let param_num = index + 1;
        match v {
            Bson::Null => {
                sql = format!("{sql}{field_name} = NULL");
            }
            _ => {
                sql = format!("{sql}{field_name} = ${param_num}");
            }
        }
    }

    // OFFSET start
    if let (Some(offset), _) = options.pagination {
        sql = format!("{sql} OFFSET {offset}");
    }

    // LIMIT count
    if let (_, Some(limit)) = options.pagination {
        sql = format!("{sql} LIMIT {limit}");
    }

    // Bind params
    let mut query = sqlx::query(&sql);
    for (_k, v) in options.filter.iter() {
        match v {
            Bson::Null => {
                // already handled above in WHERE loop
            }
            Bson::String(s) => {
                query = query.bind(s);
            }
            _ => {
                query = query.bind(v.to_string());
            }
        }
    }

    // Perform query
    let mut rows = query.fetch(pool);

    // Collect results
    let mut models = vec![];
    while let Some(row) = rows.try_next().await.map_err(|e| Error::internal(e.to_string()))? {
        let model = deserialize(row, table)?;
        models.push(model);
    }
    Ok(models)
}

fn get_field_names(table: &Table, projection: &Option<Model>) -> String {
    match projection {
        None => "*".to_string(),
        Some(projection) => match has_subtractive_projections(projection) {
            false => {
                let mut fields = vec![];
                for (k, v) in projection.iter() {
                    if let Some(n) = v.as_i32()
                        && n == 1
                    {
                        let field_name = snake_case(k);
                        fields.push(field_name);
                    }
                }
                fields.join(",")
            }
            true => {
                let mut fields: Vec<_> = table.fields_by_name.keys().cloned().collect();
                for (k, v) in projection.iter() {
                    if let Some(n) = v.as_i32()
                        && n == 0
                    {
                        let field_name = snake_case(k);
                        fields.retain(|f| f != &field_name);
                    }
                }
                fields.join(",")
            }
        },
    }
}

fn has_subtractive_projections(projection: &Model) -> bool {
    for (_k, v) in projection.iter() {
        if let Some(n) = v.as_i32()
            && n == 0
        {
            return true;
        }
    }
    false
}