1use std::collections::HashMap;
17
18use crate::tile::sparse_tile::{RowKind, SparseTile};
19use crate::types::cell_value::value::CellValue;
20use crate::types::coord::value::CoordValue;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Reducer {
24 Sum,
25 Count,
26 Min,
27 Max,
28 Mean,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq)]
34pub enum AggregateResult {
35 Sum {
36 value: f64,
37 count: u64,
38 },
39 Count {
40 count: u64,
41 },
42 Min {
43 value: f64,
44 count: u64,
45 },
46 Max {
47 value: f64,
48 count: u64,
49 },
50 Mean {
51 sum: f64,
52 count: u64,
53 },
54 Empty(Reducer),
56}
57
58impl AggregateResult {
59 pub fn merge(self, other: AggregateResult) -> AggregateResult {
63 use AggregateResult::*;
64 match (self, other) {
65 (Empty(_), x) | (x, Empty(_)) => x,
66 (
67 Sum {
68 value: a,
69 count: ca,
70 },
71 Sum {
72 value: b,
73 count: cb,
74 },
75 ) => Sum {
76 value: a + b,
77 count: ca + cb,
78 },
79 (Count { count: a }, Count { count: b }) => Count { count: a + b },
80 (
81 Min {
82 value: a,
83 count: ca,
84 },
85 Min {
86 value: b,
87 count: cb,
88 },
89 ) => Min {
90 value: if a <= b { a } else { b },
91 count: ca + cb,
92 },
93 (
94 Max {
95 value: a,
96 count: ca,
97 },
98 Max {
99 value: b,
100 count: cb,
101 },
102 ) => Max {
103 value: if a >= b { a } else { b },
104 count: ca + cb,
105 },
106 (Mean { sum: a, count: ca }, Mean { sum: b, count: cb }) => Mean {
107 sum: a + b,
108 count: ca + cb,
109 },
110 (lhs, _) => lhs,
111 }
112 }
113
114 pub fn finalize(self) -> Option<f64> {
116 match self {
117 AggregateResult::Sum { value, .. } => Some(value),
118 AggregateResult::Count { count } => Some(count as f64),
119 AggregateResult::Min { value, .. } | AggregateResult::Max { value, .. } => Some(value),
120 AggregateResult::Mean { sum, count } if count > 0 => Some(sum / count as f64),
121 _ => None,
122 }
123 }
124}
125
126fn empty(r: Reducer) -> AggregateResult {
127 AggregateResult::Empty(r)
128}
129
130pub fn aggregate_attr(tile: &SparseTile, attr_idx: usize, reducer: Reducer) -> AggregateResult {
135 let Some(col) = tile.attr_cols.get(attr_idx) else {
136 return empty(reducer);
137 };
138 let mut acc = empty(reducer);
139 for v in col {
140 let one = single_cell(v, reducer);
141 acc = acc.merge(one);
142 }
143 acc
144}
145
146fn single_cell(v: &CellValue, reducer: Reducer) -> AggregateResult {
147 use AggregateResult::*;
148 let n = match v {
149 CellValue::Int64(x) => Some(*x as f64),
150 CellValue::Float64(x) => Some(*x),
151 CellValue::String(_) | CellValue::Bytes(_) | CellValue::Null => None,
152 };
153 match (reducer, n) {
154 (Reducer::Count, _) if !v.is_null() => Count { count: 1 },
155 (Reducer::Count, _) => Empty(Reducer::Count),
156 (Reducer::Sum, Some(x)) => Sum { value: x, count: 1 },
157 (Reducer::Min, Some(x)) => Min { value: x, count: 1 },
158 (Reducer::Max, Some(x)) => Max { value: x, count: 1 },
159 (Reducer::Mean, Some(x)) => Mean { sum: x, count: 1 },
160 (r, None) => Empty(r),
161 }
162}
163
164#[derive(Debug, Clone)]
168pub struct GroupAggregate {
169 pub key: CoordValue,
170 pub result: AggregateResult,
171}
172
173pub fn group_by_dim(
176 tile: &SparseTile,
177 dim_idx: usize,
178 attr_idx: usize,
179 reducer: Reducer,
180) -> Vec<GroupAggregate> {
181 let Some(dict) = tile.dim_dicts.get(dim_idx) else {
182 return Vec::new();
183 };
184 let Some(col) = tile.attr_cols.get(attr_idx) else {
185 return Vec::new();
186 };
187 let mut order: Vec<CoordValue> = Vec::new();
188 let mut by_key: HashMap<CoordValue, AggregateResult> = HashMap::new();
189 let mut live_idx = 0usize;
192 for row in 0..tile.row_count() {
193 let kind = match tile.row_kind(row) {
194 Ok(k) => k,
195 Err(_) => break,
196 };
197 if kind != RowKind::Live {
198 continue;
199 }
200 let cell = match col.get(live_idx) {
201 Some(c) => c,
202 None => break,
203 };
204 live_idx += 1;
205 let key = dict.values[dict.indices[row] as usize].clone();
206 let one = single_cell(cell, reducer);
207 match by_key.get_mut(&key) {
208 Some(slot) => *slot = slot.merge(one),
209 None => {
210 order.push(key.clone());
211 by_key.insert(key, empty(reducer).merge(one));
212 }
213 }
214 }
215 order
216 .into_iter()
217 .map(|k| GroupAggregate {
218 result: by_key.remove(&k).unwrap_or(empty(reducer)),
219 key: k,
220 })
221 .collect()
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::schema::ArraySchema;
228 use crate::schema::ArraySchemaBuilder;
229 use crate::schema::attr_spec::{AttrSpec, AttrType};
230 use crate::schema::dim_spec::{DimSpec, DimType};
231 use crate::tile::sparse_tile::SparseTileBuilder;
232 use crate::types::domain::{Domain, DomainBound};
233
234 fn schema() -> ArraySchema {
235 ArraySchemaBuilder::new("g")
236 .dim(DimSpec::new(
237 "k",
238 DimType::Int64,
239 Domain::new(DomainBound::Int64(0), DomainBound::Int64(15)),
240 ))
241 .attr(AttrSpec::new("v", AttrType::Float64, true))
242 .tile_extents(vec![16])
243 .build()
244 .unwrap()
245 }
246
247 fn tile(rows: &[(i64, Option<f64>)]) -> SparseTile {
248 let s = schema();
249 let mut b = SparseTileBuilder::new(&s);
250 for (k, v) in rows {
251 let cv = v.map(CellValue::Float64).unwrap_or(CellValue::Null);
252 b.push(&[CoordValue::Int64(*k)], &[cv]).unwrap();
253 }
254 b.build()
255 }
256
257 #[test]
258 fn sum_skips_nulls() {
259 let t = tile(&[(0, Some(1.0)), (1, None), (2, Some(3.0))]);
260 let r = aggregate_attr(&t, 0, Reducer::Sum);
261 assert_eq!(r.finalize(), Some(4.0));
262 }
263
264 #[test]
265 fn count_includes_nonnull_only() {
266 let t = tile(&[(0, Some(1.0)), (1, None), (2, Some(3.0))]);
267 let r = aggregate_attr(&t, 0, Reducer::Count);
268 assert_eq!(r.finalize(), Some(2.0));
269 }
270
271 #[test]
272 fn min_max_mean() {
273 let t = tile(&[(0, Some(1.0)), (1, Some(5.0)), (2, Some(3.0))]);
274 assert_eq!(aggregate_attr(&t, 0, Reducer::Min).finalize(), Some(1.0));
275 assert_eq!(aggregate_attr(&t, 0, Reducer::Max).finalize(), Some(5.0));
276 assert_eq!(aggregate_attr(&t, 0, Reducer::Mean).finalize(), Some(3.0));
277 }
278
279 #[test]
280 fn merge_combines_partials_exactly() {
281 let a = AggregateResult::Mean {
282 sum: 10.0,
283 count: 4,
284 };
285 let b = AggregateResult::Mean { sum: 6.0, count: 2 };
286 assert_eq!(a.merge(b).finalize(), Some(16.0 / 6.0));
288 }
289
290 #[test]
291 fn empty_reducer_finalizes_to_none() {
292 let t = tile(&[(0, None)]);
293 let r = aggregate_attr(&t, 0, Reducer::Sum);
294 assert_eq!(r.finalize(), None);
295 }
296
297 #[test]
298 fn group_by_buckets_by_dim_value() {
299 let t = tile(&[
300 (0, Some(1.0)),
301 (1, Some(2.0)),
302 (0, Some(3.0)),
303 (1, Some(4.0)),
304 ]);
305 let g = group_by_dim(&t, 0, 0, Reducer::Sum);
306 assert_eq!(g.len(), 2);
307 assert_eq!(g[0].key, CoordValue::Int64(0));
309 assert_eq!(g[0].result.finalize(), Some(4.0));
310 assert_eq!(g[1].key, CoordValue::Int64(1));
311 assert_eq!(g[1].result.finalize(), Some(6.0));
312 }
313}