datafusion_functions_aggregate_common/aggregate/sum_distinct/
numeric.rs1use std::collections::HashSet;
21use std::fmt::Debug;
22use std::mem::{size_of, size_of_val};
23
24use ahash::RandomState;
25use arrow::array::Array;
26use arrow::array::ArrayRef;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::ArrowPrimitiveType;
29use arrow::array::AsArray;
30use arrow::datatypes::ArrowNativeType;
31use arrow::datatypes::DataType;
32
33use datafusion_common::Result;
34use datafusion_common::ScalarValue;
35use datafusion_expr_common::accumulator::Accumulator;
36
37use crate::utils::Hashable;
38
39pub struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
41 values: HashSet<Hashable<T::Native>, RandomState>,
42 data_type: DataType,
43}
44
45impl<T: ArrowPrimitiveType> Debug for DistinctSumAccumulator<T> {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "DistinctSumAccumulator({})", self.data_type)
48 }
49}
50
51impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
52 pub fn new(data_type: &DataType) -> Self {
53 Self {
54 values: HashSet::default(),
55 data_type: data_type.clone(),
56 }
57 }
58
59 pub fn distinct_count(&self) -> usize {
60 self.values.len()
61 }
62}
63
64impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
65 fn state(&mut self) -> Result<Vec<ScalarValue>> {
66 let state_out = {
69 let distinct_values = self
70 .values
71 .iter()
72 .map(|value| {
73 ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
74 })
75 .collect::<Result<Vec<_>>>()?;
76
77 vec![ScalarValue::List(ScalarValue::new_list_nullable(
78 &distinct_values,
79 &self.data_type,
80 ))]
81 };
82 Ok(state_out)
83 }
84
85 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
86 if values.is_empty() {
87 return Ok(());
88 }
89
90 let array = values[0].as_primitive::<T>();
91 match array.nulls().filter(|x| x.null_count() > 0) {
92 Some(n) => {
93 for idx in n.valid_indices() {
94 self.values.insert(Hashable(array.value(idx)));
95 }
96 }
97 None => array.values().iter().for_each(|x| {
98 self.values.insert(Hashable(*x));
99 }),
100 }
101 Ok(())
102 }
103
104 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
105 for x in states[0].as_list::<i32>().iter().flatten() {
106 self.update_batch(&[x])?
107 }
108 Ok(())
109 }
110
111 fn evaluate(&mut self) -> Result<ScalarValue> {
112 let mut acc = T::Native::usize_as(0);
113 for distinct_value in self.values.iter() {
114 acc = acc.add_wrapping(distinct_value.0)
115 }
116 let v = (!self.values.is_empty()).then_some(acc);
117 ScalarValue::new_primitive::<T>(v, &self.data_type)
118 }
119
120 fn size(&self) -> usize {
121 size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
122 }
123}