1use datafusion::physical_expr::aggregate::utils::Hashable;
19use datafusion::{arrow, common, error, logical_expr, scalar};
20use std::{cmp, collections, fmt, hash, mem};
21
22#[derive(fmt::Debug)]
23pub struct PrimitiveModeAccumulator<T>
24where
25 T: arrow::array::ArrowPrimitiveType + Send,
26 T::Native: Eq + hash::Hash,
27{
28 value_counts: collections::HashMap<T::Native, i64>,
29 data_type: arrow::datatypes::DataType,
30}
31
32impl<T> PrimitiveModeAccumulator<T>
33where
34 T: arrow::array::ArrowPrimitiveType + Send,
35 T::Native: Eq + hash::Hash + Clone,
36{
37 pub fn new(data_type: &arrow::datatypes::DataType) -> Self {
38 Self {
39 value_counts: collections::HashMap::default(),
40 data_type: data_type.clone(),
41 }
42 }
43}
44
45impl<T> logical_expr::Accumulator for PrimitiveModeAccumulator<T>
46where
47 T: arrow::array::ArrowPrimitiveType + Send + fmt::Debug,
48 T::Native: Eq + hash::Hash + Clone + PartialOrd + fmt::Debug,
49{
50 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> error::Result<()> {
51 if values.is_empty() {
52 return Ok(());
53 }
54 let arr = common::cast::as_primitive_array::<T>(&values[0])?;
55
56 for value in arr.iter().flatten() {
57 let counter = self.value_counts.entry(value).or_insert(0);
58 *counter += 1;
59 }
60
61 Ok(())
62 }
63
64 fn state(&mut self) -> error::Result<Vec<scalar::ScalarValue>> {
65 let values: Vec<scalar::ScalarValue> = self
66 .value_counts
67 .keys()
68 .map(|key| scalar::ScalarValue::new_primitive::<T>(Some(*key), &self.data_type))
69 .collect::<error::Result<Vec<_>>>()?;
70
71 let frequencies: Vec<scalar::ScalarValue> = self
72 .value_counts
73 .values()
74 .map(|count| scalar::ScalarValue::from(*count))
75 .collect();
76
77 let values_scalar =
78 scalar::ScalarValue::new_list_nullable(&values, &self.data_type.clone());
79 let frequencies_scalar = scalar::ScalarValue::new_list_nullable(
80 &frequencies,
81 &arrow::datatypes::DataType::Int64,
82 );
83
84 Ok(vec![
85 scalar::ScalarValue::List(values_scalar),
86 scalar::ScalarValue::List(frequencies_scalar),
87 ])
88 }
89
90 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> error::Result<()> {
91 if states.is_empty() {
92 return Ok(());
93 }
94
95 let values_array = common::cast::as_primitive_array::<T>(&states[0])?;
96 let counts_array =
97 common::cast::as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
98
99 for i in 0..values_array.len() {
100 let value = values_array.value(i);
101 let count = counts_array.value(i);
102 let entry = self.value_counts.entry(value).or_insert(0);
103 *entry += count;
104 }
105
106 Ok(())
107 }
108
109 fn evaluate(&mut self) -> error::Result<scalar::ScalarValue> {
110 let mut max_value: Option<T::Native> = None;
111 let mut max_count: i64 = 0;
112
113 self.value_counts.iter().for_each(|(value, &count)| {
114 match count.cmp(&max_count) {
115 cmp::Ordering::Greater => {
116 max_value = Some(*value);
117 max_count = count;
118 }
119 cmp::Ordering::Equal => {
120 max_value = match max_value {
121 Some(ref current_max_value) if value > current_max_value => Some(*value),
122 Some(ref current_max_value) => Some(*current_max_value),
123 None => Some(*value),
124 };
125 }
126 _ => {} }
128 });
129
130 match max_value {
131 Some(val) => scalar::ScalarValue::new_primitive::<T>(Some(val), &self.data_type),
132 None => scalar::ScalarValue::new_primitive::<T>(None, &self.data_type),
133 }
134 }
135
136 fn size(&self) -> usize {
137 mem::size_of_val(&self.value_counts)
138 + self.value_counts.len() * mem::size_of::<(T::Native, i64)>()
139 }
140}
141
142#[derive(Debug)]
143pub struct FloatModeAccumulator<T>
144where
145 T: arrow::array::ArrowPrimitiveType,
146{
147 value_counts: collections::HashMap<Hashable<T::Native>, i64>,
148 data_type: arrow::datatypes::DataType,
149}
150
151impl<T> FloatModeAccumulator<T>
152where
153 T: arrow::array::ArrowPrimitiveType,
154{
155 pub fn new(data_type: &arrow::datatypes::DataType) -> Self {
156 Self {
157 value_counts: collections::HashMap::default(),
158 data_type: data_type.clone(),
159 }
160 }
161}
162
163impl<T> logical_expr::Accumulator for FloatModeAccumulator<T>
164where
165 T: arrow::array::ArrowPrimitiveType + Send + fmt::Debug,
166 T::Native: PartialOrd + fmt::Debug + Clone,
167{
168 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> error::Result<()> {
169 if values.is_empty() {
170 return Ok(());
171 }
172
173 let arr = common::cast::as_primitive_array::<T>(&values[0])?;
174
175 for value in arr.iter().flatten() {
176 let counter = self.value_counts.entry(Hashable(value)).or_insert(0);
177 *counter += 1;
178 }
179
180 Ok(())
181 }
182
183 fn state(&mut self) -> error::Result<Vec<scalar::ScalarValue>> {
184 let values: Vec<scalar::ScalarValue> = self
185 .value_counts
186 .keys()
187 .map(|key| scalar::ScalarValue::new_primitive::<T>(Some(key.0), &self.data_type))
188 .collect::<error::Result<Vec<_>>>()?;
189
190 let frequencies: Vec<scalar::ScalarValue> = self
191 .value_counts
192 .values()
193 .map(|count| scalar::ScalarValue::from(*count))
194 .collect();
195
196 let values_scalar =
197 scalar::ScalarValue::new_list_nullable(&values, &self.data_type.clone());
198 let frequencies_scalar = scalar::ScalarValue::new_list_nullable(
199 &frequencies,
200 &arrow::datatypes::DataType::Int64,
201 );
202
203 Ok(vec![
204 scalar::ScalarValue::List(values_scalar),
205 scalar::ScalarValue::List(frequencies_scalar),
206 ])
207 }
208
209 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> error::Result<()> {
210 if states.is_empty() {
211 return Ok(());
212 }
213
214 let values_array = common::cast::as_primitive_array::<T>(&states[0])?;
215 let counts_array =
216 common::cast::as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
217
218 for i in 0..values_array.len() {
219 let count = counts_array.value(i);
220 let entry = self
221 .value_counts
222 .entry(Hashable(values_array.value(i)))
223 .or_insert(0);
224 *entry += count;
225 }
226
227 Ok(())
228 }
229
230 fn evaluate(&mut self) -> error::Result<scalar::ScalarValue> {
231 let mut max_value: Option<T::Native> = None;
232 let mut max_count: i64 = 0;
233
234 self.value_counts.iter().for_each(|(value, &count)| {
235 match count.cmp(&max_count) {
236 cmp::Ordering::Greater => {
237 max_value = Some(value.0);
238 max_count = count;
239 }
240 cmp::Ordering::Equal => {
241 max_value = match max_value {
242 Some(current_max_value) if value.0 > current_max_value => Some(value.0),
243 Some(current_max_value) => Some(current_max_value),
244 None => Some(value.0),
245 };
246 }
247 _ => {} }
249 });
250
251 match max_value {
252 Some(val) => scalar::ScalarValue::new_primitive::<T>(Some(val), &self.data_type),
253 None => scalar::ScalarValue::new_primitive::<T>(None, &self.data_type),
254 }
255 }
256
257 fn size(&self) -> usize {
258 mem::size_of_val(&self.value_counts)
259 + self.value_counts.len() * mem::size_of::<(Hashable<T::Native>, i64)>()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265
266 use super::*;
267
268 use datafusion::logical_expr::Accumulator;
269 use std::sync;
270
271 #[test]
272 fn test_mode_accumulator_single_mode_int64() -> error::Result<()> {
273 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
274 &arrow::datatypes::DataType::Int64,
275 );
276 let values: arrow::array::ArrayRef =
277 sync::Arc::new(arrow::array::Int64Array::from(vec![1, 2, 2, 3, 3, 3]));
278 acc.update_batch(&[values])?;
279 let result = acc.evaluate()?;
280 assert_eq!(
281 result,
282 scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
283 Some(3),
284 &arrow::datatypes::DataType::Int64
285 )?
286 );
287 Ok(())
288 }
289
290 #[test]
291 fn test_mode_accumulator_with_nulls_int64() -> error::Result<()> {
292 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
293 &arrow::datatypes::DataType::Int64,
294 );
295 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Int64Array::from(vec![
296 None,
297 Some(1),
298 Some(2),
299 Some(2),
300 Some(3),
301 Some(3),
302 Some(3),
303 ]));
304 acc.update_batch(&[values])?;
305 let result = acc.evaluate()?;
306 assert_eq!(
307 result,
308 scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
309 Some(3),
310 &arrow::datatypes::DataType::Int64
311 )?
312 );
313 Ok(())
314 }
315
316 #[test]
317 fn test_mode_accumulator_tie_case_int64() -> error::Result<()> {
318 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
319 &arrow::datatypes::DataType::Int64,
320 );
321 let values: arrow::array::ArrayRef =
322 sync::Arc::new(arrow::array::Int64Array::from(vec![1, 2, 2, 3, 3]));
323 acc.update_batch(&[values])?;
324 let result = acc.evaluate()?;
325 assert_eq!(
326 result,
327 scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
328 Some(3),
329 &arrow::datatypes::DataType::Int64
330 )?
331 );
332 Ok(())
333 }
334
335 #[test]
336 fn test_mode_accumulator_only_nulls_int64() -> error::Result<()> {
337 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
338 &arrow::datatypes::DataType::Int64,
339 );
340 let values: arrow::array::ArrayRef =
341 sync::Arc::new(arrow::array::Int64Array::from(vec![None, None, None, None]));
342 acc.update_batch(&[values])?;
343 let result = acc.evaluate()?;
344 assert_eq!(
345 result,
346 scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
347 None,
348 &arrow::datatypes::DataType::Int64
349 )?
350 );
351 Ok(())
352 }
353
354 #[test]
355 fn test_mode_accumulator_single_mode_float64() -> error::Result<()> {
356 let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
357 &arrow::datatypes::DataType::Float64,
358 );
359 let values: arrow::array::ArrayRef =
360 sync::Arc::new(arrow::array::Float64Array::from(vec![
361 1.0, 2.0, 2.0, 3.0, 3.0, 3.0,
362 ]));
363 acc.update_batch(&[values])?;
364 let result = acc.evaluate()?;
365 assert_eq!(
366 result,
367 scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
368 Some(3.0),
369 &arrow::datatypes::DataType::Float64
370 )?
371 );
372 Ok(())
373 }
374
375 #[test]
376 fn test_mode_accumulator_with_nulls_float64() -> error::Result<()> {
377 let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
378 &arrow::datatypes::DataType::Float64,
379 );
380 let values: arrow::array::ArrayRef =
381 sync::Arc::new(arrow::array::Float64Array::from(vec![
382 None,
383 Some(1.0),
384 Some(2.0),
385 Some(2.0),
386 Some(3.0),
387 Some(3.0),
388 Some(3.0),
389 ]));
390 acc.update_batch(&[values])?;
391 let result = acc.evaluate()?;
392 assert_eq!(
393 result,
394 scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
395 Some(3.0),
396 &arrow::datatypes::DataType::Float64
397 )?
398 );
399 Ok(())
400 }
401
402 #[test]
403 fn test_mode_accumulator_tie_case_float64() -> error::Result<()> {
404 let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
405 &arrow::datatypes::DataType::Float64,
406 );
407 let values: arrow::array::ArrayRef =
408 sync::Arc::new(arrow::array::Float64Array::from(vec![
409 1.0, 2.0, 2.0, 3.0, 3.0,
410 ]));
411 acc.update_batch(&[values])?;
412 let result = acc.evaluate()?;
413 assert_eq!(
414 result,
415 scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
416 Some(3.0),
417 &arrow::datatypes::DataType::Float64
418 )?
419 );
420 Ok(())
421 }
422
423 #[test]
424 fn test_mode_accumulator_only_nulls_float64() -> error::Result<()> {
425 let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
426 &arrow::datatypes::DataType::Float64,
427 );
428 let values: arrow::array::ArrayRef =
429 sync::Arc::new(arrow::array::Float64Array::from(vec![
430 None, None, None, None,
431 ]));
432 acc.update_batch(&[values])?;
433 let result = acc.evaluate()?;
434 assert_eq!(
435 result,
436 scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
437 None,
438 &arrow::datatypes::DataType::Float64
439 )?
440 );
441 Ok(())
442 }
443
444 #[test]
445 fn test_mode_accumulator_single_mode_date64() -> error::Result<()> {
446 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
447 &arrow::datatypes::DataType::Date64,
448 );
449 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
450 1609459200000,
451 1609545600000,
452 1609545600000,
453 1609632000000,
454 1609632000000,
455 1609632000000,
456 ]));
457 acc.update_batch(&[values])?;
458 let result = acc.evaluate()?;
459 assert_eq!(
460 result,
461 scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
462 Some(1609632000000),
463 &arrow::datatypes::DataType::Date64
464 )?
465 );
466 Ok(())
467 }
468
469 #[test]
470 fn test_mode_accumulator_with_nulls_date64() -> error::Result<()> {
471 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
472 &arrow::datatypes::DataType::Date64,
473 );
474 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
475 None,
476 Some(1609459200000),
477 Some(1609545600000),
478 Some(1609545600000),
479 Some(1609632000000),
480 Some(1609632000000),
481 Some(1609632000000),
482 ]));
483 acc.update_batch(&[values])?;
484 let result = acc.evaluate()?;
485 assert_eq!(
486 result,
487 scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
488 Some(1609632000000),
489 &arrow::datatypes::DataType::Date64
490 )?
491 );
492 Ok(())
493 }
494
495 #[test]
496 fn test_mode_accumulator_tie_case_date64() -> error::Result<()> {
497 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
498 &arrow::datatypes::DataType::Date64,
499 );
500 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
501 1609459200000,
502 1609545600000,
503 1609545600000,
504 1609632000000,
505 1609632000000,
506 ]));
507 acc.update_batch(&[values])?;
508 let result = acc.evaluate()?;
509 assert_eq!(
510 result,
511 scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
512 Some(1609632000000),
513 &arrow::datatypes::DataType::Date64
514 )?
515 );
516 Ok(())
517 }
518
519 #[test]
520 fn test_mode_accumulator_only_nulls_date64() -> error::Result<()> {
521 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
522 &arrow::datatypes::DataType::Date64,
523 );
524 let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
525 None, None, None, None,
526 ]));
527 acc.update_batch(&[values])?;
528 let result = acc.evaluate()?;
529 assert_eq!(
530 result,
531 scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
532 None,
533 &arrow::datatypes::DataType::Date64
534 )?
535 );
536 Ok(())
537 }
538
539 #[test]
540 fn test_mode_accumulator_single_mode_time64() -> error::Result<()> {
541 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
542 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
543 );
544 let values: arrow::array::ArrayRef =
545 sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
546 3600000000,
547 7200000000,
548 7200000000,
549 10800000000,
550 10800000000,
551 10800000000,
552 ]));
553 acc.update_batch(&[values])?;
554 let result = acc.evaluate()?;
555 assert_eq!(
556 result,
557 scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
558 Some(10800000000),
559 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
560 )?
561 );
562 Ok(())
563 }
564
565 #[test]
566 fn test_mode_accumulator_with_nulls_time64() -> error::Result<()> {
567 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
568 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
569 );
570 let values: arrow::array::ArrayRef =
571 sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
572 None,
573 Some(3600000000),
574 Some(7200000000),
575 Some(7200000000),
576 Some(10800000000),
577 Some(10800000000),
578 Some(10800000000),
579 ]));
580 acc.update_batch(&[values])?;
581 let result = acc.evaluate()?;
582 assert_eq!(
583 result,
584 scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
585 Some(10800000000),
586 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
587 )?
588 );
589 Ok(())
590 }
591
592 #[test]
593 fn test_mode_accumulator_tie_case_time64() -> error::Result<()> {
594 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
595 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
596 );
597 let values: arrow::array::ArrayRef =
598 sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
599 3600000000,
600 7200000000,
601 7200000000,
602 10800000000,
603 10800000000,
604 ]));
605 acc.update_batch(&[values])?;
606 let result = acc.evaluate()?;
607 assert_eq!(
608 result,
609 scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
610 Some(10800000000),
611 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
612 )?
613 );
614 Ok(())
615 }
616
617 #[test]
618 fn test_mode_accumulator_only_nulls_time64() -> error::Result<()> {
619 let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
620 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
621 );
622 let values: arrow::array::ArrayRef =
623 sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
624 None, None, None, None,
625 ]));
626 acc.update_batch(&[values])?;
627 let result = acc.evaluate()?;
628 assert_eq!(
629 result,
630 scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
631 None,
632 &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
633 )?
634 );
635 Ok(())
636 }
637}