Skip to main content

reddb_server/storage/query/executors/
aggregation.rs

1//! Aggregation Framework
2//!
3//! Provides aggregation functions for query results.
4//!
5//! # Supported Functions
6//!
7//! - **COUNT**: Count rows (including COUNT(*) and COUNT DISTINCT)
8//! - **SUM**: Sum numeric values
9//! - **AVG**: Average of numeric values
10//! - **MIN**: Minimum value
11//! - **MAX**: Maximum value
12//! - **STDDEV**: Standard deviation
13//! - **VARIANCE**: Statistical variance
14//! - **PERCENTILE**: Nth percentile value
15//!
16//! # GROUP BY
17//!
18//! Aggregations can be grouped by one or more columns.
19//! HAVING clause filters groups after aggregation.
20
21use std::collections::HashMap;
22
23use super::super::engine::binding::{Binding, Value, Var};
24use super::value_compare::total_compare_values;
25
26// ============================================================================
27// Aggregator Trait
28// ============================================================================
29
30/// Trait for aggregation functions
31pub trait Aggregator: Send + Sync {
32    /// Process a single value
33    fn accumulate(&mut self, value: Option<&Value>);
34
35    /// Get the final aggregated result
36    fn finalize(&self) -> Value;
37
38    /// Reset for new group
39    fn reset(&mut self);
40
41    /// Create a fresh copy for a new group
42    fn new_instance(&self) -> Box<dyn Aggregator>;
43
44    /// Name of the aggregator
45    fn name(&self) -> &'static str;
46}
47
48// ============================================================================
49// COUNT Aggregator
50// ============================================================================
51
52/// COUNT aggregator
53#[derive(Debug, Clone, Default)]
54pub struct CountAggregator {
55    count: i64,
56    count_all: bool, // COUNT(*) vs COUNT(column)
57}
58
59impl CountAggregator {
60    /// Create COUNT(*) aggregator
61    pub fn count_all() -> Self {
62        Self {
63            count: 0,
64            count_all: true,
65        }
66    }
67
68    /// Create COUNT(column) aggregator (ignores nulls)
69    pub fn count_column() -> Self {
70        Self {
71            count: 0,
72            count_all: false,
73        }
74    }
75}
76
77impl Aggregator for CountAggregator {
78    fn accumulate(&mut self, value: Option<&Value>) {
79        if self.count_all || (value.is_some() && !matches!(value, Some(Value::Null))) {
80            self.count += 1;
81        }
82    }
83
84    fn finalize(&self) -> Value {
85        Value::Integer(self.count)
86    }
87
88    fn reset(&mut self) {
89        self.count = 0;
90    }
91
92    fn new_instance(&self) -> Box<dyn Aggregator> {
93        Box::new(Self {
94            count: 0,
95            count_all: self.count_all,
96        })
97    }
98
99    fn name(&self) -> &'static str {
100        "COUNT"
101    }
102}
103
104// ============================================================================
105// COUNT DISTINCT Aggregator
106// ============================================================================
107
108/// COUNT DISTINCT aggregator
109#[derive(Debug, Clone, Default)]
110pub struct CountDistinctAggregator {
111    seen: std::collections::HashSet<String>,
112}
113
114impl CountDistinctAggregator {
115    pub fn new() -> Self {
116        Self {
117            seen: std::collections::HashSet::new(),
118        }
119    }
120}
121
122impl Aggregator for CountDistinctAggregator {
123    fn accumulate(&mut self, value: Option<&Value>) {
124        if let Some(v) = value {
125            if !matches!(v, Value::Null) {
126                self.seen.insert(value_to_string(v));
127            }
128        }
129    }
130
131    fn finalize(&self) -> Value {
132        Value::Integer(self.seen.len() as i64)
133    }
134
135    fn reset(&mut self) {
136        self.seen.clear();
137    }
138
139    fn new_instance(&self) -> Box<dyn Aggregator> {
140        Box::new(Self::new())
141    }
142
143    fn name(&self) -> &'static str {
144        "COUNT_DISTINCT"
145    }
146}
147
148// ============================================================================
149// SUM Aggregator
150// ============================================================================
151
152/// SUM aggregator
153#[derive(Debug, Clone, Default)]
154pub struct SumAggregator {
155    sum: f64,
156    has_value: bool,
157    all_integers: bool,
158}
159
160impl SumAggregator {
161    pub fn new() -> Self {
162        Self {
163            sum: 0.0,
164            has_value: false,
165            all_integers: true,
166        }
167    }
168}
169
170impl Aggregator for SumAggregator {
171    fn accumulate(&mut self, value: Option<&Value>) {
172        if let Some(v) = value {
173            match v {
174                Value::Integer(i) => {
175                    self.sum += *i as f64;
176                    self.has_value = true;
177                }
178                Value::Float(f) => {
179                    self.sum += *f;
180                    self.has_value = true;
181                    if f.fract() != 0.0 {
182                        self.all_integers = false;
183                    }
184                }
185                _ => {}
186            }
187        }
188    }
189
190    fn finalize(&self) -> Value {
191        if self.has_value {
192            if self.all_integers && self.sum.fract() == 0.0 {
193                Value::Integer(self.sum as i64)
194            } else {
195                Value::Float(self.sum)
196            }
197        } else {
198            Value::Null
199        }
200    }
201
202    fn reset(&mut self) {
203        self.sum = 0.0;
204        self.has_value = false;
205        self.all_integers = true;
206    }
207
208    fn new_instance(&self) -> Box<dyn Aggregator> {
209        Box::new(Self::new())
210    }
211
212    fn name(&self) -> &'static str {
213        "SUM"
214    }
215}
216
217// ============================================================================
218// AVG Aggregator
219// ============================================================================
220
221/// AVG aggregator
222#[derive(Debug, Clone, Default)]
223pub struct AvgAggregator {
224    sum: f64,
225    count: i64,
226}
227
228impl AvgAggregator {
229    pub fn new() -> Self {
230        Self { sum: 0.0, count: 0 }
231    }
232}
233
234impl Aggregator for AvgAggregator {
235    fn accumulate(&mut self, value: Option<&Value>) {
236        if let Some(v) = value {
237            if let Some(n) = value_to_number(v) {
238                self.sum += n;
239                self.count += 1;
240            }
241        }
242    }
243
244    fn finalize(&self) -> Value {
245        if self.count > 0 {
246            Value::Float(self.sum / self.count as f64)
247        } else {
248            Value::Null
249        }
250    }
251
252    fn reset(&mut self) {
253        self.sum = 0.0;
254        self.count = 0;
255    }
256
257    fn new_instance(&self) -> Box<dyn Aggregator> {
258        Box::new(Self::new())
259    }
260
261    fn name(&self) -> &'static str {
262        "AVG"
263    }
264}
265
266// ============================================================================
267// MIN Aggregator
268// ============================================================================
269
270/// MIN aggregator
271#[derive(Debug, Clone, Default)]
272pub struct MinAggregator {
273    min: Option<Value>,
274}
275
276impl MinAggregator {
277    pub fn new() -> Self {
278        Self { min: None }
279    }
280}
281
282impl Aggregator for MinAggregator {
283    fn accumulate(&mut self, value: Option<&Value>) {
284        if let Some(v) = value {
285            if matches!(v, Value::Null) {
286                return;
287            }
288            match &self.min {
289                None => self.min = Some(v.clone()),
290                Some(current) => {
291                    if total_compare_values(v, current) == std::cmp::Ordering::Less {
292                        self.min = Some(v.clone());
293                    }
294                }
295            }
296        }
297    }
298
299    fn finalize(&self) -> Value {
300        self.min.clone().unwrap_or(Value::Null)
301    }
302
303    fn reset(&mut self) {
304        self.min = None;
305    }
306
307    fn new_instance(&self) -> Box<dyn Aggregator> {
308        Box::new(Self::new())
309    }
310
311    fn name(&self) -> &'static str {
312        "MIN"
313    }
314}
315
316// ============================================================================
317// MAX Aggregator
318// ============================================================================
319
320/// MAX aggregator
321#[derive(Debug, Clone, Default)]
322pub struct MaxAggregator {
323    max: Option<Value>,
324}
325
326impl MaxAggregator {
327    pub fn new() -> Self {
328        Self { max: None }
329    }
330}
331
332impl Aggregator for MaxAggregator {
333    fn accumulate(&mut self, value: Option<&Value>) {
334        if let Some(v) = value {
335            if matches!(v, Value::Null) {
336                return;
337            }
338            match &self.max {
339                None => self.max = Some(v.clone()),
340                Some(current) => {
341                    if total_compare_values(v, current) == std::cmp::Ordering::Greater {
342                        self.max = Some(v.clone());
343                    }
344                }
345            }
346        }
347    }
348
349    fn finalize(&self) -> Value {
350        self.max.clone().unwrap_or(Value::Null)
351    }
352
353    fn reset(&mut self) {
354        self.max = None;
355    }
356
357    fn new_instance(&self) -> Box<dyn Aggregator> {
358        Box::new(Self::new())
359    }
360
361    fn name(&self) -> &'static str {
362        "MAX"
363    }
364}
365
366// ============================================================================
367// SAMPLE Aggregator
368// ============================================================================
369
370/// SAMPLE aggregator (returns first non-null value)
371#[derive(Debug, Clone, Default)]
372pub struct SampleAggregator {
373    value: Option<Value>,
374}
375
376impl SampleAggregator {
377    pub fn new() -> Self {
378        Self { value: None }
379    }
380}
381
382impl Aggregator for SampleAggregator {
383    fn accumulate(&mut self, value: Option<&Value>) {
384        if self.value.is_some() {
385            return;
386        }
387        if let Some(v) = value {
388            if !matches!(v, Value::Null) {
389                self.value = Some(v.clone());
390            }
391        }
392    }
393
394    fn finalize(&self) -> Value {
395        self.value.clone().unwrap_or(Value::Null)
396    }
397
398    fn reset(&mut self) {
399        self.value = None;
400    }
401
402    fn new_instance(&self) -> Box<dyn Aggregator> {
403        Box::new(Self::new())
404    }
405
406    fn name(&self) -> &'static str {
407        "SAMPLE"
408    }
409}
410
411// ============================================================================
412// GROUP_CONCAT Aggregator
413// ============================================================================
414
415/// GROUP_CONCAT aggregator
416#[derive(Debug, Clone)]
417pub struct GroupConcatAggregator {
418    separator: String,
419    values: Vec<String>,
420}
421
422impl GroupConcatAggregator {
423    pub fn new(separator: Option<String>) -> Self {
424        Self {
425            separator: separator.unwrap_or_else(|| " ".to_string()),
426            values: Vec::new(),
427        }
428    }
429}
430
431impl Aggregator for GroupConcatAggregator {
432    fn accumulate(&mut self, value: Option<&Value>) {
433        if let Some(v) = value {
434            if !matches!(v, Value::Null) {
435                self.values.push(value_to_string(v));
436            }
437        }
438    }
439
440    fn finalize(&self) -> Value {
441        if self.values.is_empty() {
442            Value::Null
443        } else {
444            Value::String(self.values.join(&self.separator))
445        }
446    }
447
448    fn reset(&mut self) {
449        self.values.clear();
450    }
451
452    fn new_instance(&self) -> Box<dyn Aggregator> {
453        Box::new(Self::new(Some(self.separator.clone())))
454    }
455
456    fn name(&self) -> &'static str {
457        "GROUP_CONCAT"
458    }
459}
460
461// ============================================================================
462// STDDEV Aggregator (Population Standard Deviation)
463// ============================================================================
464
465/// Standard deviation aggregator
466#[derive(Debug, Clone, Default)]
467pub struct StdDevAggregator {
468    values: Vec<f64>,
469}
470
471impl StdDevAggregator {
472    pub fn new() -> Self {
473        Self { values: Vec::new() }
474    }
475}
476
477impl Aggregator for StdDevAggregator {
478    fn accumulate(&mut self, value: Option<&Value>) {
479        if let Some(v) = value {
480            if let Some(n) = value_to_number(v) {
481                self.values.push(n);
482            }
483        }
484    }
485
486    fn finalize(&self) -> Value {
487        if self.values.is_empty() {
488            return Value::Null;
489        }
490
491        let n = self.values.len() as f64;
492        let mean = self.values.iter().sum::<f64>() / n;
493        let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
494
495        Value::Float(variance.sqrt())
496    }
497
498    fn reset(&mut self) {
499        self.values.clear();
500    }
501
502    fn new_instance(&self) -> Box<dyn Aggregator> {
503        Box::new(Self::new())
504    }
505
506    fn name(&self) -> &'static str {
507        "STDDEV"
508    }
509}
510
511// ============================================================================
512// VARIANCE Aggregator
513// ============================================================================
514
515/// Variance aggregator
516#[derive(Debug, Clone, Default)]
517pub struct VarianceAggregator {
518    values: Vec<f64>,
519}
520
521impl VarianceAggregator {
522    pub fn new() -> Self {
523        Self { values: Vec::new() }
524    }
525}
526
527impl Aggregator for VarianceAggregator {
528    fn accumulate(&mut self, value: Option<&Value>) {
529        if let Some(v) = value {
530            if let Some(n) = value_to_number(v) {
531                self.values.push(n);
532            }
533        }
534    }
535
536    fn finalize(&self) -> Value {
537        if self.values.is_empty() {
538            return Value::Null;
539        }
540
541        let n = self.values.len() as f64;
542        let mean = self.values.iter().sum::<f64>() / n;
543        let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
544
545        Value::Float(variance)
546    }
547
548    fn reset(&mut self) {
549        self.values.clear();
550    }
551
552    fn new_instance(&self) -> Box<dyn Aggregator> {
553        Box::new(Self::new())
554    }
555
556    fn name(&self) -> &'static str {
557        "VARIANCE"
558    }
559}
560
561// ============================================================================
562// PERCENTILE Aggregator
563// ============================================================================
564
565/// Percentile aggregator
566#[derive(Debug, Clone)]
567pub struct PercentileAggregator {
568    values: Vec<f64>,
569    percentile: f64, // 0.0 to 1.0 (e.g., 0.5 for median)
570}
571
572impl PercentileAggregator {
573    pub fn new(percentile: f64) -> Self {
574        Self {
575            values: Vec::new(),
576            percentile: percentile.clamp(0.0, 1.0),
577        }
578    }
579
580    /// Create median aggregator (50th percentile)
581    pub fn median() -> Self {
582        Self::new(0.5)
583    }
584}
585
586impl Aggregator for PercentileAggregator {
587    fn accumulate(&mut self, value: Option<&Value>) {
588        if let Some(v) = value {
589            if let Some(n) = value_to_number(v) {
590                self.values.push(n);
591            }
592        }
593    }
594
595    fn finalize(&self) -> Value {
596        if self.values.is_empty() {
597            return Value::Null;
598        }
599
600        let mut sorted = self.values.clone();
601        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
602
603        let index = (self.percentile * (sorted.len() - 1) as f64).round() as usize;
604        Value::Float(sorted[index])
605    }
606
607    fn reset(&mut self) {
608        self.values.clear();
609    }
610
611    fn new_instance(&self) -> Box<dyn Aggregator> {
612        Box::new(Self::new(self.percentile))
613    }
614
615    fn name(&self) -> &'static str {
616        "PERCENTILE"
617    }
618}
619
620// ============================================================================
621// GROUP BY Executor
622// ============================================================================
623
624/// Definition of an aggregation to compute
625pub struct AggregationDef {
626    /// Source variable to aggregate
627    pub source_var: Var,
628    /// Result variable name
629    pub result_var: Var,
630    /// Aggregator factory
631    pub aggregator: Box<dyn Aggregator>,
632}
633
634impl std::fmt::Debug for AggregationDef {
635    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636        f.debug_struct("AggregationDef")
637            .field("source_var", &self.source_var)
638            .field("result_var", &self.result_var)
639            .field("aggregator", &self.aggregator.name())
640            .finish()
641    }
642}
643
644/// Soft memory cap for in-process hash aggregation.
645///
646/// When the groups HashMap grows beyond this threshold, an OOM-guard
647/// warning fires. Full spill-to-disk requires changing the calling
648/// convention to a row-at-a-time streaming API (tracked separately).
649const WORK_MEM_BYTES: usize = 64 * 1024 * 1024; // 64 MiB
650
651/// Estimated heap cost per group entry in the streaming HashMap.
652///
653/// In practice each entry holds:
654///   - a String group-key (~32 B avg)
655///   - group key Var/Value pairs (~64 B)
656///   - one Box<dyn Aggregator> per agg_def (~64 B each, assume ≤4 defs → ~256 B)
657///
658/// 512 B is deliberately conservative to avoid premature eviction
659/// in the common case.
660const AVG_GROUP_ENTRY_BYTES: usize = 512;
661
662/// 1-pass streaming GROUP BY.
663///
664/// Previous implementation accumulated ALL input bindings per group
665/// (`HashMap<String, Vec<Binding>>`), then ran aggregations in a
666/// second pass. Memory cost: O(input_rows) in the groups map.
667///
668/// This version keeps only the incremental aggregation state per group
669/// — one `Box<dyn Aggregator>` per `AggregationDef`. Memory cost drops
670/// to O(distinct_groups × agg_defs), which is dramatically lower for
671/// high-cardinality inputs with few distinct groups.
672pub fn execute_group_by(
673    bindings: Vec<Binding>,
674    group_vars: &[Var],
675    aggregations: &[AggregationDef],
676) -> Vec<Binding> {
677    // Each entry: (snapshot of group-key values from first binding,
678    //              incremental aggregator state for each agg_def)
679    let mut groups: HashMap<String, (Binding, Vec<Box<dyn Aggregator>>)> = HashMap::new();
680
681    for binding in &bindings {
682        let key = make_group_key(binding, group_vars);
683        let entry = groups.entry(key).or_insert_with(|| {
684            // Capture group key values once from the first binding in this group.
685            let mut key_binding = Binding::empty();
686            for var in group_vars {
687                if let Some(value) = binding.get(var) {
688                    let partial = Binding::one(var.clone(), value.clone());
689                    key_binding = key_binding.merge(&partial).unwrap_or(key_binding);
690                }
691            }
692            // Allocate one fresh aggregator instance per agg def.
693            let agg_instances = aggregations
694                .iter()
695                .map(|a| a.aggregator.new_instance())
696                .collect();
697            (key_binding, agg_instances)
698        });
699
700        // Accumulate each aggregation in a single pass over the binding.
701        for (i, agg_def) in aggregations.iter().enumerate() {
702            entry.1[i].accumulate(binding.get(&agg_def.source_var));
703        }
704
705        // Memory guard: O(1) check, avoids estimating actual heap usage.
706        // When the number of distinct groups × avg cost exceeds WORK_MEM,
707        // we've likely exhausted the intended budget. For now we continue
708        // (the data is already in memory via the input Vec<Binding>) but
709        // emit a debug trace so operators can see when this fires.
710        #[cfg(debug_assertions)]
711        if groups.len() * AVG_GROUP_ENTRY_BYTES > WORK_MEM_BYTES {
712            // Only log once — on entry count crossing the threshold,
713            // not on every subsequent row.
714            static WARNED: std::sync::atomic::AtomicBool =
715                std::sync::atomic::AtomicBool::new(false);
716            if !WARNED.swap(true, std::sync::atomic::Ordering::Relaxed) {
717                eprintln!(
718                    "[reddb] hash-agg: {} distinct groups × {} B ≈ {} MiB exceeds WORK_MEM {}  MiB; \
719                     disk spill not yet wired — upgrade calling convention to streaming for OOM safety",
720                    groups.len(),
721                    AVG_GROUP_ENTRY_BYTES,
722                    (groups.len() * AVG_GROUP_ENTRY_BYTES) / (1024 * 1024),
723                    WORK_MEM_BYTES / (1024 * 1024),
724                );
725            }
726        }
727    }
728
729    // Finalize: emit one output Binding per distinct group.
730    let mut results = Vec::with_capacity(groups.len());
731    for (_, (key_binding, agg_instances)) in groups {
732        let mut result = key_binding;
733        for (i, agg_def) in aggregations.iter().enumerate() {
734            let agg_result = agg_instances[i].finalize();
735            let partial = Binding::one(agg_def.result_var.clone(), agg_result);
736            result = result.merge(&partial).unwrap_or(result);
737        }
738        results.push(result);
739    }
740    results
741}
742
743/// Execute HAVING clause (filter on aggregated results)
744pub fn execute_having<F>(bindings: Vec<Binding>, predicate: F) -> Vec<Binding>
745where
746    F: Fn(&Binding) -> bool,
747{
748    bindings.into_iter().filter(|b| predicate(b)).collect()
749}
750
751// ============================================================================
752// Helper Functions
753// ============================================================================
754
755/// Format a single `Value` as a String. **Cold path** — used by
756/// non-hot consumers like GROUP_CONCAT result formatting. The
757/// hot group-by key path inlines this logic into a shared buffer
758/// in [`make_group_key`] to avoid per-row allocations.
759fn value_to_string(value: &Value) -> String {
760    match value {
761        Value::Node(id) => format!("node:{}", id),
762        Value::Edge(id) => format!("edge:{}", id),
763        Value::String(s) => s.clone(),
764        Value::Integer(i) => i.to_string(),
765        Value::Float(f) => f.to_string(),
766        Value::Boolean(b) => b.to_string(),
767        Value::Uri(u) => u.clone(),
768        Value::Null => "null".to_string(),
769    }
770}
771
772/// Build a group-by key for one row.
773///
774/// **Hot path** — called once per row in `execute_group_by`. The
775/// previous implementation paid `N+2` String allocations per row
776/// (one per `value_to_string`, one per `format!`, one for the final
777/// `join("|")`). This version writes everything into a single
778/// `String` buffer with one allocation.
779///
780/// On a 3-column GROUP BY the difference is ~5 allocations vs 1,
781/// which on a 1M-row aggregation saves ~4M small allocations.
782fn make_group_key(binding: &Binding, group_vars: &[Var]) -> String {
783    use std::fmt::Write;
784    // Tunable initial capacity. 64 bytes covers most numeric / short
785    // text group keys in one allocation; longer text grows in place
786    // through String's exponential growth.
787    let mut key = String::with_capacity(64);
788    for (i, var) in group_vars.iter().enumerate() {
789        if i > 0 {
790            key.push('|');
791        }
792        key.push_str(var.name());
793        key.push('=');
794        match binding.get(var) {
795            None => key.push_str("NULL"),
796            Some(Value::Null) => key.push_str("null"),
797            Some(Value::String(s)) => key.push_str(s),
798            Some(Value::Integer(n)) => {
799                let _ = write!(key, "{n}");
800            }
801            Some(Value::Float(f)) => {
802                let _ = write!(key, "{f}");
803            }
804            Some(Value::Boolean(b)) => {
805                let _ = write!(key, "{b}");
806            }
807            Some(Value::Node(id)) => {
808                key.push_str("node:");
809                key.push_str(id);
810            }
811            Some(Value::Edge(id)) => {
812                key.push_str("edge:");
813                key.push_str(id);
814            }
815            Some(Value::Uri(u)) => key.push_str(u),
816        }
817    }
818    key
819}
820
821fn value_to_number(value: &Value) -> Option<f64> {
822    match value {
823        Value::Integer(i) => Some(*i as f64),
824        Value::Float(f) => Some(*f),
825        Value::String(s) => s.parse().ok(),
826        _ => None,
827    }
828}
829
830// ============================================================================
831// Aggregator Factory
832// ============================================================================
833
834/// Create an aggregator by name
835pub fn create_aggregator(name: &str) -> Option<Box<dyn Aggregator>> {
836    match name.to_uppercase().as_str() {
837        "COUNT" => Some(Box::new(CountAggregator::count_all())),
838        "COUNT_COLUMN" => Some(Box::new(CountAggregator::count_column())),
839        "COUNT_DISTINCT" => Some(Box::new(CountDistinctAggregator::new())),
840        "SUM" => Some(Box::new(SumAggregator::new())),
841        "AVG" => Some(Box::new(AvgAggregator::new())),
842        "MIN" => Some(Box::new(MinAggregator::new())),
843        "MAX" => Some(Box::new(MaxAggregator::new())),
844        "STDDEV" => Some(Box::new(StdDevAggregator::new())),
845        "VARIANCE" => Some(Box::new(VarianceAggregator::new())),
846        "MEDIAN" => Some(Box::new(PercentileAggregator::median())),
847        "SAMPLE" => Some(Box::new(SampleAggregator::new())),
848        "GROUP_CONCAT" => Some(Box::new(GroupConcatAggregator::new(None))),
849        _ => None,
850    }
851}
852
853// ============================================================================
854// Tests
855// ============================================================================
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860
861    fn make_binding(pairs: &[(&str, Value)]) -> Binding {
862        if pairs.is_empty() {
863            return Binding::empty();
864        }
865
866        let mut result = Binding::one(Var::new(pairs[0].0), pairs[0].1.clone());
867
868        for (k, v) in pairs.iter().skip(1) {
869            let next = Binding::one(Var::new(k), v.clone());
870            result = result.merge(&next).unwrap_or(result);
871        }
872
873        result
874    }
875
876    #[test]
877    fn test_count() {
878        let mut counter = CountAggregator::count_all();
879        counter.accumulate(Some(&Value::Integer(1)));
880        counter.accumulate(Some(&Value::Integer(2)));
881        counter.accumulate(None);
882        counter.accumulate(Some(&Value::Null));
883
884        assert_eq!(counter.finalize(), Value::Integer(4));
885    }
886
887    #[test]
888    fn test_count_column() {
889        let mut counter = CountAggregator::count_column();
890        counter.accumulate(Some(&Value::Integer(1)));
891        counter.accumulate(None);
892        counter.accumulate(Some(&Value::Null));
893        counter.accumulate(Some(&Value::Integer(2)));
894
895        assert_eq!(counter.finalize(), Value::Integer(2)); // Only non-null values
896    }
897
898    #[test]
899    fn test_sum() {
900        let mut sum = SumAggregator::new();
901        sum.accumulate(Some(&Value::Integer(10)));
902        sum.accumulate(Some(&Value::Float(5.5)));
903        sum.accumulate(Some(&Value::Integer(4)));
904
905        assert_eq!(sum.finalize(), Value::Float(19.5));
906    }
907
908    #[test]
909    fn test_avg() {
910        let mut avg = AvgAggregator::new();
911        avg.accumulate(Some(&Value::Integer(10)));
912        avg.accumulate(Some(&Value::Integer(20)));
913        avg.accumulate(Some(&Value::Integer(30)));
914
915        assert_eq!(avg.finalize(), Value::Float(20.0));
916    }
917
918    #[test]
919    fn test_min_max() {
920        let mut min = MinAggregator::new();
921        let mut max = MaxAggregator::new();
922
923        for val in [5, 2, 8, 1, 9] {
924            min.accumulate(Some(&Value::Integer(val)));
925            max.accumulate(Some(&Value::Integer(val)));
926        }
927
928        assert_eq!(min.finalize(), Value::Integer(1));
929        assert_eq!(max.finalize(), Value::Integer(9));
930    }
931
932    #[test]
933    fn test_count_distinct() {
934        let mut distinct = CountDistinctAggregator::new();
935        distinct.accumulate(Some(&Value::String("a".to_string())));
936        distinct.accumulate(Some(&Value::String("b".to_string())));
937        distinct.accumulate(Some(&Value::String("a".to_string())));
938        distinct.accumulate(Some(&Value::String("c".to_string())));
939
940        assert_eq!(distinct.finalize(), Value::Integer(3));
941    }
942
943    #[test]
944    fn test_group_by() {
945        let bindings = vec![
946            make_binding(&[
947                ("dept", Value::String("Sales".to_string())),
948                ("salary", Value::Integer(50000)),
949            ]),
950            make_binding(&[
951                ("dept", Value::String("Sales".to_string())),
952                ("salary", Value::Integer(60000)),
953            ]),
954            make_binding(&[
955                ("dept", Value::String("Engineering".to_string())),
956                ("salary", Value::Integer(80000)),
957            ]),
958            make_binding(&[
959                ("dept", Value::String("Engineering".to_string())),
960                ("salary", Value::Integer(90000)),
961            ]),
962        ];
963
964        let aggs = vec![
965            AggregationDef {
966                source_var: Var::new("salary"),
967                result_var: Var::new("total"),
968                aggregator: Box::new(SumAggregator::new()),
969            },
970            AggregationDef {
971                source_var: Var::new("salary"),
972                result_var: Var::new("count"),
973                aggregator: Box::new(CountAggregator::count_all()),
974            },
975        ];
976
977        let results = execute_group_by(bindings, &[Var::new("dept")], &aggs);
978
979        assert_eq!(results.len(), 2);
980
981        // Find Sales result
982        let sales = results
983            .iter()
984            .find(|b| b.get(&Var::new("dept")) == Some(&Value::String("Sales".to_string())))
985            .expect("Sales group not found");
986
987        assert_eq!(sales.get(&Var::new("total")), Some(&Value::Integer(110000)));
988        assert_eq!(sales.get(&Var::new("count")), Some(&Value::Integer(2)));
989    }
990
991    #[test]
992    fn test_percentile() {
993        let mut p50 = PercentileAggregator::median();
994        for v in [1, 2, 3, 4, 5, 6, 7, 8, 9] {
995            p50.accumulate(Some(&Value::Integer(v)));
996        }
997        // Median of 1-9 is 5
998        assert_eq!(p50.finalize(), Value::Float(5.0));
999    }
1000}