Skip to main content

cratestack_sqlx/query/read/
aggregate_column.rs

1//! `aggregate.sum/avg/min/max(col)` — single-column scalar aggregates
2//! with filter + read policy. Caller picks the decode type at the
3//! call site since PG's `SUM(int)` returns i64, `AVG(int)` returns
4//! f64/Decimal, etc.
5
6use cratestack_core::{CoolContext, CoolError};
7use cratestack_sql::IntoColumnName;
8
9use crate::query::support::{ReadPolicyKind, push_scoped_conditions};
10use crate::{FilterExpr, ModelDescriptor, SqlxRuntime, sqlx};
11
12use super::aggregate::AggregateOp;
13
14#[derive(Debug, Clone)]
15pub struct AggregateColumn<'a, M: 'static, PK: 'static> {
16    runtime: &'a SqlxRuntime,
17    descriptor: &'static ModelDescriptor<M, PK>,
18    op: AggregateOp,
19    column: &'static str,
20    filters: Vec<FilterExpr>,
21}
22
23impl<'a, M: 'static, PK: 'static> AggregateColumn<'a, M, PK> {
24    pub(super) fn new<C: IntoColumnName>(
25        runtime: &'a SqlxRuntime,
26        descriptor: &'static ModelDescriptor<M, PK>,
27        op: AggregateOp,
28        column: C,
29    ) -> Self {
30        Self {
31            runtime,
32            descriptor,
33            op,
34            column: column.into_column_name(),
35            filters: Vec::new(),
36        }
37    }
38
39    pub fn where_(mut self, filter: crate::Filter) -> Self {
40        self.filters.push(FilterExpr::from(filter));
41        self
42    }
43
44    pub fn where_expr(mut self, filter: FilterExpr) -> Self {
45        self.filters.push(filter);
46        self
47    }
48
49    pub fn where_any(mut self, filters: impl IntoIterator<Item = FilterExpr>) -> Self {
50        self.filters.push(FilterExpr::any(filters));
51        self
52    }
53
54    pub fn where_optional<F>(mut self, filter: Option<F>) -> Self
55    where
56        F: Into<FilterExpr>,
57    {
58        if let Some(filter) = filter {
59            self.filters.push(filter.into());
60        }
61        self
62    }
63
64    fn build_query<'q>(&self, ctx: &CoolContext) -> sqlx::QueryBuilder<'q, sqlx::Postgres> {
65        let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
66        query
67            .push(self.op.function_name())
68            .push("(")
69            .push(self.column)
70            .push(") FROM ")
71            .push(self.descriptor.table_name);
72        push_scoped_conditions(
73            &mut query,
74            self.descriptor,
75            &self.filters,
76            None::<(&'static str, i64)>,
77            ctx,
78            ReadPolicyKind::List,
79        );
80        query
81    }
82
83    pub async fn run<T>(self, ctx: &CoolContext) -> Result<Option<T>, CoolError>
84    where
85        T: Send + Unpin + for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
86    {
87        let mut query = self.build_query(ctx);
88        let value: (Option<T>,) = query
89            .build_query_as::<(Option<T>,)>()
90            .fetch_one(self.runtime.pool())
91            .await
92            .map_err(|error| CoolError::Database(error.to_string()))?;
93        Ok(value.0)
94    }
95
96    pub async fn run_in_tx<'tx, T>(
97        self,
98        tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
99        ctx: &CoolContext,
100    ) -> Result<Option<T>, CoolError>
101    where
102        T: Send + Unpin + for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
103    {
104        let mut query = self.build_query(ctx);
105        let value: (Option<T>,) = query
106            .build_query_as::<(Option<T>,)>()
107            .fetch_one(&mut **tx)
108            .await
109            .map_err(|error| CoolError::Database(error.to_string()))?;
110        Ok(value.0)
111    }
112}