Skip to main content

reddb_server/storage/query/batch/operators/
aggregate.rs

1//! `BatchAggregate` — group by keys + reduce each numeric column.
2//!
3//! The reducer supports COUNT / SUM / AVG / MIN / MAX over Int64 and
4//! Float64 columns. Group keys may be Int64, Float64 bit-patterns,
5//! Bool, or Text. Output is a `Vec<AggregateRow>` — operator-level
6//! primitive; the SQL dispatch layer converts it to a result batch
7//! as part of the B5 (projections) sprint.
8
9use std::collections::HashMap;
10
11use super::super::column_batch::{ColumnBatch, ColumnVector, ValueRef};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AggregateOp {
15    Count,
16    Sum,
17    Avg,
18    Min,
19    Max,
20}
21
22#[derive(Debug, Clone)]
23pub struct AggregateSpec {
24    /// Column index to aggregate. For `Count`, ignored (counts rows).
25    pub column: usize,
26    pub op: AggregateOp,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub enum GroupKeyPart {
31    Int64(i64),
32    Float64Bits(u64),
33    Bool(bool),
34    Text(String),
35    Null,
36}
37
38impl Eq for GroupKeyPart {}
39
40impl std::hash::Hash for GroupKeyPart {
41    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42        match self {
43            GroupKeyPart::Int64(v) => {
44                0u8.hash(state);
45                v.hash(state);
46            }
47            GroupKeyPart::Float64Bits(v) => {
48                1u8.hash(state);
49                v.hash(state);
50            }
51            GroupKeyPart::Bool(v) => {
52                2u8.hash(state);
53                v.hash(state);
54            }
55            GroupKeyPart::Text(v) => {
56                3u8.hash(state);
57                v.hash(state);
58            }
59            GroupKeyPart::Null => {
60                4u8.hash(state);
61            }
62        }
63    }
64}
65
66type GroupKey = Vec<GroupKeyPart>;
67
68#[derive(Debug, Clone)]
69pub struct AggregateResult {
70    pub op: AggregateOp,
71    pub column: usize,
72    pub value: f64,
73    /// For averages we also expose the intermediate count so callers
74    /// can merge partial aggregations across batches.
75    pub count: u64,
76}
77
78#[derive(Debug, Clone)]
79pub struct AggregateRow {
80    pub key: GroupKey,
81    pub results: Vec<AggregateResult>,
82}
83
84/// Group `batch` by `group_columns` and produce one row per key with
85/// each `AggregateSpec` applied. `group_columns` may be empty, in
86/// which case the whole batch reduces to a single row.
87pub fn batch_aggregate(
88    batch: &ColumnBatch,
89    group_columns: &[usize],
90    specs: &[AggregateSpec],
91) -> Vec<AggregateRow> {
92    if batch.is_empty() {
93        return Vec::new();
94    }
95    let mut groups: HashMap<GroupKey, Vec<Accumulator>> = HashMap::new();
96    for row in 0..batch.len() {
97        let key: GroupKey = group_columns
98            .iter()
99            .map(|c| group_key_part(batch, row, *c))
100            .collect();
101        let accs = groups
102            .entry(key)
103            .or_insert_with(|| specs.iter().map(Accumulator::new).collect());
104        for (idx, spec) in specs.iter().enumerate() {
105            accs[idx].observe(batch, row, spec);
106        }
107    }
108    let mut out: Vec<AggregateRow> = groups
109        .into_iter()
110        .map(|(key, accs)| {
111            let results = accs
112                .into_iter()
113                .zip(specs.iter())
114                .map(|(acc, spec)| acc.finalize(spec))
115                .collect();
116            AggregateRow { key, results }
117        })
118        .collect();
119    // Deterministic output ordering simplifies test assertions.
120    out.sort_by(|a, b| compare_keys(&a.key, &b.key));
121    out
122}
123
124fn group_key_part(batch: &ColumnBatch, row: usize, column: usize) -> GroupKeyPart {
125    match batch.value(row, column) {
126        ValueRef::Int64(v) => GroupKeyPart::Int64(v),
127        ValueRef::Float64(v) => GroupKeyPart::Float64Bits(v.to_bits()),
128        ValueRef::Bool(v) => GroupKeyPart::Bool(v),
129        ValueRef::Text(v) => GroupKeyPart::Text(v.to_string()),
130        ValueRef::Null => GroupKeyPart::Null,
131    }
132}
133
134#[derive(Debug, Clone)]
135struct Accumulator {
136    count: u64,
137    sum: f64,
138    min: f64,
139    max: f64,
140    any_observed: bool,
141}
142
143impl Accumulator {
144    fn new(_spec: &AggregateSpec) -> Self {
145        Self {
146            count: 0,
147            sum: 0.0,
148            min: f64::INFINITY,
149            max: f64::NEG_INFINITY,
150            any_observed: false,
151        }
152    }
153
154    fn observe(&mut self, batch: &ColumnBatch, row: usize, spec: &AggregateSpec) {
155        match spec.op {
156            AggregateOp::Count => {
157                self.count += 1;
158            }
159            AggregateOp::Sum | AggregateOp::Avg | AggregateOp::Min | AggregateOp::Max => {
160                if let Some(v) = numeric_value(batch, row, spec.column) {
161                    self.count += 1;
162                    self.sum += v;
163                    if v < self.min {
164                        self.min = v;
165                    }
166                    if v > self.max {
167                        self.max = v;
168                    }
169                    self.any_observed = true;
170                }
171            }
172        }
173    }
174
175    fn finalize(self, spec: &AggregateSpec) -> AggregateResult {
176        let value = match spec.op {
177            AggregateOp::Count => self.count as f64,
178            AggregateOp::Sum => self.sum,
179            AggregateOp::Avg => {
180                if self.count == 0 {
181                    0.0
182                } else {
183                    self.sum / self.count as f64
184                }
185            }
186            AggregateOp::Min => {
187                if self.any_observed {
188                    self.min
189                } else {
190                    0.0
191                }
192            }
193            AggregateOp::Max => {
194                if self.any_observed {
195                    self.max
196                } else {
197                    0.0
198                }
199            }
200        };
201        AggregateResult {
202            op: spec.op,
203            column: spec.column,
204            value,
205            count: self.count,
206        }
207    }
208}
209
210fn numeric_value(batch: &ColumnBatch, row: usize, column: usize) -> Option<f64> {
211    let col = batch.columns.get(column)?;
212    if !col.is_valid(row) {
213        return None;
214    }
215    match col {
216        ColumnVector::Int64 { data, .. } => Some(data[row] as f64),
217        ColumnVector::Float64 { data, .. } => Some(data[row]),
218        _ => None,
219    }
220}
221
222fn compare_keys(a: &[GroupKeyPart], b: &[GroupKeyPart]) -> std::cmp::Ordering {
223    for (x, y) in a.iter().zip(b.iter()) {
224        let ord = compare_key_part(x, y);
225        if ord != std::cmp::Ordering::Equal {
226            return ord;
227        }
228    }
229    a.len().cmp(&b.len())
230}
231
232fn compare_key_part(x: &GroupKeyPart, y: &GroupKeyPart) -> std::cmp::Ordering {
233    use std::cmp::Ordering;
234    use GroupKeyPart::*;
235    match (x, y) {
236        (Int64(a), Int64(b)) => a.cmp(b),
237        (Float64Bits(a), Float64Bits(b)) => f64::from_bits(*a)
238            .partial_cmp(&f64::from_bits(*b))
239            .unwrap_or(Ordering::Equal),
240        (Bool(a), Bool(b)) => a.cmp(b),
241        (Text(a), Text(b)) => a.cmp(b),
242        (Null, Null) => Ordering::Equal,
243        (Null, _) => Ordering::Less,
244        (_, Null) => Ordering::Greater,
245        _ => Ordering::Equal,
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::super::super::column_batch::{ColumnKind, Field, Schema};
252    use super::*;
253    use std::sync::Arc;
254
255    fn batch() -> ColumnBatch {
256        let schema = Arc::new(Schema::new(vec![
257            Field {
258                name: "region".into(),
259                kind: ColumnKind::Text,
260                nullable: false,
261            },
262            Field {
263                name: "amount".into(),
264                kind: ColumnKind::Float64,
265                nullable: false,
266            },
267        ]));
268        ColumnBatch::new(
269            schema,
270            vec![
271                ColumnVector::Text {
272                    data: vec![
273                        "us".into(),
274                        "eu".into(),
275                        "us".into(),
276                        "us".into(),
277                        "eu".into(),
278                    ],
279                    validity: None,
280                },
281                ColumnVector::Float64 {
282                    data: vec![10.0, 20.0, 30.0, 40.0, 50.0],
283                    validity: None,
284                },
285            ],
286        )
287    }
288
289    #[test]
290    fn count_star_over_whole_batch() {
291        let b = batch();
292        let out = batch_aggregate(
293            &b,
294            &[],
295            &[AggregateSpec {
296                column: 0,
297                op: AggregateOp::Count,
298            }],
299        );
300        assert_eq!(out.len(), 1);
301        assert_eq!(out[0].results[0].value, 5.0);
302    }
303
304    #[test]
305    fn sum_grouped_by_region() {
306        let b = batch();
307        let out = batch_aggregate(
308            &b,
309            &[0],
310            &[AggregateSpec {
311                column: 1,
312                op: AggregateOp::Sum,
313            }],
314        );
315        assert_eq!(out.len(), 2);
316        // Ordering is deterministic (Text Ord) — eu first, us second.
317        assert_eq!(out[0].key[0], GroupKeyPart::Text("eu".into()));
318        assert_eq!(out[0].results[0].value, 70.0);
319        assert_eq!(out[1].key[0], GroupKeyPart::Text("us".into()));
320        assert_eq!(out[1].results[0].value, 80.0);
321    }
322
323    #[test]
324    fn avg_handles_empty_group_cleanly() {
325        let b = batch();
326        let out = batch_aggregate(
327            &b,
328            &[0],
329            &[AggregateSpec {
330                column: 1,
331                op: AggregateOp::Avg,
332            }],
333        );
334        let eu_row = out
335            .iter()
336            .find(|r| r.key[0] == GroupKeyPart::Text("eu".into()))
337            .unwrap();
338        assert_eq!(eu_row.results[0].value, 35.0);
339        let us_row = out
340            .iter()
341            .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
342            .unwrap();
343        assert!((us_row.results[0].value - (80.0 / 3.0)).abs() < 1e-6);
344    }
345
346    #[test]
347    fn min_and_max_agree_on_shape() {
348        let b = batch();
349        let out = batch_aggregate(
350            &b,
351            &[0],
352            &[
353                AggregateSpec {
354                    column: 1,
355                    op: AggregateOp::Min,
356                },
357                AggregateSpec {
358                    column: 1,
359                    op: AggregateOp::Max,
360                },
361            ],
362        );
363        let us = out
364            .iter()
365            .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
366            .unwrap();
367        assert_eq!(us.results[0].value, 10.0);
368        assert_eq!(us.results[1].value, 40.0);
369    }
370
371    #[test]
372    fn empty_batch_returns_empty() {
373        let b = batch();
374        let empty = b.take(&[]);
375        let out = batch_aggregate(
376            &empty,
377            &[],
378            &[AggregateSpec {
379                column: 0,
380                op: AggregateOp::Count,
381            }],
382        );
383        assert!(out.is_empty());
384    }
385
386    #[test]
387    fn multi_key_grouping_preserves_combinations() {
388        let schema = Arc::new(Schema::new(vec![
389            Field {
390                name: "region".into(),
391                kind: ColumnKind::Text,
392                nullable: false,
393            },
394            Field {
395                name: "tier".into(),
396                kind: ColumnKind::Int64,
397                nullable: false,
398            },
399            Field {
400                name: "v".into(),
401                kind: ColumnKind::Int64,
402                nullable: false,
403            },
404        ]));
405        let b = ColumnBatch::new(
406            schema,
407            vec![
408                ColumnVector::Text {
409                    data: vec!["a".into(), "a".into(), "b".into(), "a".into()],
410                    validity: None,
411                },
412                ColumnVector::Int64 {
413                    data: vec![1, 2, 1, 1],
414                    validity: None,
415                },
416                ColumnVector::Int64 {
417                    data: vec![10, 20, 30, 40],
418                    validity: None,
419                },
420            ],
421        );
422        let out = batch_aggregate(
423            &b,
424            &[0, 1],
425            &[AggregateSpec {
426                column: 2,
427                op: AggregateOp::Sum,
428            }],
429        );
430        assert_eq!(out.len(), 3);
431        // (a, 1) → 10 + 40 = 50; (a, 2) → 20; (b, 1) → 30.
432        let find = |r: &str, t: i64| {
433            out.iter()
434                .find(|row| {
435                    row.key[0] == GroupKeyPart::Text(r.into())
436                        && row.key[1] == GroupKeyPart::Int64(t)
437                })
438                .unwrap()
439                .results[0]
440                .value
441        };
442        assert_eq!(find("a", 1), 50.0);
443        assert_eq!(find("a", 2), 20.0);
444        assert_eq!(find("b", 1), 30.0);
445    }
446}