1use std::fmt::{self, Display};
2
3use serde::{Deserialize, Serialize};
4use sqlparser::ast::{self};
5use utoipa::{IntoParams, ToSchema};
6
7use crate::{
8 error::ParseError, malformed_query, query_metadata::FromClauseIdentifier, unsupported,
9};
10
11use super::support::{case_fold_identifier, extract_qualified_column, remove_outer_parens};
12
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema, IntoParams)]
18pub struct Aggregation {
19 pub function: KoronFunction,
21 pub column: String,
23 pub alias: Option<String>,
25}
26
27impl Aggregation {
28 pub(crate) fn extract(
29 from_clause_identifier: FromClauseIdentifier<'_>,
30 projection: &[ast::SelectItem],
31 ) -> Result<Self, ParseError> {
32 let multiple_aggregations = || {
33 Err(unsupported!("the SELECT clause must contain exactly one aggregation / analytic function. Nothing else is accepted.".to_string()))
34 };
35 let (expr, alias) = match projection {
37 [ast::SelectItem::UnnamedExpr(expr)] => (expr, None),
38 [ast::SelectItem::ExprWithAlias { expr, alias }] => {
39 (expr, Some(case_fold_identifier(alias)))
40 }
41 _ => {
42 return multiple_aggregations();
43 }
44 };
45 let ast::Expr::Function(function) = remove_outer_parens(expr) else {
47 return multiple_aggregations();
48 };
49
50 let ast::Function {
52 name,
53 args,
54 over,
55 distinct,
56 special: _,
57 order_by,
58 filter,
59 null_treatment,
60 } = function;
61 if over.is_some() {
62 return Err(unsupported!("window functions (OVER).".to_string()));
63 }
64 if *distinct {
65 return Err(unsupported!("DISTINCT.".to_string()));
66 }
67 if !order_by.is_empty() {
68 return Err(unsupported!("ORDER BY.".to_string()));
69 }
70 if filter.is_some() {
71 return Err(unsupported!("FILTER.".to_string()));
72 }
73 if null_treatment.is_some() {
74 return Err(unsupported!("IGNORE NULLS.".to_string()));
75 }
76 let (function, column) =
78 Self::validate_function_and_arguments(from_clause_identifier, name, args)?;
79
80 Ok(Self {
81 function,
82 column,
83 alias,
84 })
85 }
86
87 fn validate_function_and_arguments(
88 from_clause_identifier: FromClauseIdentifier<'_>,
89 function_name: &ast::ObjectName,
90 args: &[ast::FunctionArg],
91 ) -> Result<(KoronFunction, String), ParseError> {
92 let only_column_arg = |function| {
94 let column =
95 Self::extract_only_column_argument(from_clause_identifier, function_name, args)?;
96 Ok((function, column))
97 };
98
99 let ast::ObjectName(name_parts) = function_name;
100 if let [unqualified_name] = &name_parts[..] {
101 match &case_fold_identifier(unqualified_name)[..] {
103 "sum" => return only_column_arg(KoronFunction::Sum),
104 "count" => return only_column_arg(KoronFunction::Count),
105 "avg" => return only_column_arg(KoronFunction::Average),
106 "median" => return only_column_arg(KoronFunction::Median),
107 "variance" => return only_column_arg(KoronFunction::Variance),
108 "stddev" => return only_column_arg(KoronFunction::StandardDeviation),
109 "min" => return only_column_arg(KoronFunction::Min),
110 "max" => return only_column_arg(KoronFunction::Max),
111 _ => (),
112 }
113 }
114 Err(unsupported!(format!(
115 "unrecognized or unsupported function: {function_name}."
116 )))
117 }
118
119 fn extract_only_column_argument(
120 from_clause_identifier: FromClauseIdentifier<'_>,
121 function_name: &ast::ObjectName,
122 args: &[ast::FunctionArg],
123 ) -> Result<String, ParseError> {
124 match args {
126 [arg] => {
127 let arg_expr = Self::extract_unnamed_argument(arg)?;
128 Self::extract_aggregated_column(from_clause_identifier, function_name, arg_expr, "")
129 }
130 _ => Err(malformed_query!(format!(
131 "the {function_name} function takes exactly 1 argument, but {} {verb} provided.",
132 args.len(),
133 verb = if args.len() == 1 { "is" } else { "are" },
134 ))),
135 }
136 }
137
138 fn extract_unnamed_argument(
139 arg: &ast::FunctionArg,
140 ) -> Result<&ast::FunctionArgExpr, ParseError> {
141 match arg {
142 ast::FunctionArg::Named { .. } => Err(unsupported!(format!(
143 "named function arguments (such as {arg})."
144 ))),
145 ast::FunctionArg::Unnamed(arg_expr) => Ok(arg_expr),
146 }
147 }
148
149 fn extract_aggregated_column(
150 from_clause_identifier: FromClauseIdentifier<'_>,
151 function_name: &ast::ObjectName,
152 arg_expr: &ast::FunctionArgExpr,
153 which_arg: &str,
154 ) -> Result<String, ParseError> {
155 if let ast::FunctionArgExpr::Expr(expr) = arg_expr {
156 match remove_outer_parens(expr) {
157 ast::Expr::Identifier(ident) => return Ok(case_fold_identifier(ident)),
158 compound_identifier @ ast::Expr::CompoundIdentifier(name_parts) => {
159 return extract_qualified_column(
160 from_clause_identifier,
161 compound_identifier,
162 name_parts,
163 );
164 }
165 _ => (),
166 }
167 }
168 Err(unsupported!(format!(
169 "only a column name is supported as the {which_arg}{space}argument of the {function_name} function.",
170 space = if which_arg.is_empty() { "" } else { " " },
171 )))
172 }
173}
174
175#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
177pub enum KoronFunction {
178 Sum,
180 #[default]
182 Count,
183 Average,
185 Median,
187 Variance,
189 StandardDeviation,
191 Min,
193 Max,
195}
196
197impl Display for KoronFunction {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 match self {
200 Self::Sum => write!(f, "SUM"),
201 Self::Count => write!(f, "COUNT"),
202 Self::Average => write!(f, "AVG"),
203 Self::Median => write!(f, "MEDIAN"),
204 Self::Variance => write!(f, "VARIANCE"),
205 Self::StandardDeviation => write!(f, "STDDEV"),
206 Self::Min => write!(f, "MIN"),
207 Self::Max => write!(f, "MAX"),
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::KoronFunction;
215
216 #[test]
217 fn koron_fn_display() {
218 let cases = [
219 (KoronFunction::Count, "COUNT"),
220 (KoronFunction::Sum, "SUM"),
221 (KoronFunction::Variance, "VARIANCE"),
222 (KoronFunction::Median, "MEDIAN"),
223 (KoronFunction::Average, "AVG"),
224 (KoronFunction::StandardDeviation, "STDDEV"),
225 (KoronFunction::Min, "MIN"),
226 (KoronFunction::Max, "MAX"),
227 ];
228 for (koron_fn, expected) in cases {
229 assert_eq!(koron_fn.to_string(), expected.to_string());
230 }
231 }
232}