1use indexmap::IndexMap;
10
11use super::super::column_batch::{ColumnBatch, ColumnVector, ValueRef};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AggregateOp {
15 Count,
16 Sum,
17 Avg,
18 Min,
19 Max,
20}
21
22#[derive(Debug, Clone)]
23pub struct AggregateSpec {
24 pub column: usize,
26 pub op: AggregateOp,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub enum GroupKeyPart {
31 Int64(i64),
32 Float64Bits(u64),
33 Bool(bool),
34 Text(String),
35 Null,
36}
37
38impl Eq for GroupKeyPart {}
39
40impl std::hash::Hash for GroupKeyPart {
41 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42 match self {
43 GroupKeyPart::Int64(v) => {
44 0u8.hash(state);
45 v.hash(state);
46 }
47 GroupKeyPart::Float64Bits(v) => {
48 1u8.hash(state);
49 v.hash(state);
50 }
51 GroupKeyPart::Bool(v) => {
52 2u8.hash(state);
53 v.hash(state);
54 }
55 GroupKeyPart::Text(v) => {
56 3u8.hash(state);
57 v.hash(state);
58 }
59 GroupKeyPart::Null => {
60 4u8.hash(state);
61 }
62 }
63 }
64}
65
66type GroupKey = Vec<GroupKeyPart>;
67
68#[derive(Debug, Clone)]
69pub struct AggregateResult {
70 pub op: AggregateOp,
71 pub column: usize,
72 pub value: f64,
73 pub count: u64,
76}
77
78#[derive(Debug, Clone)]
79pub struct AggregateRow {
80 pub key: GroupKey,
81 pub results: Vec<AggregateResult>,
82}
83
84pub fn batch_aggregate(
88 batch: &ColumnBatch,
89 group_columns: &[usize],
90 specs: &[AggregateSpec],
91) -> Vec<AggregateRow> {
92 if batch.is_empty() {
93 return Vec::new();
94 }
95 let mut groups: IndexMap<GroupKey, Vec<Accumulator>> = IndexMap::new();
101 for row in 0..batch.len() {
102 let key: GroupKey = group_columns
103 .iter()
104 .map(|c| group_key_part(batch, row, *c))
105 .collect();
106 let accs = groups
107 .entry(key)
108 .or_insert_with(|| specs.iter().map(Accumulator::new).collect());
109 for (idx, spec) in specs.iter().enumerate() {
110 accs[idx].observe(batch, row, spec);
111 }
112 }
113 let mut out: Vec<AggregateRow> = groups
114 .into_iter()
115 .map(|(key, accs)| {
116 let results = accs
117 .into_iter()
118 .zip(specs.iter())
119 .map(|(acc, spec)| acc.finalize(spec))
120 .collect();
121 AggregateRow { key, results }
122 })
123 .collect();
124 out.sort_by(|a, b| compare_keys(&a.key, &b.key));
126 out
127}
128
129fn group_key_part(batch: &ColumnBatch, row: usize, column: usize) -> GroupKeyPart {
130 match batch.value(row, column) {
131 ValueRef::Int64(v) => GroupKeyPart::Int64(v),
132 ValueRef::Float64(v) => GroupKeyPart::Float64Bits(v.to_bits()),
133 ValueRef::Bool(v) => GroupKeyPart::Bool(v),
134 ValueRef::Text(v) => GroupKeyPart::Text(v.to_string()),
135 ValueRef::Null => GroupKeyPart::Null,
136 }
137}
138
139#[derive(Debug, Clone)]
140struct Accumulator {
141 count: u64,
142 sum: f64,
143 min: f64,
144 max: f64,
145 any_observed: bool,
146}
147
148impl Accumulator {
149 fn new(_spec: &AggregateSpec) -> Self {
150 Self {
151 count: 0,
152 sum: 0.0,
153 min: f64::INFINITY,
154 max: f64::NEG_INFINITY,
155 any_observed: false,
156 }
157 }
158
159 fn observe(&mut self, batch: &ColumnBatch, row: usize, spec: &AggregateSpec) {
160 match spec.op {
161 AggregateOp::Count => {
162 self.count += 1;
163 }
164 AggregateOp::Sum | AggregateOp::Avg | AggregateOp::Min | AggregateOp::Max => {
165 if let Some(v) = numeric_value(batch, row, spec.column) {
166 self.count += 1;
167 self.sum += v;
168 if v < self.min {
169 self.min = v;
170 }
171 if v > self.max {
172 self.max = v;
173 }
174 self.any_observed = true;
175 }
176 }
177 }
178 }
179
180 fn finalize(self, spec: &AggregateSpec) -> AggregateResult {
181 let value = match spec.op {
182 AggregateOp::Count => self.count as f64,
183 AggregateOp::Sum => self.sum,
184 AggregateOp::Avg => {
185 if self.count == 0 {
186 0.0
187 } else {
188 self.sum / self.count as f64
189 }
190 }
191 AggregateOp::Min => {
192 if self.any_observed {
193 self.min
194 } else {
195 0.0
196 }
197 }
198 AggregateOp::Max => {
199 if self.any_observed {
200 self.max
201 } else {
202 0.0
203 }
204 }
205 };
206 AggregateResult {
207 op: spec.op,
208 column: spec.column,
209 value,
210 count: self.count,
211 }
212 }
213}
214
215fn numeric_value(batch: &ColumnBatch, row: usize, column: usize) -> Option<f64> {
216 let col = batch.columns.get(column)?;
217 if !col.is_valid(row) {
218 return None;
219 }
220 match col {
221 ColumnVector::Int64 { data, .. } => Some(data[row] as f64),
222 ColumnVector::Float64 { data, .. } => Some(data[row]),
223 _ => None,
224 }
225}
226
227fn compare_keys(a: &[GroupKeyPart], b: &[GroupKeyPart]) -> std::cmp::Ordering {
228 for (x, y) in a.iter().zip(b.iter()) {
229 let ord = compare_key_part(x, y);
230 if ord != std::cmp::Ordering::Equal {
231 return ord;
232 }
233 }
234 a.len().cmp(&b.len())
235}
236
237fn compare_key_part(x: &GroupKeyPart, y: &GroupKeyPart) -> std::cmp::Ordering {
238 use std::cmp::Ordering;
239 use GroupKeyPart::*;
240 match (x, y) {
241 (Int64(a), Int64(b)) => a.cmp(b),
242 (Float64Bits(a), Float64Bits(b)) => f64::from_bits(*a)
243 .partial_cmp(&f64::from_bits(*b))
244 .unwrap_or(Ordering::Equal),
245 (Bool(a), Bool(b)) => a.cmp(b),
246 (Text(a), Text(b)) => a.cmp(b),
247 (Null, Null) => Ordering::Equal,
248 (Null, _) => Ordering::Less,
249 (_, Null) => Ordering::Greater,
250 _ => Ordering::Equal,
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::super::super::column_batch::{ColumnKind, Field, Schema};
257 use super::*;
258 use std::sync::Arc;
259
260 fn batch() -> ColumnBatch {
261 let schema = Arc::new(Schema::new(vec![
262 Field {
263 name: "region".into(),
264 kind: ColumnKind::Text,
265 nullable: false,
266 },
267 Field {
268 name: "amount".into(),
269 kind: ColumnKind::Float64,
270 nullable: false,
271 },
272 ]));
273 ColumnBatch::new(
274 schema,
275 vec![
276 ColumnVector::Text {
277 data: vec![
278 "us".into(),
279 "eu".into(),
280 "us".into(),
281 "us".into(),
282 "eu".into(),
283 ],
284 validity: None,
285 },
286 ColumnVector::Float64 {
287 data: vec![10.0, 20.0, 30.0, 40.0, 50.0],
288 validity: None,
289 },
290 ],
291 )
292 }
293
294 #[test]
295 fn count_star_over_whole_batch() {
296 let b = batch();
297 let out = batch_aggregate(
298 &b,
299 &[],
300 &[AggregateSpec {
301 column: 0,
302 op: AggregateOp::Count,
303 }],
304 );
305 assert_eq!(out.len(), 1);
306 assert_eq!(out[0].results[0].value, 5.0);
307 }
308
309 #[test]
310 fn sum_grouped_by_region() {
311 let b = batch();
312 let out = batch_aggregate(
313 &b,
314 &[0],
315 &[AggregateSpec {
316 column: 1,
317 op: AggregateOp::Sum,
318 }],
319 );
320 assert_eq!(out.len(), 2);
321 assert_eq!(out[0].key[0], GroupKeyPart::Text("eu".into()));
323 assert_eq!(out[0].results[0].value, 70.0);
324 assert_eq!(out[1].key[0], GroupKeyPart::Text("us".into()));
325 assert_eq!(out[1].results[0].value, 80.0);
326 }
327
328 #[test]
329 fn avg_handles_empty_group_cleanly() {
330 let b = batch();
331 let out = batch_aggregate(
332 &b,
333 &[0],
334 &[AggregateSpec {
335 column: 1,
336 op: AggregateOp::Avg,
337 }],
338 );
339 let eu_row = out
340 .iter()
341 .find(|r| r.key[0] == GroupKeyPart::Text("eu".into()))
342 .unwrap();
343 assert_eq!(eu_row.results[0].value, 35.0);
344 let us_row = out
345 .iter()
346 .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
347 .unwrap();
348 assert!((us_row.results[0].value - (80.0 / 3.0)).abs() < 1e-6);
349 }
350
351 #[test]
352 fn min_and_max_agree_on_shape() {
353 let b = batch();
354 let out = batch_aggregate(
355 &b,
356 &[0],
357 &[
358 AggregateSpec {
359 column: 1,
360 op: AggregateOp::Min,
361 },
362 AggregateSpec {
363 column: 1,
364 op: AggregateOp::Max,
365 },
366 ],
367 );
368 let us = out
369 .iter()
370 .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
371 .unwrap();
372 assert_eq!(us.results[0].value, 10.0);
373 assert_eq!(us.results[1].value, 40.0);
374 }
375
376 #[test]
377 fn empty_batch_returns_empty() {
378 let b = batch();
379 let empty = b.take(&[]);
380 let out = batch_aggregate(
381 &empty,
382 &[],
383 &[AggregateSpec {
384 column: 0,
385 op: AggregateOp::Count,
386 }],
387 );
388 assert!(out.is_empty());
389 }
390
391 #[test]
392 fn multi_key_grouping_preserves_combinations() {
393 let schema = Arc::new(Schema::new(vec![
394 Field {
395 name: "region".into(),
396 kind: ColumnKind::Text,
397 nullable: false,
398 },
399 Field {
400 name: "tier".into(),
401 kind: ColumnKind::Int64,
402 nullable: false,
403 },
404 Field {
405 name: "v".into(),
406 kind: ColumnKind::Int64,
407 nullable: false,
408 },
409 ]));
410 let b = ColumnBatch::new(
411 schema,
412 vec![
413 ColumnVector::Text {
414 data: vec!["a".into(), "a".into(), "b".into(), "a".into()],
415 validity: None,
416 },
417 ColumnVector::Int64 {
418 data: vec![1, 2, 1, 1],
419 validity: None,
420 },
421 ColumnVector::Int64 {
422 data: vec![10, 20, 30, 40],
423 validity: None,
424 },
425 ],
426 );
427 let out = batch_aggregate(
428 &b,
429 &[0, 1],
430 &[AggregateSpec {
431 column: 2,
432 op: AggregateOp::Sum,
433 }],
434 );
435 assert_eq!(out.len(), 3);
436 let find = |r: &str, t: i64| {
438 out.iter()
439 .find(|row| {
440 row.key[0] == GroupKeyPart::Text(r.into())
441 && row.key[1] == GroupKeyPart::Int64(t)
442 })
443 .unwrap()
444 .results[0]
445 .value
446 };
447 assert_eq!(find("a", 1), 50.0);
448 assert_eq!(find("a", 2), 20.0);
449 assert_eq!(find("b", 1), 30.0);
450 }
451}