gluesql_core/ast_builder/expr/
aggregate.rs1use {
2 super::ExprNode,
3 crate::{
4 ast::{Aggregate, CountArgExpr},
5 parse_sql::parse_expr,
6 result::{Error, Result},
7 translate::translate_expr,
8 },
9};
10
11#[derive(Clone, Debug)]
12pub enum AggregateNode<'a> {
13 Count(CountArgExprNode<'a>, bool), Sum(ExprNode<'a>, bool),
15 Min(ExprNode<'a>, bool),
16 Max(ExprNode<'a>, bool),
17 Avg(ExprNode<'a>, bool),
18 Variance(ExprNode<'a>, bool),
19 Stdev(ExprNode<'a>, bool),
20}
21
22#[derive(Clone, Debug)]
23pub enum CountArgExprNode<'a> {
24 Text(String),
25 Expr(ExprNode<'a>),
26}
27
28impl<'a> From<&'a str> for CountArgExprNode<'a> {
29 fn from(count_arg_str: &str) -> Self {
30 Self::Text(count_arg_str.to_owned())
31 }
32}
33
34impl<'a> From<ExprNode<'a>> for CountArgExprNode<'a> {
35 fn from(expr_node: ExprNode<'a>) -> Self {
36 Self::Expr(expr_node)
37 }
38}
39
40impl<'a> TryFrom<CountArgExprNode<'a>> for CountArgExpr {
41 type Error = Error;
42
43 fn try_from(count_expr_node: CountArgExprNode<'a>) -> Result<Self> {
44 match count_expr_node {
45 CountArgExprNode::Text(s) if &s == "*" => Ok(CountArgExpr::Wildcard),
46 CountArgExprNode::Text(s) => {
47 let expr = parse_expr(s).and_then(|expr| translate_expr(&expr))?;
48
49 Ok(CountArgExpr::Expr(expr))
50 }
51 CountArgExprNode::Expr(expr_node) => expr_node.try_into().map(CountArgExpr::Expr),
52 }
53 }
54}
55
56impl<'a> TryFrom<AggregateNode<'a>> for Aggregate {
57 type Error = Error;
58
59 fn try_from(aggr_node: AggregateNode<'a>) -> Result<Self> {
60 match aggr_node {
61 AggregateNode::Count(count_arg_expr_node, distinct) => count_arg_expr_node
62 .try_into()
63 .map(|expr| Aggregate::count(expr, distinct)),
64 AggregateNode::Sum(expr_node, distinct) => expr_node
65 .try_into()
66 .map(|expr| Aggregate::sum(expr, distinct)),
67 AggregateNode::Min(expr_node, distinct) => expr_node
68 .try_into()
69 .map(|expr| Aggregate::min(expr, distinct)),
70 AggregateNode::Max(expr_node, distinct) => expr_node
71 .try_into()
72 .map(|expr| Aggregate::max(expr, distinct)),
73 AggregateNode::Avg(expr_node, distinct) => expr_node
74 .try_into()
75 .map(|expr| Aggregate::avg(expr, distinct)),
76 AggregateNode::Variance(expr_node, distinct) => expr_node
77 .try_into()
78 .map(|expr| Aggregate::variance(expr, distinct)),
79 AggregateNode::Stdev(expr_node, distinct) => expr_node
80 .try_into()
81 .map(|expr| Aggregate::stdev(expr, distinct)),
82 }
83 }
84}
85
86impl<'a> ExprNode<'a> {
87 pub fn count(self) -> ExprNode<'a> {
88 ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), false)))
89 }
90
91 pub fn count_distinct(self) -> ExprNode<'a> {
92 ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), true)))
93 }
94
95 pub fn sum(self) -> ExprNode<'a> {
96 ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, false)))
97 }
98
99 pub fn sum_distinct(self) -> ExprNode<'a> {
100 ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, true)))
101 }
102
103 pub fn min(self) -> ExprNode<'a> {
104 ExprNode::Aggregate(Box::new(AggregateNode::Min(self, false)))
105 }
106
107 pub fn min_distinct(self) -> ExprNode<'a> {
108 ExprNode::Aggregate(Box::new(AggregateNode::Min(self, true)))
109 }
110
111 pub fn max(self) -> ExprNode<'a> {
112 ExprNode::Aggregate(Box::new(AggregateNode::Max(self, false)))
113 }
114
115 pub fn max_distinct(self) -> ExprNode<'a> {
116 ExprNode::Aggregate(Box::new(AggregateNode::Max(self, true)))
117 }
118
119 pub fn avg(self) -> ExprNode<'a> {
120 ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, false)))
121 }
122
123 pub fn avg_distinct(self) -> ExprNode<'a> {
124 ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, true)))
125 }
126
127 pub fn variance(self) -> ExprNode<'a> {
128 ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, false)))
129 }
130
131 pub fn variance_distinct(self) -> ExprNode<'a> {
132 ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, true)))
133 }
134
135 pub fn stdev(self) -> ExprNode<'a> {
136 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, false)))
137 }
138
139 pub fn stdev_distinct(self) -> ExprNode<'a> {
140 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, true)))
141 }
142}
143
144pub fn count<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
145 ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), false)))
146}
147
148pub fn count_distinct<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
149 ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), true)))
150}
151
152pub fn sum<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
153 ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), false)))
154}
155
156pub fn sum_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
157 ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), true)))
158}
159
160pub fn min<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
161 ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), false)))
162}
163
164pub fn min_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
165 ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), true)))
166}
167
168pub fn max<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
169 ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), false)))
170}
171
172pub fn max_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
173 ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), true)))
174}
175
176pub fn avg<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
177 ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), false)))
178}
179
180pub fn avg_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
181 ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), true)))
182}
183
184pub fn variance<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
185 ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), false)))
186}
187
188pub fn variance_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
189 ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), true)))
190}
191
192pub fn stdev<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
193 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), false)))
194}
195
196pub fn stdev_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
197 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), true)))
198}
199
200#[cfg(test)]
201mod tests {
202 use crate::ast_builder::{
203 avg, avg_distinct, col, count, count_distinct, max, max_distinct, min, min_distinct, stdev,
204 stdev_distinct, sum, sum_distinct, test_expr, variance, variance_distinct,
205 };
206
207 #[test]
208 fn aggregate() {
209 let actual = col("id").count();
210 let expected = "COUNT(id)";
211 test_expr(actual, expected);
212
213 let actual = count("id");
214 let expected = "COUNT(id)";
215 test_expr(actual, expected);
216
217 let actual = count("*");
218 let expected = "COUNT(*)";
219 test_expr(actual, expected);
220
221 let actual = count_distinct("*");
222 let expected = "COUNT(DISTINCT *)";
223 test_expr(actual, expected);
224
225 let actual = col("id").count_distinct();
226 let expected = "COUNT(DISTINCT id)";
227 test_expr(actual, expected);
228
229 let actual = count_distinct("id");
230 let expected = "COUNT(DISTINCT id)";
231 test_expr(actual, expected);
232
233 let actual = col("amount").sum();
234 let expected = "SUM(amount)";
235 test_expr(actual, expected);
236
237 let actual = sum("amount");
238 let expected = "SUM(amount)";
239 test_expr(actual, expected);
240
241 let actual = col("amount").sum_distinct();
242 let expected = "SUM(DISTINCT amount)";
243 test_expr(actual, expected);
244
245 let actual = sum_distinct("amount");
246 let expected = "SUM(DISTINCT amount)";
247 test_expr(actual, expected);
248
249 let actual = col("budget").min();
250 let expected = "MIN(budget)";
251 test_expr(actual, expected);
252
253 let actual = min("budget");
254 let expected = "MIN(budget)";
255 test_expr(actual, expected);
256
257 let actual = col("budget").min_distinct();
258 let expected = "MIN(DISTINCT budget)";
259 test_expr(actual, expected);
260
261 let actual = min_distinct("budget");
262 let expected = "MIN(DISTINCT budget)";
263 test_expr(actual, expected);
264
265 let actual = col("score").max();
266 let expected = "MAX(score)";
267 test_expr(actual, expected);
268
269 let actual = max("score");
270 let expected = "MAX(score)";
271 test_expr(actual, expected);
272
273 let actual = col("grade").max_distinct();
274 let expected = "MAX(DISTINCT grade)";
275 test_expr(actual, expected);
276
277 let actual = max_distinct("grade");
278 let expected = "MAX(DISTINCT grade)";
279 test_expr(actual, expected);
280
281 let actual = col("grade").avg();
282 let expected = "AVG(grade)";
283 test_expr(actual, expected);
284
285 let actual = avg("grade");
286 let expected = "AVG(grade)";
287 test_expr(actual, expected);
288
289 let actual = col("grade").avg_distinct();
290 let expected = "AVG(DISTINCT grade)";
291 test_expr(actual, expected);
292
293 let actual = avg_distinct("grade");
294 let expected = "AVG(DISTINCT grade)";
295 test_expr(actual, expected);
296
297 let actual = col("statistic").variance();
298 let expected = "VARIANCE(statistic)";
299 test_expr(actual, expected);
300
301 let actual = variance("statistic");
302 let expected = "VARIANCE(statistic)";
303 test_expr(actual, expected);
304
305 let actual = col("statistic").variance_distinct();
306 let expected = "VARIANCE(DISTINCT statistic)";
307 test_expr(actual, expected);
308
309 let actual = variance_distinct("statistic");
310 let expected = "VARIANCE(DISTINCT statistic)";
311 test_expr(actual, expected);
312
313 let actual = col("scatterplot").stdev();
314 let expected = "STDEV(scatterplot)";
315 test_expr(actual, expected);
316
317 let actual = stdev("scatterplot");
318 let expected = "STDEV(scatterplot)";
319 test_expr(actual, expected);
320
321 let actual = col("scatterplot").stdev_distinct();
322 let expected = "STDEV(DISTINCT scatterplot)";
323 test_expr(actual, expected);
324
325 let actual = stdev_distinct("scatterplot");
326 let expected = "STDEV(DISTINCT scatterplot)";
327 test_expr(actual, expected);
328 }
329}