1use super::ast::*;
46use super::bridge::ExecutionResult;
47use super::error::SqlResult;
48use rayon::prelude::*;
49use sochdb_core::SochValue;
50use std::collections::{HashMap, HashSet};
51
52const PARALLEL_THRESHOLD: usize = 100_000;
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum AggFn {
62 Count,
63 Sum,
64 Avg,
65 Min,
66 Max,
67 Median,
68 Stddev,
69}
70
71impl AggFn {
72 pub fn from_name(name: &str) -> Option<Self> {
74 match name.to_ascii_uppercase().as_str() {
75 "COUNT" => Some(Self::Count),
76 "SUM" => Some(Self::Sum),
77 "AVG" | "MEAN" => Some(Self::Avg),
78 "MIN" => Some(Self::Min),
79 "MAX" => Some(Self::Max),
80 "MEDIAN" => Some(Self::Median),
81 "STDDEV" | "STDDEV_SAMP" | "STDEV" | "SD" => Some(Self::Stddev),
82 _ => None,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89struct AggSpec {
90 key: String,
93 func: AggFn,
94 arg: Option<Expr>,
96 distinct: bool,
97}
98
99pub fn is_aggregate_query(select: &SelectStmt) -> bool {
101 if !select.group_by.is_empty() {
102 return true;
103 }
104 select
105 .columns
106 .iter()
107 .any(|item| matches!(item, SelectItem::Expr { expr, .. } if contains_aggregate(expr)))
108}
109
110fn contains_aggregate(expr: &Expr) -> bool {
112 match expr {
113 Expr::Function(f) => {
114 AggFn::from_name(f.name.name()).is_some()
115 || f.args.iter().any(contains_aggregate)
116 }
117 Expr::BinaryOp { left, right, .. } => {
118 contains_aggregate(left) || contains_aggregate(right)
119 }
120 Expr::UnaryOp { expr, .. } => contains_aggregate(expr),
121 Expr::IsNull { expr, .. } => contains_aggregate(expr),
122 Expr::Case {
123 operand,
124 conditions,
125 else_result,
126 } => {
127 operand.as_deref().map(contains_aggregate).unwrap_or(false)
128 || conditions
129 .iter()
130 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
131 || else_result
132 .as_deref()
133 .map(contains_aggregate)
134 .unwrap_or(false)
135 }
136 _ => false,
137 }
138}
139
140fn collect_agg_specs(select: &SelectStmt) -> Vec<AggSpec> {
142 let mut specs: Vec<AggSpec> = Vec::new();
143 let mut seen: HashSet<String> = HashSet::new();
144
145 let walk = |expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>| {
146 collect_from_expr(expr, specs, seen);
147 };
148
149 for item in &select.columns {
150 if let SelectItem::Expr { expr, .. } = item {
151 walk(expr, &mut specs, &mut seen);
152 }
153 }
154 if let Some(h) = &select.having {
155 walk(h, &mut specs, &mut seen);
156 }
157 for ob in &select.order_by {
158 walk(&ob.expr, &mut specs, &mut seen);
159 }
160 specs
161}
162
163fn collect_from_expr(expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>) {
164 match expr {
165 Expr::Function(f) => {
166 if let Some(func) = AggFn::from_name(f.name.name()) {
167 let arg = f.args.first().cloned();
168 let is_star =
169 matches!(arg.as_ref(), Some(Expr::Column(c)) if c.column == "*");
170 let arg = if is_star { None } else { arg };
171 let key = render_agg_key(func, arg.as_ref(), f.distinct);
172 if seen.insert(key.clone()) {
173 specs.push(AggSpec {
174 key,
175 func,
176 arg,
177 distinct: f.distinct,
178 });
179 }
180 } else {
181 for a in &f.args {
182 collect_from_expr(a, specs, seen);
183 }
184 }
185 }
186 Expr::BinaryOp { left, right, .. } => {
187 collect_from_expr(left, specs, seen);
188 collect_from_expr(right, specs, seen);
189 }
190 Expr::UnaryOp { expr, .. } => collect_from_expr(expr, specs, seen),
191 Expr::IsNull { expr, .. } => collect_from_expr(expr, specs, seen),
192 Expr::Case {
193 operand,
194 conditions,
195 else_result,
196 } => {
197 if let Some(op) = operand {
198 collect_from_expr(op, specs, seen);
199 }
200 for (w, t) in conditions {
201 collect_from_expr(w, specs, seen);
202 collect_from_expr(t, specs, seen);
203 }
204 if let Some(e) = else_result {
205 collect_from_expr(e, specs, seen);
206 }
207 }
208 _ => {}
209 }
210}
211
212fn render_agg_key(func: AggFn, arg: Option<&Expr>, distinct: bool) -> String {
215 let fname = match func {
216 AggFn::Count => "count",
217 AggFn::Sum => "sum",
218 AggFn::Avg => "avg",
219 AggFn::Min => "min",
220 AggFn::Max => "max",
221 AggFn::Median => "median",
222 AggFn::Stddev => "stddev",
223 };
224 let arg_s = match arg {
225 None => "*".to_string(),
226 Some(e) => render_expr_name(e),
227 };
228 if distinct {
229 format!("{}(distinct {})", fname, arg_s)
230 } else {
231 format!("{}({})", fname, arg_s)
232 }
233}
234
235pub fn render_expr_name(expr: &Expr) -> String {
238 match expr {
239 Expr::Column(c) => {
240 if let Some(t) = &c.table {
241 format!("{}.{}", t, c.column)
242 } else {
243 c.column.clone()
244 }
245 }
246 Expr::Literal(Literal::Integer(n)) => n.to_string(),
247 Expr::Literal(Literal::Float(f)) => f.to_string(),
248 Expr::Literal(Literal::String(s)) => format!("'{}'", s),
249 Expr::Literal(Literal::Boolean(b)) => b.to_string(),
250 Expr::Literal(Literal::Null) => "null".to_string(),
251 Expr::Function(f) => {
252 if let Some(func) = AggFn::from_name(f.name.name()) {
253 let arg = f.args.first();
254 let is_star =
255 matches!(arg, Some(Expr::Column(c)) if c.column == "*");
256 render_agg_key(func, if is_star { None } else { arg }, f.distinct)
257 } else {
258 let args: Vec<String> = f.args.iter().map(render_expr_name).collect();
259 format!("{}({})", f.name.name().to_lowercase(), args.join(", "))
260 }
261 }
262 Expr::BinaryOp { left, op, right } => format!(
263 "{} {} {}",
264 render_expr_name(left),
265 binary_op_symbol(op),
266 render_expr_name(right)
267 ),
268 Expr::UnaryOp { op, expr } => match op {
269 UnaryOperator::Minus => format!("-{}", render_expr_name(expr)),
270 UnaryOperator::Plus => render_expr_name(expr),
271 UnaryOperator::Not => format!("not {}", render_expr_name(expr)),
272 UnaryOperator::BitNot => format!("~{}", render_expr_name(expr)),
273 },
274 _ => "expr".to_string(),
275 }
276}
277
278fn binary_op_symbol(op: &BinaryOperator) -> &'static str {
279 match op {
280 BinaryOperator::Plus => "+",
281 BinaryOperator::Minus => "-",
282 BinaryOperator::Multiply => "*",
283 BinaryOperator::Divide => "/",
284 BinaryOperator::Modulo => "%",
285 BinaryOperator::Eq => "=",
286 BinaryOperator::Ne => "<>",
287 BinaryOperator::Lt => "<",
288 BinaryOperator::Le => "<=",
289 BinaryOperator::Gt => ">",
290 BinaryOperator::Ge => ">=",
291 BinaryOperator::And => "and",
292 BinaryOperator::Or => "or",
293 _ => "?",
294 }
295}
296
297fn eval_scalar(
307 expr: &Expr,
308 row: &HashMap<String, SochValue>,
309 params: &[SochValue],
310) -> SochValue {
311 match expr {
312 Expr::Column(c) => {
313 if let Some(t) = &c.table {
314 let qualified = format!("{}.{}", t, c.column);
315 if let Some(v) = row.get(&qualified) {
316 return v.clone();
317 }
318 }
319 row.get(&c.column).cloned().unwrap_or(SochValue::Null)
320 }
321 Expr::Literal(lit) => literal_to_value(lit),
322 Expr::Placeholder(idx) => params
323 .get((*idx as usize).saturating_sub(1))
324 .cloned()
325 .unwrap_or(SochValue::Null),
326 Expr::Function(f) => {
327 let key = render_expr_name(&Expr::Function(f.clone()));
330 row.get(&key).cloned().unwrap_or(SochValue::Null)
331 }
332 Expr::BinaryOp { left, op, right } => {
333 let l = eval_scalar(left, row, params);
334 let r = eval_scalar(right, row, params);
335 eval_binary(&l, op, &r)
336 }
337 Expr::UnaryOp { op, expr } => {
338 let v = eval_scalar(expr, row, params);
339 match op {
340 UnaryOperator::Minus => match v {
341 SochValue::Int(i) => SochValue::Int(-i),
342 SochValue::Float(f) => SochValue::Float(-f),
343 _ => SochValue::Null,
344 },
345 UnaryOperator::Plus => v,
346 UnaryOperator::Not => match v {
347 SochValue::Bool(b) => SochValue::Bool(!b),
348 _ => SochValue::Null,
349 },
350 UnaryOperator::BitNot => match v {
351 SochValue::Int(i) => SochValue::Int(!i),
352 _ => SochValue::Null,
353 },
354 }
355 }
356 Expr::IsNull { expr, negated } => {
357 let v = eval_scalar(expr, row, params);
358 let is_null = v.is_null();
359 SochValue::Bool(if *negated { !is_null } else { is_null })
360 }
361 _ => SochValue::Null,
362 }
363}
364
365fn literal_to_value(lit: &Literal) -> SochValue {
366 match lit {
367 Literal::Integer(i) => SochValue::Int(*i),
368 Literal::Float(f) => SochValue::Float(*f),
369 Literal::String(s) => SochValue::Text(s.clone()),
370 Literal::Boolean(b) => SochValue::Bool(*b),
371 Literal::Null => SochValue::Null,
372 _ => SochValue::Null,
373 }
374}
375
376fn numeric(v: &SochValue) -> Option<f64> {
377 match v {
378 SochValue::Int(i) => Some(*i as f64),
379 SochValue::UInt(u) => Some(*u as f64),
380 SochValue::Float(f) => Some(*f),
381 SochValue::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
382 _ => None,
383 }
384}
385
386fn eval_binary(l: &SochValue, op: &BinaryOperator, r: &SochValue) -> SochValue {
387 use BinaryOperator::*;
388 match op {
389 Plus | Minus | Multiply | Divide | Modulo => {
390 if let (SochValue::Int(a), SochValue::Int(b)) = (l, r) {
392 return match op {
393 Plus => SochValue::Int(a.wrapping_add(*b)),
394 Minus => SochValue::Int(a.wrapping_sub(*b)),
395 Multiply => SochValue::Int(a.wrapping_mul(*b)),
396 Divide => {
397 if *b == 0 {
398 SochValue::Null
399 } else {
400 SochValue::Float(*a as f64 / *b as f64)
401 }
402 }
403 Modulo => {
404 if *b == 0 {
405 SochValue::Null
406 } else {
407 SochValue::Int(a % b)
408 }
409 }
410 _ => unreachable!(),
411 };
412 }
413 let (a, b) = match (numeric(l), numeric(r)) {
414 (Some(a), Some(b)) => (a, b),
415 _ => return SochValue::Null,
416 };
417 match op {
418 Plus => SochValue::Float(a + b),
419 Minus => SochValue::Float(a - b),
420 Multiply => SochValue::Float(a * b),
421 Divide => {
422 if b == 0.0 {
423 SochValue::Null
424 } else {
425 SochValue::Float(a / b)
426 }
427 }
428 Modulo => {
429 if b == 0.0 {
430 SochValue::Null
431 } else {
432 SochValue::Float(a % b)
433 }
434 }
435 _ => unreachable!(),
436 }
437 }
438 Eq | Ne | Lt | Le | Gt | Ge => {
439 if l.is_null() || r.is_null() {
440 return SochValue::Null;
441 }
442 let ord = compare_values(l, r);
443 let b = match op {
444 Eq => ord == std::cmp::Ordering::Equal,
445 Ne => ord != std::cmp::Ordering::Equal,
446 Lt => ord == std::cmp::Ordering::Less,
447 Le => ord != std::cmp::Ordering::Greater,
448 Gt => ord == std::cmp::Ordering::Greater,
449 Ge => ord != std::cmp::Ordering::Less,
450 _ => unreachable!(),
451 };
452 SochValue::Bool(b)
453 }
454 And => match (as_bool(l), as_bool(r)) {
455 (Some(a), Some(b)) => SochValue::Bool(a && b),
456 _ => SochValue::Null,
457 },
458 Or => match (as_bool(l), as_bool(r)) {
459 (Some(a), Some(b)) => SochValue::Bool(a || b),
460 _ => SochValue::Null,
461 },
462 _ => SochValue::Null,
463 }
464}
465
466fn as_bool(v: &SochValue) -> Option<bool> {
467 match v {
468 SochValue::Bool(b) => Some(*b),
469 SochValue::Int(i) => Some(*i != 0),
470 SochValue::Null => None,
471 _ => None,
472 }
473}
474
475pub fn compare_values(a: &SochValue, b: &SochValue) -> std::cmp::Ordering {
477 use std::cmp::Ordering;
478 match (numeric(a), numeric(b)) {
479 (Some(x), Some(y)) => return x.partial_cmp(&y).unwrap_or(Ordering::Equal),
480 _ => {}
481 }
482 match (a, b) {
483 (SochValue::Text(x), SochValue::Text(y)) => x.cmp(y),
484 (SochValue::Null, SochValue::Null) => Ordering::Equal,
485 (SochValue::Null, _) => Ordering::Less,
486 (_, SochValue::Null) => Ordering::Greater,
487 _ => Ordering::Equal,
488 }
489}
490
491fn key_repr(v: &SochValue) -> String {
494 match v {
495 SochValue::Null => "\u{0}N".to_string(),
496 SochValue::Int(i) => format!("i{}", i),
497 SochValue::UInt(u) => format!("i{}", u),
498 SochValue::Float(f) => {
499 if f.fract() == 0.0 && f.abs() < 9.0e15 {
500 format!("i{}", *f as i64)
501 } else {
502 format!("f{}", f)
503 }
504 }
505 SochValue::Text(s) => format!("s{}", s),
506 SochValue::Bool(b) => format!("b{}", b),
507 other => format!("{:?}", other),
508 }
509}
510
511#[derive(Debug)]
516enum Acc {
517 CountStar(u64),
518 Count(u64),
519 CountDistinct(HashSet<String>),
520 Sum {
522 int: i64,
523 float: f64,
524 saw_float: bool,
525 saw_any: bool,
526 overflowed: bool,
527 },
528 Avg {
529 sum: f64,
530 n: u64,
531 },
532 Min(Option<SochValue>),
533 Max(Option<SochValue>),
534 Median(Vec<f64>),
535 Stddev {
537 n: u64,
538 mean: f64,
539 m2: f64,
540 },
541}
542
543impl Acc {
544 fn new(spec: &AggSpec) -> Self {
545 match (spec.func, spec.arg.is_some(), spec.distinct) {
546 (AggFn::Count, false, _) => Acc::CountStar(0),
547 (AggFn::Count, true, true) => Acc::CountDistinct(HashSet::new()),
548 (AggFn::Count, true, false) => Acc::Count(0),
549 (AggFn::Sum, _, _) => Acc::Sum {
550 int: 0,
551 float: 0.0,
552 saw_float: false,
553 saw_any: false,
554 overflowed: false,
555 },
556 (AggFn::Avg, _, _) => Acc::Avg { sum: 0.0, n: 0 },
557 (AggFn::Min, _, _) => Acc::Min(None),
558 (AggFn::Max, _, _) => Acc::Max(None),
559 (AggFn::Median, _, _) => Acc::Median(Vec::new()),
560 (AggFn::Stddev, _, _) => Acc::Stddev {
561 n: 0,
562 mean: 0.0,
563 m2: 0.0,
564 },
565 }
566 }
567
568 fn update(&mut self, val: Option<&SochValue>) {
570 match self {
571 Acc::CountStar(n) => *n += 1,
572 Acc::Count(n) => {
573 if let Some(v) = val {
574 if !v.is_null() {
575 *n += 1;
576 }
577 }
578 }
579 Acc::CountDistinct(set) => {
580 if let Some(v) = val {
581 if !v.is_null() {
582 set.insert(key_repr(v));
583 }
584 }
585 }
586 Acc::Sum {
587 int,
588 float,
589 saw_float,
590 saw_any,
591 overflowed,
592 } => {
593 let Some(v) = val else { return };
594 match v {
595 SochValue::Int(i) => {
596 *saw_any = true;
597 match int.checked_add(*i) {
598 Some(s) => *int = s,
599 None => *overflowed = true,
600 }
601 *float += *i as f64;
602 }
603 SochValue::UInt(u) => {
604 *saw_any = true;
605 match int.checked_add(*u as i64) {
606 Some(s) => *int = s,
607 None => *overflowed = true,
608 }
609 *float += *u as f64;
610 }
611 SochValue::Float(f) => {
612 *saw_any = true;
613 *saw_float = true;
614 *float += *f;
615 }
616 _ => {}
617 }
618 }
619 Acc::Avg { sum, n } => {
620 if let Some(x) = val.and_then(numeric) {
621 *sum += x;
622 *n += 1;
623 }
624 }
625 Acc::Min(cur) => {
626 let Some(v) = val else { return };
627 if v.is_null() {
628 return;
629 }
630 match cur {
631 None => *cur = Some(v.clone()),
632 Some(c) => {
633 if compare_values(v, c) == std::cmp::Ordering::Less {
634 *cur = Some(v.clone());
635 }
636 }
637 }
638 }
639 Acc::Max(cur) => {
640 let Some(v) = val else { return };
641 if v.is_null() {
642 return;
643 }
644 match cur {
645 None => *cur = Some(v.clone()),
646 Some(c) => {
647 if compare_values(v, c) == std::cmp::Ordering::Greater {
648 *cur = Some(v.clone());
649 }
650 }
651 }
652 }
653 Acc::Median(vals) => {
654 if let Some(x) = val.and_then(numeric) {
655 vals.push(x);
656 }
657 }
658 Acc::Stddev { n, mean, m2 } => {
659 if let Some(x) = val.and_then(numeric) {
660 *n += 1;
661 let delta = x - *mean;
662 *mean += delta / *n as f64;
663 let delta2 = x - *mean;
664 *m2 += delta * delta2;
665 }
666 }
667 }
668 }
669
670 fn merge(&mut self, other: Acc) {
673 match (self, other) {
674 (Acc::CountStar(a), Acc::CountStar(b)) => *a += b,
675 (Acc::Count(a), Acc::Count(b)) => *a += b,
676 (Acc::CountDistinct(a), Acc::CountDistinct(b)) => a.extend(b),
677 (
678 Acc::Sum {
679 int,
680 float,
681 saw_float,
682 saw_any,
683 overflowed,
684 },
685 Acc::Sum {
686 int: i2,
687 float: f2,
688 saw_float: sf2,
689 saw_any: sa2,
690 overflowed: of2,
691 },
692 ) => {
693 match int.checked_add(i2) {
694 Some(s) => *int = s,
695 None => *overflowed = true,
696 }
697 *float += f2;
698 *saw_float |= sf2;
699 *saw_any |= sa2;
700 *overflowed |= of2;
701 }
702 (Acc::Avg { sum, n }, Acc::Avg { sum: s2, n: n2 }) => {
703 *sum += s2;
704 *n += n2;
705 }
706 (Acc::Min(a), Acc::Min(Some(b))) => match a {
707 None => *a = Some(b),
708 Some(cur) => {
709 if compare_values(&b, cur) == std::cmp::Ordering::Less {
710 *a = Some(b);
711 }
712 }
713 },
714 (Acc::Max(a), Acc::Max(Some(b))) => match a {
715 None => *a = Some(b),
716 Some(cur) => {
717 if compare_values(&b, cur) == std::cmp::Ordering::Greater {
718 *a = Some(b);
719 }
720 }
721 },
722 (Acc::Min(_), Acc::Min(None)) | (Acc::Max(_), Acc::Max(None)) => {}
723 (Acc::Median(a), Acc::Median(b)) => a.extend(b),
724 (
725 Acc::Stddev { n, mean, m2 },
726 Acc::Stddev {
727 n: nb,
728 mean: mb,
729 m2: m2b,
730 },
731 ) => {
732 if nb > 0 {
734 if *n == 0 {
735 *n = nb;
736 *mean = mb;
737 *m2 = m2b;
738 } else {
739 let na = *n as f64;
740 let nbf = nb as f64;
741 let delta = mb - *mean;
742 let total = na + nbf;
743 *mean += delta * nbf / total;
744 *m2 += m2b + delta * delta * na * nbf / total;
745 *n += nb;
746 }
747 }
748 }
749 _ => unreachable!("mismatched accumulator merge"),
750 }
751 }
752
753 fn finalize(self) -> SochValue {
754 match self {
755 Acc::CountStar(n) | Acc::Count(n) => SochValue::Int(n as i64),
756 Acc::CountDistinct(set) => SochValue::Int(set.len() as i64),
757 Acc::Sum {
758 int,
759 float,
760 saw_float,
761 saw_any,
762 overflowed,
763 } => {
764 if !saw_any {
765 SochValue::Null
766 } else if saw_float || overflowed {
767 SochValue::Float(float)
768 } else {
769 SochValue::Int(int)
770 }
771 }
772 Acc::Avg { sum, n } => {
773 if n == 0 {
774 SochValue::Null
775 } else {
776 SochValue::Float(sum / n as f64)
777 }
778 }
779 Acc::Min(v) | Acc::Max(v) => v.unwrap_or(SochValue::Null),
780 Acc::Median(mut vals) => {
781 if vals.is_empty() {
782 return SochValue::Null;
783 }
784 let mid = vals.len() / 2;
785 if vals.len() % 2 == 1 {
786 let (_, m, _) =
787 vals.select_nth_unstable_by(mid, |a, b| {
788 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
789 });
790 SochValue::Float(*m)
791 } else {
792 let (lo, hi_first, _) =
794 vals.select_nth_unstable_by(mid, |a, b| {
795 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
796 });
797 let lo_max = lo
798 .iter()
799 .copied()
800 .fold(f64::NEG_INFINITY, f64::max);
801 SochValue::Float((lo_max + *hi_first) / 2.0)
802 }
803 }
804 Acc::Stddev { n, m2, .. } => {
805 if n < 2 {
806 SochValue::Null
807 } else {
808 SochValue::Float((m2 / (n - 1) as f64).sqrt())
809 }
810 }
811 }
812 }
813}
814
815struct GroupState {
820 key_values: Vec<SochValue>,
821 first_row: HashMap<String, SochValue>,
822 accs: Vec<Acc>,
823}
824
825#[derive(Debug, Clone, PartialEq, Eq, Hash)]
832enum KeyAtom<'a> {
833 Null,
834 Int(i64),
835 FBits(u64),
837 Str(&'a str),
838 Bool(bool),
839}
840
841impl<'a> KeyAtom<'a> {
842 fn from_value(v: &'a SochValue) -> Self {
843 match v {
844 SochValue::Null => KeyAtom::Null,
845 SochValue::Int(i) => KeyAtom::Int(*i),
846 SochValue::UInt(u) => KeyAtom::Int(*u as i64),
847 SochValue::Float(f) => {
848 if f.fract() == 0.0 && f.abs() < 9.0e15 {
849 KeyAtom::Int(*f as i64)
850 } else if f.is_nan() {
851 KeyAtom::FBits(f64::NAN.to_bits())
852 } else {
853 KeyAtom::FBits(f.to_bits())
854 }
855 }
856 SochValue::Text(s) => KeyAtom::Str(s.as_str()),
857 SochValue::Bool(b) => KeyAtom::Bool(*b),
858 _ => KeyAtom::Null,
859 }
860 }
861}
862
863#[derive(Debug, Clone, PartialEq, Eq, Hash)]
864enum GroupKey<'a> {
865 Empty,
866 One(KeyAtom<'a>),
867 Many(Vec<KeyAtom<'a>>),
868}
869
870static NULL_VALUE: SochValue = SochValue::Null;
871
872#[inline]
874fn col_get<'r>(row: &'r HashMap<String, SochValue>, col: &PlainCol) -> &'r SochValue {
875 if let Some(q) = &col.qualified {
876 if let Some(v) = row.get(q) {
877 return v;
878 }
879 }
880 row.get(&col.name).unwrap_or(&NULL_VALUE)
881}
882
883struct PlainCol {
885 name: String,
886 qualified: Option<String>,
887}
888
889fn as_plain_col(expr: &Expr) -> Option<PlainCol> {
890 match expr {
891 Expr::Column(c) => Some(PlainCol {
892 name: c.column.clone(),
893 qualified: c.table.as_ref().map(|t| format!("{}.{}", t, c.column)),
894 }),
895 _ => None,
896 }
897}
898
899fn make_group_key<'r>(
901 row: &'r HashMap<String, SochValue>,
902 group_cols: &[PlainCol],
903) -> GroupKey<'r> {
904 match group_cols.len() {
905 0 => GroupKey::Empty,
906 1 => GroupKey::One(KeyAtom::from_value(col_get(row, &group_cols[0]))),
907 _ => GroupKey::Many(
908 group_cols
909 .iter()
910 .map(|c| KeyAtom::from_value(col_get(row, c)))
911 .collect(),
912 ),
913 }
914}
915
916fn accumulate_fast<'a>(
921 select: &SelectStmt,
922 specs: &[AggSpec],
923 rows: &'a [HashMap<String, SochValue>],
924) -> Option<Vec<GroupState>> {
925 let group_cols: Vec<PlainCol> = select
927 .group_by
928 .iter()
929 .map(as_plain_col)
930 .collect::<Option<Vec<_>>>()?;
931 let arg_cols: Vec<Option<PlainCol>> = specs
933 .iter()
934 .map(|s| match &s.arg {
935 None => Some(None),
936 Some(e) => as_plain_col(e).map(Some),
937 })
938 .collect::<Option<Vec<_>>>()?;
939
940 let accumulate_chunk = |chunk: &'a [HashMap<String, SochValue>]| -> Vec<(GroupKey<'a>, GroupState)> {
941 let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
942 let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
943 for row in chunk {
944 let key = make_group_key(row, &group_cols);
945 let idx = match index.get(&key) {
946 Some(&i) => i,
947 None => {
948 let state = GroupState {
949 key_values: group_cols
950 .iter()
951 .map(|c| col_get(row, c).clone())
952 .collect(),
953 first_row: row.clone(),
954 accs: specs.iter().map(Acc::new).collect(),
955 };
956 order.push((key.clone(), state));
957 index.insert(key, order.len() - 1);
958 order.len() - 1
959 }
960 };
961 let accs = &mut order[idx].1.accs;
962 for (acc, arg) in accs.iter_mut().zip(arg_cols.iter()) {
963 match arg {
964 None => acc.update(None),
965 Some(col) => acc.update(Some(col_get(row, col))),
966 }
967 }
968 }
969 order
970 };
971
972 let merged: Vec<(GroupKey<'a>, GroupState)> = if rows.len() >= PARALLEL_THRESHOLD {
973 let n_threads = rayon::current_num_threads().max(1);
974 let chunk_size = (rows.len() / (n_threads * 4)).max(16_384);
975 let partials: Vec<Vec<(GroupKey<'a>, GroupState)>> = rows
976 .par_chunks(chunk_size)
977 .map(accumulate_chunk)
978 .collect();
979 let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
981 let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
982 for partial in partials {
983 for (key, state) in partial {
984 match index.get(&key) {
985 Some(&i) => {
986 let dst = &mut order[i].1;
987 for (a, b) in dst.accs.iter_mut().zip(state.accs.into_iter()) {
988 a.merge(b);
989 }
990 }
991 None => {
992 order.push((key.clone(), state));
993 index.insert(key, order.len() - 1);
994 }
995 }
996 }
997 }
998 order
999 } else {
1000 accumulate_chunk(rows)
1001 };
1002
1003 Some(merged.into_iter().map(|(_, s)| s).collect())
1004}
1005
1006pub fn execute_aggregate(
1011 select: &SelectStmt,
1012 rows: &[HashMap<String, SochValue>],
1013 params: &[SochValue],
1014 limit: Option<usize>,
1015 offset: Option<usize>,
1016) -> SqlResult<ExecutionResult> {
1017 let specs = collect_agg_specs(select);
1018 let grouped = !select.group_by.is_empty();
1019
1020 let mut order: Vec<GroupState> = match accumulate_fast(select, &specs, rows) {
1025 Some(states) => states,
1026 None => {
1027 let mut order: Vec<GroupState> = Vec::new();
1028 let mut index: HashMap<Vec<String>, usize> = HashMap::new();
1029
1030 for row in rows {
1031 let key_values: Vec<SochValue> = select
1032 .group_by
1033 .iter()
1034 .map(|e| eval_scalar(e, row, params))
1035 .collect();
1036 let hash_key: Vec<String> = key_values.iter().map(key_repr).collect();
1037
1038 let idx = match index.get(&hash_key) {
1039 Some(&i) => i,
1040 None => {
1041 let state = GroupState {
1042 key_values,
1043 first_row: row.clone(),
1044 accs: specs.iter().map(Acc::new).collect(),
1045 };
1046 order.push(state);
1047 index.insert(hash_key, order.len() - 1);
1048 order.len() - 1
1049 }
1050 };
1051
1052 let state = &mut order[idx];
1053 for (acc, spec) in state.accs.iter_mut().zip(specs.iter()) {
1054 match &spec.arg {
1055 None => acc.update(None),
1056 Some(arg) => {
1057 let v = eval_scalar(arg, row, params);
1058 acc.update(Some(&v));
1059 }
1060 }
1061 }
1062 }
1063 order
1064 }
1065 };
1066
1067 if !grouped && order.is_empty() {
1069 order.push(GroupState {
1070 key_values: Vec::new(),
1071 first_row: HashMap::new(),
1072 accs: specs.iter().map(Acc::new).collect(),
1073 });
1074 }
1075
1076 let group_names: Vec<String> = select.group_by.iter().map(render_expr_name).collect();
1078
1079 let mut out_rows: Vec<HashMap<String, SochValue>> = Vec::with_capacity(order.len());
1080 for state in order {
1081 let mut row = state.first_row;
1084 for (name, val) in group_names.iter().zip(state.key_values.into_iter()) {
1085 row.insert(name.clone(), val);
1086 }
1087 for (spec, acc) in specs.iter().zip(state.accs.into_iter()) {
1088 row.insert(spec.key.clone(), acc.finalize());
1089 }
1090 out_rows.push(row);
1091 }
1092
1093 if let Some(having) = &select.having {
1095 out_rows.retain(|row| {
1096 matches!(eval_scalar(having, row, params), SochValue::Bool(true))
1097 });
1098 }
1099
1100 if !select.order_by.is_empty() {
1102 let alias_map: Vec<(String, Expr)> = select
1104 .columns
1105 .iter()
1106 .filter_map(|item| match item {
1107 SelectItem::Expr {
1108 expr,
1109 alias: Some(a),
1110 } => Some((a.clone(), expr.clone())),
1111 _ => None,
1112 })
1113 .collect();
1114 for row in &mut out_rows {
1115 for (alias, expr) in &alias_map {
1116 if !row.contains_key(alias) {
1117 let v = eval_scalar(expr, row, params);
1118 row.insert(alias.clone(), v);
1119 }
1120 }
1121 }
1122 out_rows.sort_by(|a, b| {
1123 for item in &select.order_by {
1124 let va = eval_scalar(&item.expr, a, params);
1125 let vb = eval_scalar(&item.expr, b, params);
1126 let mut cmp = compare_values(&va, &vb);
1127 if !item.asc {
1128 cmp = cmp.reverse();
1129 }
1130 if cmp != std::cmp::Ordering::Equal {
1131 return cmp;
1132 }
1133 }
1134 std::cmp::Ordering::Equal
1135 });
1136 }
1137
1138 if let Some(off) = offset {
1140 if off > 0 {
1141 out_rows.drain(..off.min(out_rows.len()));
1142 }
1143 }
1144 if let Some(lim) = limit {
1145 out_rows.truncate(lim);
1146 }
1147
1148 let mut columns: Vec<String> = Vec::new();
1150 let mut projections: Vec<(String, Expr)> = Vec::new();
1151 for item in &select.columns {
1152 match item {
1153 SelectItem::Wildcard | SelectItem::QualifiedWildcard(_) => {
1154 for name in &group_names {
1156 columns.push(name.clone());
1157 projections.push((
1158 name.clone(),
1159 Expr::Column(ColumnRef::new(name.clone())),
1160 ));
1161 }
1162 for spec in &specs {
1163 columns.push(spec.key.clone());
1164 projections.push((
1165 spec.key.clone(),
1166 Expr::Column(ColumnRef::new(spec.key.clone())),
1167 ));
1168 }
1169 }
1170 SelectItem::Expr { expr, alias } => {
1171 let name = alias.clone().unwrap_or_else(|| render_expr_name(expr));
1172 columns.push(name.clone());
1173 projections.push((name, expr.clone()));
1174 }
1175 }
1176 }
1177
1178 let projected: Vec<HashMap<String, SochValue>> = out_rows
1179 .into_iter()
1180 .map(|row| {
1181 let mut out = HashMap::with_capacity(projections.len());
1182 for (name, expr) in &projections {
1183 let v = eval_scalar(expr, &row, params);
1184 out.insert(name.clone(), v);
1185 }
1186 out
1187 })
1188 .collect();
1189
1190 Ok(ExecutionResult::Rows {
1191 columns,
1192 rows: projected,
1193 })
1194}
1195
1196#[cfg(test)]
1197mod tests {
1198 use super::super::bridge::{SqlBridge, SqlConnection};
1199 use super::*;
1200
1201 fn fcall(name: &str, arg: &str) -> Expr {
1202 Expr::Function(FunctionCall {
1203 name: ObjectName::new(name),
1204 args: vec![Expr::Column(ColumnRef::new(arg))],
1205 distinct: false,
1206 filter: None,
1207 over: None,
1208 })
1209 }
1210
1211 #[test]
1212 fn agg_fn_recognition() {
1213 assert_eq!(AggFn::from_name("median"), Some(AggFn::Median));
1214 assert_eq!(AggFn::from_name("STDDEV"), Some(AggFn::Stddev));
1215 assert_eq!(AggFn::from_name("stddev_samp"), Some(AggFn::Stddev));
1216 assert_eq!(AggFn::from_name("upper"), None);
1217 }
1218
1219 #[test]
1220 fn canonical_keys() {
1221 assert_eq!(render_expr_name(&fcall("SUM", "v1")), "sum(v1)");
1222 assert_eq!(render_expr_name(&fcall("Median", "v3")), "median(v3)");
1223 }
1224
1225 struct DataConn {
1231 tables: HashMap<String, Vec<HashMap<String, SochValue>>>,
1232 }
1233
1234 impl DataConn {
1235 fn new() -> Self {
1236 Self {
1237 tables: HashMap::new(),
1238 }
1239 }
1240
1241 fn with_table(
1242 mut self,
1243 name: &str,
1244 cols: &[&str],
1245 rows: Vec<Vec<SochValue>>,
1246 ) -> Self {
1247 let rows = rows
1248 .into_iter()
1249 .map(|vals| {
1250 cols.iter()
1251 .map(|c| c.to_string())
1252 .zip(vals.into_iter())
1253 .collect::<HashMap<_, _>>()
1254 })
1255 .collect();
1256 self.tables.insert(name.to_string(), rows);
1257 self
1258 }
1259 }
1260
1261 impl SqlConnection for DataConn {
1262 fn select(
1263 &self,
1264 table: &str,
1265 _: &[String],
1266 _where_clause: Option<&Expr>,
1267 _: &[OrderByItem],
1268 _: Option<usize>,
1269 _: Option<usize>,
1270 _: &[SochValue],
1271 ) -> SqlResult<ExecutionResult> {
1272 let rows = self.tables.get(table).cloned().unwrap_or_default();
1274 Ok(ExecutionResult::Rows {
1275 columns: vec![],
1276 rows,
1277 })
1278 }
1279 fn insert(
1280 &mut self,
1281 _: &str,
1282 _: Option<&[String]>,
1283 _: &[Vec<Expr>],
1284 _: Option<&OnConflict>,
1285 _: &[SochValue],
1286 ) -> SqlResult<ExecutionResult> {
1287 Ok(ExecutionResult::RowsAffected(0))
1288 }
1289 fn update(
1290 &mut self,
1291 _: &str,
1292 _: &[Assignment],
1293 _: Option<&Expr>,
1294 _: &[SochValue],
1295 ) -> SqlResult<ExecutionResult> {
1296 Ok(ExecutionResult::RowsAffected(0))
1297 }
1298 fn delete(
1299 &mut self,
1300 _: &str,
1301 _: Option<&Expr>,
1302 _: &[SochValue],
1303 ) -> SqlResult<ExecutionResult> {
1304 Ok(ExecutionResult::RowsAffected(0))
1305 }
1306 fn create_table(&mut self, _: &CreateTableStmt) -> SqlResult<ExecutionResult> {
1307 Ok(ExecutionResult::Ok)
1308 }
1309 fn drop_table(&mut self, _: &DropTableStmt) -> SqlResult<ExecutionResult> {
1310 Ok(ExecutionResult::Ok)
1311 }
1312 fn create_index(&mut self, _: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
1313 Ok(ExecutionResult::Ok)
1314 }
1315 fn drop_index(&mut self, _: &DropIndexStmt) -> SqlResult<ExecutionResult> {
1316 Ok(ExecutionResult::Ok)
1317 }
1318 fn alter_table(&mut self, _: &AlterTableStmt) -> SqlResult<ExecutionResult> {
1319 Ok(ExecutionResult::Ok)
1320 }
1321 fn begin(&mut self, _: &BeginStmt) -> SqlResult<ExecutionResult> {
1322 Ok(ExecutionResult::TransactionOk)
1323 }
1324 fn commit(&mut self) -> SqlResult<ExecutionResult> {
1325 Ok(ExecutionResult::TransactionOk)
1326 }
1327 fn rollback(&mut self, _: Option<&str>) -> SqlResult<ExecutionResult> {
1328 Ok(ExecutionResult::TransactionOk)
1329 }
1330 fn table_exists(&self, t: &str) -> SqlResult<bool> {
1331 Ok(self.tables.contains_key(t))
1332 }
1333 fn index_exists(&self, _: &str) -> SqlResult<bool> {
1334 Ok(false)
1335 }
1336 fn scan_all(
1337 &self,
1338 table: &str,
1339 _: &[String],
1340 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
1341 Ok(self.tables.get(table).cloned().unwrap_or_default())
1342 }
1343 fn eval_join_predicate(
1344 &self,
1345 expr: &Expr,
1346 row: &HashMap<String, SochValue>,
1347 params: &[SochValue],
1348 ) -> Option<bool> {
1349 match eval_scalar(expr, row, params) {
1350 SochValue::Bool(b) => Some(b),
1351 SochValue::Null => Some(false),
1352 _ => None,
1353 }
1354 }
1355 }
1356
1357 fn i(v: i64) -> SochValue {
1358 SochValue::Int(v)
1359 }
1360 fn f(v: f64) -> SochValue {
1361 SochValue::Float(v)
1362 }
1363 fn t(v: &str) -> SochValue {
1364 SochValue::Text(v.to_string())
1365 }
1366
1367 fn bench_bridge() -> SqlBridge<DataConn> {
1369 let conn = DataConn::new().with_table(
1370 "x",
1371 &["id1", "id3", "v1", "v2", "v3"],
1372 vec![
1373 vec![t("id001"), t("id0000001"), i(1), i(10), f(1.0)],
1374 vec![t("id001"), t("id0000002"), i(2), i(20), f(2.0)],
1375 vec![t("id002"), t("id0000001"), i(3), i(30), f(3.0)],
1376 vec![t("id002"), t("id0000002"), i(4), i(40), f(4.0)],
1377 ],
1378 );
1379 SqlBridge::new(conn)
1380 }
1381
1382 fn rows_of(result: ExecutionResult) -> Vec<HashMap<String, SochValue>> {
1383 match result {
1384 ExecutionResult::Rows { rows, .. } => rows,
1385 other => panic!("expected rows, got {:?}", other),
1386 }
1387 }
1388
1389 fn get<'a>(row: &'a HashMap<String, SochValue>, k: &str) -> &'a SochValue {
1390 row.get(k)
1391 .unwrap_or_else(|| panic!("column '{}' missing from {:?}", k, row))
1392 }
1393
1394 #[test]
1395 fn groupby_sum_q1_shape() {
1396 let mut b = bench_bridge();
1398 let rows = rows_of(
1399 b.execute("SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1 ORDER BY id1")
1400 .unwrap(),
1401 );
1402 assert_eq!(rows.len(), 2);
1403 assert_eq!(get(&rows[0], "id1"), &t("id001"));
1404 assert_eq!(get(&rows[0], "v1"), &i(3));
1405 assert_eq!(get(&rows[1], "id1"), &t("id002"));
1406 assert_eq!(get(&rows[1], "v1"), &i(7));
1407 }
1408
1409 #[test]
1410 fn groupby_multi_key_mean() {
1411 let mut b = bench_bridge();
1413 let rows = rows_of(
1414 b.execute(
1415 "SELECT id1, id3, avg(v1) AS m FROM x GROUP BY id1, id3 ORDER BY id1, id3",
1416 )
1417 .unwrap(),
1418 );
1419 assert_eq!(rows.len(), 4);
1420 assert_eq!(get(&rows[0], "m"), &f(1.0));
1421 assert_eq!(get(&rows[3], "m"), &f(4.0));
1422 }
1423
1424 #[test]
1425 fn median_and_stddev() {
1426 let mut b = bench_bridge();
1429 let rows = rows_of(
1430 b.execute("SELECT median(v3) AS med, stddev(v3) AS sd FROM x")
1431 .unwrap(),
1432 );
1433 assert_eq!(rows.len(), 1);
1434 assert_eq!(get(&rows[0], "med"), &f(2.5));
1435 match get(&rows[0], "sd") {
1436 SochValue::Float(sd) => {
1437 assert!((sd - (5.0f64 / 3.0).sqrt()).abs() < 1e-12, "sd={}", sd)
1438 }
1439 other => panic!("expected float sd, got {:?}", other),
1440 }
1441 }
1442
1443 #[test]
1444 fn median_odd_count() {
1445 let conn = DataConn::new().with_table(
1446 "t",
1447 &["v"],
1448 vec![vec![f(5.0)], vec![f(1.0)], vec![f(3.0)]],
1449 );
1450 let mut b = SqlBridge::new(conn);
1451 let rows = rows_of(b.execute("SELECT median(v) AS m FROM t").unwrap());
1452 assert_eq!(get(&rows[0], "m"), &f(3.0));
1453 }
1454
1455 #[test]
1456 fn range_expression_q9_shape() {
1457 let mut b = bench_bridge();
1459 let rows = rows_of(
1460 b.execute(
1461 "SELECT id3, max(v1) - min(v2) AS range_v1_v2 FROM x GROUP BY id3 ORDER BY id3",
1462 )
1463 .unwrap(),
1464 );
1465 assert_eq!(rows.len(), 2);
1466 assert_eq!(get(&rows[0], "range_v1_v2"), &i(-7));
1468 assert_eq!(get(&rows[1], "range_v1_v2"), &i(-16));
1470 }
1471
1472 #[test]
1473 fn count_star_vs_count_col_with_nulls() {
1474 let conn = DataConn::new().with_table(
1475 "t",
1476 &["g", "v"],
1477 vec![
1478 vec![t("a"), i(1)],
1479 vec![t("a"), SochValue::Null],
1480 vec![t("b"), i(2)],
1481 ],
1482 );
1483 let mut b = SqlBridge::new(conn);
1484 let rows = rows_of(
1485 b.execute(
1486 "SELECT g, count(*) AS n, count(v) AS nv FROM t GROUP BY g ORDER BY g",
1487 )
1488 .unwrap(),
1489 );
1490 assert_eq!(rows.len(), 2);
1491 assert_eq!(get(&rows[0], "n"), &i(2));
1492 assert_eq!(get(&rows[0], "nv"), &i(1));
1493 assert_eq!(get(&rows[1], "n"), &i(1));
1494 assert_eq!(get(&rows[1], "nv"), &i(1));
1495 }
1496
1497 #[test]
1498 fn count_distinct() {
1499 let mut b = bench_bridge();
1501 let rows = rows_of(
1502 b.execute("SELECT id3, count(DISTINCT id1) AS u FROM x GROUP BY id3 ORDER BY id3")
1503 .unwrap(),
1504 );
1505 assert_eq!(rows.len(), 2);
1506 assert_eq!(get(&rows[0], "u"), &i(2));
1507 assert_eq!(get(&rows[1], "u"), &i(2));
1508 }
1509
1510 #[test]
1511 fn having_filters_groups() {
1512 let mut b = bench_bridge();
1513 let rows = rows_of(
1514 b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 HAVING sum(v1) > 5")
1515 .unwrap(),
1516 );
1517 assert_eq!(rows.len(), 1);
1518 assert_eq!(get(&rows[0], "id1"), &t("id002"));
1519 assert_eq!(get(&rows[0], "s"), &i(7));
1520 }
1521
1522 #[test]
1523 fn order_by_aggregate_desc_with_limit() {
1524 let mut b = bench_bridge();
1525 let rows = rows_of(
1526 b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 ORDER BY s DESC LIMIT 1")
1527 .unwrap(),
1528 );
1529 assert_eq!(rows.len(), 1);
1530 assert_eq!(get(&rows[0], "id1"), &t("id002"));
1531 }
1532
1533 #[test]
1534 fn ungrouped_aggregate_over_empty_table() {
1535 let conn = DataConn::new().with_table("e", &["v"], vec![]);
1536 let mut b = SqlBridge::new(conn);
1537 let rows = rows_of(
1538 b.execute("SELECT count(*) AS n, sum(v) AS s FROM e").unwrap(),
1539 );
1540 assert_eq!(rows.len(), 1, "ungrouped agg over empty input = one row");
1541 assert_eq!(get(&rows[0], "n"), &i(0));
1542 assert_eq!(get(&rows[0], "s"), &SochValue::Null);
1543 }
1544
1545 #[test]
1546 fn grouped_aggregate_over_empty_table_yields_no_rows() {
1547 let conn = DataConn::new().with_table("e", &["g", "v"], vec![]);
1548 let mut b = SqlBridge::new(conn);
1549 let rows = rows_of(
1550 b.execute("SELECT g, sum(v) AS s FROM e GROUP BY g").unwrap(),
1551 );
1552 assert!(rows.is_empty());
1553 }
1554
1555 #[test]
1556 fn sum_overflow_promotes_to_float() {
1557 let conn = DataConn::new().with_table(
1558 "t",
1559 &["v"],
1560 vec![vec![i(i64::MAX)], vec![i(i64::MAX)]],
1561 );
1562 let mut b = SqlBridge::new(conn);
1563 let rows = rows_of(b.execute("SELECT sum(v) AS s FROM t").unwrap());
1564 match get(&rows[0], "s") {
1565 SochValue::Float(v) => assert!(*v > 1.8e19),
1566 other => panic!("expected float after overflow, got {:?}", other),
1567 }
1568 }
1569
1570 #[test]
1571 fn aggregate_after_join() {
1572 let conn = DataConn::new()
1574 .with_table(
1575 "a",
1576 &["id", "v"],
1577 vec![
1578 vec![t("k1"), i(1)],
1579 vec![t("k1"), i(2)],
1580 vec![t("k2"), i(3)],
1581 ],
1582 )
1583 .with_table(
1584 "b",
1585 &["id", "w"],
1586 vec![vec![t("k1"), i(10)], vec![t("k2"), i(20)]],
1587 );
1588 let mut br = SqlBridge::new(conn);
1589 let rows = rows_of(
1590 br.execute(
1591 "SELECT a.id, sum(a.v) AS sv, sum(b.w) AS sw \
1592 FROM a JOIN b ON a.id = b.id GROUP BY a.id ORDER BY a.id",
1593 )
1594 .unwrap(),
1595 );
1596 assert_eq!(rows.len(), 2);
1597 assert_eq!(get(&rows[0], "sv"), &i(3));
1598 assert_eq!(get(&rows[0], "sw"), &i(20)); assert_eq!(get(&rows[1], "sv"), &i(3));
1600 assert_eq!(get(&rows[1], "sw"), &i(20));
1601 }
1602
1603 #[test]
1604 fn lowercase_function_names_parse() {
1605 let mut b = bench_bridge();
1607 assert!(b.execute("SELECT id1, sum(v1) FROM x GROUP BY id1").is_ok());
1608 assert!(b.execute("SELECT median(v3) FROM x").is_ok());
1609 assert!(b.execute("SELECT stddev(v3) FROM x").is_ok());
1610 }
1611
1612 #[test]
1613 fn parallel_path_matches_reference_computation() {
1614 let n: usize = 150_000;
1618 let groups = 7usize;
1619 let mut data: Vec<Vec<SochValue>> = Vec::with_capacity(n);
1620 for idx in 0..n {
1621 data.push(vec![
1622 t(&format!("g{}", idx % groups)),
1623 f((idx * 31 % 1000) as f64 / 4.0),
1624 ]);
1625 }
1626 let mut per_group: Vec<Vec<f64>> = vec![Vec::new(); groups];
1628 for idx in 0..n {
1629 per_group[idx % groups].push((idx * 31 % 1000) as f64 / 4.0);
1630 }
1631
1632 let conn = DataConn::new().with_table("big", &["g", "v"], data);
1633 let mut b = SqlBridge::new(conn);
1634 let rows = rows_of(
1635 b.execute(
1636 "SELECT g, count(*) AS n, sum(v) AS s, avg(v) AS m, \
1637 median(v) AS med, stddev(v) AS sd FROM big GROUP BY g ORDER BY g",
1638 )
1639 .unwrap(),
1640 );
1641 assert_eq!(rows.len(), groups);
1642
1643 for (gi, row) in rows.iter().enumerate() {
1644 let vals = &per_group[gi];
1645 let cnt = vals.len() as f64;
1646 let sum: f64 = vals.iter().sum();
1647 let mean = sum / cnt;
1648 let var =
1649 vals.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / (cnt - 1.0);
1650 let mut sorted = vals.clone();
1651 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1652 let med = if sorted.len() % 2 == 1 {
1653 sorted[sorted.len() / 2]
1654 } else {
1655 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
1656 };
1657
1658 assert_eq!(get(row, "g"), &t(&format!("g{}", gi)));
1659 assert_eq!(get(row, "n"), &i(vals.len() as i64));
1660 match get(row, "s") {
1661 SochValue::Float(v) => assert!((v - sum).abs() < 1e-6, "sum"),
1662 other => panic!("sum type {:?}", other),
1663 }
1664 match get(row, "m") {
1665 SochValue::Float(v) => assert!((v - mean).abs() < 1e-9, "mean"),
1666 other => panic!("mean type {:?}", other),
1667 }
1668 match get(row, "med") {
1669 SochValue::Float(v) => assert!((v - med).abs() < 1e-9, "median"),
1670 other => panic!("median type {:?}", other),
1671 }
1672 match get(row, "sd") {
1673 SochValue::Float(v) => {
1674 assert!((v - var.sqrt()).abs() < 1e-9, "sd {} vs {}", v, var.sqrt())
1675 }
1676 other => panic!("sd type {:?}", other),
1677 }
1678 }
1679 }
1680
1681 #[test]
1682 fn unaliased_aggregate_column_name_is_canonical() {
1683 let mut b = bench_bridge();
1684 let result = b.execute("SELECT id1, sum(v1) FROM x GROUP BY id1").unwrap();
1685 let cols = result.columns().unwrap().clone();
1686 assert!(cols.contains(&"sum(v1)".to_string()), "cols={:?}", cols);
1687 }
1688}