datafusion_functions_extra/common/mode/
native.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::physical_expr::aggregate::utils::Hashable;
19use datafusion::{arrow, common, error, logical_expr, scalar};
20use std::{cmp, collections, fmt, hash, mem};
21
22#[derive(fmt::Debug)]
23pub struct PrimitiveModeAccumulator<T>
24where
25    T: arrow::array::ArrowPrimitiveType + Send,
26    T::Native: Eq + hash::Hash,
27{
28    value_counts: collections::HashMap<T::Native, i64>,
29    data_type: arrow::datatypes::DataType,
30}
31
32impl<T> PrimitiveModeAccumulator<T>
33where
34    T: arrow::array::ArrowPrimitiveType + Send,
35    T::Native: Eq + hash::Hash + Clone,
36{
37    pub fn new(data_type: &arrow::datatypes::DataType) -> Self {
38        Self {
39            value_counts: collections::HashMap::default(),
40            data_type: data_type.clone(),
41        }
42    }
43}
44
45impl<T> logical_expr::Accumulator for PrimitiveModeAccumulator<T>
46where
47    T: arrow::array::ArrowPrimitiveType + Send + fmt::Debug,
48    T::Native: Eq + hash::Hash + Clone + PartialOrd + fmt::Debug,
49{
50    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> error::Result<()> {
51        if values.is_empty() {
52            return Ok(());
53        }
54        let arr = common::cast::as_primitive_array::<T>(&values[0])?;
55
56        for value in arr.iter().flatten() {
57            let counter = self.value_counts.entry(value).or_insert(0);
58            *counter += 1;
59        }
60
61        Ok(())
62    }
63
64    fn state(&mut self) -> error::Result<Vec<scalar::ScalarValue>> {
65        let values: Vec<scalar::ScalarValue> = self
66            .value_counts
67            .keys()
68            .map(|key| scalar::ScalarValue::new_primitive::<T>(Some(*key), &self.data_type))
69            .collect::<error::Result<Vec<_>>>()?;
70
71        let frequencies: Vec<scalar::ScalarValue> = self
72            .value_counts
73            .values()
74            .map(|count| scalar::ScalarValue::from(*count))
75            .collect();
76
77        let values_scalar =
78            scalar::ScalarValue::new_list_nullable(&values, &self.data_type.clone());
79        let frequencies_scalar = scalar::ScalarValue::new_list_nullable(
80            &frequencies,
81            &arrow::datatypes::DataType::Int64,
82        );
83
84        Ok(vec![
85            scalar::ScalarValue::List(values_scalar),
86            scalar::ScalarValue::List(frequencies_scalar),
87        ])
88    }
89
90    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> error::Result<()> {
91        if states.is_empty() {
92            return Ok(());
93        }
94
95        let values_array = common::cast::as_primitive_array::<T>(&states[0])?;
96        let counts_array =
97            common::cast::as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
98
99        for i in 0..values_array.len() {
100            let value = values_array.value(i);
101            let count = counts_array.value(i);
102            let entry = self.value_counts.entry(value).or_insert(0);
103            *entry += count;
104        }
105
106        Ok(())
107    }
108
109    fn evaluate(&mut self) -> error::Result<scalar::ScalarValue> {
110        let mut max_value: Option<T::Native> = None;
111        let mut max_count: i64 = 0;
112
113        self.value_counts.iter().for_each(|(value, &count)| {
114            match count.cmp(&max_count) {
115                cmp::Ordering::Greater => {
116                    max_value = Some(*value);
117                    max_count = count;
118                }
119                cmp::Ordering::Equal => {
120                    max_value = match max_value {
121                        Some(ref current_max_value) if value > current_max_value => Some(*value),
122                        Some(ref current_max_value) => Some(*current_max_value),
123                        None => Some(*value),
124                    };
125                }
126                _ => {} // Do nothing if count is less than max_count
127            }
128        });
129
130        match max_value {
131            Some(val) => scalar::ScalarValue::new_primitive::<T>(Some(val), &self.data_type),
132            None => scalar::ScalarValue::new_primitive::<T>(None, &self.data_type),
133        }
134    }
135
136    fn size(&self) -> usize {
137        mem::size_of_val(&self.value_counts)
138            + self.value_counts.len() * mem::size_of::<(T::Native, i64)>()
139    }
140}
141
142#[derive(Debug)]
143pub struct FloatModeAccumulator<T>
144where
145    T: arrow::array::ArrowPrimitiveType,
146{
147    value_counts: collections::HashMap<Hashable<T::Native>, i64>,
148    data_type: arrow::datatypes::DataType,
149}
150
151impl<T> FloatModeAccumulator<T>
152where
153    T: arrow::array::ArrowPrimitiveType,
154{
155    pub fn new(data_type: &arrow::datatypes::DataType) -> Self {
156        Self {
157            value_counts: collections::HashMap::default(),
158            data_type: data_type.clone(),
159        }
160    }
161}
162
163impl<T> logical_expr::Accumulator for FloatModeAccumulator<T>
164where
165    T: arrow::array::ArrowPrimitiveType + Send + fmt::Debug,
166    T::Native: PartialOrd + fmt::Debug + Clone,
167{
168    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> error::Result<()> {
169        if values.is_empty() {
170            return Ok(());
171        }
172
173        let arr = common::cast::as_primitive_array::<T>(&values[0])?;
174
175        for value in arr.iter().flatten() {
176            let counter = self.value_counts.entry(Hashable(value)).or_insert(0);
177            *counter += 1;
178        }
179
180        Ok(())
181    }
182
183    fn state(&mut self) -> error::Result<Vec<scalar::ScalarValue>> {
184        let values: Vec<scalar::ScalarValue> = self
185            .value_counts
186            .keys()
187            .map(|key| scalar::ScalarValue::new_primitive::<T>(Some(key.0), &self.data_type))
188            .collect::<error::Result<Vec<_>>>()?;
189
190        let frequencies: Vec<scalar::ScalarValue> = self
191            .value_counts
192            .values()
193            .map(|count| scalar::ScalarValue::from(*count))
194            .collect();
195
196        let values_scalar =
197            scalar::ScalarValue::new_list_nullable(&values, &self.data_type.clone());
198        let frequencies_scalar = scalar::ScalarValue::new_list_nullable(
199            &frequencies,
200            &arrow::datatypes::DataType::Int64,
201        );
202
203        Ok(vec![
204            scalar::ScalarValue::List(values_scalar),
205            scalar::ScalarValue::List(frequencies_scalar),
206        ])
207    }
208
209    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> error::Result<()> {
210        if states.is_empty() {
211            return Ok(());
212        }
213
214        let values_array = common::cast::as_primitive_array::<T>(&states[0])?;
215        let counts_array =
216            common::cast::as_primitive_array::<arrow::datatypes::Int64Type>(&states[1])?;
217
218        for i in 0..values_array.len() {
219            let count = counts_array.value(i);
220            let entry = self
221                .value_counts
222                .entry(Hashable(values_array.value(i)))
223                .or_insert(0);
224            *entry += count;
225        }
226
227        Ok(())
228    }
229
230    fn evaluate(&mut self) -> error::Result<scalar::ScalarValue> {
231        let mut max_value: Option<T::Native> = None;
232        let mut max_count: i64 = 0;
233
234        self.value_counts.iter().for_each(|(value, &count)| {
235            match count.cmp(&max_count) {
236                cmp::Ordering::Greater => {
237                    max_value = Some(value.0);
238                    max_count = count;
239                }
240                cmp::Ordering::Equal => {
241                    max_value = match max_value {
242                        Some(current_max_value) if value.0 > current_max_value => Some(value.0),
243                        Some(current_max_value) => Some(current_max_value),
244                        None => Some(value.0),
245                    };
246                }
247                _ => {} // Do nothing if count is less than max_count
248            }
249        });
250
251        match max_value {
252            Some(val) => scalar::ScalarValue::new_primitive::<T>(Some(val), &self.data_type),
253            None => scalar::ScalarValue::new_primitive::<T>(None, &self.data_type),
254        }
255    }
256
257    fn size(&self) -> usize {
258        mem::size_of_val(&self.value_counts)
259            + self.value_counts.len() * mem::size_of::<(Hashable<T::Native>, i64)>()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265
266    use super::*;
267
268    use datafusion::logical_expr::Accumulator;
269    use std::sync;
270
271    #[test]
272    fn test_mode_accumulator_single_mode_int64() -> error::Result<()> {
273        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
274            &arrow::datatypes::DataType::Int64,
275        );
276        let values: arrow::array::ArrayRef =
277            sync::Arc::new(arrow::array::Int64Array::from(vec![1, 2, 2, 3, 3, 3]));
278        acc.update_batch(&[values])?;
279        let result = acc.evaluate()?;
280        assert_eq!(
281            result,
282            scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
283                Some(3),
284                &arrow::datatypes::DataType::Int64
285            )?
286        );
287        Ok(())
288    }
289
290    #[test]
291    fn test_mode_accumulator_with_nulls_int64() -> error::Result<()> {
292        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
293            &arrow::datatypes::DataType::Int64,
294        );
295        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Int64Array::from(vec![
296            None,
297            Some(1),
298            Some(2),
299            Some(2),
300            Some(3),
301            Some(3),
302            Some(3),
303        ]));
304        acc.update_batch(&[values])?;
305        let result = acc.evaluate()?;
306        assert_eq!(
307            result,
308            scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
309                Some(3),
310                &arrow::datatypes::DataType::Int64
311            )?
312        );
313        Ok(())
314    }
315
316    #[test]
317    fn test_mode_accumulator_tie_case_int64() -> error::Result<()> {
318        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
319            &arrow::datatypes::DataType::Int64,
320        );
321        let values: arrow::array::ArrayRef =
322            sync::Arc::new(arrow::array::Int64Array::from(vec![1, 2, 2, 3, 3]));
323        acc.update_batch(&[values])?;
324        let result = acc.evaluate()?;
325        assert_eq!(
326            result,
327            scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
328                Some(3),
329                &arrow::datatypes::DataType::Int64
330            )?
331        );
332        Ok(())
333    }
334
335    #[test]
336    fn test_mode_accumulator_only_nulls_int64() -> error::Result<()> {
337        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Int64Type>::new(
338            &arrow::datatypes::DataType::Int64,
339        );
340        let values: arrow::array::ArrayRef =
341            sync::Arc::new(arrow::array::Int64Array::from(vec![None, None, None, None]));
342        acc.update_batch(&[values])?;
343        let result = acc.evaluate()?;
344        assert_eq!(
345            result,
346            scalar::ScalarValue::new_primitive::<arrow::datatypes::Int64Type>(
347                None,
348                &arrow::datatypes::DataType::Int64
349            )?
350        );
351        Ok(())
352    }
353
354    #[test]
355    fn test_mode_accumulator_single_mode_float64() -> error::Result<()> {
356        let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
357            &arrow::datatypes::DataType::Float64,
358        );
359        let values: arrow::array::ArrayRef =
360            sync::Arc::new(arrow::array::Float64Array::from(vec![
361                1.0, 2.0, 2.0, 3.0, 3.0, 3.0,
362            ]));
363        acc.update_batch(&[values])?;
364        let result = acc.evaluate()?;
365        assert_eq!(
366            result,
367            scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
368                Some(3.0),
369                &arrow::datatypes::DataType::Float64
370            )?
371        );
372        Ok(())
373    }
374
375    #[test]
376    fn test_mode_accumulator_with_nulls_float64() -> error::Result<()> {
377        let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
378            &arrow::datatypes::DataType::Float64,
379        );
380        let values: arrow::array::ArrayRef =
381            sync::Arc::new(arrow::array::Float64Array::from(vec![
382                None,
383                Some(1.0),
384                Some(2.0),
385                Some(2.0),
386                Some(3.0),
387                Some(3.0),
388                Some(3.0),
389            ]));
390        acc.update_batch(&[values])?;
391        let result = acc.evaluate()?;
392        assert_eq!(
393            result,
394            scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
395                Some(3.0),
396                &arrow::datatypes::DataType::Float64
397            )?
398        );
399        Ok(())
400    }
401
402    #[test]
403    fn test_mode_accumulator_tie_case_float64() -> error::Result<()> {
404        let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
405            &arrow::datatypes::DataType::Float64,
406        );
407        let values: arrow::array::ArrayRef =
408            sync::Arc::new(arrow::array::Float64Array::from(vec![
409                1.0, 2.0, 2.0, 3.0, 3.0,
410            ]));
411        acc.update_batch(&[values])?;
412        let result = acc.evaluate()?;
413        assert_eq!(
414            result,
415            scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
416                Some(3.0),
417                &arrow::datatypes::DataType::Float64
418            )?
419        );
420        Ok(())
421    }
422
423    #[test]
424    fn test_mode_accumulator_only_nulls_float64() -> error::Result<()> {
425        let mut acc = FloatModeAccumulator::<arrow::datatypes::Float64Type>::new(
426            &arrow::datatypes::DataType::Float64,
427        );
428        let values: arrow::array::ArrayRef =
429            sync::Arc::new(arrow::array::Float64Array::from(vec![
430                None, None, None, None,
431            ]));
432        acc.update_batch(&[values])?;
433        let result = acc.evaluate()?;
434        assert_eq!(
435            result,
436            scalar::ScalarValue::new_primitive::<arrow::datatypes::Float64Type>(
437                None,
438                &arrow::datatypes::DataType::Float64
439            )?
440        );
441        Ok(())
442    }
443
444    #[test]
445    fn test_mode_accumulator_single_mode_date64() -> error::Result<()> {
446        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
447            &arrow::datatypes::DataType::Date64,
448        );
449        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
450            1609459200000,
451            1609545600000,
452            1609545600000,
453            1609632000000,
454            1609632000000,
455            1609632000000,
456        ]));
457        acc.update_batch(&[values])?;
458        let result = acc.evaluate()?;
459        assert_eq!(
460            result,
461            scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
462                Some(1609632000000),
463                &arrow::datatypes::DataType::Date64
464            )?
465        );
466        Ok(())
467    }
468
469    #[test]
470    fn test_mode_accumulator_with_nulls_date64() -> error::Result<()> {
471        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
472            &arrow::datatypes::DataType::Date64,
473        );
474        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
475            None,
476            Some(1609459200000),
477            Some(1609545600000),
478            Some(1609545600000),
479            Some(1609632000000),
480            Some(1609632000000),
481            Some(1609632000000),
482        ]));
483        acc.update_batch(&[values])?;
484        let result = acc.evaluate()?;
485        assert_eq!(
486            result,
487            scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
488                Some(1609632000000),
489                &arrow::datatypes::DataType::Date64
490            )?
491        );
492        Ok(())
493    }
494
495    #[test]
496    fn test_mode_accumulator_tie_case_date64() -> error::Result<()> {
497        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
498            &arrow::datatypes::DataType::Date64,
499        );
500        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
501            1609459200000,
502            1609545600000,
503            1609545600000,
504            1609632000000,
505            1609632000000,
506        ]));
507        acc.update_batch(&[values])?;
508        let result = acc.evaluate()?;
509        assert_eq!(
510            result,
511            scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
512                Some(1609632000000),
513                &arrow::datatypes::DataType::Date64
514            )?
515        );
516        Ok(())
517    }
518
519    #[test]
520    fn test_mode_accumulator_only_nulls_date64() -> error::Result<()> {
521        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Date64Type>::new(
522            &arrow::datatypes::DataType::Date64,
523        );
524        let values: arrow::array::ArrayRef = sync::Arc::new(arrow::array::Date64Array::from(vec![
525            None, None, None, None,
526        ]));
527        acc.update_batch(&[values])?;
528        let result = acc.evaluate()?;
529        assert_eq!(
530            result,
531            scalar::ScalarValue::new_primitive::<arrow::datatypes::Date64Type>(
532                None,
533                &arrow::datatypes::DataType::Date64
534            )?
535        );
536        Ok(())
537    }
538
539    #[test]
540    fn test_mode_accumulator_single_mode_time64() -> error::Result<()> {
541        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
542            &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
543        );
544        let values: arrow::array::ArrayRef =
545            sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
546                3600000000,
547                7200000000,
548                7200000000,
549                10800000000,
550                10800000000,
551                10800000000,
552            ]));
553        acc.update_batch(&[values])?;
554        let result = acc.evaluate()?;
555        assert_eq!(
556            result,
557            scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
558                Some(10800000000),
559                &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
560            )?
561        );
562        Ok(())
563    }
564
565    #[test]
566    fn test_mode_accumulator_with_nulls_time64() -> error::Result<()> {
567        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
568            &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
569        );
570        let values: arrow::array::ArrayRef =
571            sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
572                None,
573                Some(3600000000),
574                Some(7200000000),
575                Some(7200000000),
576                Some(10800000000),
577                Some(10800000000),
578                Some(10800000000),
579            ]));
580        acc.update_batch(&[values])?;
581        let result = acc.evaluate()?;
582        assert_eq!(
583            result,
584            scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
585                Some(10800000000),
586                &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
587            )?
588        );
589        Ok(())
590    }
591
592    #[test]
593    fn test_mode_accumulator_tie_case_time64() -> error::Result<()> {
594        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
595            &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
596        );
597        let values: arrow::array::ArrayRef =
598            sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
599                3600000000,
600                7200000000,
601                7200000000,
602                10800000000,
603                10800000000,
604            ]));
605        acc.update_batch(&[values])?;
606        let result = acc.evaluate()?;
607        assert_eq!(
608            result,
609            scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
610                Some(10800000000),
611                &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
612            )?
613        );
614        Ok(())
615    }
616
617    #[test]
618    fn test_mode_accumulator_only_nulls_time64() -> error::Result<()> {
619        let mut acc = PrimitiveModeAccumulator::<arrow::datatypes::Time64MicrosecondType>::new(
620            &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
621        );
622        let values: arrow::array::ArrayRef =
623            sync::Arc::new(arrow::array::Time64MicrosecondArray::from(vec![
624                None, None, None, None,
625            ]));
626        acc.update_batch(&[values])?;
627        let result = acc.evaluate()?;
628        assert_eq!(
629            result,
630            scalar::ScalarValue::new_primitive::<arrow::datatypes::Time64MicrosecondType>(
631                None,
632                &arrow::datatypes::DataType::Time64(arrow::datatypes::TimeUnit::Microsecond)
633            )?
634        );
635        Ok(())
636    }
637}