Skip to main content

reddb_server/storage/query/batch/operators/
aggregate.rs

1//! `BatchAggregate` — group by keys + reduce each numeric column.
2//!
3//! The reducer supports COUNT / SUM / AVG / MIN / MAX over Int64 and
4//! Float64 columns. Group keys may be Int64, Float64 bit-patterns,
5//! Bool, or Text. Output is a `Vec<AggregateRow>` — operator-level
6//! primitive; the SQL dispatch layer converts it to a result batch
7//! as part of the B5 (projections) sprint.
8
9use indexmap::IndexMap;
10
11use super::super::column_batch::{ColumnBatch, ColumnVector, ValueRef};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AggregateOp {
15    Count,
16    Sum,
17    Avg,
18    Min,
19    Max,
20}
21
22#[derive(Debug, Clone)]
23pub struct AggregateSpec {
24    /// Column index to aggregate. For `Count`, ignored (counts rows).
25    pub column: usize,
26    pub op: AggregateOp,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub enum GroupKeyPart {
31    Int64(i64),
32    Float64Bits(u64),
33    Bool(bool),
34    Text(String),
35    Null,
36}
37
38impl Eq for GroupKeyPart {}
39
40impl std::hash::Hash for GroupKeyPart {
41    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42        match self {
43            GroupKeyPart::Int64(v) => {
44                0u8.hash(state);
45                v.hash(state);
46            }
47            GroupKeyPart::Float64Bits(v) => {
48                1u8.hash(state);
49                v.hash(state);
50            }
51            GroupKeyPart::Bool(v) => {
52                2u8.hash(state);
53                v.hash(state);
54            }
55            GroupKeyPart::Text(v) => {
56                3u8.hash(state);
57                v.hash(state);
58            }
59            GroupKeyPart::Null => {
60                4u8.hash(state);
61            }
62        }
63    }
64}
65
66type GroupKey = Vec<GroupKeyPart>;
67
68#[derive(Debug, Clone)]
69pub struct AggregateResult {
70    pub op: AggregateOp,
71    pub column: usize,
72    pub value: f64,
73    /// For averages we also expose the intermediate count so callers
74    /// can merge partial aggregations across batches.
75    pub count: u64,
76}
77
78#[derive(Debug, Clone)]
79pub struct AggregateRow {
80    pub key: GroupKey,
81    pub results: Vec<AggregateResult>,
82}
83
84/// Group `batch` by `group_columns` and produce one row per key with
85/// each `AggregateSpec` applied. `group_columns` may be empty, in
86/// which case the whole batch reduces to a single row.
87pub fn batch_aggregate(
88    batch: &ColumnBatch,
89    group_columns: &[usize],
90    specs: &[AggregateSpec],
91) -> Vec<AggregateRow> {
92    if batch.is_empty() {
93        return Vec::new();
94    }
95    // IndexMap (insertion-order iteration) instead of HashMap (#630): the
96    // output is sorted by key below, but using HashMap also makes the
97    // *intermediate* group-discovery order randomized per-process, which
98    // is the shape the original bug report attributed to non-deterministic
99    // column ordering. IndexMap pins that intermediate order too.
100    let mut groups: IndexMap<GroupKey, Vec<Accumulator>> = IndexMap::new();
101    for row in 0..batch.len() {
102        let key: GroupKey = group_columns
103            .iter()
104            .map(|c| group_key_part(batch, row, *c))
105            .collect();
106        let accs = groups
107            .entry(key)
108            .or_insert_with(|| specs.iter().map(Accumulator::new).collect());
109        for (idx, spec) in specs.iter().enumerate() {
110            accs[idx].observe(batch, row, spec);
111        }
112    }
113    let mut out: Vec<AggregateRow> = groups
114        .into_iter()
115        .map(|(key, accs)| {
116            let results = accs
117                .into_iter()
118                .zip(specs.iter())
119                .map(|(acc, spec)| acc.finalize(spec))
120                .collect();
121            AggregateRow { key, results }
122        })
123        .collect();
124    // Deterministic output ordering simplifies test assertions.
125    out.sort_by(|a, b| compare_keys(&a.key, &b.key));
126    out
127}
128
129fn group_key_part(batch: &ColumnBatch, row: usize, column: usize) -> GroupKeyPart {
130    match batch.value(row, column) {
131        ValueRef::Int64(v) => GroupKeyPart::Int64(v),
132        ValueRef::Float64(v) => GroupKeyPart::Float64Bits(v.to_bits()),
133        ValueRef::Bool(v) => GroupKeyPart::Bool(v),
134        ValueRef::Text(v) => GroupKeyPart::Text(v.to_string()),
135        ValueRef::Null => GroupKeyPart::Null,
136    }
137}
138
139#[derive(Debug, Clone)]
140struct Accumulator {
141    count: u64,
142    sum: f64,
143    min: f64,
144    max: f64,
145    any_observed: bool,
146}
147
148impl Accumulator {
149    fn new(_spec: &AggregateSpec) -> Self {
150        Self {
151            count: 0,
152            sum: 0.0,
153            min: f64::INFINITY,
154            max: f64::NEG_INFINITY,
155            any_observed: false,
156        }
157    }
158
159    fn observe(&mut self, batch: &ColumnBatch, row: usize, spec: &AggregateSpec) {
160        match spec.op {
161            AggregateOp::Count => {
162                self.count += 1;
163            }
164            AggregateOp::Sum | AggregateOp::Avg | AggregateOp::Min | AggregateOp::Max => {
165                if let Some(v) = numeric_value(batch, row, spec.column) {
166                    self.count += 1;
167                    self.sum += v;
168                    if v < self.min {
169                        self.min = v;
170                    }
171                    if v > self.max {
172                        self.max = v;
173                    }
174                    self.any_observed = true;
175                }
176            }
177        }
178    }
179
180    fn finalize(self, spec: &AggregateSpec) -> AggregateResult {
181        let value = match spec.op {
182            AggregateOp::Count => self.count as f64,
183            AggregateOp::Sum => self.sum,
184            AggregateOp::Avg => {
185                if self.count == 0 {
186                    0.0
187                } else {
188                    self.sum / self.count as f64
189                }
190            }
191            AggregateOp::Min => {
192                if self.any_observed {
193                    self.min
194                } else {
195                    0.0
196                }
197            }
198            AggregateOp::Max => {
199                if self.any_observed {
200                    self.max
201                } else {
202                    0.0
203                }
204            }
205        };
206        AggregateResult {
207            op: spec.op,
208            column: spec.column,
209            value,
210            count: self.count,
211        }
212    }
213}
214
215fn numeric_value(batch: &ColumnBatch, row: usize, column: usize) -> Option<f64> {
216    let col = batch.columns.get(column)?;
217    if !col.is_valid(row) {
218        return None;
219    }
220    match col {
221        ColumnVector::Int64 { data, .. } => Some(data[row] as f64),
222        ColumnVector::Float64 { data, .. } => Some(data[row]),
223        _ => None,
224    }
225}
226
227fn compare_keys(a: &[GroupKeyPart], b: &[GroupKeyPart]) -> std::cmp::Ordering {
228    for (x, y) in a.iter().zip(b.iter()) {
229        let ord = compare_key_part(x, y);
230        if ord != std::cmp::Ordering::Equal {
231            return ord;
232        }
233    }
234    a.len().cmp(&b.len())
235}
236
237fn compare_key_part(x: &GroupKeyPart, y: &GroupKeyPart) -> std::cmp::Ordering {
238    use std::cmp::Ordering;
239    use GroupKeyPart::*;
240    match (x, y) {
241        (Int64(a), Int64(b)) => a.cmp(b),
242        (Float64Bits(a), Float64Bits(b)) => f64::from_bits(*a)
243            .partial_cmp(&f64::from_bits(*b))
244            .unwrap_or(Ordering::Equal),
245        (Bool(a), Bool(b)) => a.cmp(b),
246        (Text(a), Text(b)) => a.cmp(b),
247        (Null, Null) => Ordering::Equal,
248        (Null, _) => Ordering::Less,
249        (_, Null) => Ordering::Greater,
250        _ => Ordering::Equal,
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::super::super::column_batch::{ColumnKind, Field, Schema};
257    use super::*;
258    use std::sync::Arc;
259
260    fn batch() -> ColumnBatch {
261        let schema = Arc::new(Schema::new(vec![
262            Field {
263                name: "region".into(),
264                kind: ColumnKind::Text,
265                nullable: false,
266            },
267            Field {
268                name: "amount".into(),
269                kind: ColumnKind::Float64,
270                nullable: false,
271            },
272        ]));
273        ColumnBatch::new(
274            schema,
275            vec![
276                ColumnVector::Text {
277                    data: vec![
278                        "us".into(),
279                        "eu".into(),
280                        "us".into(),
281                        "us".into(),
282                        "eu".into(),
283                    ],
284                    validity: None,
285                },
286                ColumnVector::Float64 {
287                    data: vec![10.0, 20.0, 30.0, 40.0, 50.0],
288                    validity: None,
289                },
290            ],
291        )
292    }
293
294    #[test]
295    fn count_star_over_whole_batch() {
296        let b = batch();
297        let out = batch_aggregate(
298            &b,
299            &[],
300            &[AggregateSpec {
301                column: 0,
302                op: AggregateOp::Count,
303            }],
304        );
305        assert_eq!(out.len(), 1);
306        assert_eq!(out[0].results[0].value, 5.0);
307    }
308
309    #[test]
310    fn sum_grouped_by_region() {
311        let b = batch();
312        let out = batch_aggregate(
313            &b,
314            &[0],
315            &[AggregateSpec {
316                column: 1,
317                op: AggregateOp::Sum,
318            }],
319        );
320        assert_eq!(out.len(), 2);
321        // Ordering is deterministic (Text Ord) — eu first, us second.
322        assert_eq!(out[0].key[0], GroupKeyPart::Text("eu".into()));
323        assert_eq!(out[0].results[0].value, 70.0);
324        assert_eq!(out[1].key[0], GroupKeyPart::Text("us".into()));
325        assert_eq!(out[1].results[0].value, 80.0);
326    }
327
328    #[test]
329    fn avg_handles_empty_group_cleanly() {
330        let b = batch();
331        let out = batch_aggregate(
332            &b,
333            &[0],
334            &[AggregateSpec {
335                column: 1,
336                op: AggregateOp::Avg,
337            }],
338        );
339        let eu_row = out
340            .iter()
341            .find(|r| r.key[0] == GroupKeyPart::Text("eu".into()))
342            .unwrap();
343        assert_eq!(eu_row.results[0].value, 35.0);
344        let us_row = out
345            .iter()
346            .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
347            .unwrap();
348        assert!((us_row.results[0].value - (80.0 / 3.0)).abs() < 1e-6);
349    }
350
351    #[test]
352    fn min_and_max_agree_on_shape() {
353        let b = batch();
354        let out = batch_aggregate(
355            &b,
356            &[0],
357            &[
358                AggregateSpec {
359                    column: 1,
360                    op: AggregateOp::Min,
361                },
362                AggregateSpec {
363                    column: 1,
364                    op: AggregateOp::Max,
365                },
366            ],
367        );
368        let us = out
369            .iter()
370            .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
371            .unwrap();
372        assert_eq!(us.results[0].value, 10.0);
373        assert_eq!(us.results[1].value, 40.0);
374    }
375
376    #[test]
377    fn empty_batch_returns_empty() {
378        let b = batch();
379        let empty = b.take(&[]);
380        let out = batch_aggregate(
381            &empty,
382            &[],
383            &[AggregateSpec {
384                column: 0,
385                op: AggregateOp::Count,
386            }],
387        );
388        assert!(out.is_empty());
389    }
390
391    #[test]
392    fn multi_key_grouping_preserves_combinations() {
393        let schema = Arc::new(Schema::new(vec![
394            Field {
395                name: "region".into(),
396                kind: ColumnKind::Text,
397                nullable: false,
398            },
399            Field {
400                name: "tier".into(),
401                kind: ColumnKind::Int64,
402                nullable: false,
403            },
404            Field {
405                name: "v".into(),
406                kind: ColumnKind::Int64,
407                nullable: false,
408            },
409        ]));
410        let b = ColumnBatch::new(
411            schema,
412            vec![
413                ColumnVector::Text {
414                    data: vec!["a".into(), "a".into(), "b".into(), "a".into()],
415                    validity: None,
416                },
417                ColumnVector::Int64 {
418                    data: vec![1, 2, 1, 1],
419                    validity: None,
420                },
421                ColumnVector::Int64 {
422                    data: vec![10, 20, 30, 40],
423                    validity: None,
424                },
425            ],
426        );
427        let out = batch_aggregate(
428            &b,
429            &[0, 1],
430            &[AggregateSpec {
431                column: 2,
432                op: AggregateOp::Sum,
433            }],
434        );
435        assert_eq!(out.len(), 3);
436        // (a, 1) → 10 + 40 = 50; (a, 2) → 20; (b, 1) → 30.
437        let find = |r: &str, t: i64| {
438            out.iter()
439                .find(|row| {
440                    row.key[0] == GroupKeyPart::Text(r.into())
441                        && row.key[1] == GroupKeyPart::Int64(t)
442                })
443                .unwrap()
444                .results[0]
445                .value
446        };
447        assert_eq!(find("a", 1), 50.0);
448        assert_eq!(find("a", 2), 20.0);
449        assert_eq!(find("b", 1), 30.0);
450    }
451}