gluesql_core/ast_builder/expr/
aggregate.rs

1use {
2    super::ExprNode,
3    crate::{
4        ast::{Aggregate, CountArgExpr},
5        parse_sql::parse_expr,
6        result::{Error, Result},
7        translate::translate_expr,
8    },
9};
10
11#[derive(Clone, Debug)]
12pub enum AggregateNode<'a> {
13    Count(CountArgExprNode<'a>),
14    Sum(ExprNode<'a>),
15    Min(ExprNode<'a>),
16    Max(ExprNode<'a>),
17    Avg(ExprNode<'a>),
18    Variance(ExprNode<'a>),
19    Stdev(ExprNode<'a>),
20}
21
22#[derive(Clone, Debug)]
23pub enum CountArgExprNode<'a> {
24    Text(String),
25    Expr(ExprNode<'a>),
26}
27
28impl<'a> From<&'a str> for CountArgExprNode<'a> {
29    fn from(count_arg_str: &str) -> Self {
30        Self::Text(count_arg_str.to_owned())
31    }
32}
33
34impl<'a> From<ExprNode<'a>> for CountArgExprNode<'a> {
35    fn from(expr_node: ExprNode<'a>) -> Self {
36        Self::Expr(expr_node)
37    }
38}
39
40impl<'a> TryFrom<CountArgExprNode<'a>> for CountArgExpr {
41    type Error = Error;
42
43    fn try_from(count_expr_node: CountArgExprNode<'a>) -> Result<Self> {
44        match count_expr_node {
45            CountArgExprNode::Text(s) if &s == "*" => Ok(CountArgExpr::Wildcard),
46            CountArgExprNode::Text(s) => {
47                let expr = parse_expr(s).and_then(|expr| translate_expr(&expr))?;
48
49                Ok(CountArgExpr::Expr(expr))
50            }
51            CountArgExprNode::Expr(expr_node) => expr_node.try_into().map(CountArgExpr::Expr),
52        }
53    }
54}
55
56impl<'a> TryFrom<AggregateNode<'a>> for Aggregate {
57    type Error = Error;
58
59    fn try_from(aggr_node: AggregateNode<'a>) -> Result<Self> {
60        match aggr_node {
61            AggregateNode::Count(count_arg_expr_node) => {
62                count_arg_expr_node.try_into().map(Aggregate::Count)
63            }
64            AggregateNode::Sum(expr_node) => expr_node.try_into().map(Aggregate::Sum),
65            AggregateNode::Min(expr_node) => expr_node.try_into().map(Aggregate::Min),
66            AggregateNode::Max(expr_node) => expr_node.try_into().map(Aggregate::Max),
67            AggregateNode::Avg(expr_node) => expr_node.try_into().map(Aggregate::Avg),
68            AggregateNode::Variance(expr_node) => expr_node.try_into().map(Aggregate::Variance),
69            AggregateNode::Stdev(expr_node) => expr_node.try_into().map(Aggregate::Stdev),
70        }
71    }
72}
73
74impl<'a> ExprNode<'a> {
75    pub fn count(self) -> Self {
76        count(self)
77    }
78
79    pub fn sum(self) -> Self {
80        sum(self)
81    }
82
83    pub fn min(self) -> Self {
84        min(self)
85    }
86
87    pub fn max(self) -> Self {
88        max(self)
89    }
90
91    pub fn avg(self) -> Self {
92        avg(self)
93    }
94
95    pub fn variance(self) -> Self {
96        variance(self)
97    }
98
99    pub fn stdev(self) -> Self {
100        stdev(self)
101    }
102}
103
104pub fn count<'a, T: Into<CountArgExprNode<'a>>>(expr: T) -> ExprNode<'a> {
105    ExprNode::Aggregate(Box::new(AggregateNode::Count(expr.into())))
106}
107
108pub fn sum<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
109    ExprNode::Aggregate(Box::new(AggregateNode::Sum(expr.into())))
110}
111
112pub fn min<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
113    ExprNode::Aggregate(Box::new(AggregateNode::Min(expr.into())))
114}
115
116pub fn max<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
117    ExprNode::Aggregate(Box::new(AggregateNode::Max(expr.into())))
118}
119
120pub fn avg<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
121    ExprNode::Aggregate(Box::new(AggregateNode::Avg(expr.into())))
122}
123
124pub fn variance<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
125    ExprNode::Aggregate(Box::new(AggregateNode::Variance(expr.into())))
126}
127
128pub fn stdev<'a, T: Into<ExprNode<'a>>>(expr: T) -> ExprNode<'a> {
129    ExprNode::Aggregate(Box::new(AggregateNode::Stdev(expr.into())))
130}
131
132#[cfg(test)]
133mod tests {
134    use crate::ast_builder::{avg, col, count, max, min, stdev, sum, test_expr, variance};
135
136    #[test]
137    fn aggregate() {
138        let actual = col("id").count();
139        let expected = "COUNT(id)";
140        test_expr(actual, expected);
141
142        let actual = count("id");
143        let expected = "COUNT(id)";
144        test_expr(actual, expected);
145
146        let actual = count("*");
147        let expected = "COUNT(*)";
148        test_expr(actual, expected);
149
150        let actual = col("amount").sum();
151        let expected = "SUM(amount)";
152        test_expr(actual, expected);
153
154        let actual = sum("amount");
155        let expected = "SUM(amount)";
156        test_expr(actual, expected);
157
158        let actual = col("budget").min();
159        let expected = "MIN(budget)";
160        test_expr(actual, expected);
161        let actual = min("budget");
162        let expected = "MIN(budget)";
163        test_expr(actual, expected);
164
165        let actual = col("score").max();
166        let expected = "MAX(score)";
167        test_expr(actual, expected);
168
169        let actual = max("score");
170        let expected = "MAX(score)";
171        test_expr(actual, expected);
172
173        let actual = col("grade").avg();
174        let expected = "AVG(grade)";
175        test_expr(actual, expected);
176
177        let actual = avg("grade");
178        let expected = "AVG(grade)";
179        test_expr(actual, expected);
180
181        let actual = col("statistic").variance();
182        let expected = "VARIANCE(statistic)";
183        test_expr(actual, expected);
184
185        let actual = variance("statistic");
186        let expected = "VARIANCE(statistic)";
187        test_expr(actual, expected);
188
189        let actual = col("scatterplot").stdev();
190        let expected = "STDEV(scatterplot)";
191        test_expr(actual, expected);
192
193        let actual = stdev("scatterplot");
194        let expected = "STDEV(scatterplot)";
195        test_expr(actual, expected);
196    }
197}