Skip to main content

cratestack_sqlx/delegate/
scoped_aggregate.rs

1//! Context-bound aggregate wrappers — `ScopedAggregate` dispatches to
2//! `ScopedAggregateCount` / `ScopedAggregateColumn`.
3
4use cratestack_core::{CoolContext, CoolError};
5
6use crate::{Aggregate, AggregateColumn, AggregateCount, Filter, FilterExpr, sqlx};
7
8#[derive(Debug, Clone)]
9pub struct ScopedAggregate<'a, M: 'static, PK: 'static> {
10    request: Aggregate<'a, M, PK>,
11    ctx: CoolContext,
12}
13
14impl<'a, M: 'static, PK: 'static> ScopedAggregate<'a, M, PK> {
15    pub(super) fn new(request: Aggregate<'a, M, PK>, ctx: CoolContext) -> Self {
16        Self { request, ctx }
17    }
18
19    pub fn count(self) -> ScopedAggregateCount<'a, M, PK> {
20        ScopedAggregateCount {
21            request: self.request.count(),
22            ctx: self.ctx,
23        }
24    }
25
26    pub fn sum<C: cratestack_sql::IntoColumnName>(
27        self,
28        column: C,
29    ) -> ScopedAggregateColumn<'a, M, PK> {
30        ScopedAggregateColumn {
31            request: self.request.sum(column),
32            ctx: self.ctx,
33        }
34    }
35
36    pub fn avg<C: cratestack_sql::IntoColumnName>(
37        self,
38        column: C,
39    ) -> ScopedAggregateColumn<'a, M, PK> {
40        ScopedAggregateColumn {
41            request: self.request.avg(column),
42            ctx: self.ctx,
43        }
44    }
45
46    pub fn min<C: cratestack_sql::IntoColumnName>(
47        self,
48        column: C,
49    ) -> ScopedAggregateColumn<'a, M, PK> {
50        ScopedAggregateColumn {
51            request: self.request.min(column),
52            ctx: self.ctx,
53        }
54    }
55
56    pub fn max<C: cratestack_sql::IntoColumnName>(
57        self,
58        column: C,
59    ) -> ScopedAggregateColumn<'a, M, PK> {
60        ScopedAggregateColumn {
61            request: self.request.max(column),
62            ctx: self.ctx,
63        }
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct ScopedAggregateCount<'a, M: 'static, PK: 'static> {
69    request: AggregateCount<'a, M, PK>,
70    ctx: CoolContext,
71}
72
73impl<'a, M: 'static, PK: 'static> ScopedAggregateCount<'a, M, PK> {
74    pub fn where_(mut self, filter: Filter) -> Self {
75        self.request = self.request.where_(filter);
76        self
77    }
78
79    pub fn where_expr(mut self, filter: FilterExpr) -> Self {
80        self.request = self.request.where_expr(filter);
81        self
82    }
83
84    pub fn where_any(mut self, filters: impl IntoIterator<Item = FilterExpr>) -> Self {
85        self.request = self.request.where_any(filters);
86        self
87    }
88
89    pub fn where_optional<F>(mut self, filter: Option<F>) -> Self
90    where
91        F: Into<FilterExpr>,
92    {
93        self.request = self.request.where_optional(filter);
94        self
95    }
96
97    pub async fn run(self) -> Result<i64, CoolError> {
98        self.request.run(&self.ctx).await
99    }
100
101    pub async fn run_in_tx<'tx>(
102        self,
103        tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
104    ) -> Result<i64, CoolError> {
105        self.request.run_in_tx(tx, &self.ctx).await
106    }
107}
108
109#[derive(Debug, Clone)]
110pub struct ScopedAggregateColumn<'a, M: 'static, PK: 'static> {
111    request: AggregateColumn<'a, M, PK>,
112    ctx: CoolContext,
113}
114
115impl<'a, M: 'static, PK: 'static> ScopedAggregateColumn<'a, M, PK> {
116    pub fn where_(mut self, filter: Filter) -> Self {
117        self.request = self.request.where_(filter);
118        self
119    }
120
121    pub fn where_expr(mut self, filter: FilterExpr) -> Self {
122        self.request = self.request.where_expr(filter);
123        self
124    }
125
126    pub fn where_any(mut self, filters: impl IntoIterator<Item = FilterExpr>) -> Self {
127        self.request = self.request.where_any(filters);
128        self
129    }
130
131    pub fn where_optional<F>(mut self, filter: Option<F>) -> Self
132    where
133        F: Into<FilterExpr>,
134    {
135        self.request = self.request.where_optional(filter);
136        self
137    }
138
139    pub async fn run<T>(self) -> Result<Option<T>, CoolError>
140    where
141        T: Send + Unpin + for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
142    {
143        self.request.run::<T>(&self.ctx).await
144    }
145
146    pub async fn run_in_tx<'tx, T>(
147        self,
148        tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
149    ) -> Result<Option<T>, CoolError>
150    where
151        T: Send + Unpin + for<'r> sqlx::Decode<'r, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
152    {
153        self.request.run_in_tx::<T>(tx, &self.ctx).await
154    }
155}