Skip to main content

nodedb_array/query/
aggregate.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Reductions over a sparse tile.
4//!
5//! Five reducers (Sum / Count / Min / Max / Mean) operate on a single
6//! attribute column. The shape is partial-friendly: each tile produces
7//! an [`AggregateResult`] which the executor merges across tiles via
8//! [`AggregateResult::merge`]. Mean carries (sum, count) so merges are
9//! exact rather than averaging averages.
10//!
11//! Group-by ([`group_by_dim`]) buckets cells by one dim's values and
12//! returns a `Vec<(CoordValue, AggregateResult)>` ordered by first
13//! appearance. Sort/order is the caller's job — keeping insertion
14//! order avoids forcing `Ord` on `CoordValue`.
15
16use std::collections::HashMap;
17
18use crate::tile::sparse_tile::{RowKind, SparseTile};
19use crate::types::cell_value::value::CellValue;
20use crate::types::coord::value::CoordValue;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Reducer {
24    Sum,
25    Count,
26    Min,
27    Max,
28    Mean,
29}
30
31/// Partial result for one reducer over one (group of) cell(s).
32/// `Mean` retains `(sum, count)` so partials merge exactly.
33#[derive(Debug, Clone, Copy, PartialEq)]
34pub enum AggregateResult {
35    Sum {
36        value: f64,
37        count: u64,
38    },
39    Count {
40        count: u64,
41    },
42    Min {
43        value: f64,
44        count: u64,
45    },
46    Max {
47        value: f64,
48        count: u64,
49    },
50    Mean {
51        sum: f64,
52        count: u64,
53    },
54    /// No non-null cells observed yet for this reducer.
55    Empty(Reducer),
56}
57
58impl AggregateResult {
59    /// Merge two partials of the same reducer kind. Mismatched
60    /// reducers produce a left-biased result (the executor only ever
61    /// merges same-kind partials, so this is a programming-error path).
62    pub fn merge(self, other: AggregateResult) -> AggregateResult {
63        use AggregateResult::*;
64        match (self, other) {
65            (Empty(_), x) | (x, Empty(_)) => x,
66            (
67                Sum {
68                    value: a,
69                    count: ca,
70                },
71                Sum {
72                    value: b,
73                    count: cb,
74                },
75            ) => Sum {
76                value: a + b,
77                count: ca + cb,
78            },
79            (Count { count: a }, Count { count: b }) => Count { count: a + b },
80            (
81                Min {
82                    value: a,
83                    count: ca,
84                },
85                Min {
86                    value: b,
87                    count: cb,
88                },
89            ) => Min {
90                value: if a <= b { a } else { b },
91                count: ca + cb,
92            },
93            (
94                Max {
95                    value: a,
96                    count: ca,
97                },
98                Max {
99                    value: b,
100                    count: cb,
101                },
102            ) => Max {
103                value: if a >= b { a } else { b },
104                count: ca + cb,
105            },
106            (Mean { sum: a, count: ca }, Mean { sum: b, count: cb }) => Mean {
107                sum: a + b,
108                count: ca + cb,
109            },
110            (lhs, _) => lhs,
111        }
112    }
113
114    /// Final scalar — useful for tests and the planner's last fold.
115    pub fn finalize(self) -> Option<f64> {
116        match self {
117            AggregateResult::Sum { value, .. } => Some(value),
118            AggregateResult::Count { count } => Some(count as f64),
119            AggregateResult::Min { value, .. } | AggregateResult::Max { value, .. } => Some(value),
120            AggregateResult::Mean { sum, count } if count > 0 => Some(sum / count as f64),
121            _ => None,
122        }
123    }
124}
125
126fn empty(r: Reducer) -> AggregateResult {
127    AggregateResult::Empty(r)
128}
129
130/// Reduce one attr column over all rows of a tile. `attr_idx` must be
131/// in range; non-numeric / null cells are skipped (Count includes
132/// non-null cells regardless of dtype, since "count of cells with this
133/// attribute populated" is a sensible fold across all attr types).
134pub fn aggregate_attr(tile: &SparseTile, attr_idx: usize, reducer: Reducer) -> AggregateResult {
135    let Some(col) = tile.attr_cols.get(attr_idx) else {
136        return empty(reducer);
137    };
138    let mut acc = empty(reducer);
139    for v in col {
140        let one = single_cell(v, reducer);
141        acc = acc.merge(one);
142    }
143    acc
144}
145
146fn single_cell(v: &CellValue, reducer: Reducer) -> AggregateResult {
147    use AggregateResult::*;
148    let n = match v {
149        CellValue::Int64(x) => Some(*x as f64),
150        CellValue::Float64(x) => Some(*x),
151        CellValue::String(_) | CellValue::Bytes(_) | CellValue::Null => None,
152    };
153    match (reducer, n) {
154        (Reducer::Count, _) if !v.is_null() => Count { count: 1 },
155        (Reducer::Count, _) => Empty(Reducer::Count),
156        (Reducer::Sum, Some(x)) => Sum { value: x, count: 1 },
157        (Reducer::Min, Some(x)) => Min { value: x, count: 1 },
158        (Reducer::Max, Some(x)) => Max { value: x, count: 1 },
159        (Reducer::Mean, Some(x)) => Mean { sum: x, count: 1 },
160        (r, None) => Empty(r),
161    }
162}
163
164/// One bucket of a group-by aggregate. The key is the dim value the
165/// rows share; `result` is the reducer's running partial over that
166/// bucket.
167#[derive(Debug, Clone)]
168pub struct GroupAggregate {
169    pub key: CoordValue,
170    pub result: AggregateResult,
171}
172
173/// Group rows by one dim's values and reduce `attr_idx` per group.
174/// Returns groups in first-seen order.
175pub fn group_by_dim(
176    tile: &SparseTile,
177    dim_idx: usize,
178    attr_idx: usize,
179    reducer: Reducer,
180) -> Vec<GroupAggregate> {
181    let Some(dict) = tile.dim_dicts.get(dim_idx) else {
182        return Vec::new();
183    };
184    let Some(col) = tile.attr_cols.get(attr_idx) else {
185        return Vec::new();
186    };
187    let mut order: Vec<CoordValue> = Vec::new();
188    let mut by_key: HashMap<CoordValue, AggregateResult> = HashMap::new();
189    // Iterate physical rows; track live_idx separately because attr_cols only
190    // has entries for Live rows — sentinel rows carry no attr payload.
191    let mut live_idx = 0usize;
192    for row in 0..tile.row_count() {
193        let kind = match tile.row_kind(row) {
194            Ok(k) => k,
195            Err(_) => break,
196        };
197        if kind != RowKind::Live {
198            continue;
199        }
200        let cell = match col.get(live_idx) {
201            Some(c) => c,
202            None => break,
203        };
204        live_idx += 1;
205        let key = dict.values[dict.indices[row] as usize].clone();
206        let one = single_cell(cell, reducer);
207        match by_key.get_mut(&key) {
208            Some(slot) => *slot = slot.merge(one),
209            None => {
210                order.push(key.clone());
211                by_key.insert(key, empty(reducer).merge(one));
212            }
213        }
214    }
215    order
216        .into_iter()
217        .map(|k| GroupAggregate {
218            result: by_key.remove(&k).unwrap_or(empty(reducer)),
219            key: k,
220        })
221        .collect()
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::schema::ArraySchema;
228    use crate::schema::ArraySchemaBuilder;
229    use crate::schema::attr_spec::{AttrSpec, AttrType};
230    use crate::schema::dim_spec::{DimSpec, DimType};
231    use crate::tile::sparse_tile::SparseTileBuilder;
232    use crate::types::domain::{Domain, DomainBound};
233
234    fn schema() -> ArraySchema {
235        ArraySchemaBuilder::new("g")
236            .dim(DimSpec::new(
237                "k",
238                DimType::Int64,
239                Domain::new(DomainBound::Int64(0), DomainBound::Int64(15)),
240            ))
241            .attr(AttrSpec::new("v", AttrType::Float64, true))
242            .tile_extents(vec![16])
243            .build()
244            .unwrap()
245    }
246
247    fn tile(rows: &[(i64, Option<f64>)]) -> SparseTile {
248        let s = schema();
249        let mut b = SparseTileBuilder::new(&s);
250        for (k, v) in rows {
251            let cv = v.map(CellValue::Float64).unwrap_or(CellValue::Null);
252            b.push(&[CoordValue::Int64(*k)], &[cv]).unwrap();
253        }
254        b.build()
255    }
256
257    #[test]
258    fn sum_skips_nulls() {
259        let t = tile(&[(0, Some(1.0)), (1, None), (2, Some(3.0))]);
260        let r = aggregate_attr(&t, 0, Reducer::Sum);
261        assert_eq!(r.finalize(), Some(4.0));
262    }
263
264    #[test]
265    fn count_includes_nonnull_only() {
266        let t = tile(&[(0, Some(1.0)), (1, None), (2, Some(3.0))]);
267        let r = aggregate_attr(&t, 0, Reducer::Count);
268        assert_eq!(r.finalize(), Some(2.0));
269    }
270
271    #[test]
272    fn min_max_mean() {
273        let t = tile(&[(0, Some(1.0)), (1, Some(5.0)), (2, Some(3.0))]);
274        assert_eq!(aggregate_attr(&t, 0, Reducer::Min).finalize(), Some(1.0));
275        assert_eq!(aggregate_attr(&t, 0, Reducer::Max).finalize(), Some(5.0));
276        assert_eq!(aggregate_attr(&t, 0, Reducer::Mean).finalize(), Some(3.0));
277    }
278
279    #[test]
280    fn merge_combines_partials_exactly() {
281        let a = AggregateResult::Mean {
282            sum: 10.0,
283            count: 4,
284        };
285        let b = AggregateResult::Mean { sum: 6.0, count: 2 };
286        // (10 + 6) / (4 + 2) = 16/6
287        assert_eq!(a.merge(b).finalize(), Some(16.0 / 6.0));
288    }
289
290    #[test]
291    fn empty_reducer_finalizes_to_none() {
292        let t = tile(&[(0, None)]);
293        let r = aggregate_attr(&t, 0, Reducer::Sum);
294        assert_eq!(r.finalize(), None);
295    }
296
297    #[test]
298    fn group_by_buckets_by_dim_value() {
299        let t = tile(&[
300            (0, Some(1.0)),
301            (1, Some(2.0)),
302            (0, Some(3.0)),
303            (1, Some(4.0)),
304        ]);
305        let g = group_by_dim(&t, 0, 0, Reducer::Sum);
306        assert_eq!(g.len(), 2);
307        // first-seen order: 0, then 1
308        assert_eq!(g[0].key, CoordValue::Int64(0));
309        assert_eq!(g[0].result.finalize(), Some(4.0));
310        assert_eq!(g[1].key, CoordValue::Int64(1));
311        assert_eq!(g[1].result.finalize(), Some(6.0));
312    }
313}