1use alloc::boxed::Box;
24use alloc::collections::{BTreeMap, BTreeSet};
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::AggregateOrdered { .. } => true,
65 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67 contains_aggregate(expr)
68 }
69 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70 Expr::Extract { source, .. } => contains_aggregate(source),
71 Expr::ScalarSubquery(_)
76 | Expr::Exists { .. }
77 | Expr::InSubquery { .. }
78 | Expr::WindowFunction { .. }
79 | Expr::Literal(_)
80 | Expr::Placeholder(_)
81 | Expr::Column(_) => false,
82 Expr::Array(items) => items.iter().any(contains_aggregate),
86 Expr::ArraySubscript { target, index } => {
87 contains_aggregate(target) || contains_aggregate(index)
88 }
89 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90 Expr::Case {
93 operand,
94 branches,
95 else_branch,
96 } => {
97 operand.as_deref().is_some_and(contains_aggregate)
98 || branches
99 .iter()
100 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101 || else_branch.as_deref().is_some_and(contains_aggregate)
102 }
103 }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107 matches!(
108 name.to_ascii_lowercase().as_str(),
109 "count"
110 | "count_star"
111 | "sum"
112 | "min"
113 | "max"
114 | "avg"
115 | "string_agg"
120 | "array_agg"
121 | "bool_and"
124 | "bool_or"
125 | "every"
126 )
127}
128
129#[derive(Debug, Default, Clone)]
131struct AggState {
132 count: i64,
133 sum_int: i64,
134 sum_float: f64,
135 extreme: Option<Value>,
136 use_float: bool,
137 items: Vec<Value>,
144 seen: BTreeSet<String>,
148 item_keys: Vec<Vec<Value>>,
152 separator: Option<String>,
158 bool_acc: Option<bool>,
162}
163
164#[derive(Debug, Clone)]
165struct AggSpec {
166 name: String, arg: Option<Expr>,
170 arg2: Option<Expr>,
176 distinct: bool,
179 order_by: Vec<spg_sql::ast::OrderBy>,
185}
186
187#[derive(Debug)]
190pub struct AggResult {
191 pub columns: Vec<ColumnSchema>,
192 pub rows: Vec<Row>,
193}
194
195#[allow(clippy::too_many_lines)]
198pub type CorrelatedEval<'a> = &'a dyn Fn(&Expr, &Row, &EvalContext<'_>) -> Result<Value, EvalError>;
205
206pub fn run(
207 stmt: &SelectStatement,
208 rows: &[&Row],
209 schema_cols: &[ColumnSchema],
210 table_alias: Option<&str>,
211 correlated_eval: Option<CorrelatedEval<'_>>,
212) -> Result<AggResult, EvalError> {
213 let ctx = EvalContext::new(schema_cols, table_alias);
214 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
215
216 let mut agg_specs: Vec<AggSpec> = Vec::new();
218 for item in &stmt.items {
219 if let SelectItem::Expr { expr, .. } = item {
220 collect_aggregates(expr, &mut agg_specs);
221 }
222 }
223 for o in &stmt.order_by {
224 collect_aggregates(&o.expr, &mut agg_specs);
225 }
226 if let Some(h) = &stmt.having {
227 collect_aggregates(h, &mut agg_specs);
228 }
229 validate_agg_arities(stmt, &agg_specs)?;
235
236 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
239 let mut key_order: Vec<String> = Vec::new();
240 if rows.is_empty() && group_exprs.is_empty() {
243 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
245 groups.insert(String::new(), (Vec::new(), init));
246 key_order.push(String::new());
247 }
248
249 for row in rows {
250 let group_vals: Vec<Value> = group_exprs
251 .iter()
252 .map(|g| eval::eval_expr(g, row, &ctx))
253 .collect::<Result<_, _>>()?;
254 let mut key_vals = group_vals.clone();
260 for (i, g) in group_exprs.iter().enumerate() {
261 if matches!(
262 eval::column_collation(g, &ctx),
263 Some(spg_storage::Collation::CaseInsensitive)
264 ) {
265 if let Value::Text(s) = &key_vals[i] {
266 key_vals[i] = Value::Text(s.to_ascii_lowercase());
267 }
268 }
269 }
270 let key = encode_key(&key_vals);
271 let entry = groups.entry(key.clone()).or_insert_with(|| {
272 key_order.push(key.clone());
273 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
274 (group_vals.clone(), init)
275 });
276 for (i, spec) in agg_specs.iter().enumerate() {
277 let arg_val = match &spec.arg {
278 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
280 };
281 let arg2_val = match &spec.arg2 {
287 None => None,
288 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
289 };
290 let order_keys = if spec.order_by.is_empty() {
293 None
294 } else {
295 let mut keys = Vec::with_capacity(spec.order_by.len());
296 for o in &spec.order_by {
297 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
298 }
299 Some(keys)
300 };
301 if spec.distinct {
306 let key = encode_key(core::slice::from_ref(&arg_val));
307 if !entry.1[i].seen.insert(key) {
308 continue;
309 }
310 }
311 update_state(
312 &mut entry.1[i],
313 &spec.name,
314 &arg_val,
315 arg2_val.as_ref(),
316 order_keys,
317 )?;
318 }
319 }
320
321 let group_types: Vec<DataType> = if rows.is_empty() {
323 group_exprs.iter().map(|_| DataType::Text).collect()
326 } else {
327 let probe = rows[0];
328 group_exprs
329 .iter()
330 .map(|g| {
331 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
332 })
333 .collect::<Result<_, _>>()?
334 };
335 let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
336 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
337 for (i, ty) in group_types.iter().enumerate() {
338 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
339 }
340 for (i, ty) in agg_types.iter().enumerate() {
341 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
342 }
343
344 let mut synth_rows: Vec<Row> = Vec::new();
346 for k in &key_order {
347 let (gvals, states) = &groups[k];
348 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
349 values.extend(gvals.iter().cloned());
350 for (i, st) in states.iter().enumerate() {
351 let st_sorted;
355 let st_final: &AggState =
356 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
357 let mut idx: Vec<usize> = (0..st.items.len()).collect();
358 let ob = &agg_specs[i].order_by;
359 idx.sort_by(|&x, &y| {
360 for (k, o) in ob.iter().enumerate() {
361 let cmp = crate::order_by_value_cmp(
362 o.desc,
363 o.nulls_first,
364 &st.item_keys[x][k],
365 &st.item_keys[y][k],
366 );
367 if cmp != core::cmp::Ordering::Equal {
368 return cmp;
369 }
370 }
371 core::cmp::Ordering::Equal
372 });
373 let mut sorted = st.clone();
374 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
375 st_sorted = sorted;
376 &st_sorted
377 } else {
378 st
379 };
380 values.push(finalize(&agg_specs[i].name, st_final));
381 }
382 synth_rows.push(Row::new(values));
383 }
384
385 let columns: Vec<ColumnSchema> = stmt
390 .items
391 .iter()
392 .map(|item| match item {
393 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
394 detail: "SELECT * with aggregates is not supported".into(),
395 }),
396 SelectItem::Expr { expr, alias } => {
397 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
398 let name = alias.clone().unwrap_or_else(|| expr.to_string());
399 Ok(ColumnSchema::new(
400 name,
401 agg_or_group_type(&rewritten, &synth_schema),
402 true,
403 ))
404 }
405 })
406 .collect::<Result<_, _>>()?;
407
408 let synth_ctx = EvalContext::new(&synth_schema, None);
413 let having_rewritten = stmt
414 .having
415 .as_ref()
416 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
417 let mut kept_synth: Vec<Row> = Vec::new();
418 let mut out_rows: Vec<Row> = Vec::new();
419 for srow in synth_rows {
420 if let Some(h) = &having_rewritten {
421 let cond = match correlated_eval {
422 Some(f) if crate::expr_has_subquery(h) => f(h, &srow, &synth_ctx)?,
423 _ => eval::eval_expr(h, &srow, &synth_ctx)?,
424 };
425 if !matches!(cond, Value::Bool(true)) {
426 continue;
427 }
428 }
429 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
430 for item in &stmt.items {
431 if let SelectItem::Expr { expr, .. } = item {
432 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
433 values.push(match correlated_eval {
434 Some(f) if crate::expr_has_subquery(&rewritten) => {
435 f(&rewritten, &srow, &synth_ctx)?
436 }
437 _ => eval::eval_expr(&rewritten, &srow, &synth_ctx)?,
438 });
439 }
440 }
441 kept_synth.push(srow);
442 out_rows.push(Row::new(values));
443 }
444
445 if !stmt.order_by.is_empty() {
448 let rewritten: Vec<Expr> = stmt
451 .order_by
452 .iter()
453 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
454 .collect();
455 let keys_meta: Vec<(bool, Option<bool>)> = stmt
456 .order_by
457 .iter()
458 .map(|o| (o.desc, o.nulls_first))
459 .collect();
460 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
461 .into_iter()
462 .zip(out_rows)
463 .map(|(s, o)| {
464 let mut keys = Vec::with_capacity(rewritten.len());
465 for e in &rewritten {
466 keys.push(match correlated_eval {
467 Some(f) if crate::expr_has_subquery(e) => f(e, &s, &synth_ctx)?,
468 _ => eval::eval_expr(e, &s, &synth_ctx)?,
469 });
470 }
471 Ok::<_, EvalError>((keys, o))
472 })
473 .collect::<Result<_, _>>()?;
474 tagged.sort_by(|a, b| {
475 use core::cmp::Ordering;
476 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
477 let (desc, nf) = keys_meta[i];
478 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
479 if cmp != Ordering::Equal {
480 return cmp;
481 }
482 }
483 Ordering::Equal
484 });
485 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
486 }
487
488 Ok(AggResult {
489 columns,
490 rows: out_rows,
491 })
492}
493
494fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
500 fn walk(e: &Expr) -> Result<(), EvalError> {
501 if let Expr::FunctionCall { name, args } = e {
502 let lower = name.to_ascii_lowercase();
503 let expected: Option<usize> = match lower.as_str() {
504 "count_star" => Some(0),
505 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
506 | "bool_and" | "bool_or" | "every" => Some(1),
510 "string_agg" => Some(2),
511 _ => None,
512 };
513 if let Some(want) = expected
514 && args.len() != want
515 {
516 return Err(EvalError::TypeMismatch {
517 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
518 });
519 }
520 for a in args {
521 walk(a)?;
522 }
523 } else if let Expr::Binary { lhs, rhs, .. } = e {
524 walk(lhs)?;
525 walk(rhs)?;
526 } else if let Expr::Unary { expr, .. }
527 | Expr::Cast { expr, .. }
528 | Expr::IsNull { expr, .. } = e
529 {
530 walk(expr)?;
531 }
532 Ok(())
533 }
534 for item in &stmt.items {
535 if let SelectItem::Expr { expr, .. } = item {
536 walk(expr)?;
537 }
538 }
539 for o in &stmt.order_by {
540 walk(&o.expr)?;
541 }
542 if let Some(h) = &stmt.having {
543 walk(h)?;
544 }
545 Ok(())
546}
547
548fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
549 match e {
550 Expr::AggregateOrdered {
553 call,
554 order_by,
555 distinct,
556 } => {
557 if let Expr::FunctionCall { name, args } = call.as_ref() {
558 let lower = name.to_ascii_lowercase();
559 if is_aggregate_name(&lower) {
560 let canonical = if lower == "every" {
561 "bool_and".to_string()
562 } else {
563 lower
564 };
565 let spec = AggSpec {
566 name: canonical,
567 arg: args.first().cloned(),
568 arg2: if name.eq_ignore_ascii_case("string_agg") {
569 args.get(1).cloned()
570 } else {
571 None
572 },
573 distinct: *distinct,
574 order_by: order_by.clone(),
575 };
576 if !out.iter().any(|s| {
577 s.name == spec.name
578 && s.arg == spec.arg
579 && s.arg2 == spec.arg2
580 && s.distinct == spec.distinct
581 && s.order_by == spec.order_by
582 }) {
583 out.push(spec);
584 }
585 return;
586 }
587 }
588 collect_aggregates(call, out);
589 for o in order_by {
590 collect_aggregates(&o.expr, out);
591 }
592 }
593 Expr::FunctionCall { name, args } => {
594 let lower = name.to_ascii_lowercase();
595 if is_aggregate_name(&lower) {
596 let arg = if lower == "count_star" {
597 None
598 } else {
599 args.first().cloned()
600 };
601 let arg2 = if lower == "string_agg" {
605 args.get(1).cloned()
606 } else {
607 None
608 };
609 let canonical = if lower == "every" {
613 "bool_and".to_string()
614 } else {
615 lower
616 };
617 let spec = AggSpec {
618 name: canonical,
619 arg: arg.clone(),
620 arg2: arg2.clone(),
621 distinct: false,
622 order_by: Vec::new(),
623 };
624 if !out.iter().any(|s| {
625 s.name == spec.name
626 && s.arg == spec.arg
627 && s.arg2 == spec.arg2
628 && !s.distinct
629 && s.order_by == spec.order_by
630 }) {
631 out.push(spec);
632 }
633 } else {
636 for a in args {
637 collect_aggregates(a, out);
638 }
639 }
640 }
641 Expr::Binary { lhs, rhs, .. } => {
642 collect_aggregates(lhs, out);
643 collect_aggregates(rhs, out);
644 }
645 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
646 collect_aggregates(expr, out);
647 }
648 Expr::Like { expr, pattern, .. } => {
649 collect_aggregates(expr, out);
650 collect_aggregates(pattern, out);
651 }
652 Expr::Extract { source, .. } => collect_aggregates(source, out),
653 Expr::ScalarSubquery(_)
656 | Expr::Exists { .. }
657 | Expr::InSubquery { .. }
658 | Expr::WindowFunction { .. }
659 | Expr::Literal(_)
660 | Expr::Placeholder(_)
661 | Expr::Column(_) => {}
662 Expr::Array(items) => {
665 for elem in items {
666 collect_aggregates(elem, out);
667 }
668 }
669 Expr::ArraySubscript { target, index } => {
670 collect_aggregates(target, out);
671 collect_aggregates(index, out);
672 }
673 Expr::AnyAll { expr, array, .. } => {
674 collect_aggregates(expr, out);
675 collect_aggregates(array, out);
676 }
677 Expr::Case {
678 operand,
679 branches,
680 else_branch,
681 } => {
682 if let Some(o) = operand {
683 collect_aggregates(o, out);
684 }
685 for (w, t) in branches {
686 collect_aggregates(w, out);
687 collect_aggregates(t, out);
688 }
689 if let Some(e) = else_branch {
690 collect_aggregates(e, out);
691 }
692 }
693 }
694}
695
696fn update_state(
697 st: &mut AggState,
698 name: &str,
699 v: &Value,
700 arg2: Option<&Value>,
701 order_keys: Option<Vec<Value>>,
702) -> Result<(), EvalError> {
703 let is_null = matches!(v, Value::Null);
704 match name {
705 "count_star" => st.count += 1,
706 "count" => {
707 if !is_null {
708 st.count += 1;
709 }
710 }
711 "sum" | "avg" => {
712 if is_null {
713 return Ok(());
714 }
715 st.count += 1;
716 match v {
717 Value::Int(n) => st.sum_int += i64::from(*n),
718 Value::BigInt(n) => st.sum_int += *n,
719 Value::Float(x) => {
720 st.use_float = true;
721 st.sum_float += *x;
722 }
723 other => {
724 return Err(EvalError::TypeMismatch {
725 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
726 });
727 }
728 }
729 }
730 "min" => {
731 if is_null {
732 return Ok(());
733 }
734 match &st.extreme {
735 None => st.extreme = Some(v.clone()),
736 Some(cur) => {
737 if value_cmp(v, cur) == core::cmp::Ordering::Less {
738 st.extreme = Some(v.clone());
739 }
740 }
741 }
742 }
743 "max" => {
744 if is_null {
745 return Ok(());
746 }
747 match &st.extreme {
748 None => st.extreme = Some(v.clone()),
749 Some(cur) => {
750 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
751 st.extreme = Some(v.clone());
752 }
753 }
754 }
755 }
756 "string_agg" => {
764 if let Some(sep) = arg2
765 && let Value::Text(s) = sep
766 {
767 st.separator = Some(s.clone());
768 }
769 if is_null {
770 return Ok(());
771 }
772 if let Value::Text(s) = v {
773 st.items.push(Value::Text(s.clone()));
774 if let Some(k) = order_keys {
775 st.item_keys.push(k);
776 }
777 st.count += 1;
778 } else {
779 return Err(EvalError::TypeMismatch {
780 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
781 });
782 }
783 }
784 "array_agg" => {
790 st.items.push(v.clone());
791 if let Some(k) = order_keys {
792 st.item_keys.push(k);
793 }
794 st.count += 1;
795 }
796 "bool_and" => {
800 if is_null {
801 return Ok(());
802 }
803 let b = match v {
804 Value::Bool(b) => *b,
805 other => {
806 return Err(EvalError::TypeMismatch {
807 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
808 });
809 }
810 };
811 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
812 }
813 "bool_or" => {
816 if is_null {
817 return Ok(());
818 }
819 let b = match v {
820 Value::Bool(b) => *b,
821 other => {
822 return Err(EvalError::TypeMismatch {
823 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
824 });
825 }
826 };
827 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
828 }
829 _ => unreachable!("non-aggregate {name} in update_state"),
830 }
831 Ok(())
832}
833
834#[allow(clippy::cast_precision_loss)]
835fn finalize(name: &str, st: &AggState) -> Value {
836 match name {
837 "count" | "count_star" => Value::BigInt(st.count),
838 "sum" => {
839 if st.count == 0 {
840 Value::Null
841 } else if st.use_float {
842 Value::Float(st.sum_float + (st.sum_int as f64))
843 } else {
844 Value::BigInt(st.sum_int)
845 }
846 }
847 "avg" => {
848 if st.count == 0 {
849 Value::Null
850 } else {
851 let total = if st.use_float {
852 st.sum_float + (st.sum_int as f64)
853 } else {
854 st.sum_int as f64
855 };
856 Value::Float(total / (st.count as f64))
857 }
858 }
859 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
860 "string_agg" => {
864 if st.items.is_empty() {
865 return Value::Null;
866 }
867 let sep = st.separator.clone().unwrap_or_default();
868 let mut out = String::new();
869 for (i, item) in st.items.iter().enumerate() {
870 if i > 0 {
871 out.push_str(&sep);
872 }
873 if let Value::Text(s) = item {
874 out.push_str(s);
875 }
876 }
877 Value::Text(out)
878 }
879 "array_agg" => {
886 if st.items.is_empty() {
887 return Value::Null;
888 }
889 let probe = st.items.iter().find(|v| !v.is_null());
890 match probe.and_then(spg_storage::Value::data_type) {
891 Some(DataType::Int) | Some(DataType::SmallInt) => {
892 let items: Vec<Option<i32>> = st
893 .items
894 .iter()
895 .map(|v| match v {
896 Value::Int(n) => Some(*n),
897 Value::SmallInt(n) => Some(i32::from(*n)),
898 _ => None,
899 })
900 .collect();
901 Value::IntArray(items)
902 }
903 Some(DataType::BigInt) => {
904 let items: Vec<Option<i64>> = st
905 .items
906 .iter()
907 .map(|v| match v {
908 Value::BigInt(n) => Some(*n),
909 _ => None,
910 })
911 .collect();
912 Value::BigIntArray(items)
913 }
914 _ => {
915 let items: Vec<Option<String>> = st
916 .items
917 .iter()
918 .map(|v| match v {
919 Value::Text(s) => Some(s.clone()),
920 Value::Null => None,
921 other => Some(format!("{other:?}")),
922 })
923 .collect();
924 Value::TextArray(items)
925 }
926 }
927 }
928 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
932 _ => unreachable!(),
933 }
934}
935
936fn infer_agg_type(spec: &AggSpec) -> DataType {
937 match spec.name.as_str() {
938 "count" | "count_star" | "sum" => DataType::BigInt,
942 "avg" => DataType::Float,
943 "string_agg" => DataType::Text,
945 "array_agg" => DataType::TextArray,
952 "bool_and" | "bool_or" => DataType::Bool,
955 _ => DataType::Text,
958 }
959}
960
961fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
962 if let Expr::Column(c) = e
963 && let Some(s) = synth.iter().find(|s| s.name == c.name)
964 {
965 return s.ty;
966 }
967 DataType::Text
970}
971
972fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
973 if let Expr::AggregateOrdered {
976 call,
977 order_by,
978 distinct,
979 } = e
980 && let Expr::FunctionCall { name, args } = call.as_ref()
981 {
982 let lower = name.to_ascii_lowercase();
983 if is_aggregate_name(&lower) {
984 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
985 let arg = args.first().cloned();
986 let arg2 = if lower == "string_agg" {
987 args.get(1).cloned()
988 } else {
989 None
990 };
991 for (i, spec) in aggs.iter().enumerate() {
992 if spec.name == canonical
993 && spec.arg == arg
994 && spec.arg2 == arg2
995 && spec.distinct == *distinct
996 && spec.order_by == *order_by
997 {
998 return Expr::Column(spg_sql::ast::ColumnName {
999 qualifier: None,
1000 name: format!("__agg_{i}"),
1001 });
1002 }
1003 }
1004 }
1005 }
1006 if let Expr::FunctionCall { name, args } = e {
1008 let lower = name.to_ascii_lowercase();
1009 if is_aggregate_name(&lower) {
1010 let arg = if lower == "count_star" {
1011 None
1012 } else {
1013 args.first().cloned()
1014 };
1015 let arg2 = if lower == "string_agg" {
1018 args.get(1).cloned()
1019 } else {
1020 None
1021 };
1022 let canonical: &str = if lower == "every" {
1026 "bool_and"
1027 } else {
1028 lower.as_str()
1029 };
1030 for (i, spec) in aggs.iter().enumerate() {
1031 if spec.name == canonical
1032 && spec.arg == arg
1033 && spec.arg2 == arg2
1034 && !spec.distinct
1035 && spec.order_by.is_empty()
1036 {
1037 return Expr::Column(spg_sql::ast::ColumnName {
1038 qualifier: None,
1039 name: format!("__agg_{i}"),
1040 });
1041 }
1042 }
1043 }
1044 }
1045 for (i, g) in group_exprs.iter().enumerate() {
1047 if g == e {
1048 return Expr::Column(spg_sql::ast::ColumnName {
1049 qualifier: None,
1050 name: format!("__grp_{i}"),
1051 });
1052 }
1053 }
1054 match e {
1056 Expr::AggregateOrdered {
1057 call,
1058 order_by,
1059 distinct,
1060 } => Expr::AggregateOrdered {
1061 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1062 distinct: *distinct,
1063 order_by: order_by
1064 .iter()
1065 .map(|o| spg_sql::ast::OrderBy {
1066 expr: rewrite_expr(&o.expr, group_exprs, aggs),
1067 desc: o.desc,
1068 nulls_first: o.nulls_first,
1069 })
1070 .collect(),
1071 },
1072 Expr::Binary { lhs, op, rhs } => Expr::Binary {
1073 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1074 op: *op,
1075 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1076 },
1077 Expr::Unary { op, expr } => Expr::Unary {
1078 op: *op,
1079 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1080 },
1081 Expr::Cast { expr, target } => Expr::Cast {
1082 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1083 target: *target,
1084 },
1085 Expr::IsNull { expr, negated } => Expr::IsNull {
1086 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1087 negated: *negated,
1088 },
1089 Expr::FunctionCall { name, args } => Expr::FunctionCall {
1090 name: name.clone(),
1091 args: args
1092 .iter()
1093 .map(|a| rewrite_expr(a, group_exprs, aggs))
1094 .collect(),
1095 },
1096 Expr::Like {
1097 expr,
1098 pattern,
1099 negated,
1100 case_insensitive,
1101 } => Expr::Like {
1102 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1103 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1104 negated: *negated,
1105 case_insensitive: *case_insensitive,
1106 },
1107 Expr::Extract { field, source } => Expr::Extract {
1108 field: *field,
1109 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1110 },
1111 Expr::ScalarSubquery(s) => {
1117 Expr::ScalarSubquery(Box::new(rewrite_group_keys_in_select(s, group_exprs)))
1118 }
1119 Expr::Exists { subquery, negated } => Expr::Exists {
1120 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1121 negated: *negated,
1122 },
1123 Expr::InSubquery {
1124 expr,
1125 subquery,
1126 negated,
1127 } => Expr::InSubquery {
1128 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1129 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1130 negated: *negated,
1131 },
1132 Expr::WindowFunction { .. } | Expr::Literal(_) | Expr::Placeholder(_) | Expr::Column(_) => {
1135 e.clone()
1136 }
1137 Expr::Array(items) => Expr::Array(
1139 items
1140 .iter()
1141 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1142 .collect(),
1143 ),
1144 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1145 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1146 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1147 },
1148 Expr::AnyAll {
1149 expr,
1150 op,
1151 array,
1152 is_any,
1153 } => Expr::AnyAll {
1154 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1155 op: *op,
1156 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1157 is_any: *is_any,
1158 },
1159 Expr::Case {
1160 operand,
1161 branches,
1162 else_branch,
1163 } => Expr::Case {
1164 operand: operand
1165 .as_deref()
1166 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1167 branches: branches
1168 .iter()
1169 .map(|(w, t)| {
1170 (
1171 rewrite_expr(w, group_exprs, aggs),
1172 rewrite_expr(t, group_exprs, aggs),
1173 )
1174 })
1175 .collect(),
1176 else_branch: else_branch
1177 .as_deref()
1178 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1179 },
1180 }
1181}
1182
1183fn rewrite_group_keys_in_select(
1188 s: &spg_sql::ast::SelectStatement,
1189 group_exprs: &[Expr],
1190) -> spg_sql::ast::SelectStatement {
1191 let mut out = s.clone();
1192 let _ = crate::walk_select_exprs_mut(&mut out, &mut |e| {
1193 *e = rewrite_expr(e, group_exprs, &[]);
1194 Ok(())
1195 });
1196 out
1197}
1198
1199fn encode_key(vals: &[Value]) -> String {
1201 let mut out = String::new();
1202 for v in vals {
1203 match v {
1204 Value::Null => out.push_str("N|"),
1205 Value::SmallInt(n) => {
1206 out.push('s');
1207 out.push_str(&n.to_string());
1208 out.push('|');
1209 }
1210 Value::Int(n) => {
1211 out.push('I');
1212 out.push_str(&n.to_string());
1213 out.push('|');
1214 }
1215 Value::BigInt(n) => {
1216 out.push('B');
1217 out.push_str(&n.to_string());
1218 out.push('|');
1219 }
1220 Value::Float(x) => {
1221 out.push('F');
1222 out.push_str(&x.to_string());
1223 out.push('|');
1224 }
1225 Value::Bool(b) => {
1226 out.push(if *b { 'T' } else { 'f' });
1227 out.push('|');
1228 }
1229 Value::Text(s) => {
1230 out.push('S');
1231 out.push_str(s);
1232 out.push('|');
1233 }
1234 Value::Vector(v) => {
1235 out.push('V');
1236 for x in v {
1237 out.push_str(&x.to_string());
1238 out.push(',');
1239 }
1240 out.push('|');
1241 }
1242 Value::Sq8Vector(q) => {
1248 out.push('Q');
1249 out.push_str(&q.min.to_string());
1250 out.push('@');
1251 out.push_str(&q.max.to_string());
1252 out.push(':');
1253 for b in &q.bytes {
1254 out.push_str(&b.to_string());
1255 out.push(',');
1256 }
1257 out.push('|');
1258 }
1259 Value::HalfVector(h) => {
1263 out.push('H');
1264 for b in &h.bytes {
1265 out.push_str(&b.to_string());
1266 out.push(',');
1267 }
1268 out.push('|');
1269 }
1270 Value::Numeric { scaled, scale } => {
1271 out.push('D');
1272 out.push_str(&scaled.to_string());
1273 out.push('@');
1274 out.push_str(&scale.to_string());
1275 out.push('|');
1276 }
1277 Value::Date(d) => {
1278 out.push('d');
1279 out.push_str(&d.to_string());
1280 out.push('|');
1281 }
1282 Value::Timestamp(t) => {
1283 out.push('t');
1284 out.push_str(&t.to_string());
1285 out.push('|');
1286 }
1287 Value::Interval { months, micros } => {
1288 out.push('i');
1289 out.push_str(&months.to_string());
1290 out.push('m');
1291 out.push_str(µs.to_string());
1292 out.push('|');
1293 }
1294 Value::Json(s) => {
1295 out.push('j');
1296 out.push_str(s);
1297 out.push('|');
1298 }
1299 _ => {
1304 out.push('?');
1305 out.push_str(&format!("{v:?}"));
1306 out.push('|');
1307 }
1308 }
1309 }
1310 out
1311}
1312
1313#[allow(clippy::cast_precision_loss)]
1314fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1315 use core::cmp::Ordering::Equal;
1316 match (a, b) {
1317 (Value::Null, Value::Null) => Equal,
1318 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
1320 (Value::Int(x), Value::Int(y)) => x.cmp(y),
1321 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1322 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1323 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1324 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1325 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1326 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1327 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1328 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1329 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1330 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1331 _ => Equal,
1332 }
1333}