koron_query_parser/
aggregation.rs

1use std::fmt::{self, Display};
2
3use serde::{Deserialize, Serialize};
4use sqlparser::ast::{self};
5use utoipa::{IntoParams, ToSchema};
6
7use crate::{
8    error::ParseError, malformed_query, query_metadata::FromClauseIdentifier, unsupported,
9};
10
11use super::support::{case_fold_identifier, extract_qualified_column, remove_outer_parens};
12
13/// An aggregation that's computed over the values of a column.
14///
15/// Represents an occurrence of an aggregation such as `function(column)`
16/// within the `SELECT` clause of a query.
17#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema, IntoParams)]
18pub struct Aggregation {
19    /// The function used as aggregator of column's values.
20    pub function: KoronFunction,
21    /// The name of the column on which the function is executed.
22    pub column: String,
23    /// The alias that's assigned to the result of the function: `function(column) AS alias`.
24    pub alias: Option<String>,
25}
26
27impl Aggregation {
28    pub(crate) fn extract(
29        from_clause_identifier: FromClauseIdentifier<'_>,
30        projection: &[ast::SelectItem],
31    ) -> Result<Self, ParseError> {
32        let multiple_aggregations = || {
33            Err(unsupported!("the SELECT clause must contain exactly one aggregation / analytic function. Nothing else is accepted.".to_string()))
34        };
35        //check if single operation in the projection
36        let (expr, alias) = match projection {
37            [ast::SelectItem::UnnamedExpr(expr)] => (expr, None),
38            [ast::SelectItem::ExprWithAlias { expr, alias }] => {
39                (expr, Some(case_fold_identifier(alias)))
40            }
41            _ => {
42                return multiple_aggregations();
43            }
44        };
45        //remove outer parens if any and check if the contained expression is a single function
46        let ast::Expr::Function(function) = remove_outer_parens(expr) else {
47            return multiple_aggregations();
48        };
49
50        //destructure function
51        let ast::Function {
52            name,
53            args,
54            over,
55            distinct,
56            special: _,
57            order_by,
58            filter,
59            null_treatment,
60        } = function;
61        if over.is_some() {
62            return Err(unsupported!("window functions (OVER).".to_string()));
63        }
64        if *distinct {
65            return Err(unsupported!("DISTINCT.".to_string()));
66        }
67        if !order_by.is_empty() {
68            return Err(unsupported!("ORDER BY.".to_string()));
69        }
70        if filter.is_some() {
71            return Err(unsupported!("FILTER.".to_string()));
72        }
73        if null_treatment.is_some() {
74            return Err(unsupported!("IGNORE NULLS.".to_string()));
75        }
76        //check if it is a supported function
77        let (function, column) =
78            Self::validate_function_and_arguments(from_clause_identifier, name, args)?;
79
80        Ok(Self {
81            function,
82            column,
83            alias,
84        })
85    }
86
87    fn validate_function_and_arguments(
88        from_clause_identifier: FromClauseIdentifier<'_>,
89        function_name: &ast::ObjectName,
90        args: &[ast::FunctionArg],
91    ) -> Result<(KoronFunction, String), ParseError> {
92        //closure that extracts column information from the statement
93        let only_column_arg = |function| {
94            let column =
95                Self::extract_only_column_argument(from_clause_identifier, function_name, args)?;
96            Ok((function, column))
97        };
98
99        let ast::ObjectName(name_parts) = function_name;
100        if let [unqualified_name] = &name_parts[..] {
101            //currently only these four functions are supported by Koron
102            match &case_fold_identifier(unqualified_name)[..] {
103                "sum" => return only_column_arg(KoronFunction::Sum),
104                "count" => return only_column_arg(KoronFunction::Count),
105                "avg" => return only_column_arg(KoronFunction::Average),
106                "median" => return only_column_arg(KoronFunction::Median),
107                "variance" => return only_column_arg(KoronFunction::Variance),
108                "stddev" => return only_column_arg(KoronFunction::StandardDeviation),
109                "min" => return only_column_arg(KoronFunction::Min),
110                "max" => return only_column_arg(KoronFunction::Max),
111                _ => (),
112            }
113        }
114        Err(unsupported!(format!(
115            "unrecognized or unsupported function: {function_name}."
116        )))
117    }
118
119    fn extract_only_column_argument(
120        from_clause_identifier: FromClauseIdentifier<'_>,
121        function_name: &ast::ObjectName,
122        args: &[ast::FunctionArg],
123    ) -> Result<String, ParseError> {
124        //currently only functions that takes as input a single column are supported (i.e. a single argument)
125        match args {
126            [arg] => {
127                let arg_expr = Self::extract_unnamed_argument(arg)?;
128                Self::extract_aggregated_column(from_clause_identifier, function_name, arg_expr, "")
129            }
130            _ => Err(malformed_query!(format!(
131                "the {function_name} function takes exactly 1 argument, but {} {verb} provided.",
132                args.len(),
133                verb = if args.len() == 1 { "is" } else { "are" },
134            ))),
135        }
136    }
137
138    fn extract_unnamed_argument(
139        arg: &ast::FunctionArg,
140    ) -> Result<&ast::FunctionArgExpr, ParseError> {
141        match arg {
142            ast::FunctionArg::Named { .. } => Err(unsupported!(format!(
143                "named function arguments (such as {arg})."
144            ))),
145            ast::FunctionArg::Unnamed(arg_expr) => Ok(arg_expr),
146        }
147    }
148
149    fn extract_aggregated_column(
150        from_clause_identifier: FromClauseIdentifier<'_>,
151        function_name: &ast::ObjectName,
152        arg_expr: &ast::FunctionArgExpr,
153        which_arg: &str,
154    ) -> Result<String, ParseError> {
155        if let ast::FunctionArgExpr::Expr(expr) = arg_expr {
156            match remove_outer_parens(expr) {
157                ast::Expr::Identifier(ident) => return Ok(case_fold_identifier(ident)),
158                compound_identifier @ ast::Expr::CompoundIdentifier(name_parts) => {
159                    return extract_qualified_column(
160                        from_clause_identifier,
161                        compound_identifier,
162                        name_parts,
163                    );
164                }
165                _ => (),
166            }
167        }
168        Err(unsupported!(format!(
169                "only a column name is supported as the {which_arg}{space}argument of the {function_name} function.",
170                space = if which_arg.is_empty() { "" } else { " " },
171            )))
172    }
173}
174
175/// Represents a Koron aggregation / analytic function.
176#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
177pub enum KoronFunction {
178    /// The `sum` aggregation function.
179    Sum,
180    /// The `count` aggregation function.
181    #[default]
182    Count,
183    /// The `average` aggregation function.
184    Average,
185    /// The `median` aggregation function.
186    Median,
187    /// The `variance` aggregation function.
188    Variance,
189    /// The `stddev` aggregation function.
190    StandardDeviation,
191    /// The `min` aggregation function.
192    Min,
193    /// The `max` aggregation function.
194    Max,
195}
196
197impl Display for KoronFunction {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        match self {
200            Self::Sum => write!(f, "SUM"),
201            Self::Count => write!(f, "COUNT"),
202            Self::Average => write!(f, "AVG"),
203            Self::Median => write!(f, "MEDIAN"),
204            Self::Variance => write!(f, "VARIANCE"),
205            Self::StandardDeviation => write!(f, "STDDEV"),
206            Self::Min => write!(f, "MIN"),
207            Self::Max => write!(f, "MAX"),
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::KoronFunction;
215
216    #[test]
217    fn koron_fn_display() {
218        let cases = [
219            (KoronFunction::Count, "COUNT"),
220            (KoronFunction::Sum, "SUM"),
221            (KoronFunction::Variance, "VARIANCE"),
222            (KoronFunction::Median, "MEDIAN"),
223            (KoronFunction::Average, "AVG"),
224            (KoronFunction::StandardDeviation, "STDDEV"),
225            (KoronFunction::Min, "MIN"),
226            (KoronFunction::Max, "MAX"),
227        ];
228        for (koron_fn, expected) in cases {
229            assert_eq!(koron_fn.to_string(), expected.to_string());
230        }
231    }
232}