fnck_sql/binder/
aggregate.rs

1use ahash::RandomState;
2use itertools::Itertools;
3use sqlparser::ast::{Expr, OrderByExpr};
4use std::collections::HashSet;
5
6use super::{Binder, QueryBindStep};
7use crate::errors::DatabaseError;
8use crate::expression::function::scala::ScalarFunction;
9use crate::planner::LogicalPlan;
10use crate::storage::Transaction;
11use crate::types::value::DataValue;
12use crate::{
13    expression::ScalarExpression,
14    planner::operator::{aggregate::AggregateOperator, sort::SortField},
15};
16
17impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A> {
18    pub fn bind_aggregate(
19        &mut self,
20        children: LogicalPlan,
21        agg_calls: Vec<ScalarExpression>,
22        groupby_exprs: Vec<ScalarExpression>,
23    ) -> LogicalPlan {
24        self.context.step(QueryBindStep::Agg);
25
26        AggregateOperator::build(children, agg_calls, groupby_exprs, false)
27    }
28
29    pub fn extract_select_aggregate(
30        &mut self,
31        select_items: &mut [ScalarExpression],
32    ) -> Result<(), DatabaseError> {
33        for column in select_items {
34            self.visit_column_agg_expr(column)?;
35        }
36        Ok(())
37    }
38
39    pub fn extract_group_by_aggregate(
40        &mut self,
41        select_list: &mut [ScalarExpression],
42        groupby: &[Expr],
43    ) -> Result<(), DatabaseError> {
44        let mut group_by_exprs = Vec::with_capacity(groupby.len());
45        for expr in groupby.iter() {
46            group_by_exprs.push(self.bind_expr(expr)?);
47        }
48
49        self.validate_groupby_illegal_column(select_list, &group_by_exprs)?;
50
51        for expr in group_by_exprs.iter_mut() {
52            self.visit_group_by_expr(select_list, expr);
53        }
54        Ok(())
55    }
56
57    pub fn extract_having_orderby_aggregate(
58        &mut self,
59        having: &Option<Expr>,
60        orderbys: &[OrderByExpr],
61    ) -> Result<(Option<ScalarExpression>, Option<Vec<SortField>>), DatabaseError> {
62        // Extract having expression.
63        let return_having = if let Some(having) = having {
64            let mut having = self.bind_expr(having)?;
65            self.visit_column_agg_expr(&mut having)?;
66
67            Some(having)
68        } else {
69            None
70        };
71
72        // Extract orderby expression.
73        let return_orderby = if !orderbys.is_empty() {
74            let mut return_orderby = vec![];
75            for orderby in orderbys {
76                let OrderByExpr {
77                    expr,
78                    asc,
79                    nulls_first,
80                } = orderby;
81                let mut expr = self.bind_expr(expr)?;
82                self.visit_column_agg_expr(&mut expr)?;
83
84                return_orderby.push(SortField::new(
85                    expr,
86                    asc.map_or(true, |asc| asc),
87                    nulls_first.map_or(false, |first| first),
88                ));
89            }
90            Some(return_orderby)
91        } else {
92            None
93        };
94        Ok((return_having, return_orderby))
95    }
96
97    fn visit_column_agg_expr(&mut self, expr: &mut ScalarExpression) -> Result<(), DatabaseError> {
98        match expr {
99            ScalarExpression::AggCall { .. } => {
100                self.context.agg_calls.push(expr.clone());
101            }
102            ScalarExpression::TypeCast { expr, .. } => self.visit_column_agg_expr(expr)?,
103            ScalarExpression::IsNull { expr, .. } => self.visit_column_agg_expr(expr)?,
104            ScalarExpression::Unary { expr, .. } => self.visit_column_agg_expr(expr)?,
105            ScalarExpression::Alias { expr, .. } => self.visit_column_agg_expr(expr)?,
106            ScalarExpression::Binary {
107                left_expr,
108                right_expr,
109                ..
110            } => {
111                self.visit_column_agg_expr(left_expr)?;
112                self.visit_column_agg_expr(right_expr)?;
113            }
114            ScalarExpression::In { expr, args, .. } => {
115                self.visit_column_agg_expr(expr)?;
116                for arg in args {
117                    self.visit_column_agg_expr(arg)?;
118                }
119            }
120            ScalarExpression::Between {
121                expr,
122                left_expr,
123                right_expr,
124                ..
125            } => {
126                self.visit_column_agg_expr(expr)?;
127                self.visit_column_agg_expr(left_expr)?;
128                self.visit_column_agg_expr(right_expr)?;
129            }
130            ScalarExpression::SubString {
131                expr,
132                for_expr,
133                from_expr,
134            } => {
135                self.visit_column_agg_expr(expr)?;
136                if let Some(expr) = for_expr {
137                    self.visit_column_agg_expr(expr)?;
138                }
139                if let Some(expr) = from_expr {
140                    self.visit_column_agg_expr(expr)?;
141                }
142            }
143            ScalarExpression::Position { expr, in_expr } => {
144                self.visit_column_agg_expr(expr)?;
145                self.visit_column_agg_expr(in_expr)?;
146            }
147            ScalarExpression::Trim {
148                expr,
149                trim_what_expr,
150                ..
151            } => {
152                self.visit_column_agg_expr(expr)?;
153                if let Some(trim_what_expr) = trim_what_expr {
154                    self.visit_column_agg_expr(trim_what_expr)?;
155                }
156            }
157            ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
158            ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
159            ScalarExpression::Tuple(args)
160            | ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
161            | ScalarExpression::Coalesce { exprs: args, .. } => {
162                for expr in args {
163                    self.visit_column_agg_expr(expr)?;
164                }
165            }
166            ScalarExpression::If {
167                condition,
168                left_expr,
169                right_expr,
170                ..
171            } => {
172                self.visit_column_agg_expr(condition)?;
173                self.visit_column_agg_expr(left_expr)?;
174                self.visit_column_agg_expr(right_expr)?;
175            }
176            ScalarExpression::IfNull {
177                left_expr,
178                right_expr,
179                ..
180            }
181            | ScalarExpression::NullIf {
182                left_expr,
183                right_expr,
184                ..
185            } => {
186                self.visit_column_agg_expr(left_expr)?;
187                self.visit_column_agg_expr(right_expr)?;
188            }
189            ScalarExpression::CaseWhen {
190                operand_expr,
191                expr_pairs,
192                else_expr,
193                ..
194            } => {
195                if let Some(expr) = operand_expr {
196                    self.visit_column_agg_expr(expr)?;
197                }
198                for (expr_1, expr_2) in expr_pairs {
199                    self.visit_column_agg_expr(expr_1)?;
200                    self.visit_column_agg_expr(expr_2)?;
201                }
202                if let Some(expr) = else_expr {
203                    self.visit_column_agg_expr(expr)?;
204                }
205            }
206            ScalarExpression::TableFunction(_) => unreachable!(),
207        }
208
209        Ok(())
210    }
211
212    /// Validate select exprs must appear in the GROUP BY clause or be used in
213    /// an aggregate function.
214    /// e.g. SELECT a,count(b) FROM t GROUP BY a. it's ok.
215    ///      SELECT a,b FROM t GROUP BY a.        it's error.
216    ///      SELECT a,count(b) FROM t GROUP BY b. it's error.
217    fn validate_groupby_illegal_column(
218        &mut self,
219        select_items: &[ScalarExpression],
220        groupby: &[ScalarExpression],
221    ) -> Result<(), DatabaseError> {
222        let mut group_raw_exprs = vec![];
223        for expr in groupby {
224            if let ScalarExpression::Alias { alias, .. } = expr {
225                let alias_expr = select_items.iter().find(|column| {
226                    if let ScalarExpression::Alias {
227                        alias: inner_alias, ..
228                    } = &column
229                    {
230                        alias == inner_alias
231                    } else {
232                        false
233                    }
234                });
235
236                if let Some(inner_expr) = alias_expr {
237                    group_raw_exprs.push(inner_expr);
238                }
239            } else {
240                group_raw_exprs.push(expr);
241            }
242        }
243        let mut group_raw_set: HashSet<&ScalarExpression, RandomState> =
244            HashSet::from_iter(group_raw_exprs.iter().copied());
245
246        for expr in select_items {
247            if expr.has_agg_call() {
248                continue;
249            }
250            group_raw_set.remove(expr);
251
252            if !group_raw_exprs.iter().contains(&expr) {
253                return Err(DatabaseError::AggMiss(format!(
254                    "`{}` must appear in the GROUP BY clause or be used in an aggregate function",
255                    expr
256                )));
257            }
258        }
259
260        if !group_raw_set.is_empty() {
261            return Err(DatabaseError::AggMiss(
262                "in the GROUP BY clause the field must be in the select clause".to_string(),
263            ));
264        }
265
266        Ok(())
267    }
268
269    fn visit_group_by_expr(
270        &mut self,
271        select_list: &mut [ScalarExpression],
272        expr: &mut ScalarExpression,
273    ) {
274        if let ScalarExpression::Alias { alias, .. } = expr {
275            if let Some(i) = select_list.iter().position(|inner_expr| {
276                if let ScalarExpression::Alias {
277                    alias: inner_alias, ..
278                } = &inner_expr
279                {
280                    alias == inner_alias
281                } else {
282                    false
283                }
284            }) {
285                self.context.group_by_exprs.push(select_list[i].clone());
286                return;
287            }
288        }
289
290        if let Some(i) = select_list.iter().position(|column| column == expr) {
291            self.context.group_by_exprs.push(select_list[i].clone())
292        }
293    }
294
295    /// Validate having or orderby clause is valid, if SQL has group by clause.
296    pub fn validate_having_orderby(&self, expr: &ScalarExpression) -> Result<(), DatabaseError> {
297        if self.context.group_by_exprs.is_empty() {
298            return Ok(());
299        }
300
301        match expr {
302            ScalarExpression::AggCall { .. } => {
303                if self.context.group_by_exprs.contains(expr)
304                    || self.context.agg_calls.contains(expr)
305                {
306                    return Ok(());
307                }
308
309                Err(DatabaseError::AggMiss(
310                    format!(
311                        "expression '{}' must appear in the GROUP BY clause or be used in an aggregate function",
312                        expr
313                    )
314                ))
315            }
316            ScalarExpression::ColumnRef { .. } | ScalarExpression::Alias { .. } => {
317                if self.context.group_by_exprs.contains(expr) {
318                    return Ok(());
319                }
320                if matches!(expr, ScalarExpression::Alias { .. }) {
321                    return self.validate_having_orderby(expr.unpack_alias_ref());
322                }
323
324                Err(DatabaseError::AggMiss(
325                    format!(
326                        "expression '{}' must appear in the GROUP BY clause or be used in an aggregate function",
327                        expr
328                    )
329                ))
330            }
331
332            ScalarExpression::TypeCast { expr, .. } => self.validate_having_orderby(expr),
333            ScalarExpression::IsNull { expr, .. } => self.validate_having_orderby(expr),
334            ScalarExpression::Unary { expr, .. } => self.validate_having_orderby(expr),
335            ScalarExpression::In { expr, args, .. } => {
336                self.validate_having_orderby(expr)?;
337                for arg in args {
338                    self.validate_having_orderby(arg)?;
339                }
340                Ok(())
341            }
342            ScalarExpression::Binary {
343                left_expr,
344                right_expr,
345                ..
346            } => {
347                self.validate_having_orderby(left_expr)?;
348                self.validate_having_orderby(right_expr)?;
349                Ok(())
350            }
351            ScalarExpression::Between {
352                expr,
353                left_expr,
354                right_expr,
355                ..
356            } => {
357                self.validate_having_orderby(expr)?;
358                self.validate_having_orderby(left_expr)?;
359                self.validate_having_orderby(right_expr)?;
360                Ok(())
361            }
362            ScalarExpression::SubString {
363                expr,
364                for_expr,
365                from_expr,
366            } => {
367                self.validate_having_orderby(expr)?;
368                if let Some(expr) = for_expr {
369                    self.validate_having_orderby(expr)?;
370                }
371                if let Some(expr) = from_expr {
372                    self.validate_having_orderby(expr)?;
373                }
374                Ok(())
375            }
376            ScalarExpression::Position { expr, in_expr } => {
377                self.validate_having_orderby(expr)?;
378                self.validate_having_orderby(in_expr)?;
379                Ok(())
380            }
381            ScalarExpression::Trim {
382                expr,
383                trim_what_expr,
384                ..
385            } => {
386                self.validate_having_orderby(expr)?;
387                if let Some(trim_what_expr) = trim_what_expr {
388                    self.validate_having_orderby(trim_what_expr)?;
389                }
390                Ok(())
391            }
392            ScalarExpression::Constant(_) => Ok(()),
393            ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
394            ScalarExpression::Tuple(args)
395            | ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
396            | ScalarExpression::Coalesce { exprs: args, .. } => {
397                for expr in args {
398                    self.validate_having_orderby(expr)?;
399                }
400                Ok(())
401            }
402            ScalarExpression::If {
403                condition,
404                left_expr,
405                right_expr,
406                ..
407            } => {
408                self.validate_having_orderby(condition)?;
409                self.validate_having_orderby(left_expr)?;
410                self.validate_having_orderby(right_expr)?;
411
412                Ok(())
413            }
414            ScalarExpression::IfNull {
415                left_expr,
416                right_expr,
417                ..
418            }
419            | ScalarExpression::NullIf {
420                left_expr,
421                right_expr,
422                ..
423            } => {
424                self.validate_having_orderby(left_expr)?;
425                self.validate_having_orderby(right_expr)?;
426
427                Ok(())
428            }
429            ScalarExpression::CaseWhen {
430                operand_expr,
431                expr_pairs,
432                else_expr,
433                ..
434            } => {
435                if let Some(expr) = operand_expr {
436                    self.validate_having_orderby(expr)?;
437                }
438                for (expr_1, expr_2) in expr_pairs {
439                    self.validate_having_orderby(expr_1)?;
440                    self.validate_having_orderby(expr_2)?;
441                }
442                if let Some(expr) = else_expr {
443                    self.validate_having_orderby(expr)?;
444                }
445
446                Ok(())
447            }
448            ScalarExpression::TableFunction(_) => unreachable!(),
449        }
450    }
451}