gluesql_core/ast_builder/expr/
aggregate.rs

1use {
2    super::ExprNode,
3    crate::{
4        ast::{Aggregate, CountArgExpr},
5        parse_sql::parse_expr,
6        result::{Error, Result},
7        translate::{NO_PARAMS, translate_expr},
8    },
9};
10
11#[derive(Clone, Debug)]
12pub enum AggregateNode<'a> {
13    Count(CountArgExprNode<'a>, bool), // second field is distinct
14    Sum(ExprNode<'a>, bool),
15    Min(ExprNode<'a>, bool),
16    Max(ExprNode<'a>, bool),
17    Avg(ExprNode<'a>, bool),
18    Variance(ExprNode<'a>, bool),
19    Stdev(ExprNode<'a>, bool),
20}
21
22#[derive(Clone, Debug)]
23pub enum CountArgExprNode<'a> {
24    Text(String),
25    Expr(ExprNode<'a>),
26}
27
28impl<'a> From<&'a str> for CountArgExprNode<'a> {
29    fn from(count_arg_str: &str) -> Self {
30        Self::Text(count_arg_str.to_owned())
31    }
32}
33
34impl<'a> From<ExprNode<'a>> for CountArgExprNode<'a> {
35    fn from(expr_node: ExprNode<'a>) -> Self {
36        Self::Expr(expr_node)
37    }
38}
39
40impl<'a> TryFrom<CountArgExprNode<'a>> for CountArgExpr {
41    type Error = Error;
42
43    fn try_from(count_expr_node: CountArgExprNode<'a>) -> Result<Self> {
44        match count_expr_node {
45            CountArgExprNode::Text(s) if &s == "*" => Ok(CountArgExpr::Wildcard),
46            CountArgExprNode::Text(s) => {
47                let expr = parse_expr(s).and_then(|expr| translate_expr(&expr, NO_PARAMS))?;
48
49                Ok(CountArgExpr::Expr(expr))
50            }
51            CountArgExprNode::Expr(expr_node) => expr_node.try_into().map(CountArgExpr::Expr),
52        }
53    }
54}
55
56impl<'a> TryFrom<AggregateNode<'a>> for Aggregate {
57    type Error = Error;
58
59    fn try_from(aggr_node: AggregateNode<'a>) -> Result<Self> {
60        match aggr_node {
61            AggregateNode::Count(count_arg_expr_node, distinct) => count_arg_expr_node
62                .try_into()
63                .map(|expr| Aggregate::count(expr, distinct)),
64            AggregateNode::Sum(expr_node, distinct) => expr_node
65                .try_into()
66                .map(|expr| Aggregate::sum(expr, distinct)),
67            AggregateNode::Min(expr_node, distinct) => expr_node
68                .try_into()
69                .map(|expr| Aggregate::min(expr, distinct)),
70            AggregateNode::Max(expr_node, distinct) => expr_node
71                .try_into()
72                .map(|expr| Aggregate::max(expr, distinct)),
73            AggregateNode::Avg(expr_node, distinct) => expr_node
74                .try_into()
75                .map(|expr| Aggregate::avg(expr, distinct)),
76            AggregateNode::Variance(expr_node, distinct) => expr_node
77                .try_into()
78                .map(|expr| Aggregate::variance(expr, distinct)),
79            AggregateNode::Stdev(expr_node, distinct) => expr_node
80                .try_into()
81                .map(|expr| Aggregate::stdev(expr, distinct)),
82        }
83    }
84}
85
86impl<'a> ExprNode<'a> {
87    #[must_use]
88    pub fn count(self) -> ExprNode<'a> {
89        ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), false)))
90    }
91
92    #[must_use]
93    pub fn count_distinct(self) -> ExprNode<'a> {
94        ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), true)))
95    }
96
97    #[must_use]
98    pub fn sum(self) -> ExprNode<'a> {
99        ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, false)))
100    }
101
102    #[must_use]
103    pub fn sum_distinct(self) -> ExprNode<'a> {
104        ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, true)))
105    }
106
107    #[must_use]
108    pub fn min(self) -> ExprNode<'a> {
109        ExprNode::Aggregate(Box::new(AggregateNode::Min(self, false)))
110    }
111
112    #[must_use]
113    pub fn min_distinct(self) -> ExprNode<'a> {
114        ExprNode::Aggregate(Box::new(AggregateNode::Min(self, true)))
115    }
116
117    #[must_use]
118    pub fn max(self) -> ExprNode<'a> {
119        ExprNode::Aggregate(Box::new(AggregateNode::Max(self, false)))
120    }
121
122    #[must_use]
123    pub fn max_distinct(self) -> ExprNode<'a> {
124        ExprNode::Aggregate(Box::new(AggregateNode::Max(self, true)))
125    }
126
127    #[must_use]
128    pub fn avg(self) -> ExprNode<'a> {
129        ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, false)))
130    }
131
132    #[must_use]
133    pub fn avg_distinct(self) -> ExprNode<'a> {
134        ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, true)))
135    }
136
137    #[must_use]
138    pub fn variance(self) -> ExprNode<'a> {
139        ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, false)))
140    }
141
142    #[must_use]
143    pub fn variance_distinct(self) -> ExprNode<'a> {
144        ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, true)))
145    }
146
147    #[must_use]
148    pub fn stdev(self) -> ExprNode<'a> {
149        ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, false)))
150    }
151
152    #[must_use]
153    pub fn stdev_distinct(self) -> ExprNode<'a> {
154        ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, true)))
155    }
156}
157
158pub fn count<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
159    ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), false)))
160}
161
162pub fn count_distinct<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
163    ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), true)))
164}
165
166pub fn sum<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
167    ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), false)))
168}
169
170pub fn sum_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
171    ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), true)))
172}
173
174pub fn min<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
175    ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), false)))
176}
177
178pub fn min_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
179    ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), true)))
180}
181
182pub fn max<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
183    ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), false)))
184}
185
186pub fn max_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
187    ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), true)))
188}
189
190pub fn avg<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
191    ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), false)))
192}
193
194pub fn avg_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
195    ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), true)))
196}
197
198pub fn variance<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
199    ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), false)))
200}
201
202pub fn variance_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
203    ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), true)))
204}
205
206pub fn stdev<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
207    ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), false)))
208}
209
210pub fn stdev_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
211    ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), true)))
212}
213
214#[cfg(test)]
215mod tests {
216    use crate::ast_builder::{
217        avg, avg_distinct, col, count, count_distinct, max, max_distinct, min, min_distinct, stdev,
218        stdev_distinct, sum, sum_distinct, test_expr, variance, variance_distinct,
219    };
220
221    #[test]
222    fn aggregate() {
223        let actual = col("id").count();
224        let expected = "COUNT(id)";
225        test_expr(actual, expected);
226
227        let actual = count("id");
228        let expected = "COUNT(id)";
229        test_expr(actual, expected);
230
231        let actual = count("*");
232        let expected = "COUNT(*)";
233        test_expr(actual, expected);
234
235        let actual = count_distinct("*");
236        let expected = "COUNT(DISTINCT *)";
237        test_expr(actual, expected);
238
239        let actual = col("id").count_distinct();
240        let expected = "COUNT(DISTINCT id)";
241        test_expr(actual, expected);
242
243        let actual = count_distinct("id");
244        let expected = "COUNT(DISTINCT id)";
245        test_expr(actual, expected);
246
247        let actual = col("amount").sum();
248        let expected = "SUM(amount)";
249        test_expr(actual, expected);
250
251        let actual = sum("amount");
252        let expected = "SUM(amount)";
253        test_expr(actual, expected);
254
255        let actual = col("amount").sum_distinct();
256        let expected = "SUM(DISTINCT amount)";
257        test_expr(actual, expected);
258
259        let actual = sum_distinct("amount");
260        let expected = "SUM(DISTINCT amount)";
261        test_expr(actual, expected);
262
263        let actual = col("budget").min();
264        let expected = "MIN(budget)";
265        test_expr(actual, expected);
266
267        let actual = min("budget");
268        let expected = "MIN(budget)";
269        test_expr(actual, expected);
270
271        let actual = col("budget").min_distinct();
272        let expected = "MIN(DISTINCT budget)";
273        test_expr(actual, expected);
274
275        let actual = min_distinct("budget");
276        let expected = "MIN(DISTINCT budget)";
277        test_expr(actual, expected);
278
279        let actual = col("score").max();
280        let expected = "MAX(score)";
281        test_expr(actual, expected);
282
283        let actual = max("score");
284        let expected = "MAX(score)";
285        test_expr(actual, expected);
286
287        let actual = col("grade").max_distinct();
288        let expected = "MAX(DISTINCT grade)";
289        test_expr(actual, expected);
290
291        let actual = max_distinct("grade");
292        let expected = "MAX(DISTINCT grade)";
293        test_expr(actual, expected);
294
295        let actual = col("grade").avg();
296        let expected = "AVG(grade)";
297        test_expr(actual, expected);
298
299        let actual = avg("grade");
300        let expected = "AVG(grade)";
301        test_expr(actual, expected);
302
303        let actual = col("grade").avg_distinct();
304        let expected = "AVG(DISTINCT grade)";
305        test_expr(actual, expected);
306
307        let actual = avg_distinct("grade");
308        let expected = "AVG(DISTINCT grade)";
309        test_expr(actual, expected);
310
311        let actual = col("statistic").variance();
312        let expected = "VARIANCE(statistic)";
313        test_expr(actual, expected);
314
315        let actual = variance("statistic");
316        let expected = "VARIANCE(statistic)";
317        test_expr(actual, expected);
318
319        let actual = col("statistic").variance_distinct();
320        let expected = "VARIANCE(DISTINCT statistic)";
321        test_expr(actual, expected);
322
323        let actual = variance_distinct("statistic");
324        let expected = "VARIANCE(DISTINCT statistic)";
325        test_expr(actual, expected);
326
327        let actual = col("scatterplot").stdev();
328        let expected = "STDEV(scatterplot)";
329        test_expr(actual, expected);
330
331        let actual = stdev("scatterplot");
332        let expected = "STDEV(scatterplot)";
333        test_expr(actual, expected);
334
335        let actual = col("scatterplot").stdev_distinct();
336        let expected = "STDEV(DISTINCT scatterplot)";
337        test_expr(actual, expected);
338
339        let actual = stdev_distinct("scatterplot");
340        let expected = "STDEV(DISTINCT scatterplot)";
341        test_expr(actual, expected);
342    }
343}