1use ahash::RandomState;
2use itertools::Itertools;
3use sqlparser::ast::{Expr, OrderByExpr};
4use std::collections::HashSet;
5
6use super::{Binder, QueryBindStep};
7use crate::errors::DatabaseError;
8use crate::expression::function::scala::ScalarFunction;
9use crate::planner::LogicalPlan;
10use crate::storage::Transaction;
11use crate::types::value::DataValue;
12use crate::{
13 expression::ScalarExpression,
14 planner::operator::{aggregate::AggregateOperator, sort::SortField},
15};
16
17impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A> {
18 pub fn bind_aggregate(
19 &mut self,
20 children: LogicalPlan,
21 agg_calls: Vec<ScalarExpression>,
22 groupby_exprs: Vec<ScalarExpression>,
23 ) -> LogicalPlan {
24 self.context.step(QueryBindStep::Agg);
25
26 AggregateOperator::build(children, agg_calls, groupby_exprs, false)
27 }
28
29 pub fn extract_select_aggregate(
30 &mut self,
31 select_items: &mut [ScalarExpression],
32 ) -> Result<(), DatabaseError> {
33 for column in select_items {
34 self.visit_column_agg_expr(column)?;
35 }
36 Ok(())
37 }
38
39 pub fn extract_group_by_aggregate(
40 &mut self,
41 select_list: &mut [ScalarExpression],
42 groupby: &[Expr],
43 ) -> Result<(), DatabaseError> {
44 let mut group_by_exprs = Vec::with_capacity(groupby.len());
45 for expr in groupby.iter() {
46 group_by_exprs.push(self.bind_expr(expr)?);
47 }
48
49 self.validate_groupby_illegal_column(select_list, &group_by_exprs)?;
50
51 for expr in group_by_exprs.iter_mut() {
52 self.visit_group_by_expr(select_list, expr);
53 }
54 Ok(())
55 }
56
57 pub fn extract_having_orderby_aggregate(
58 &mut self,
59 having: &Option<Expr>,
60 orderbys: &[OrderByExpr],
61 ) -> Result<(Option<ScalarExpression>, Option<Vec<SortField>>), DatabaseError> {
62 let return_having = if let Some(having) = having {
64 let mut having = self.bind_expr(having)?;
65 self.visit_column_agg_expr(&mut having)?;
66
67 Some(having)
68 } else {
69 None
70 };
71
72 let return_orderby = if !orderbys.is_empty() {
74 let mut return_orderby = vec![];
75 for orderby in orderbys {
76 let OrderByExpr {
77 expr,
78 asc,
79 nulls_first,
80 } = orderby;
81 let mut expr = self.bind_expr(expr)?;
82 self.visit_column_agg_expr(&mut expr)?;
83
84 return_orderby.push(SortField::new(
85 expr,
86 asc.map_or(true, |asc| asc),
87 nulls_first.map_or(false, |first| first),
88 ));
89 }
90 Some(return_orderby)
91 } else {
92 None
93 };
94 Ok((return_having, return_orderby))
95 }
96
97 fn visit_column_agg_expr(&mut self, expr: &mut ScalarExpression) -> Result<(), DatabaseError> {
98 match expr {
99 ScalarExpression::AggCall { .. } => {
100 self.context.agg_calls.push(expr.clone());
101 }
102 ScalarExpression::TypeCast { expr, .. } => self.visit_column_agg_expr(expr)?,
103 ScalarExpression::IsNull { expr, .. } => self.visit_column_agg_expr(expr)?,
104 ScalarExpression::Unary { expr, .. } => self.visit_column_agg_expr(expr)?,
105 ScalarExpression::Alias { expr, .. } => self.visit_column_agg_expr(expr)?,
106 ScalarExpression::Binary {
107 left_expr,
108 right_expr,
109 ..
110 } => {
111 self.visit_column_agg_expr(left_expr)?;
112 self.visit_column_agg_expr(right_expr)?;
113 }
114 ScalarExpression::In { expr, args, .. } => {
115 self.visit_column_agg_expr(expr)?;
116 for arg in args {
117 self.visit_column_agg_expr(arg)?;
118 }
119 }
120 ScalarExpression::Between {
121 expr,
122 left_expr,
123 right_expr,
124 ..
125 } => {
126 self.visit_column_agg_expr(expr)?;
127 self.visit_column_agg_expr(left_expr)?;
128 self.visit_column_agg_expr(right_expr)?;
129 }
130 ScalarExpression::SubString {
131 expr,
132 for_expr,
133 from_expr,
134 } => {
135 self.visit_column_agg_expr(expr)?;
136 if let Some(expr) = for_expr {
137 self.visit_column_agg_expr(expr)?;
138 }
139 if let Some(expr) = from_expr {
140 self.visit_column_agg_expr(expr)?;
141 }
142 }
143 ScalarExpression::Position { expr, in_expr } => {
144 self.visit_column_agg_expr(expr)?;
145 self.visit_column_agg_expr(in_expr)?;
146 }
147 ScalarExpression::Trim {
148 expr,
149 trim_what_expr,
150 ..
151 } => {
152 self.visit_column_agg_expr(expr)?;
153 if let Some(trim_what_expr) = trim_what_expr {
154 self.visit_column_agg_expr(trim_what_expr)?;
155 }
156 }
157 ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
158 ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
159 ScalarExpression::Tuple(args)
160 | ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
161 | ScalarExpression::Coalesce { exprs: args, .. } => {
162 for expr in args {
163 self.visit_column_agg_expr(expr)?;
164 }
165 }
166 ScalarExpression::If {
167 condition,
168 left_expr,
169 right_expr,
170 ..
171 } => {
172 self.visit_column_agg_expr(condition)?;
173 self.visit_column_agg_expr(left_expr)?;
174 self.visit_column_agg_expr(right_expr)?;
175 }
176 ScalarExpression::IfNull {
177 left_expr,
178 right_expr,
179 ..
180 }
181 | ScalarExpression::NullIf {
182 left_expr,
183 right_expr,
184 ..
185 } => {
186 self.visit_column_agg_expr(left_expr)?;
187 self.visit_column_agg_expr(right_expr)?;
188 }
189 ScalarExpression::CaseWhen {
190 operand_expr,
191 expr_pairs,
192 else_expr,
193 ..
194 } => {
195 if let Some(expr) = operand_expr {
196 self.visit_column_agg_expr(expr)?;
197 }
198 for (expr_1, expr_2) in expr_pairs {
199 self.visit_column_agg_expr(expr_1)?;
200 self.visit_column_agg_expr(expr_2)?;
201 }
202 if let Some(expr) = else_expr {
203 self.visit_column_agg_expr(expr)?;
204 }
205 }
206 ScalarExpression::TableFunction(_) => unreachable!(),
207 }
208
209 Ok(())
210 }
211
212 fn validate_groupby_illegal_column(
218 &mut self,
219 select_items: &[ScalarExpression],
220 groupby: &[ScalarExpression],
221 ) -> Result<(), DatabaseError> {
222 let mut group_raw_exprs = vec![];
223 for expr in groupby {
224 if let ScalarExpression::Alias { alias, .. } = expr {
225 let alias_expr = select_items.iter().find(|column| {
226 if let ScalarExpression::Alias {
227 alias: inner_alias, ..
228 } = &column
229 {
230 alias == inner_alias
231 } else {
232 false
233 }
234 });
235
236 if let Some(inner_expr) = alias_expr {
237 group_raw_exprs.push(inner_expr);
238 }
239 } else {
240 group_raw_exprs.push(expr);
241 }
242 }
243 let mut group_raw_set: HashSet<&ScalarExpression, RandomState> =
244 HashSet::from_iter(group_raw_exprs.iter().copied());
245
246 for expr in select_items {
247 if expr.has_agg_call() {
248 continue;
249 }
250 group_raw_set.remove(expr);
251
252 if !group_raw_exprs.iter().contains(&expr) {
253 return Err(DatabaseError::AggMiss(format!(
254 "`{}` must appear in the GROUP BY clause or be used in an aggregate function",
255 expr
256 )));
257 }
258 }
259
260 if !group_raw_set.is_empty() {
261 return Err(DatabaseError::AggMiss(
262 "in the GROUP BY clause the field must be in the select clause".to_string(),
263 ));
264 }
265
266 Ok(())
267 }
268
269 fn visit_group_by_expr(
270 &mut self,
271 select_list: &mut [ScalarExpression],
272 expr: &mut ScalarExpression,
273 ) {
274 if let ScalarExpression::Alias { alias, .. } = expr {
275 if let Some(i) = select_list.iter().position(|inner_expr| {
276 if let ScalarExpression::Alias {
277 alias: inner_alias, ..
278 } = &inner_expr
279 {
280 alias == inner_alias
281 } else {
282 false
283 }
284 }) {
285 self.context.group_by_exprs.push(select_list[i].clone());
286 return;
287 }
288 }
289
290 if let Some(i) = select_list.iter().position(|column| column == expr) {
291 self.context.group_by_exprs.push(select_list[i].clone())
292 }
293 }
294
295 pub fn validate_having_orderby(&self, expr: &ScalarExpression) -> Result<(), DatabaseError> {
297 if self.context.group_by_exprs.is_empty() {
298 return Ok(());
299 }
300
301 match expr {
302 ScalarExpression::AggCall { .. } => {
303 if self.context.group_by_exprs.contains(expr)
304 || self.context.agg_calls.contains(expr)
305 {
306 return Ok(());
307 }
308
309 Err(DatabaseError::AggMiss(
310 format!(
311 "expression '{}' must appear in the GROUP BY clause or be used in an aggregate function",
312 expr
313 )
314 ))
315 }
316 ScalarExpression::ColumnRef { .. } | ScalarExpression::Alias { .. } => {
317 if self.context.group_by_exprs.contains(expr) {
318 return Ok(());
319 }
320 if matches!(expr, ScalarExpression::Alias { .. }) {
321 return self.validate_having_orderby(expr.unpack_alias_ref());
322 }
323
324 Err(DatabaseError::AggMiss(
325 format!(
326 "expression '{}' must appear in the GROUP BY clause or be used in an aggregate function",
327 expr
328 )
329 ))
330 }
331
332 ScalarExpression::TypeCast { expr, .. } => self.validate_having_orderby(expr),
333 ScalarExpression::IsNull { expr, .. } => self.validate_having_orderby(expr),
334 ScalarExpression::Unary { expr, .. } => self.validate_having_orderby(expr),
335 ScalarExpression::In { expr, args, .. } => {
336 self.validate_having_orderby(expr)?;
337 for arg in args {
338 self.validate_having_orderby(arg)?;
339 }
340 Ok(())
341 }
342 ScalarExpression::Binary {
343 left_expr,
344 right_expr,
345 ..
346 } => {
347 self.validate_having_orderby(left_expr)?;
348 self.validate_having_orderby(right_expr)?;
349 Ok(())
350 }
351 ScalarExpression::Between {
352 expr,
353 left_expr,
354 right_expr,
355 ..
356 } => {
357 self.validate_having_orderby(expr)?;
358 self.validate_having_orderby(left_expr)?;
359 self.validate_having_orderby(right_expr)?;
360 Ok(())
361 }
362 ScalarExpression::SubString {
363 expr,
364 for_expr,
365 from_expr,
366 } => {
367 self.validate_having_orderby(expr)?;
368 if let Some(expr) = for_expr {
369 self.validate_having_orderby(expr)?;
370 }
371 if let Some(expr) = from_expr {
372 self.validate_having_orderby(expr)?;
373 }
374 Ok(())
375 }
376 ScalarExpression::Position { expr, in_expr } => {
377 self.validate_having_orderby(expr)?;
378 self.validate_having_orderby(in_expr)?;
379 Ok(())
380 }
381 ScalarExpression::Trim {
382 expr,
383 trim_what_expr,
384 ..
385 } => {
386 self.validate_having_orderby(expr)?;
387 if let Some(trim_what_expr) = trim_what_expr {
388 self.validate_having_orderby(trim_what_expr)?;
389 }
390 Ok(())
391 }
392 ScalarExpression::Constant(_) => Ok(()),
393 ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
394 ScalarExpression::Tuple(args)
395 | ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
396 | ScalarExpression::Coalesce { exprs: args, .. } => {
397 for expr in args {
398 self.validate_having_orderby(expr)?;
399 }
400 Ok(())
401 }
402 ScalarExpression::If {
403 condition,
404 left_expr,
405 right_expr,
406 ..
407 } => {
408 self.validate_having_orderby(condition)?;
409 self.validate_having_orderby(left_expr)?;
410 self.validate_having_orderby(right_expr)?;
411
412 Ok(())
413 }
414 ScalarExpression::IfNull {
415 left_expr,
416 right_expr,
417 ..
418 }
419 | ScalarExpression::NullIf {
420 left_expr,
421 right_expr,
422 ..
423 } => {
424 self.validate_having_orderby(left_expr)?;
425 self.validate_having_orderby(right_expr)?;
426
427 Ok(())
428 }
429 ScalarExpression::CaseWhen {
430 operand_expr,
431 expr_pairs,
432 else_expr,
433 ..
434 } => {
435 if let Some(expr) = operand_expr {
436 self.validate_having_orderby(expr)?;
437 }
438 for (expr_1, expr_2) in expr_pairs {
439 self.validate_having_orderby(expr_1)?;
440 self.validate_having_orderby(expr_2)?;
441 }
442 if let Some(expr) = else_expr {
443 self.validate_having_orderby(expr)?;
444 }
445
446 Ok(())
447 }
448 ScalarExpression::TableFunction(_) => unreachable!(),
449 }
450 }
451}