datafusion_functions_aggregate_common/aggregate/count_distinct/
native.rs1use std::collections::HashSet;
24use std::fmt::Debug;
25use std::hash::Hash;
26use std::mem::size_of_val;
27use std::sync::Arc;
28
29use ahash::RandomState;
30use arrow::array::types::ArrowPrimitiveType;
31use arrow::array::ArrayRef;
32use arrow::array::PrimitiveArray;
33use arrow::datatypes::DataType;
34
35use datafusion_common::cast::{as_list_array, as_primitive_array};
36use datafusion_common::utils::memory::estimate_memory_size;
37use datafusion_common::utils::SingleRowListArrayBuilder;
38use datafusion_common::ScalarValue;
39use datafusion_expr_common::accumulator::Accumulator;
40
41use crate::utils::Hashable;
42
43#[derive(Debug)]
44pub struct PrimitiveDistinctCountAccumulator<T>
45where
46 T: ArrowPrimitiveType + Send,
47 T::Native: Eq + Hash,
48{
49 values: HashSet<T::Native, RandomState>,
50 data_type: DataType,
51}
52
53impl<T> PrimitiveDistinctCountAccumulator<T>
54where
55 T: ArrowPrimitiveType + Send,
56 T::Native: Eq + Hash,
57{
58 pub fn new(data_type: &DataType) -> Self {
59 Self {
60 values: HashSet::default(),
61 data_type: data_type.clone(),
62 }
63 }
64}
65
66impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
67where
68 T: ArrowPrimitiveType + Send + Debug,
69 T::Native: Eq + Hash,
70{
71 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
72 let arr = Arc::new(
73 PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
74 .with_data_type(self.data_type.clone()),
75 );
76 Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
77 }
78
79 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
80 if values.is_empty() {
81 return Ok(());
82 }
83
84 let arr = as_primitive_array::<T>(&values[0])?;
85 arr.iter().for_each(|value| {
86 if let Some(value) = value {
87 self.values.insert(value);
88 }
89 });
90
91 Ok(())
92 }
93
94 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
95 if states.is_empty() {
96 return Ok(());
97 }
98 assert_eq!(
99 states.len(),
100 1,
101 "count_distinct states must be single array"
102 );
103
104 let arr = as_list_array(&states[0])?;
105 arr.iter().try_for_each(|maybe_list| {
106 if let Some(list) = maybe_list {
107 let list = as_primitive_array::<T>(&list)?;
108 self.values.extend(list.values())
109 };
110 Ok(())
111 })
112 }
113
114 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
115 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
116 }
117
118 fn size(&self) -> usize {
119 let num_elements = self.values.len();
120 let fixed_size = size_of_val(self) + size_of_val(&self.values);
121
122 estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
123 }
124}
125
126#[derive(Debug)]
127pub struct FloatDistinctCountAccumulator<T>
128where
129 T: ArrowPrimitiveType + Send,
130{
131 values: HashSet<Hashable<T::Native>, RandomState>,
132}
133
134impl<T> FloatDistinctCountAccumulator<T>
135where
136 T: ArrowPrimitiveType + Send,
137{
138 pub fn new() -> Self {
139 Self {
140 values: HashSet::default(),
141 }
142 }
143}
144
145impl<T> Default for FloatDistinctCountAccumulator<T>
146where
147 T: ArrowPrimitiveType + Send,
148{
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl<T> Accumulator for FloatDistinctCountAccumulator<T>
155where
156 T: ArrowPrimitiveType + Send + Debug,
157{
158 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
159 let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
160 self.values.iter().map(|v| v.0),
161 )) as ArrayRef;
162 Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
163 }
164
165 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
166 if values.is_empty() {
167 return Ok(());
168 }
169
170 let arr = as_primitive_array::<T>(&values[0])?;
171 arr.iter().for_each(|value| {
172 if let Some(value) = value {
173 self.values.insert(Hashable(value));
174 }
175 });
176
177 Ok(())
178 }
179
180 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
181 if states.is_empty() {
182 return Ok(());
183 }
184 assert_eq!(
185 states.len(),
186 1,
187 "count_distinct states must be single array"
188 );
189
190 let arr = as_list_array(&states[0])?;
191 arr.iter().try_for_each(|maybe_list| {
192 if let Some(list) = maybe_list {
193 let list = as_primitive_array::<T>(&list)?;
194 self.values
195 .extend(list.values().iter().map(|v| Hashable(*v)));
196 };
197 Ok(())
198 })
199 }
200
201 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
202 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
203 }
204
205 fn size(&self) -> usize {
206 let num_elements = self.values.len();
207 let fixed_size = size_of_val(self) + size_of_val(&self.values);
208
209 estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
210 }
211}