cratestack-sqlx 0.2.1

Rust-native schema-first framework for typed HTTP APIs, generated clients, and backend services.
Documentation
use cratestack_core::{CoolContext, CoolError};

use crate::{
    FilterExpr, ModelDescriptor, OrderClause, SqlxRuntime, render::render_read_policy_sql,
    render::render_scoped_select_sql,
};

use super::support::{push_order_and_paging, push_scoped_conditions};

#[derive(Debug, Clone)]
pub struct FindMany<'a, M: 'static, PK: 'static> {
    pub(crate) runtime: &'a SqlxRuntime,
    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
    pub(crate) filters: Vec<FilterExpr>,
    pub(crate) order_by: Vec<OrderClause>,
    pub(crate) limit: Option<i64>,
    pub(crate) offset: Option<i64>,
}

impl<'a, M: 'static, PK: 'static> FindMany<'a, M, PK> {
    pub fn where_(mut self, filter: crate::Filter) -> Self {
        self.filters.push(FilterExpr::from(filter));
        self
    }

    pub fn where_expr(mut self, filter: FilterExpr) -> Self {
        self.filters.push(filter);
        self
    }

    pub fn where_any(mut self, filters: impl IntoIterator<Item = FilterExpr>) -> Self {
        self.filters.push(FilterExpr::any(filters));
        self
    }

    pub fn order_by(mut self, clause: OrderClause) -> Self {
        self.order_by.push(clause);
        self
    }

    pub fn limit(mut self, limit: i64) -> Self {
        self.limit = Some(limit);
        self
    }

    pub fn offset(mut self, offset: i64) -> Self {
        self.offset = Some(offset);
        self
    }

    pub fn preview_sql(&self) -> String {
        let mut sql = format!(
            "SELECT {} FROM {}",
            self.descriptor.select_projection(),
            self.descriptor.table_name,
        );
        let order_by = self.effective_order_by();

        let mut bind_index = 1usize;
        if !self.filters.is_empty() {
            sql.push_str(" WHERE ");
            for (index, filter) in self.filters.iter().enumerate() {
                if index > 0 {
                    sql.push_str(" AND ");
                }
                crate::render::render_filter_expr_sql(filter, &mut sql, &mut bind_index);
            }
        }

        if !order_by.is_empty() {
            sql.push_str(" ORDER BY ");
            for (index, clause) in order_by.iter().enumerate() {
                if index > 0 {
                    sql.push_str(", ");
                }
                crate::render::render_order_clause_sql(clause, &mut sql);
            }
        }

        match (self.limit, self.offset) {
            (Some(_), Some(_)) => {
                sql.push_str(&format!(" LIMIT ${bind_index} OFFSET ${}", bind_index + 1));
            }
            (Some(_), None) => {
                sql.push_str(&format!(" LIMIT ${bind_index}"));
            }
            (None, Some(_)) => {
                sql.push_str(&format!(" OFFSET ${bind_index}"));
            }
            (None, None) => {}
        }

        sql
    }

    pub fn preview_scoped_sql(&self, ctx: &CoolContext) -> String {
        let order_by = self.effective_order_by();
        render_scoped_select_sql(
            self.descriptor,
            &self.filters,
            &order_by,
            self.limit,
            self.offset,
            ctx,
        )
    }

    pub async fn run(self, ctx: &CoolContext) -> Result<Vec<M>, CoolError>
    where
        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
    {
        let order_by = self.effective_order_by();
        let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
        query
            .push(self.descriptor.select_projection())
            .push(" FROM ")
            .push(self.descriptor.table_name);

        push_scoped_conditions(
            &mut query,
            self.descriptor,
            &self.filters,
            None::<(&'static str, i64)>,
            ctx,
        );
        push_order_and_paging(&mut query, &order_by, self.limit, self.offset);

        query
            .build_query_as::<M>()
            .fetch_all(self.runtime.pool())
            .await
            .map_err(|error| CoolError::Database(error.to_string()))
    }

    fn effective_order_by(&self) -> Vec<OrderClause> {
        let mut order_by = self.order_by.clone();
        let Some(direction) = order_by
            .iter()
            .find(|clause| clause.is_relation_scalar())
            .map(OrderClause::direction)
        else {
            return order_by;
        };

        if order_by
            .iter()
            .any(|clause| clause.targets_column(self.descriptor.primary_key))
        {
            return order_by;
        }

        order_by.push(OrderClause::column(self.descriptor.primary_key, direction));
        order_by
    }
}

#[derive(Debug, Clone)]
pub struct FindUnique<'a, M: 'static, PK: 'static> {
    pub(crate) runtime: &'a SqlxRuntime,
    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
    pub(crate) id: PK,
}

impl<'a, M: 'static, PK: 'static> FindUnique<'a, M, PK> {
    pub fn preview_sql(&self) -> String {
        format!(
            "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
            self.descriptor.select_projection(),
            self.descriptor.table_name,
            self.descriptor.primary_key,
        )
    }

    pub fn preview_scoped_sql(&self, ctx: &CoolContext) -> String {
        let mut sql = format!(
            "SELECT {} FROM {}",
            self.descriptor.select_projection(),
            self.descriptor.table_name,
        );
        let mut bind_index = 1usize;
        if let Some(policy_clause) = render_read_policy_sql(
            self.descriptor.detail_allow_policies,
            self.descriptor.detail_deny_policies,
            ctx,
            &mut bind_index,
        ) {
            sql.push_str(&format!(
                " WHERE ({policy_clause}) AND {} = ${bind_index} LIMIT 1",
                self.descriptor.primary_key
            ));
        } else {
            sql.push_str(&format!(
                " WHERE {} = ${bind_index} LIMIT 1",
                self.descriptor.primary_key
            ));
        }
        sql
    }

    pub async fn run(self, ctx: &CoolContext) -> Result<Option<M>, CoolError>
    where
        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
    {
        let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
        query
            .push(self.descriptor.select_projection())
            .push(" FROM ")
            .push(self.descriptor.table_name);
        push_scoped_conditions(
            &mut query,
            self.descriptor,
            &[],
            Some((self.descriptor.primary_key, self.id)),
            ctx,
        );
        query.push(" LIMIT 1");

        query
            .build_query_as::<M>()
            .fetch_optional(self.runtime.pool())
            .await
            .map_err(|error| CoolError::Database(error.to_string()))
    }
}