gluesql_core/ast_builder/expr/
aggregate.rs

1use {
2    super::ExprNode,
3    crate::{
4        ast::{Aggregate, CountArgExpr},
5        parse_sql::parse_expr,
6        result::{Error, Result},
7        translate::translate_expr,
8    },
9};
10
11#[derive(Clone, Debug)]
12pub enum AggregateNode<'a> {
13    Count(CountArgExprNode<'a>, bool), // second field is distinct
14    Sum(ExprNode<'a>, bool),
15    Min(ExprNode<'a>, bool),
16    Max(ExprNode<'a>, bool),
17    Avg(ExprNode<'a>, bool),
18    Variance(ExprNode<'a>, bool),
19    Stdev(ExprNode<'a>, bool),
20}
21
22#[derive(Clone, Debug)]
23pub enum CountArgExprNode<'a> {
24    Text(String),
25    Expr(ExprNode<'a>),
26}
27
28impl<'a> From<&'a str> for CountArgExprNode<'a> {
29    fn from(count_arg_str: &str) -> Self {
30        Self::Text(count_arg_str.to_owned())
31    }
32}
33
34impl<'a> From<ExprNode<'a>> for CountArgExprNode<'a> {
35    fn from(expr_node: ExprNode<'a>) -> Self {
36        Self::Expr(expr_node)
37    }
38}
39
40impl<'a> TryFrom<CountArgExprNode<'a>> for CountArgExpr {
41    type Error = Error;
42
43    fn try_from(count_expr_node: CountArgExprNode<'a>) -> Result<Self> {
44        match count_expr_node {
45            CountArgExprNode::Text(s) if &s == "*" => Ok(CountArgExpr::Wildcard),
46            CountArgExprNode::Text(s) => {
47                let expr = parse_expr(s).and_then(|expr| translate_expr(&expr))?;
48
49                Ok(CountArgExpr::Expr(expr))
50            }
51            CountArgExprNode::Expr(expr_node) => expr_node.try_into().map(CountArgExpr::Expr),
52        }
53    }
54}
55
56impl<'a> TryFrom<AggregateNode<'a>> for Aggregate {
57    type Error = Error;
58
59    fn try_from(aggr_node: AggregateNode<'a>) -> Result<Self> {
60        match aggr_node {
61            AggregateNode::Count(count_arg_expr_node, distinct) => count_arg_expr_node
62                .try_into()
63                .map(|expr| Aggregate::count(expr, distinct)),
64            AggregateNode::Sum(expr_node, distinct) => expr_node
65                .try_into()
66                .map(|expr| Aggregate::sum(expr, distinct)),
67            AggregateNode::Min(expr_node, distinct) => expr_node
68                .try_into()
69                .map(|expr| Aggregate::min(expr, distinct)),
70            AggregateNode::Max(expr_node, distinct) => expr_node
71                .try_into()
72                .map(|expr| Aggregate::max(expr, distinct)),
73            AggregateNode::Avg(expr_node, distinct) => expr_node
74                .try_into()
75                .map(|expr| Aggregate::avg(expr, distinct)),
76            AggregateNode::Variance(expr_node, distinct) => expr_node
77                .try_into()
78                .map(|expr| Aggregate::variance(expr, distinct)),
79            AggregateNode::Stdev(expr_node, distinct) => expr_node
80                .try_into()
81                .map(|expr| Aggregate::stdev(expr, distinct)),
82        }
83    }
84}
85
86impl<'a> ExprNode<'a> {
87    pub fn count(self) -> ExprNode<'a> {
88        ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), false)))
89    }
90
91    pub fn count_distinct(self) -> ExprNode<'a> {
92        ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), true)))
93    }
94
95    pub fn sum(self) -> ExprNode<'a> {
96        ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, false)))
97    }
98
99    pub fn sum_distinct(self) -> ExprNode<'a> {
100        ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, true)))
101    }
102
103    pub fn min(self) -> ExprNode<'a> {
104        ExprNode::Aggregate(Box::new(AggregateNode::Min(self, false)))
105    }
106
107    pub fn min_distinct(self) -> ExprNode<'a> {
108        ExprNode::Aggregate(Box::new(AggregateNode::Min(self, true)))
109    }
110
111    pub fn max(self) -> ExprNode<'a> {
112        ExprNode::Aggregate(Box::new(AggregateNode::Max(self, false)))
113    }
114
115    pub fn max_distinct(self) -> ExprNode<'a> {
116        ExprNode::Aggregate(Box::new(AggregateNode::Max(self, true)))
117    }
118
119    pub fn avg(self) -> ExprNode<'a> {
120        ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, false)))
121    }
122
123    pub fn avg_distinct(self) -> ExprNode<'a> {
124        ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, true)))
125    }
126
127    pub fn variance(self) -> ExprNode<'a> {
128        ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, false)))
129    }
130
131    pub fn variance_distinct(self) -> ExprNode<'a> {
132        ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, true)))
133    }
134
135    pub fn stdev(self) -> ExprNode<'a> {
136        ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, false)))
137    }
138
139    pub fn stdev_distinct(self) -> ExprNode<'a> {
140        ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, true)))
141    }
142}
143
144pub fn count<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
145    ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), false)))
146}
147
148pub fn count_distinct<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
149    ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), true)))
150}
151
152pub fn sum<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
153    ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), false)))
154}
155
156pub fn sum_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
157    ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), true)))
158}
159
160pub fn min<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
161    ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), false)))
162}
163
164pub fn min_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
165    ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), true)))
166}
167
168pub fn max<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
169    ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), false)))
170}
171
172pub fn max_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
173    ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), true)))
174}
175
176pub fn avg<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
177    ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), false)))
178}
179
180pub fn avg_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
181    ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), true)))
182}
183
184pub fn variance<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
185    ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), false)))
186}
187
188pub fn variance_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
189    ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), true)))
190}
191
192pub fn stdev<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
193    ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), false)))
194}
195
196pub fn stdev_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
197    ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), true)))
198}
199
200#[cfg(test)]
201mod tests {
202    use crate::ast_builder::{
203        avg, avg_distinct, col, count, count_distinct, max, max_distinct, min, min_distinct, stdev,
204        stdev_distinct, sum, sum_distinct, test_expr, variance, variance_distinct,
205    };
206
207    #[test]
208    fn aggregate() {
209        let actual = col("id").count();
210        let expected = "COUNT(id)";
211        test_expr(actual, expected);
212
213        let actual = count("id");
214        let expected = "COUNT(id)";
215        test_expr(actual, expected);
216
217        let actual = count("*");
218        let expected = "COUNT(*)";
219        test_expr(actual, expected);
220
221        let actual = count_distinct("*");
222        let expected = "COUNT(DISTINCT *)";
223        test_expr(actual, expected);
224
225        let actual = col("id").count_distinct();
226        let expected = "COUNT(DISTINCT id)";
227        test_expr(actual, expected);
228
229        let actual = count_distinct("id");
230        let expected = "COUNT(DISTINCT id)";
231        test_expr(actual, expected);
232
233        let actual = col("amount").sum();
234        let expected = "SUM(amount)";
235        test_expr(actual, expected);
236
237        let actual = sum("amount");
238        let expected = "SUM(amount)";
239        test_expr(actual, expected);
240
241        let actual = col("amount").sum_distinct();
242        let expected = "SUM(DISTINCT amount)";
243        test_expr(actual, expected);
244
245        let actual = sum_distinct("amount");
246        let expected = "SUM(DISTINCT amount)";
247        test_expr(actual, expected);
248
249        let actual = col("budget").min();
250        let expected = "MIN(budget)";
251        test_expr(actual, expected);
252
253        let actual = min("budget");
254        let expected = "MIN(budget)";
255        test_expr(actual, expected);
256
257        let actual = col("budget").min_distinct();
258        let expected = "MIN(DISTINCT budget)";
259        test_expr(actual, expected);
260
261        let actual = min_distinct("budget");
262        let expected = "MIN(DISTINCT budget)";
263        test_expr(actual, expected);
264
265        let actual = col("score").max();
266        let expected = "MAX(score)";
267        test_expr(actual, expected);
268
269        let actual = max("score");
270        let expected = "MAX(score)";
271        test_expr(actual, expected);
272
273        let actual = col("grade").max_distinct();
274        let expected = "MAX(DISTINCT grade)";
275        test_expr(actual, expected);
276
277        let actual = max_distinct("grade");
278        let expected = "MAX(DISTINCT grade)";
279        test_expr(actual, expected);
280
281        let actual = col("grade").avg();
282        let expected = "AVG(grade)";
283        test_expr(actual, expected);
284
285        let actual = avg("grade");
286        let expected = "AVG(grade)";
287        test_expr(actual, expected);
288
289        let actual = col("grade").avg_distinct();
290        let expected = "AVG(DISTINCT grade)";
291        test_expr(actual, expected);
292
293        let actual = avg_distinct("grade");
294        let expected = "AVG(DISTINCT grade)";
295        test_expr(actual, expected);
296
297        let actual = col("statistic").variance();
298        let expected = "VARIANCE(statistic)";
299        test_expr(actual, expected);
300
301        let actual = variance("statistic");
302        let expected = "VARIANCE(statistic)";
303        test_expr(actual, expected);
304
305        let actual = col("statistic").variance_distinct();
306        let expected = "VARIANCE(DISTINCT statistic)";
307        test_expr(actual, expected);
308
309        let actual = variance_distinct("statistic");
310        let expected = "VARIANCE(DISTINCT statistic)";
311        test_expr(actual, expected);
312
313        let actual = col("scatterplot").stdev();
314        let expected = "STDEV(scatterplot)";
315        test_expr(actual, expected);
316
317        let actual = stdev("scatterplot");
318        let expected = "STDEV(scatterplot)";
319        test_expr(actual, expected);
320
321        let actual = col("scatterplot").stdev_distinct();
322        let expected = "STDEV(DISTINCT scatterplot)";
323        test_expr(actual, expected);
324
325        let actual = stdev_distinct("scatterplot");
326        let expected = "STDEV(DISTINCT scatterplot)";
327        test_expr(actual, expected);
328    }
329}