gluesql_core/ast_builder/expr/
aggregate.rs1use {
2 super::ExprNode,
3 crate::{
4 ast::{Aggregate, CountArgExpr},
5 parse_sql::parse_expr,
6 result::{Error, Result},
7 translate::{NO_PARAMS, translate_expr},
8 },
9};
10
11#[derive(Clone, Debug)]
12pub enum AggregateNode<'a> {
13 Count(CountArgExprNode<'a>, bool), Sum(ExprNode<'a>, bool),
15 Min(ExprNode<'a>, bool),
16 Max(ExprNode<'a>, bool),
17 Avg(ExprNode<'a>, bool),
18 Variance(ExprNode<'a>, bool),
19 Stdev(ExprNode<'a>, bool),
20}
21
22#[derive(Clone, Debug)]
23pub enum CountArgExprNode<'a> {
24 Text(String),
25 Expr(ExprNode<'a>),
26}
27
28impl<'a> From<&'a str> for CountArgExprNode<'a> {
29 fn from(count_arg_str: &str) -> Self {
30 Self::Text(count_arg_str.to_owned())
31 }
32}
33
34impl<'a> From<ExprNode<'a>> for CountArgExprNode<'a> {
35 fn from(expr_node: ExprNode<'a>) -> Self {
36 Self::Expr(expr_node)
37 }
38}
39
40impl<'a> TryFrom<CountArgExprNode<'a>> for CountArgExpr {
41 type Error = Error;
42
43 fn try_from(count_expr_node: CountArgExprNode<'a>) -> Result<Self> {
44 match count_expr_node {
45 CountArgExprNode::Text(s) if &s == "*" => Ok(CountArgExpr::Wildcard),
46 CountArgExprNode::Text(s) => {
47 let expr = parse_expr(s).and_then(|expr| translate_expr(&expr, NO_PARAMS))?;
48
49 Ok(CountArgExpr::Expr(expr))
50 }
51 CountArgExprNode::Expr(expr_node) => expr_node.try_into().map(CountArgExpr::Expr),
52 }
53 }
54}
55
56impl<'a> TryFrom<AggregateNode<'a>> for Aggregate {
57 type Error = Error;
58
59 fn try_from(aggr_node: AggregateNode<'a>) -> Result<Self> {
60 match aggr_node {
61 AggregateNode::Count(count_arg_expr_node, distinct) => count_arg_expr_node
62 .try_into()
63 .map(|expr| Aggregate::count(expr, distinct)),
64 AggregateNode::Sum(expr_node, distinct) => expr_node
65 .try_into()
66 .map(|expr| Aggregate::sum(expr, distinct)),
67 AggregateNode::Min(expr_node, distinct) => expr_node
68 .try_into()
69 .map(|expr| Aggregate::min(expr, distinct)),
70 AggregateNode::Max(expr_node, distinct) => expr_node
71 .try_into()
72 .map(|expr| Aggregate::max(expr, distinct)),
73 AggregateNode::Avg(expr_node, distinct) => expr_node
74 .try_into()
75 .map(|expr| Aggregate::avg(expr, distinct)),
76 AggregateNode::Variance(expr_node, distinct) => expr_node
77 .try_into()
78 .map(|expr| Aggregate::variance(expr, distinct)),
79 AggregateNode::Stdev(expr_node, distinct) => expr_node
80 .try_into()
81 .map(|expr| Aggregate::stdev(expr, distinct)),
82 }
83 }
84}
85
86impl<'a> ExprNode<'a> {
87 #[must_use]
88 pub fn count(self) -> ExprNode<'a> {
89 ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), false)))
90 }
91
92 #[must_use]
93 pub fn count_distinct(self) -> ExprNode<'a> {
94 ExprNode::Aggregate(Box::new(AggregateNode::Count(self.into(), true)))
95 }
96
97 #[must_use]
98 pub fn sum(self) -> ExprNode<'a> {
99 ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, false)))
100 }
101
102 #[must_use]
103 pub fn sum_distinct(self) -> ExprNode<'a> {
104 ExprNode::Aggregate(Box::new(AggregateNode::Sum(self, true)))
105 }
106
107 #[must_use]
108 pub fn min(self) -> ExprNode<'a> {
109 ExprNode::Aggregate(Box::new(AggregateNode::Min(self, false)))
110 }
111
112 #[must_use]
113 pub fn min_distinct(self) -> ExprNode<'a> {
114 ExprNode::Aggregate(Box::new(AggregateNode::Min(self, true)))
115 }
116
117 #[must_use]
118 pub fn max(self) -> ExprNode<'a> {
119 ExprNode::Aggregate(Box::new(AggregateNode::Max(self, false)))
120 }
121
122 #[must_use]
123 pub fn max_distinct(self) -> ExprNode<'a> {
124 ExprNode::Aggregate(Box::new(AggregateNode::Max(self, true)))
125 }
126
127 #[must_use]
128 pub fn avg(self) -> ExprNode<'a> {
129 ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, false)))
130 }
131
132 #[must_use]
133 pub fn avg_distinct(self) -> ExprNode<'a> {
134 ExprNode::Aggregate(Box::new(AggregateNode::Avg(self, true)))
135 }
136
137 #[must_use]
138 pub fn variance(self) -> ExprNode<'a> {
139 ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, false)))
140 }
141
142 #[must_use]
143 pub fn variance_distinct(self) -> ExprNode<'a> {
144 ExprNode::Aggregate(Box::new(AggregateNode::Variance(self, true)))
145 }
146
147 #[must_use]
148 pub fn stdev(self) -> ExprNode<'a> {
149 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, false)))
150 }
151
152 #[must_use]
153 pub fn stdev_distinct(self) -> ExprNode<'a> {
154 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(self, true)))
155 }
156}
157
158pub fn count<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
159 ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), false)))
160}
161
162pub fn count_distinct<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
163 ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into(), true)))
164}
165
166pub fn sum<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
167 ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), false)))
168}
169
170pub fn sum_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
171 ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into(), true)))
172}
173
174pub fn min<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
175 ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), false)))
176}
177
178pub fn min_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
179 ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into(), true)))
180}
181
182pub fn max<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
183 ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), false)))
184}
185
186pub fn max_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
187 ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into(), true)))
188}
189
190pub fn avg<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
191 ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), false)))
192}
193
194pub fn avg_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
195 ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into(), true)))
196}
197
198pub fn variance<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
199 ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), false)))
200}
201
202pub fn variance_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
203 ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into(), true)))
204}
205
206pub fn stdev<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
207 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), false)))
208}
209
210pub fn stdev_distinct<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
211 ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into(), true)))
212}
213
214#[cfg(test)]
215mod tests {
216 use crate::ast_builder::{
217 avg, avg_distinct, col, count, count_distinct, max, max_distinct, min, min_distinct, stdev,
218 stdev_distinct, sum, sum_distinct, test_expr, variance, variance_distinct,
219 };
220
221 #[test]
222 fn aggregate() {
223 let actual = col("id").count();
224 let expected = "COUNT(id)";
225 test_expr(actual, expected);
226
227 let actual = count("id");
228 let expected = "COUNT(id)";
229 test_expr(actual, expected);
230
231 let actual = count("*");
232 let expected = "COUNT(*)";
233 test_expr(actual, expected);
234
235 let actual = count_distinct("*");
236 let expected = "COUNT(DISTINCT *)";
237 test_expr(actual, expected);
238
239 let actual = col("id").count_distinct();
240 let expected = "COUNT(DISTINCT id)";
241 test_expr(actual, expected);
242
243 let actual = count_distinct("id");
244 let expected = "COUNT(DISTINCT id)";
245 test_expr(actual, expected);
246
247 let actual = col("amount").sum();
248 let expected = "SUM(amount)";
249 test_expr(actual, expected);
250
251 let actual = sum("amount");
252 let expected = "SUM(amount)";
253 test_expr(actual, expected);
254
255 let actual = col("amount").sum_distinct();
256 let expected = "SUM(DISTINCT amount)";
257 test_expr(actual, expected);
258
259 let actual = sum_distinct("amount");
260 let expected = "SUM(DISTINCT amount)";
261 test_expr(actual, expected);
262
263 let actual = col("budget").min();
264 let expected = "MIN(budget)";
265 test_expr(actual, expected);
266
267 let actual = min("budget");
268 let expected = "MIN(budget)";
269 test_expr(actual, expected);
270
271 let actual = col("budget").min_distinct();
272 let expected = "MIN(DISTINCT budget)";
273 test_expr(actual, expected);
274
275 let actual = min_distinct("budget");
276 let expected = "MIN(DISTINCT budget)";
277 test_expr(actual, expected);
278
279 let actual = col("score").max();
280 let expected = "MAX(score)";
281 test_expr(actual, expected);
282
283 let actual = max("score");
284 let expected = "MAX(score)";
285 test_expr(actual, expected);
286
287 let actual = col("grade").max_distinct();
288 let expected = "MAX(DISTINCT grade)";
289 test_expr(actual, expected);
290
291 let actual = max_distinct("grade");
292 let expected = "MAX(DISTINCT grade)";
293 test_expr(actual, expected);
294
295 let actual = col("grade").avg();
296 let expected = "AVG(grade)";
297 test_expr(actual, expected);
298
299 let actual = avg("grade");
300 let expected = "AVG(grade)";
301 test_expr(actual, expected);
302
303 let actual = col("grade").avg_distinct();
304 let expected = "AVG(DISTINCT grade)";
305 test_expr(actual, expected);
306
307 let actual = avg_distinct("grade");
308 let expected = "AVG(DISTINCT grade)";
309 test_expr(actual, expected);
310
311 let actual = col("statistic").variance();
312 let expected = "VARIANCE(statistic)";
313 test_expr(actual, expected);
314
315 let actual = variance("statistic");
316 let expected = "VARIANCE(statistic)";
317 test_expr(actual, expected);
318
319 let actual = col("statistic").variance_distinct();
320 let expected = "VARIANCE(DISTINCT statistic)";
321 test_expr(actual, expected);
322
323 let actual = variance_distinct("statistic");
324 let expected = "VARIANCE(DISTINCT statistic)";
325 test_expr(actual, expected);
326
327 let actual = col("scatterplot").stdev();
328 let expected = "STDEV(scatterplot)";
329 test_expr(actual, expected);
330
331 let actual = stdev("scatterplot");
332 let expected = "STDEV(scatterplot)";
333 test_expr(actual, expected);
334
335 let actual = col("scatterplot").stdev_distinct();
336 let expected = "STDEV(DISTINCT scatterplot)";
337 test_expr(actual, expected);
338
339 let actual = stdev_distinct("scatterplot");
340 let expected = "STDEV(DISTINCT scatterplot)";
341 test_expr(actual, expected);
342 }
343}