1use std::collections::HashMap;
10
11use super::super::column_batch::{ColumnBatch, ColumnVector, ValueRef};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AggregateOp {
15 Count,
16 Sum,
17 Avg,
18 Min,
19 Max,
20}
21
22#[derive(Debug, Clone)]
23pub struct AggregateSpec {
24 pub column: usize,
26 pub op: AggregateOp,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub enum GroupKeyPart {
31 Int64(i64),
32 Float64Bits(u64),
33 Bool(bool),
34 Text(String),
35 Null,
36}
37
38impl Eq for GroupKeyPart {}
39
40impl std::hash::Hash for GroupKeyPart {
41 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42 match self {
43 GroupKeyPart::Int64(v) => {
44 0u8.hash(state);
45 v.hash(state);
46 }
47 GroupKeyPart::Float64Bits(v) => {
48 1u8.hash(state);
49 v.hash(state);
50 }
51 GroupKeyPart::Bool(v) => {
52 2u8.hash(state);
53 v.hash(state);
54 }
55 GroupKeyPart::Text(v) => {
56 3u8.hash(state);
57 v.hash(state);
58 }
59 GroupKeyPart::Null => {
60 4u8.hash(state);
61 }
62 }
63 }
64}
65
66type GroupKey = Vec<GroupKeyPart>;
67
68#[derive(Debug, Clone)]
69pub struct AggregateResult {
70 pub op: AggregateOp,
71 pub column: usize,
72 pub value: f64,
73 pub count: u64,
76}
77
78#[derive(Debug, Clone)]
79pub struct AggregateRow {
80 pub key: GroupKey,
81 pub results: Vec<AggregateResult>,
82}
83
84pub fn batch_aggregate(
88 batch: &ColumnBatch,
89 group_columns: &[usize],
90 specs: &[AggregateSpec],
91) -> Vec<AggregateRow> {
92 if batch.is_empty() {
93 return Vec::new();
94 }
95 let mut groups: HashMap<GroupKey, Vec<Accumulator>> = HashMap::new();
96 for row in 0..batch.len() {
97 let key: GroupKey = group_columns
98 .iter()
99 .map(|c| group_key_part(batch, row, *c))
100 .collect();
101 let accs = groups
102 .entry(key)
103 .or_insert_with(|| specs.iter().map(Accumulator::new).collect());
104 for (idx, spec) in specs.iter().enumerate() {
105 accs[idx].observe(batch, row, spec);
106 }
107 }
108 let mut out: Vec<AggregateRow> = groups
109 .into_iter()
110 .map(|(key, accs)| {
111 let results = accs
112 .into_iter()
113 .zip(specs.iter())
114 .map(|(acc, spec)| acc.finalize(spec))
115 .collect();
116 AggregateRow { key, results }
117 })
118 .collect();
119 out.sort_by(|a, b| compare_keys(&a.key, &b.key));
121 out
122}
123
124fn group_key_part(batch: &ColumnBatch, row: usize, column: usize) -> GroupKeyPart {
125 match batch.value(row, column) {
126 ValueRef::Int64(v) => GroupKeyPart::Int64(v),
127 ValueRef::Float64(v) => GroupKeyPart::Float64Bits(v.to_bits()),
128 ValueRef::Bool(v) => GroupKeyPart::Bool(v),
129 ValueRef::Text(v) => GroupKeyPart::Text(v.to_string()),
130 ValueRef::Null => GroupKeyPart::Null,
131 }
132}
133
134#[derive(Debug, Clone)]
135struct Accumulator {
136 count: u64,
137 sum: f64,
138 min: f64,
139 max: f64,
140 any_observed: bool,
141}
142
143impl Accumulator {
144 fn new(_spec: &AggregateSpec) -> Self {
145 Self {
146 count: 0,
147 sum: 0.0,
148 min: f64::INFINITY,
149 max: f64::NEG_INFINITY,
150 any_observed: false,
151 }
152 }
153
154 fn observe(&mut self, batch: &ColumnBatch, row: usize, spec: &AggregateSpec) {
155 match spec.op {
156 AggregateOp::Count => {
157 self.count += 1;
158 }
159 AggregateOp::Sum | AggregateOp::Avg | AggregateOp::Min | AggregateOp::Max => {
160 if let Some(v) = numeric_value(batch, row, spec.column) {
161 self.count += 1;
162 self.sum += v;
163 if v < self.min {
164 self.min = v;
165 }
166 if v > self.max {
167 self.max = v;
168 }
169 self.any_observed = true;
170 }
171 }
172 }
173 }
174
175 fn finalize(self, spec: &AggregateSpec) -> AggregateResult {
176 let value = match spec.op {
177 AggregateOp::Count => self.count as f64,
178 AggregateOp::Sum => self.sum,
179 AggregateOp::Avg => {
180 if self.count == 0 {
181 0.0
182 } else {
183 self.sum / self.count as f64
184 }
185 }
186 AggregateOp::Min => {
187 if self.any_observed {
188 self.min
189 } else {
190 0.0
191 }
192 }
193 AggregateOp::Max => {
194 if self.any_observed {
195 self.max
196 } else {
197 0.0
198 }
199 }
200 };
201 AggregateResult {
202 op: spec.op,
203 column: spec.column,
204 value,
205 count: self.count,
206 }
207 }
208}
209
210fn numeric_value(batch: &ColumnBatch, row: usize, column: usize) -> Option<f64> {
211 let col = batch.columns.get(column)?;
212 if !col.is_valid(row) {
213 return None;
214 }
215 match col {
216 ColumnVector::Int64 { data, .. } => Some(data[row] as f64),
217 ColumnVector::Float64 { data, .. } => Some(data[row]),
218 _ => None,
219 }
220}
221
222fn compare_keys(a: &[GroupKeyPart], b: &[GroupKeyPart]) -> std::cmp::Ordering {
223 for (x, y) in a.iter().zip(b.iter()) {
224 let ord = compare_key_part(x, y);
225 if ord != std::cmp::Ordering::Equal {
226 return ord;
227 }
228 }
229 a.len().cmp(&b.len())
230}
231
232fn compare_key_part(x: &GroupKeyPart, y: &GroupKeyPart) -> std::cmp::Ordering {
233 use std::cmp::Ordering;
234 use GroupKeyPart::*;
235 match (x, y) {
236 (Int64(a), Int64(b)) => a.cmp(b),
237 (Float64Bits(a), Float64Bits(b)) => f64::from_bits(*a)
238 .partial_cmp(&f64::from_bits(*b))
239 .unwrap_or(Ordering::Equal),
240 (Bool(a), Bool(b)) => a.cmp(b),
241 (Text(a), Text(b)) => a.cmp(b),
242 (Null, Null) => Ordering::Equal,
243 (Null, _) => Ordering::Less,
244 (_, Null) => Ordering::Greater,
245 _ => Ordering::Equal,
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::super::super::column_batch::{ColumnKind, Field, Schema};
252 use super::*;
253 use std::sync::Arc;
254
255 fn batch() -> ColumnBatch {
256 let schema = Arc::new(Schema::new(vec![
257 Field {
258 name: "region".into(),
259 kind: ColumnKind::Text,
260 nullable: false,
261 },
262 Field {
263 name: "amount".into(),
264 kind: ColumnKind::Float64,
265 nullable: false,
266 },
267 ]));
268 ColumnBatch::new(
269 schema,
270 vec![
271 ColumnVector::Text {
272 data: vec![
273 "us".into(),
274 "eu".into(),
275 "us".into(),
276 "us".into(),
277 "eu".into(),
278 ],
279 validity: None,
280 },
281 ColumnVector::Float64 {
282 data: vec![10.0, 20.0, 30.0, 40.0, 50.0],
283 validity: None,
284 },
285 ],
286 )
287 }
288
289 #[test]
290 fn count_star_over_whole_batch() {
291 let b = batch();
292 let out = batch_aggregate(
293 &b,
294 &[],
295 &[AggregateSpec {
296 column: 0,
297 op: AggregateOp::Count,
298 }],
299 );
300 assert_eq!(out.len(), 1);
301 assert_eq!(out[0].results[0].value, 5.0);
302 }
303
304 #[test]
305 fn sum_grouped_by_region() {
306 let b = batch();
307 let out = batch_aggregate(
308 &b,
309 &[0],
310 &[AggregateSpec {
311 column: 1,
312 op: AggregateOp::Sum,
313 }],
314 );
315 assert_eq!(out.len(), 2);
316 assert_eq!(out[0].key[0], GroupKeyPart::Text("eu".into()));
318 assert_eq!(out[0].results[0].value, 70.0);
319 assert_eq!(out[1].key[0], GroupKeyPart::Text("us".into()));
320 assert_eq!(out[1].results[0].value, 80.0);
321 }
322
323 #[test]
324 fn avg_handles_empty_group_cleanly() {
325 let b = batch();
326 let out = batch_aggregate(
327 &b,
328 &[0],
329 &[AggregateSpec {
330 column: 1,
331 op: AggregateOp::Avg,
332 }],
333 );
334 let eu_row = out
335 .iter()
336 .find(|r| r.key[0] == GroupKeyPart::Text("eu".into()))
337 .unwrap();
338 assert_eq!(eu_row.results[0].value, 35.0);
339 let us_row = out
340 .iter()
341 .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
342 .unwrap();
343 assert!((us_row.results[0].value - (80.0 / 3.0)).abs() < 1e-6);
344 }
345
346 #[test]
347 fn min_and_max_agree_on_shape() {
348 let b = batch();
349 let out = batch_aggregate(
350 &b,
351 &[0],
352 &[
353 AggregateSpec {
354 column: 1,
355 op: AggregateOp::Min,
356 },
357 AggregateSpec {
358 column: 1,
359 op: AggregateOp::Max,
360 },
361 ],
362 );
363 let us = out
364 .iter()
365 .find(|r| r.key[0] == GroupKeyPart::Text("us".into()))
366 .unwrap();
367 assert_eq!(us.results[0].value, 10.0);
368 assert_eq!(us.results[1].value, 40.0);
369 }
370
371 #[test]
372 fn empty_batch_returns_empty() {
373 let b = batch();
374 let empty = b.take(&[]);
375 let out = batch_aggregate(
376 &empty,
377 &[],
378 &[AggregateSpec {
379 column: 0,
380 op: AggregateOp::Count,
381 }],
382 );
383 assert!(out.is_empty());
384 }
385
386 #[test]
387 fn multi_key_grouping_preserves_combinations() {
388 let schema = Arc::new(Schema::new(vec![
389 Field {
390 name: "region".into(),
391 kind: ColumnKind::Text,
392 nullable: false,
393 },
394 Field {
395 name: "tier".into(),
396 kind: ColumnKind::Int64,
397 nullable: false,
398 },
399 Field {
400 name: "v".into(),
401 kind: ColumnKind::Int64,
402 nullable: false,
403 },
404 ]));
405 let b = ColumnBatch::new(
406 schema,
407 vec![
408 ColumnVector::Text {
409 data: vec!["a".into(), "a".into(), "b".into(), "a".into()],
410 validity: None,
411 },
412 ColumnVector::Int64 {
413 data: vec![1, 2, 1, 1],
414 validity: None,
415 },
416 ColumnVector::Int64 {
417 data: vec![10, 20, 30, 40],
418 validity: None,
419 },
420 ],
421 );
422 let out = batch_aggregate(
423 &b,
424 &[0, 1],
425 &[AggregateSpec {
426 column: 2,
427 op: AggregateOp::Sum,
428 }],
429 );
430 assert_eq!(out.len(), 3);
431 let find = |r: &str, t: i64| {
433 out.iter()
434 .find(|row| {
435 row.key[0] == GroupKeyPart::Text(r.into())
436 && row.key[1] == GroupKeyPart::Int64(t)
437 })
438 .unwrap()
439 .results[0]
440 .value
441 };
442 assert_eq!(find("a", 1), 50.0);
443 assert_eq!(find("a", 2), 20.0);
444 assert_eq!(find("b", 1), 30.0);
445 }
446}