1use alloc::boxed::Box;
24use alloc::collections::BTreeSet;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::AggregateOrdered { .. } => true,
65 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67 contains_aggregate(expr)
68 }
69 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70 Expr::Extract { source, .. } => contains_aggregate(source),
71 Expr::ScalarSubquery(_)
76 | Expr::Exists { .. }
77 | Expr::InSubquery { .. }
78 | Expr::WindowFunction { .. }
79 | Expr::Literal(_)
80 | Expr::Placeholder(_)
81 | Expr::Column(_) => false,
82 Expr::Array(items) => items.iter().any(contains_aggregate),
86 Expr::ArraySubscript { target, index } => {
87 contains_aggregate(target) || contains_aggregate(index)
88 }
89 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90 Expr::InList { expr, list, .. } => {
91 contains_aggregate(expr) || list.iter().any(contains_aggregate)
92 }
93 Expr::Case {
96 operand,
97 branches,
98 else_branch,
99 } => {
100 operand.as_deref().is_some_and(contains_aggregate)
101 || branches
102 .iter()
103 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
104 || else_branch.as_deref().is_some_and(contains_aggregate)
105 }
106 }
107}
108
109pub fn is_aggregate_name(name: &str) -> bool {
110 matches!(
111 name.to_ascii_lowercase().as_str(),
112 "count"
113 | "count_star"
114 | "sum"
115 | "min"
116 | "max"
117 | "avg"
118 | "string_agg"
123 | "array_agg"
124 | "bool_and"
127 | "bool_or"
128 | "every"
129 )
130}
131
132#[derive(Debug, Default, Clone)]
134struct AggState {
135 count: i64,
136 sum_int: i64,
137 sum_float: f64,
138 extreme: Option<Value>,
139 use_float: bool,
140 items: Vec<Value>,
147 seen: BTreeSet<String>,
151 item_keys: Vec<Vec<Value>>,
155 separator: Option<String>,
161 bool_acc: Option<bool>,
165}
166
167#[derive(Debug, Clone)]
168struct AggSpec {
169 name: String, arg: Option<Expr>,
173 arg2: Option<Expr>,
179 distinct: bool,
182 order_by: Vec<spg_sql::ast::OrderBy>,
188}
189
190#[derive(Debug)]
193pub struct AggResult {
194 pub columns: Vec<ColumnSchema>,
195 pub rows: Vec<Row>,
196}
197
198#[allow(clippy::too_many_lines)]
201pub type CorrelatedEval<'a> = &'a dyn Fn(&Expr, &Row, &EvalContext<'_>) -> Result<Value, EvalError>;
208
209pub fn run(
210 stmt: &SelectStatement,
211 rows: &[&Row],
212 schema_cols: &[ColumnSchema],
213 table_alias: Option<&str>,
214 correlated_eval: Option<CorrelatedEval<'_>>,
215) -> Result<AggResult, EvalError> {
216 let ctx = EvalContext::new(schema_cols, table_alias);
217 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
218
219 let mut agg_specs: Vec<AggSpec> = Vec::new();
221 for item in &stmt.items {
222 if let SelectItem::Expr { expr, .. } = item {
223 collect_aggregates(expr, &mut agg_specs);
224 }
225 }
226 for o in &stmt.order_by {
227 collect_aggregates(&o.expr, &mut agg_specs);
228 }
229 if let Some(h) = &stmt.having {
230 collect_aggregates(h, &mut agg_specs);
231 }
232 validate_agg_arities(stmt, &agg_specs)?;
238
239 let mut groups: hashbrown::HashMap<String, (Vec<Value>, Vec<AggState>)> =
243 hashbrown::HashMap::new();
244 let mut key_order: Vec<String> = Vec::new();
245 if rows.is_empty() && group_exprs.is_empty() {
248 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
250 groups.insert(String::new(), (Vec::new(), init));
251 key_order.push(String::new());
252 }
253
254 let col_pos = |e: &Expr| -> Option<usize> {
265 if let Expr::Column(c) = e
268 && c.qualifier.is_some()
269 {
270 eval::find_column_pos(c, &ctx)
271 } else {
272 None
273 }
274 };
275 let group_pos: Vec<Option<usize>> = group_exprs.iter().map(col_pos).collect();
276 let all_groups_bound = group_pos.iter().all(Option::is_some);
277 let arg_pos: Vec<Option<usize>> = agg_specs
278 .iter()
279 .map(|spec| spec.arg.as_ref().and_then(|e| col_pos(e)))
280 .collect();
281 let ci_positions: Vec<usize> = group_exprs
282 .iter()
283 .enumerate()
284 .filter(|(_, g)| {
285 matches!(
286 eval::column_collation(g, &ctx),
287 Some(spg_storage::Collation::CaseInsensitive)
288 )
289 })
290 .map(|(i, _)| i)
291 .collect();
292 for row in rows {
293 if all_groups_bound && ci_positions.is_empty() && !group_exprs.is_empty() {
297 let refs: Vec<&Value> = group_pos
298 .iter()
299 .map(|p| row.values.get(p.unwrap()).unwrap_or(&Value::Null))
300 .collect();
301 let key = encode_key_refs(&refs);
302 let entry = match groups.entry_ref(key.as_str()) {
303 hashbrown::hash_map::EntryRef::Occupied(o) => o.into_mut(),
304 hashbrown::hash_map::EntryRef::Vacant(v) => {
305 key_order.push(key.clone());
306 let init: Vec<AggState> =
307 (0..agg_specs.len()).map(|_| AggState::default()).collect();
308 let owned: Vec<Value> = refs.iter().map(|v| (*v).clone()).collect();
309 v.insert((owned, init))
310 }
311 };
312 for (i, spec) in agg_specs.iter().enumerate() {
313 let arg_owned: Value;
314 let arg_ref: &Value = match (&arg_pos[i], &spec.arg) {
315 (Some(p), _) => row.values.get(*p).unwrap_or(&Value::Null),
316 (None, None) => {
317 arg_owned = Value::Bool(true);
318 &arg_owned
319 }
320 (None, Some(e)) => {
321 arg_owned = eval::eval_expr(e, row, &ctx)?;
322 &arg_owned
323 }
324 };
325 let arg2_val = match &spec.arg2 {
326 None => None,
327 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
328 };
329 let order_keys = if spec.order_by.is_empty() {
330 None
331 } else {
332 let mut keys = Vec::with_capacity(spec.order_by.len());
333 for o in &spec.order_by {
334 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
335 }
336 Some(keys)
337 };
338 if spec.distinct {
339 let dkey = encode_key_refs(core::slice::from_ref(&arg_ref));
340 if !entry.1[i].seen.insert(dkey) {
341 continue;
342 }
343 }
344 update_state(
345 &mut entry.1[i],
346 &spec.name,
347 arg_ref,
348 arg2_val.as_ref(),
349 order_keys,
350 )?;
351 }
352 continue;
353 }
354 let group_vals: Vec<Value> = group_exprs
355 .iter()
356 .map(|g| eval::eval_expr(g, row, &ctx))
357 .collect::<Result<_, _>>()?;
358 let key = if ci_positions.is_empty() {
362 encode_key(&group_vals)
363 } else {
364 let mut key_vals = group_vals.clone();
365 for &i in &ci_positions {
366 if let Value::Text(s) = &key_vals[i] {
367 key_vals[i] = Value::Text(s.to_ascii_lowercase());
368 }
369 }
370 encode_key(&key_vals)
371 };
372 let entry = match groups.entry_ref(key.as_str()) {
374 hashbrown::hash_map::EntryRef::Occupied(o) => o.into_mut(),
375 hashbrown::hash_map::EntryRef::Vacant(v) => {
376 key_order.push(key.clone());
377 let init: Vec<AggState> =
378 (0..agg_specs.len()).map(|_| AggState::default()).collect();
379 v.insert((group_vals.clone(), init))
380 }
381 };
382 for (i, spec) in agg_specs.iter().enumerate() {
383 let arg_val = match &spec.arg {
384 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
386 };
387 let arg2_val = match &spec.arg2 {
393 None => None,
394 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
395 };
396 let order_keys = if spec.order_by.is_empty() {
399 None
400 } else {
401 let mut keys = Vec::with_capacity(spec.order_by.len());
402 for o in &spec.order_by {
403 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
404 }
405 Some(keys)
406 };
407 if spec.distinct {
412 let key = encode_key(core::slice::from_ref(&arg_val));
413 if !entry.1[i].seen.insert(key) {
414 continue;
415 }
416 }
417 update_state(
418 &mut entry.1[i],
419 &spec.name,
420 &arg_val,
421 arg2_val.as_ref(),
422 order_keys,
423 )?;
424 }
425 }
426
427 let group_types: Vec<DataType> = if rows.is_empty() {
429 group_exprs.iter().map(|_| DataType::Text).collect()
432 } else {
433 let probe = rows[0];
434 group_exprs
435 .iter()
436 .map(|g| {
437 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
438 })
439 .collect::<Result<_, _>>()?
440 };
441 let agg_types: Vec<DataType> = agg_specs
442 .iter()
443 .map(|spec| infer_agg_type(spec, schema_cols))
444 .collect();
445 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
446 for (i, ty) in group_types.iter().enumerate() {
447 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
448 }
449 for (i, ty) in agg_types.iter().enumerate() {
450 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
451 }
452
453 let mut synth_rows: Vec<Row> = Vec::new();
455 for k in &key_order {
456 let (gvals, states) = &groups[k];
457 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
458 values.extend(gvals.iter().cloned());
459 for (i, st) in states.iter().enumerate() {
460 let st_sorted;
464 let st_final: &AggState =
465 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
466 let mut idx: Vec<usize> = (0..st.items.len()).collect();
467 let ob = &agg_specs[i].order_by;
468 idx.sort_by(|&x, &y| {
469 for (k, o) in ob.iter().enumerate() {
470 let cmp = crate::order_by_value_cmp(
471 o.desc,
472 o.nulls_first,
473 &st.item_keys[x][k],
474 &st.item_keys[y][k],
475 );
476 if cmp != core::cmp::Ordering::Equal {
477 return cmp;
478 }
479 }
480 core::cmp::Ordering::Equal
481 });
482 let mut sorted = st.clone();
483 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
484 st_sorted = sorted;
485 &st_sorted
486 } else {
487 st
488 };
489 values.push(finalize(&agg_specs[i].name, st_final));
490 }
491 synth_rows.push(Row::new(values));
492 }
493
494 let columns: Vec<ColumnSchema> = stmt
499 .items
500 .iter()
501 .map(|item| match item {
502 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
503 detail: "SELECT * with aggregates is not supported".into(),
504 }),
505 SelectItem::Expr { expr, alias } => {
506 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
507 let name = alias.clone().unwrap_or_else(|| expr.to_string());
508 Ok(ColumnSchema::new(
509 name,
510 agg_or_group_type(&rewritten, &synth_schema),
511 true,
512 ))
513 }
514 })
515 .collect::<Result<_, _>>()?;
516
517 let synth_ctx = EvalContext::new(&synth_schema, None);
522 let having_rewritten = stmt
523 .having
524 .as_ref()
525 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
526 let items_rewritten: alloc::vec::Vec<Option<Expr>> = stmt
532 .items
533 .iter()
534 .map(|item| match item {
535 SelectItem::Expr { expr, .. } => Some(rewrite_expr(expr, &group_exprs, &agg_specs)),
536 SelectItem::Wildcard => None,
537 })
538 .collect();
539 let mut kept_synth: Vec<Row> = Vec::new();
540 let mut out_rows: Vec<Row> = Vec::new();
541 for srow in synth_rows {
542 if let Some(h) = &having_rewritten {
543 let cond = match correlated_eval {
544 Some(f) if crate::expr_has_subquery(h) => f(h, &srow, &synth_ctx)?,
545 _ => eval::eval_expr(h, &srow, &synth_ctx)?,
546 };
547 if !matches!(cond, Value::Bool(true)) {
548 continue;
549 }
550 }
551 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
552 for rewritten in items_rewritten.iter().flatten() {
553 values.push(match correlated_eval {
554 Some(f) if crate::expr_has_subquery(rewritten) => f(rewritten, &srow, &synth_ctx)?,
555 _ => eval::eval_expr(rewritten, &srow, &synth_ctx)?,
556 });
557 }
558 kept_synth.push(srow);
559 out_rows.push(Row::new(values));
560 }
561
562 if !stmt.order_by.is_empty() {
565 let rewritten: Vec<Expr> = stmt
568 .order_by
569 .iter()
570 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
571 .collect();
572 let keys_meta: Vec<(bool, Option<bool>)> = stmt
573 .order_by
574 .iter()
575 .map(|o| (o.desc, o.nulls_first))
576 .collect();
577 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
578 .into_iter()
579 .zip(out_rows)
580 .map(|(s, o)| {
581 let mut keys = Vec::with_capacity(rewritten.len());
582 for e in &rewritten {
583 keys.push(match correlated_eval {
584 Some(f) if crate::expr_has_subquery(e) => f(e, &s, &synth_ctx)?,
585 _ => eval::eval_expr(e, &s, &synth_ctx)?,
586 });
587 }
588 Ok::<_, EvalError>((keys, o))
589 })
590 .collect::<Result<_, _>>()?;
591 tagged.sort_by(|a, b| {
592 use core::cmp::Ordering;
593 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
594 let (desc, nf) = keys_meta[i];
595 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
596 if cmp != Ordering::Equal {
597 return cmp;
598 }
599 }
600 Ordering::Equal
601 });
602 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
603 }
604
605 Ok(AggResult {
606 columns,
607 rows: out_rows,
608 })
609}
610
611fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
617 fn walk(e: &Expr) -> Result<(), EvalError> {
618 if let Expr::FunctionCall { name, args } = e {
619 let lower = name.to_ascii_lowercase();
620 let expected: Option<usize> = match lower.as_str() {
621 "count_star" => Some(0),
622 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
623 | "bool_and" | "bool_or" | "every" => Some(1),
627 "string_agg" => Some(2),
628 _ => None,
629 };
630 if let Some(want) = expected
631 && args.len() != want
632 {
633 return Err(EvalError::TypeMismatch {
634 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
635 });
636 }
637 for a in args {
638 walk(a)?;
639 }
640 } else if let Expr::Binary { lhs, rhs, .. } = e {
641 walk(lhs)?;
642 walk(rhs)?;
643 } else if let Expr::Unary { expr, .. }
644 | Expr::Cast { expr, .. }
645 | Expr::IsNull { expr, .. } = e
646 {
647 walk(expr)?;
648 }
649 Ok(())
650 }
651 for item in &stmt.items {
652 if let SelectItem::Expr { expr, .. } = item {
653 walk(expr)?;
654 }
655 }
656 for o in &stmt.order_by {
657 walk(&o.expr)?;
658 }
659 if let Some(h) = &stmt.having {
660 walk(h)?;
661 }
662 Ok(())
663}
664
665fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
666 match e {
667 Expr::AggregateOrdered {
670 call,
671 order_by,
672 distinct,
673 } => {
674 if let Expr::FunctionCall { name, args } = call.as_ref() {
675 let lower = name.to_ascii_lowercase();
676 if is_aggregate_name(&lower) {
677 let canonical = if lower == "every" {
678 "bool_and".to_string()
679 } else {
680 lower
681 };
682 let spec = AggSpec {
683 name: canonical,
684 arg: args.first().cloned(),
685 arg2: if name.eq_ignore_ascii_case("string_agg") {
686 args.get(1).cloned()
687 } else {
688 None
689 },
690 distinct: *distinct,
691 order_by: order_by.clone(),
692 };
693 if !out.iter().any(|s| {
694 s.name == spec.name
695 && s.arg == spec.arg
696 && s.arg2 == spec.arg2
697 && s.distinct == spec.distinct
698 && s.order_by == spec.order_by
699 }) {
700 out.push(spec);
701 }
702 return;
703 }
704 }
705 collect_aggregates(call, out);
706 for o in order_by {
707 collect_aggregates(&o.expr, out);
708 }
709 }
710 Expr::FunctionCall { name, args } => {
711 let lower = name.to_ascii_lowercase();
712 if is_aggregate_name(&lower) {
713 let arg = if lower == "count_star" {
714 None
715 } else {
716 args.first().cloned()
717 };
718 let arg2 = if lower == "string_agg" {
722 args.get(1).cloned()
723 } else {
724 None
725 };
726 let canonical = if lower == "every" {
730 "bool_and".to_string()
731 } else {
732 lower
733 };
734 let spec = AggSpec {
735 name: canonical,
736 arg: arg.clone(),
737 arg2: arg2.clone(),
738 distinct: false,
739 order_by: Vec::new(),
740 };
741 if !out.iter().any(|s| {
742 s.name == spec.name
743 && s.arg == spec.arg
744 && s.arg2 == spec.arg2
745 && !s.distinct
746 && s.order_by == spec.order_by
747 }) {
748 out.push(spec);
749 }
750 } else {
753 for a in args {
754 collect_aggregates(a, out);
755 }
756 }
757 }
758 Expr::Binary { lhs, rhs, .. } => {
759 collect_aggregates(lhs, out);
760 collect_aggregates(rhs, out);
761 }
762 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
763 collect_aggregates(expr, out);
764 }
765 Expr::Like { expr, pattern, .. } => {
766 collect_aggregates(expr, out);
767 collect_aggregates(pattern, out);
768 }
769 Expr::InList { expr, list, .. } => {
770 collect_aggregates(expr, out);
771 for item in list {
772 collect_aggregates(item, out);
773 }
774 }
775 Expr::Extract { source, .. } => collect_aggregates(source, out),
776 Expr::ScalarSubquery(_)
779 | Expr::Exists { .. }
780 | Expr::InSubquery { .. }
781 | Expr::WindowFunction { .. }
782 | Expr::Literal(_)
783 | Expr::Placeholder(_)
784 | Expr::Column(_) => {}
785 Expr::Array(items) => {
788 for elem in items {
789 collect_aggregates(elem, out);
790 }
791 }
792 Expr::ArraySubscript { target, index } => {
793 collect_aggregates(target, out);
794 collect_aggregates(index, out);
795 }
796 Expr::AnyAll { expr, array, .. } => {
797 collect_aggregates(expr, out);
798 collect_aggregates(array, out);
799 }
800 Expr::Case {
801 operand,
802 branches,
803 else_branch,
804 } => {
805 if let Some(o) = operand {
806 collect_aggregates(o, out);
807 }
808 for (w, t) in branches {
809 collect_aggregates(w, out);
810 collect_aggregates(t, out);
811 }
812 if let Some(e) = else_branch {
813 collect_aggregates(e, out);
814 }
815 }
816 }
817}
818
819fn update_state(
820 st: &mut AggState,
821 name: &str,
822 v: &Value,
823 arg2: Option<&Value>,
824 order_keys: Option<Vec<Value>>,
825) -> Result<(), EvalError> {
826 let is_null = matches!(v, Value::Null);
827 match name {
828 "count_star" => st.count += 1,
829 "count" => {
830 if !is_null {
831 st.count += 1;
832 }
833 }
834 "sum" | "avg" => {
835 if is_null {
836 return Ok(());
837 }
838 st.count += 1;
839 match v {
840 Value::Int(n) => st.sum_int += i64::from(*n),
841 Value::BigInt(n) => st.sum_int += *n,
842 Value::Float(x) => {
843 st.use_float = true;
844 st.sum_float += *x;
845 }
846 other => {
847 return Err(EvalError::TypeMismatch {
848 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
849 });
850 }
851 }
852 }
853 "min" => {
854 if is_null {
855 return Ok(());
856 }
857 match &st.extreme {
858 None => st.extreme = Some(v.clone()),
859 Some(cur) => {
860 if value_cmp(v, cur) == core::cmp::Ordering::Less {
861 st.extreme = Some(v.clone());
862 }
863 }
864 }
865 }
866 "max" => {
867 if is_null {
868 return Ok(());
869 }
870 match &st.extreme {
871 None => st.extreme = Some(v.clone()),
872 Some(cur) => {
873 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
874 st.extreme = Some(v.clone());
875 }
876 }
877 }
878 }
879 "string_agg" => {
887 if let Some(sep) = arg2
888 && let Value::Text(s) = sep
889 {
890 st.separator = Some(s.clone());
891 }
892 if is_null {
893 return Ok(());
894 }
895 if let Value::Text(s) = v {
896 st.items.push(Value::Text(s.clone()));
897 if let Some(k) = order_keys {
898 st.item_keys.push(k);
899 }
900 st.count += 1;
901 } else {
902 return Err(EvalError::TypeMismatch {
903 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
904 });
905 }
906 }
907 "array_agg" => {
913 st.items.push(v.clone());
914 if let Some(k) = order_keys {
915 st.item_keys.push(k);
916 }
917 st.count += 1;
918 }
919 "bool_and" => {
923 if is_null {
924 return Ok(());
925 }
926 let b = match v {
927 Value::Bool(b) => *b,
928 other => {
929 return Err(EvalError::TypeMismatch {
930 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
931 });
932 }
933 };
934 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
935 }
936 "bool_or" => {
939 if is_null {
940 return Ok(());
941 }
942 let b = match v {
943 Value::Bool(b) => *b,
944 other => {
945 return Err(EvalError::TypeMismatch {
946 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
947 });
948 }
949 };
950 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
951 }
952 _ => unreachable!("non-aggregate {name} in update_state"),
953 }
954 Ok(())
955}
956
957#[allow(clippy::cast_precision_loss)]
958fn finalize(name: &str, st: &AggState) -> Value {
959 match name {
960 "count" | "count_star" => Value::BigInt(st.count),
961 "sum" => {
962 if st.count == 0 {
963 Value::Null
964 } else if st.use_float {
965 Value::Float(st.sum_float + (st.sum_int as f64))
966 } else {
967 Value::BigInt(st.sum_int)
968 }
969 }
970 "avg" => {
971 if st.count == 0 {
972 Value::Null
973 } else {
974 let total = if st.use_float {
975 st.sum_float + (st.sum_int as f64)
976 } else {
977 st.sum_int as f64
978 };
979 Value::Float(total / (st.count as f64))
980 }
981 }
982 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
983 "string_agg" => {
987 if st.items.is_empty() {
988 return Value::Null;
989 }
990 let sep = st.separator.clone().unwrap_or_default();
991 let mut out = String::new();
992 for (i, item) in st.items.iter().enumerate() {
993 if i > 0 {
994 out.push_str(&sep);
995 }
996 if let Value::Text(s) = item {
997 out.push_str(s);
998 }
999 }
1000 Value::Text(out)
1001 }
1002 "array_agg" => {
1009 if st.items.is_empty() {
1010 return Value::Null;
1011 }
1012 let probe = st.items.iter().find(|v| !v.is_null());
1013 match probe.and_then(spg_storage::Value::data_type) {
1014 Some(DataType::Int) | Some(DataType::SmallInt) => {
1015 let items: Vec<Option<i32>> = st
1016 .items
1017 .iter()
1018 .map(|v| match v {
1019 Value::Int(n) => Some(*n),
1020 Value::SmallInt(n) => Some(i32::from(*n)),
1021 _ => None,
1022 })
1023 .collect();
1024 Value::IntArray(items)
1025 }
1026 Some(DataType::BigInt) => {
1027 let items: Vec<Option<i64>> = st
1028 .items
1029 .iter()
1030 .map(|v| match v {
1031 Value::BigInt(n) => Some(*n),
1032 _ => None,
1033 })
1034 .collect();
1035 Value::BigIntArray(items)
1036 }
1037 _ => {
1038 let items: Vec<Option<String>> = st
1039 .items
1040 .iter()
1041 .map(|v| match v {
1042 Value::Text(s) => Some(s.clone()),
1043 Value::Null => None,
1044 other => Some(format!("{other:?}")),
1045 })
1046 .collect();
1047 Value::TextArray(items)
1048 }
1049 }
1050 }
1051 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
1055 _ => unreachable!(),
1056 }
1057}
1058
1059fn infer_agg_type(spec: &AggSpec, schema_cols: &[ColumnSchema]) -> DataType {
1060 let arg_ty = spec
1064 .arg
1065 .as_ref()
1066 .and_then(|a| crate::describe::describe_expr(a, schema_cols))
1067 .map(|shape| shape.ty);
1068 match spec.name.as_str() {
1069 "count" | "count_star" => DataType::BigInt,
1070 "sum" => match arg_ty {
1071 Some(DataType::Float) => DataType::Float,
1072 _ => DataType::BigInt,
1073 },
1074 "avg" => DataType::Float,
1075 "string_agg" => DataType::Text,
1077 "array_agg" => match arg_ty {
1078 Some(DataType::Int | DataType::SmallInt) => DataType::IntArray,
1079 Some(DataType::BigInt) => DataType::BigIntArray,
1080 _ => DataType::TextArray,
1081 },
1082 "bool_and" | "bool_or" => DataType::Bool,
1085 _ => arg_ty.unwrap_or(DataType::Text),
1087 }
1088}
1089
1090fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
1091 if let Expr::Column(c) = e
1092 && let Some(s) = synth.iter().find(|s| s.name == c.name)
1093 {
1094 return s.ty;
1095 }
1096 crate::describe::describe_expr(e, synth)
1102 .map(|shape| shape.ty)
1103 .unwrap_or(DataType::Text)
1104}
1105
1106fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
1107 if let Expr::AggregateOrdered {
1110 call,
1111 order_by,
1112 distinct,
1113 } = e
1114 && let Expr::FunctionCall { name, args } = call.as_ref()
1115 {
1116 let lower = name.to_ascii_lowercase();
1117 if is_aggregate_name(&lower) {
1118 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
1119 let arg = args.first().cloned();
1120 let arg2 = if lower == "string_agg" {
1121 args.get(1).cloned()
1122 } else {
1123 None
1124 };
1125 for (i, spec) in aggs.iter().enumerate() {
1126 if spec.name == canonical
1127 && spec.arg == arg
1128 && spec.arg2 == arg2
1129 && spec.distinct == *distinct
1130 && spec.order_by == *order_by
1131 {
1132 return Expr::Column(spg_sql::ast::ColumnName {
1133 qualifier: None,
1134 name: format!("__agg_{i}"),
1135 });
1136 }
1137 }
1138 }
1139 }
1140 if let Expr::FunctionCall { name, args } = e {
1142 let lower = name.to_ascii_lowercase();
1143 if is_aggregate_name(&lower) {
1144 let arg = if lower == "count_star" {
1145 None
1146 } else {
1147 args.first().cloned()
1148 };
1149 let arg2 = if lower == "string_agg" {
1152 args.get(1).cloned()
1153 } else {
1154 None
1155 };
1156 let canonical: &str = if lower == "every" {
1160 "bool_and"
1161 } else {
1162 lower.as_str()
1163 };
1164 for (i, spec) in aggs.iter().enumerate() {
1165 if spec.name == canonical
1166 && spec.arg == arg
1167 && spec.arg2 == arg2
1168 && !spec.distinct
1169 && spec.order_by.is_empty()
1170 {
1171 return Expr::Column(spg_sql::ast::ColumnName {
1172 qualifier: None,
1173 name: format!("__agg_{i}"),
1174 });
1175 }
1176 }
1177 }
1178 }
1179 for (i, g) in group_exprs.iter().enumerate() {
1181 if g == e {
1182 return Expr::Column(spg_sql::ast::ColumnName {
1183 qualifier: None,
1184 name: format!("__grp_{i}"),
1185 });
1186 }
1187 }
1188 match e {
1190 Expr::AggregateOrdered {
1191 call,
1192 order_by,
1193 distinct,
1194 } => Expr::AggregateOrdered {
1195 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1196 distinct: *distinct,
1197 order_by: order_by
1198 .iter()
1199 .map(|o| spg_sql::ast::OrderBy {
1200 expr: rewrite_expr(&o.expr, group_exprs, aggs),
1201 desc: o.desc,
1202 nulls_first: o.nulls_first,
1203 })
1204 .collect(),
1205 },
1206 Expr::Binary { lhs, op, rhs } => Expr::Binary {
1207 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1208 op: *op,
1209 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1210 },
1211 Expr::Unary { op, expr } => Expr::Unary {
1212 op: *op,
1213 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1214 },
1215 Expr::Cast { expr, target } => Expr::Cast {
1216 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1217 target: *target,
1218 },
1219 Expr::IsNull { expr, negated } => Expr::IsNull {
1220 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1221 negated: *negated,
1222 },
1223 Expr::FunctionCall { name, args } => Expr::FunctionCall {
1224 name: name.clone(),
1225 args: args
1226 .iter()
1227 .map(|a| rewrite_expr(a, group_exprs, aggs))
1228 .collect(),
1229 },
1230 Expr::Like {
1231 expr,
1232 pattern,
1233 negated,
1234 case_insensitive,
1235 } => Expr::Like {
1236 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1237 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1238 negated: *negated,
1239 case_insensitive: *case_insensitive,
1240 },
1241 Expr::Extract { field, source } => Expr::Extract {
1242 field: *field,
1243 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1244 },
1245 Expr::ScalarSubquery(s) => {
1251 Expr::ScalarSubquery(Box::new(rewrite_group_keys_in_select(s, group_exprs)))
1252 }
1253 Expr::Exists { subquery, negated } => Expr::Exists {
1254 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1255 negated: *negated,
1256 },
1257 Expr::InSubquery {
1258 expr,
1259 subquery,
1260 negated,
1261 } => Expr::InSubquery {
1262 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1263 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1264 negated: *negated,
1265 },
1266 Expr::WindowFunction { .. } | Expr::Literal(_) | Expr::Placeholder(_) | Expr::Column(_) => {
1269 e.clone()
1270 }
1271 Expr::Array(items) => Expr::Array(
1273 items
1274 .iter()
1275 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1276 .collect(),
1277 ),
1278 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1279 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1280 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1281 },
1282 Expr::AnyAll {
1283 expr,
1284 op,
1285 array,
1286 is_any,
1287 } => Expr::AnyAll {
1288 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1289 op: *op,
1290 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1291 is_any: *is_any,
1292 },
1293 Expr::InList {
1294 expr,
1295 list,
1296 negated,
1297 } => Expr::InList {
1298 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1299 list: list
1300 .iter()
1301 .map(|item| rewrite_expr(item, group_exprs, aggs))
1302 .collect(),
1303 negated: *negated,
1304 },
1305 Expr::Case {
1306 operand,
1307 branches,
1308 else_branch,
1309 } => Expr::Case {
1310 operand: operand
1311 .as_deref()
1312 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1313 branches: branches
1314 .iter()
1315 .map(|(w, t)| {
1316 (
1317 rewrite_expr(w, group_exprs, aggs),
1318 rewrite_expr(t, group_exprs, aggs),
1319 )
1320 })
1321 .collect(),
1322 else_branch: else_branch
1323 .as_deref()
1324 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1325 },
1326 }
1327}
1328
1329fn rewrite_group_keys_in_select(
1334 s: &spg_sql::ast::SelectStatement,
1335 group_exprs: &[Expr],
1336) -> spg_sql::ast::SelectStatement {
1337 let mut out = s.clone();
1338 let _ = crate::walk_select_exprs_mut(&mut out, &mut |e| {
1339 *e = rewrite_expr(e, group_exprs, &[]);
1340 Ok(())
1341 });
1342 out
1343}
1344
1345fn encode_one(out: &mut String, v: &Value) {
1348 match v {
1349 Value::Null => out.push_str("N|"),
1350 Value::SmallInt(n) => {
1351 out.push('s');
1352 out.push_str(&n.to_string());
1353 out.push('|');
1354 }
1355 Value::Int(n) => {
1356 out.push('I');
1357 out.push_str(&n.to_string());
1358 out.push('|');
1359 }
1360 Value::BigInt(n) => {
1361 out.push('B');
1362 out.push_str(&n.to_string());
1363 out.push('|');
1364 }
1365 Value::Float(x) => {
1366 out.push('F');
1367 out.push_str(&x.to_string());
1368 out.push('|');
1369 }
1370 Value::Bool(b) => {
1371 out.push(if *b { 'T' } else { 'f' });
1372 out.push('|');
1373 }
1374 Value::Text(s) => {
1375 out.push('S');
1376 out.push_str(s);
1377 out.push('|');
1378 }
1379 Value::Vector(v) => {
1380 out.push('V');
1381 for x in v {
1382 out.push_str(&x.to_string());
1383 out.push(',');
1384 }
1385 out.push('|');
1386 }
1387 Value::Sq8Vector(q) => {
1393 out.push('Q');
1394 out.push_str(&q.min.to_string());
1395 out.push('@');
1396 out.push_str(&q.max.to_string());
1397 out.push(':');
1398 for b in &q.bytes {
1399 out.push_str(&b.to_string());
1400 out.push(',');
1401 }
1402 out.push('|');
1403 }
1404 Value::HalfVector(h) => {
1408 out.push('H');
1409 for b in &h.bytes {
1410 out.push_str(&b.to_string());
1411 out.push(',');
1412 }
1413 out.push('|');
1414 }
1415 Value::Numeric { scaled, scale } => {
1416 out.push('D');
1417 out.push_str(&scaled.to_string());
1418 out.push('@');
1419 out.push_str(&scale.to_string());
1420 out.push('|');
1421 }
1422 Value::Date(d) => {
1423 out.push('d');
1424 out.push_str(&d.to_string());
1425 out.push('|');
1426 }
1427 Value::Timestamp(t) => {
1428 out.push('t');
1429 out.push_str(&t.to_string());
1430 out.push('|');
1431 }
1432 Value::Interval { months, micros } => {
1433 out.push('i');
1434 out.push_str(&months.to_string());
1435 out.push('m');
1436 out.push_str(µs.to_string());
1437 out.push('|');
1438 }
1439 Value::Json(s) => {
1440 out.push('j');
1441 out.push_str(s);
1442 out.push('|');
1443 }
1444 _ => {
1449 out.push('?');
1450 out.push_str(&format!("{v:?}"));
1451 out.push('|');
1452 }
1453 }
1454}
1455
1456fn encode_key_refs(vals: &[&Value]) -> String {
1459 let mut out = String::new();
1460 for v in vals {
1461 encode_one(&mut out, v);
1462 }
1463 out
1464}
1465
1466pub(crate) fn encode_key(vals: &[Value]) -> String {
1467 let mut out = String::new();
1468 for v in vals {
1469 encode_one(&mut out, v);
1470 }
1471 out
1472}
1473
1474#[allow(clippy::cast_precision_loss)]
1475fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1476 use core::cmp::Ordering::Equal;
1477 match (a, b) {
1478 (Value::Null, Value::Null) => Equal,
1479 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
1481 (Value::Int(x), Value::Int(y)) => x.cmp(y),
1482 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1483 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1484 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1485 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1486 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1487 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1488 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1489 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1490 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1491 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1492 _ => Equal,
1493 }
1494}