1use alloc::boxed::Box;
24use alloc::collections::{BTreeMap, BTreeSet};
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::AggregateOrdered { .. } => true,
65 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67 contains_aggregate(expr)
68 }
69 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70 Expr::Extract { source, .. } => contains_aggregate(source),
71 Expr::ScalarSubquery(_)
76 | Expr::Exists { .. }
77 | Expr::InSubquery { .. }
78 | Expr::WindowFunction { .. }
79 | Expr::Literal(_)
80 | Expr::Placeholder(_)
81 | Expr::Column(_) => false,
82 Expr::Array(items) => items.iter().any(contains_aggregate),
86 Expr::ArraySubscript { target, index } => {
87 contains_aggregate(target) || contains_aggregate(index)
88 }
89 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90 Expr::Case {
93 operand,
94 branches,
95 else_branch,
96 } => {
97 operand.as_deref().is_some_and(contains_aggregate)
98 || branches
99 .iter()
100 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101 || else_branch.as_deref().is_some_and(contains_aggregate)
102 }
103 }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107 matches!(
108 name.to_ascii_lowercase().as_str(),
109 "count"
110 | "count_star"
111 | "sum"
112 | "min"
113 | "max"
114 | "avg"
115 | "string_agg"
120 | "array_agg"
121 | "bool_and"
124 | "bool_or"
125 | "every"
126 )
127}
128
129#[derive(Debug, Default, Clone)]
131struct AggState {
132 count: i64,
133 sum_int: i64,
134 sum_float: f64,
135 extreme: Option<Value>,
136 use_float: bool,
137 items: Vec<Value>,
144 seen: BTreeSet<String>,
148 item_keys: Vec<Vec<Value>>,
152 separator: Option<String>,
158 bool_acc: Option<bool>,
162}
163
164#[derive(Debug, Clone)]
165struct AggSpec {
166 name: String, arg: Option<Expr>,
170 arg2: Option<Expr>,
176 distinct: bool,
179 order_by: Vec<spg_sql::ast::OrderBy>,
185}
186
187#[derive(Debug)]
190pub struct AggResult {
191 pub columns: Vec<ColumnSchema>,
192 pub rows: Vec<Row>,
193}
194
195#[allow(clippy::too_many_lines)]
198pub fn run(
199 stmt: &SelectStatement,
200 rows: &[&Row],
201 schema_cols: &[ColumnSchema],
202 table_alias: Option<&str>,
203) -> Result<AggResult, EvalError> {
204 let ctx = EvalContext::new(schema_cols, table_alias);
205 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
206
207 let mut agg_specs: Vec<AggSpec> = Vec::new();
209 for item in &stmt.items {
210 if let SelectItem::Expr { expr, .. } = item {
211 collect_aggregates(expr, &mut agg_specs);
212 }
213 }
214 for o in &stmt.order_by {
215 collect_aggregates(&o.expr, &mut agg_specs);
216 }
217 if let Some(h) = &stmt.having {
218 collect_aggregates(h, &mut agg_specs);
219 }
220 validate_agg_arities(stmt, &agg_specs)?;
226
227 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
230 let mut key_order: Vec<String> = Vec::new();
231 if rows.is_empty() && group_exprs.is_empty() {
234 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
236 groups.insert(String::new(), (Vec::new(), init));
237 key_order.push(String::new());
238 }
239
240 for row in rows {
241 let group_vals: Vec<Value> = group_exprs
242 .iter()
243 .map(|g| eval::eval_expr(g, row, &ctx))
244 .collect::<Result<_, _>>()?;
245 let mut key_vals = group_vals.clone();
251 for (i, g) in group_exprs.iter().enumerate() {
252 if matches!(
253 eval::column_collation(g, &ctx),
254 Some(spg_storage::Collation::CaseInsensitive)
255 ) {
256 if let Value::Text(s) = &key_vals[i] {
257 key_vals[i] = Value::Text(s.to_ascii_lowercase());
258 }
259 }
260 }
261 let key = encode_key(&key_vals);
262 let entry = groups.entry(key.clone()).or_insert_with(|| {
263 key_order.push(key.clone());
264 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
265 (group_vals.clone(), init)
266 });
267 for (i, spec) in agg_specs.iter().enumerate() {
268 let arg_val = match &spec.arg {
269 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
271 };
272 let arg2_val = match &spec.arg2 {
278 None => None,
279 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
280 };
281 let order_keys = if spec.order_by.is_empty() {
284 None
285 } else {
286 let mut keys = Vec::with_capacity(spec.order_by.len());
287 for o in &spec.order_by {
288 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
289 }
290 Some(keys)
291 };
292 if spec.distinct {
297 let key = encode_key(core::slice::from_ref(&arg_val));
298 if !entry.1[i].seen.insert(key) {
299 continue;
300 }
301 }
302 update_state(
303 &mut entry.1[i],
304 &spec.name,
305 &arg_val,
306 arg2_val.as_ref(),
307 order_keys,
308 )?;
309 }
310 }
311
312 let group_types: Vec<DataType> = if rows.is_empty() {
314 group_exprs.iter().map(|_| DataType::Text).collect()
317 } else {
318 let probe = rows[0];
319 group_exprs
320 .iter()
321 .map(|g| {
322 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
323 })
324 .collect::<Result<_, _>>()?
325 };
326 let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
327 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
328 for (i, ty) in group_types.iter().enumerate() {
329 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
330 }
331 for (i, ty) in agg_types.iter().enumerate() {
332 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
333 }
334
335 let mut synth_rows: Vec<Row> = Vec::new();
337 for k in &key_order {
338 let (gvals, states) = &groups[k];
339 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
340 values.extend(gvals.iter().cloned());
341 for (i, st) in states.iter().enumerate() {
342 let st_sorted;
346 let st_final: &AggState =
347 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
348 let mut idx: Vec<usize> = (0..st.items.len()).collect();
349 let ob = &agg_specs[i].order_by;
350 idx.sort_by(|&x, &y| {
351 for (k, o) in ob.iter().enumerate() {
352 let cmp = crate::order_by_value_cmp(
353 o.desc,
354 o.nulls_first,
355 &st.item_keys[x][k],
356 &st.item_keys[y][k],
357 );
358 if cmp != core::cmp::Ordering::Equal {
359 return cmp;
360 }
361 }
362 core::cmp::Ordering::Equal
363 });
364 let mut sorted = st.clone();
365 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
366 st_sorted = sorted;
367 &st_sorted
368 } else {
369 st
370 };
371 values.push(finalize(&agg_specs[i].name, st_final));
372 }
373 synth_rows.push(Row::new(values));
374 }
375
376 let columns: Vec<ColumnSchema> = stmt
381 .items
382 .iter()
383 .map(|item| match item {
384 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
385 detail: "SELECT * with aggregates is not supported".into(),
386 }),
387 SelectItem::Expr { expr, alias } => {
388 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
389 let name = alias.clone().unwrap_or_else(|| expr.to_string());
390 Ok(ColumnSchema::new(
391 name,
392 agg_or_group_type(&rewritten, &synth_schema),
393 true,
394 ))
395 }
396 })
397 .collect::<Result<_, _>>()?;
398
399 let synth_ctx = EvalContext::new(&synth_schema, None);
404 let having_rewritten = stmt
405 .having
406 .as_ref()
407 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
408 let mut kept_synth: Vec<Row> = Vec::new();
409 let mut out_rows: Vec<Row> = Vec::new();
410 for srow in synth_rows {
411 if let Some(h) = &having_rewritten {
412 let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
413 if !matches!(cond, Value::Bool(true)) {
414 continue;
415 }
416 }
417 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
418 for item in &stmt.items {
419 if let SelectItem::Expr { expr, .. } = item {
420 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
421 values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
422 }
423 }
424 kept_synth.push(srow);
425 out_rows.push(Row::new(values));
426 }
427
428 if !stmt.order_by.is_empty() {
431 let rewritten: Vec<Expr> = stmt
434 .order_by
435 .iter()
436 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
437 .collect();
438 let keys_meta: Vec<(bool, Option<bool>)> = stmt
439 .order_by
440 .iter()
441 .map(|o| (o.desc, o.nulls_first))
442 .collect();
443 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
444 .into_iter()
445 .zip(out_rows)
446 .map(|(s, o)| {
447 let mut keys = Vec::with_capacity(rewritten.len());
448 for e in &rewritten {
449 keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
450 }
451 Ok::<_, EvalError>((keys, o))
452 })
453 .collect::<Result<_, _>>()?;
454 tagged.sort_by(|a, b| {
455 use core::cmp::Ordering;
456 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
457 let (desc, nf) = keys_meta[i];
458 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
459 if cmp != Ordering::Equal {
460 return cmp;
461 }
462 }
463 Ordering::Equal
464 });
465 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
466 }
467
468 Ok(AggResult {
469 columns,
470 rows: out_rows,
471 })
472}
473
474fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
480 fn walk(e: &Expr) -> Result<(), EvalError> {
481 if let Expr::FunctionCall { name, args } = e {
482 let lower = name.to_ascii_lowercase();
483 let expected: Option<usize> = match lower.as_str() {
484 "count_star" => Some(0),
485 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
486 | "bool_and" | "bool_or" | "every" => Some(1),
490 "string_agg" => Some(2),
491 _ => None,
492 };
493 if let Some(want) = expected
494 && args.len() != want
495 {
496 return Err(EvalError::TypeMismatch {
497 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
498 });
499 }
500 for a in args {
501 walk(a)?;
502 }
503 } else if let Expr::Binary { lhs, rhs, .. } = e {
504 walk(lhs)?;
505 walk(rhs)?;
506 } else if let Expr::Unary { expr, .. }
507 | Expr::Cast { expr, .. }
508 | Expr::IsNull { expr, .. } = e
509 {
510 walk(expr)?;
511 }
512 Ok(())
513 }
514 for item in &stmt.items {
515 if let SelectItem::Expr { expr, .. } = item {
516 walk(expr)?;
517 }
518 }
519 for o in &stmt.order_by {
520 walk(&o.expr)?;
521 }
522 if let Some(h) = &stmt.having {
523 walk(h)?;
524 }
525 Ok(())
526}
527
528fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
529 match e {
530 Expr::AggregateOrdered {
533 call,
534 order_by,
535 distinct,
536 } => {
537 if let Expr::FunctionCall { name, args } = call.as_ref() {
538 let lower = name.to_ascii_lowercase();
539 if is_aggregate_name(&lower) {
540 let canonical = if lower == "every" {
541 "bool_and".to_string()
542 } else {
543 lower
544 };
545 let spec = AggSpec {
546 name: canonical,
547 arg: args.first().cloned(),
548 arg2: if name.eq_ignore_ascii_case("string_agg") {
549 args.get(1).cloned()
550 } else {
551 None
552 },
553 distinct: *distinct,
554 order_by: order_by.clone(),
555 };
556 if !out.iter().any(|s| {
557 s.name == spec.name
558 && s.arg == spec.arg
559 && s.arg2 == spec.arg2
560 && s.distinct == spec.distinct
561 && s.order_by == spec.order_by
562 }) {
563 out.push(spec);
564 }
565 return;
566 }
567 }
568 collect_aggregates(call, out);
569 for o in order_by {
570 collect_aggregates(&o.expr, out);
571 }
572 }
573 Expr::FunctionCall { name, args } => {
574 let lower = name.to_ascii_lowercase();
575 if is_aggregate_name(&lower) {
576 let arg = if lower == "count_star" {
577 None
578 } else {
579 args.first().cloned()
580 };
581 let arg2 = if lower == "string_agg" {
585 args.get(1).cloned()
586 } else {
587 None
588 };
589 let canonical = if lower == "every" {
593 "bool_and".to_string()
594 } else {
595 lower
596 };
597 let spec = AggSpec {
598 name: canonical,
599 arg: arg.clone(),
600 arg2: arg2.clone(),
601 distinct: false,
602 order_by: Vec::new(),
603 };
604 if !out.iter().any(|s| {
605 s.name == spec.name
606 && s.arg == spec.arg
607 && s.arg2 == spec.arg2
608 && !s.distinct
609 && s.order_by == spec.order_by
610 }) {
611 out.push(spec);
612 }
613 } else {
616 for a in args {
617 collect_aggregates(a, out);
618 }
619 }
620 }
621 Expr::Binary { lhs, rhs, .. } => {
622 collect_aggregates(lhs, out);
623 collect_aggregates(rhs, out);
624 }
625 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
626 collect_aggregates(expr, out);
627 }
628 Expr::Like { expr, pattern, .. } => {
629 collect_aggregates(expr, out);
630 collect_aggregates(pattern, out);
631 }
632 Expr::Extract { source, .. } => collect_aggregates(source, out),
633 Expr::ScalarSubquery(_)
636 | Expr::Exists { .. }
637 | Expr::InSubquery { .. }
638 | Expr::WindowFunction { .. }
639 | Expr::Literal(_)
640 | Expr::Placeholder(_)
641 | Expr::Column(_) => {}
642 Expr::Array(items) => {
645 for elem in items {
646 collect_aggregates(elem, out);
647 }
648 }
649 Expr::ArraySubscript { target, index } => {
650 collect_aggregates(target, out);
651 collect_aggregates(index, out);
652 }
653 Expr::AnyAll { expr, array, .. } => {
654 collect_aggregates(expr, out);
655 collect_aggregates(array, out);
656 }
657 Expr::Case {
658 operand,
659 branches,
660 else_branch,
661 } => {
662 if let Some(o) = operand {
663 collect_aggregates(o, out);
664 }
665 for (w, t) in branches {
666 collect_aggregates(w, out);
667 collect_aggregates(t, out);
668 }
669 if let Some(e) = else_branch {
670 collect_aggregates(e, out);
671 }
672 }
673 }
674}
675
676fn update_state(
677 st: &mut AggState,
678 name: &str,
679 v: &Value,
680 arg2: Option<&Value>,
681 order_keys: Option<Vec<Value>>,
682) -> Result<(), EvalError> {
683 let is_null = matches!(v, Value::Null);
684 match name {
685 "count_star" => st.count += 1,
686 "count" => {
687 if !is_null {
688 st.count += 1;
689 }
690 }
691 "sum" | "avg" => {
692 if is_null {
693 return Ok(());
694 }
695 st.count += 1;
696 match v {
697 Value::Int(n) => st.sum_int += i64::from(*n),
698 Value::BigInt(n) => st.sum_int += *n,
699 Value::Float(x) => {
700 st.use_float = true;
701 st.sum_float += *x;
702 }
703 other => {
704 return Err(EvalError::TypeMismatch {
705 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
706 });
707 }
708 }
709 }
710 "min" => {
711 if is_null {
712 return Ok(());
713 }
714 match &st.extreme {
715 None => st.extreme = Some(v.clone()),
716 Some(cur) => {
717 if value_cmp(v, cur) == core::cmp::Ordering::Less {
718 st.extreme = Some(v.clone());
719 }
720 }
721 }
722 }
723 "max" => {
724 if is_null {
725 return Ok(());
726 }
727 match &st.extreme {
728 None => st.extreme = Some(v.clone()),
729 Some(cur) => {
730 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
731 st.extreme = Some(v.clone());
732 }
733 }
734 }
735 }
736 "string_agg" => {
744 if let Some(sep) = arg2
745 && let Value::Text(s) = sep
746 {
747 st.separator = Some(s.clone());
748 }
749 if is_null {
750 return Ok(());
751 }
752 if let Value::Text(s) = v {
753 st.items.push(Value::Text(s.clone()));
754 if let Some(k) = order_keys {
755 st.item_keys.push(k);
756 }
757 st.count += 1;
758 } else {
759 return Err(EvalError::TypeMismatch {
760 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
761 });
762 }
763 }
764 "array_agg" => {
770 st.items.push(v.clone());
771 if let Some(k) = order_keys {
772 st.item_keys.push(k);
773 }
774 st.count += 1;
775 }
776 "bool_and" => {
780 if is_null {
781 return Ok(());
782 }
783 let b = match v {
784 Value::Bool(b) => *b,
785 other => {
786 return Err(EvalError::TypeMismatch {
787 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
788 });
789 }
790 };
791 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
792 }
793 "bool_or" => {
796 if is_null {
797 return Ok(());
798 }
799 let b = match v {
800 Value::Bool(b) => *b,
801 other => {
802 return Err(EvalError::TypeMismatch {
803 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
804 });
805 }
806 };
807 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
808 }
809 _ => unreachable!("non-aggregate {name} in update_state"),
810 }
811 Ok(())
812}
813
814#[allow(clippy::cast_precision_loss)]
815fn finalize(name: &str, st: &AggState) -> Value {
816 match name {
817 "count" | "count_star" => Value::BigInt(st.count),
818 "sum" => {
819 if st.count == 0 {
820 Value::Null
821 } else if st.use_float {
822 Value::Float(st.sum_float + (st.sum_int as f64))
823 } else {
824 Value::BigInt(st.sum_int)
825 }
826 }
827 "avg" => {
828 if st.count == 0 {
829 Value::Null
830 } else {
831 let total = if st.use_float {
832 st.sum_float + (st.sum_int as f64)
833 } else {
834 st.sum_int as f64
835 };
836 Value::Float(total / (st.count as f64))
837 }
838 }
839 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
840 "string_agg" => {
844 if st.items.is_empty() {
845 return Value::Null;
846 }
847 let sep = st.separator.clone().unwrap_or_default();
848 let mut out = String::new();
849 for (i, item) in st.items.iter().enumerate() {
850 if i > 0 {
851 out.push_str(&sep);
852 }
853 if let Value::Text(s) = item {
854 out.push_str(s);
855 }
856 }
857 Value::Text(out)
858 }
859 "array_agg" => {
866 if st.items.is_empty() {
867 return Value::Null;
868 }
869 let probe = st.items.iter().find(|v| !v.is_null());
870 match probe.and_then(spg_storage::Value::data_type) {
871 Some(DataType::Int) | Some(DataType::SmallInt) => {
872 let items: Vec<Option<i32>> = st
873 .items
874 .iter()
875 .map(|v| match v {
876 Value::Int(n) => Some(*n),
877 Value::SmallInt(n) => Some(i32::from(*n)),
878 _ => None,
879 })
880 .collect();
881 Value::IntArray(items)
882 }
883 Some(DataType::BigInt) => {
884 let items: Vec<Option<i64>> = st
885 .items
886 .iter()
887 .map(|v| match v {
888 Value::BigInt(n) => Some(*n),
889 _ => None,
890 })
891 .collect();
892 Value::BigIntArray(items)
893 }
894 _ => {
895 let items: Vec<Option<String>> = st
896 .items
897 .iter()
898 .map(|v| match v {
899 Value::Text(s) => Some(s.clone()),
900 Value::Null => None,
901 other => Some(format!("{other:?}")),
902 })
903 .collect();
904 Value::TextArray(items)
905 }
906 }
907 }
908 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
912 _ => unreachable!(),
913 }
914}
915
916fn infer_agg_type(spec: &AggSpec) -> DataType {
917 match spec.name.as_str() {
918 "count" | "count_star" | "sum" => DataType::BigInt,
922 "avg" => DataType::Float,
923 "string_agg" => DataType::Text,
925 "array_agg" => DataType::TextArray,
932 "bool_and" | "bool_or" => DataType::Bool,
935 _ => DataType::Text,
938 }
939}
940
941fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
942 if let Expr::Column(c) = e
943 && let Some(s) = synth.iter().find(|s| s.name == c.name)
944 {
945 return s.ty;
946 }
947 DataType::Text
950}
951
952fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
953 if let Expr::AggregateOrdered {
956 call,
957 order_by,
958 distinct,
959 } = e
960 && let Expr::FunctionCall { name, args } = call.as_ref()
961 {
962 let lower = name.to_ascii_lowercase();
963 if is_aggregate_name(&lower) {
964 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
965 let arg = args.first().cloned();
966 let arg2 = if lower == "string_agg" {
967 args.get(1).cloned()
968 } else {
969 None
970 };
971 for (i, spec) in aggs.iter().enumerate() {
972 if spec.name == canonical
973 && spec.arg == arg
974 && spec.arg2 == arg2
975 && spec.distinct == *distinct
976 && spec.order_by == *order_by
977 {
978 return Expr::Column(spg_sql::ast::ColumnName {
979 qualifier: None,
980 name: format!("__agg_{i}"),
981 });
982 }
983 }
984 }
985 }
986 if let Expr::FunctionCall { name, args } = e {
988 let lower = name.to_ascii_lowercase();
989 if is_aggregate_name(&lower) {
990 let arg = if lower == "count_star" {
991 None
992 } else {
993 args.first().cloned()
994 };
995 let arg2 = if lower == "string_agg" {
998 args.get(1).cloned()
999 } else {
1000 None
1001 };
1002 let canonical: &str = if lower == "every" {
1006 "bool_and"
1007 } else {
1008 lower.as_str()
1009 };
1010 for (i, spec) in aggs.iter().enumerate() {
1011 if spec.name == canonical
1012 && spec.arg == arg
1013 && spec.arg2 == arg2
1014 && !spec.distinct
1015 && spec.order_by.is_empty()
1016 {
1017 return Expr::Column(spg_sql::ast::ColumnName {
1018 qualifier: None,
1019 name: format!("__agg_{i}"),
1020 });
1021 }
1022 }
1023 }
1024 }
1025 for (i, g) in group_exprs.iter().enumerate() {
1027 if g == e {
1028 return Expr::Column(spg_sql::ast::ColumnName {
1029 qualifier: None,
1030 name: format!("__grp_{i}"),
1031 });
1032 }
1033 }
1034 match e {
1036 Expr::AggregateOrdered {
1037 call,
1038 order_by,
1039 distinct,
1040 } => Expr::AggregateOrdered {
1041 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1042 distinct: *distinct,
1043 order_by: order_by
1044 .iter()
1045 .map(|o| spg_sql::ast::OrderBy {
1046 expr: rewrite_expr(&o.expr, group_exprs, aggs),
1047 desc: o.desc,
1048 nulls_first: o.nulls_first,
1049 })
1050 .collect(),
1051 },
1052 Expr::Binary { lhs, op, rhs } => Expr::Binary {
1053 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1054 op: *op,
1055 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1056 },
1057 Expr::Unary { op, expr } => Expr::Unary {
1058 op: *op,
1059 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1060 },
1061 Expr::Cast { expr, target } => Expr::Cast {
1062 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1063 target: *target,
1064 },
1065 Expr::IsNull { expr, negated } => Expr::IsNull {
1066 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1067 negated: *negated,
1068 },
1069 Expr::FunctionCall { name, args } => Expr::FunctionCall {
1070 name: name.clone(),
1071 args: args
1072 .iter()
1073 .map(|a| rewrite_expr(a, group_exprs, aggs))
1074 .collect(),
1075 },
1076 Expr::Like {
1077 expr,
1078 pattern,
1079 negated,
1080 case_insensitive,
1081 } => Expr::Like {
1082 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1083 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1084 negated: *negated,
1085 case_insensitive: *case_insensitive,
1086 },
1087 Expr::Extract { field, source } => Expr::Extract {
1088 field: *field,
1089 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1090 },
1091 Expr::ScalarSubquery(_)
1094 | Expr::Exists { .. }
1095 | Expr::InSubquery { .. }
1096 | Expr::WindowFunction { .. }
1097 | Expr::Literal(_)
1098 | Expr::Placeholder(_)
1099 | Expr::Column(_) => e.clone(),
1100 Expr::Array(items) => Expr::Array(
1102 items
1103 .iter()
1104 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1105 .collect(),
1106 ),
1107 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1108 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1109 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1110 },
1111 Expr::AnyAll {
1112 expr,
1113 op,
1114 array,
1115 is_any,
1116 } => Expr::AnyAll {
1117 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1118 op: *op,
1119 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1120 is_any: *is_any,
1121 },
1122 Expr::Case {
1123 operand,
1124 branches,
1125 else_branch,
1126 } => Expr::Case {
1127 operand: operand
1128 .as_deref()
1129 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1130 branches: branches
1131 .iter()
1132 .map(|(w, t)| {
1133 (
1134 rewrite_expr(w, group_exprs, aggs),
1135 rewrite_expr(t, group_exprs, aggs),
1136 )
1137 })
1138 .collect(),
1139 else_branch: else_branch
1140 .as_deref()
1141 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1142 },
1143 }
1144}
1145
1146fn encode_key(vals: &[Value]) -> String {
1148 let mut out = String::new();
1149 for v in vals {
1150 match v {
1151 Value::Null => out.push_str("N|"),
1152 Value::SmallInt(n) => {
1153 out.push('s');
1154 out.push_str(&n.to_string());
1155 out.push('|');
1156 }
1157 Value::Int(n) => {
1158 out.push('I');
1159 out.push_str(&n.to_string());
1160 out.push('|');
1161 }
1162 Value::BigInt(n) => {
1163 out.push('B');
1164 out.push_str(&n.to_string());
1165 out.push('|');
1166 }
1167 Value::Float(x) => {
1168 out.push('F');
1169 out.push_str(&x.to_string());
1170 out.push('|');
1171 }
1172 Value::Bool(b) => {
1173 out.push(if *b { 'T' } else { 'f' });
1174 out.push('|');
1175 }
1176 Value::Text(s) => {
1177 out.push('S');
1178 out.push_str(s);
1179 out.push('|');
1180 }
1181 Value::Vector(v) => {
1182 out.push('V');
1183 for x in v {
1184 out.push_str(&x.to_string());
1185 out.push(',');
1186 }
1187 out.push('|');
1188 }
1189 Value::Sq8Vector(q) => {
1195 out.push('Q');
1196 out.push_str(&q.min.to_string());
1197 out.push('@');
1198 out.push_str(&q.max.to_string());
1199 out.push(':');
1200 for b in &q.bytes {
1201 out.push_str(&b.to_string());
1202 out.push(',');
1203 }
1204 out.push('|');
1205 }
1206 Value::HalfVector(h) => {
1210 out.push('H');
1211 for b in &h.bytes {
1212 out.push_str(&b.to_string());
1213 out.push(',');
1214 }
1215 out.push('|');
1216 }
1217 Value::Numeric { scaled, scale } => {
1218 out.push('D');
1219 out.push_str(&scaled.to_string());
1220 out.push('@');
1221 out.push_str(&scale.to_string());
1222 out.push('|');
1223 }
1224 Value::Date(d) => {
1225 out.push('d');
1226 out.push_str(&d.to_string());
1227 out.push('|');
1228 }
1229 Value::Timestamp(t) => {
1230 out.push('t');
1231 out.push_str(&t.to_string());
1232 out.push('|');
1233 }
1234 Value::Interval { months, micros } => {
1235 out.push('i');
1236 out.push_str(&months.to_string());
1237 out.push('m');
1238 out.push_str(µs.to_string());
1239 out.push('|');
1240 }
1241 Value::Json(s) => {
1242 out.push('j');
1243 out.push_str(s);
1244 out.push('|');
1245 }
1246 _ => {
1251 out.push('?');
1252 out.push_str(&format!("{v:?}"));
1253 out.push('|');
1254 }
1255 }
1256 }
1257 out
1258}
1259
1260#[allow(clippy::cast_precision_loss)]
1261fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1262 use core::cmp::Ordering::Equal;
1263 match (a, b) {
1264 (Value::Null, Value::Null) => Equal,
1265 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
1267 (Value::Int(x), Value::Int(y)) => x.cmp(y),
1268 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1269 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1270 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1271 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1272 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1273 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1274 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1275 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1276 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1277 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1278 _ => Equal,
1279 }
1280}