datafusion_functions_aggregate_common/aggregate/count_distinct/
native.rs1use std::collections::HashSet;
24use std::fmt::Debug;
25use std::hash::Hash;
26use std::mem::size_of_val;
27use std::sync::Arc;
28
29use ahash::RandomState;
30use arrow::array::ArrayRef;
31use arrow::array::PrimitiveArray;
32use arrow::array::types::ArrowPrimitiveType;
33use arrow::datatypes::DataType;
34
35use datafusion_common::ScalarValue;
36use datafusion_common::cast::{as_list_array, as_primitive_array};
37use datafusion_common::utils::SingleRowListArrayBuilder;
38use datafusion_common::utils::memory::estimate_memory_size;
39use datafusion_expr_common::accumulator::Accumulator;
40
41use crate::utils::GenericDistinctBuffer;
42
43#[derive(Debug)]
44pub struct PrimitiveDistinctCountAccumulator<T>
45where
46 T: ArrowPrimitiveType + Send,
47 T::Native: Eq + Hash,
48{
49 values: HashSet<T::Native, RandomState>,
50 data_type: DataType,
51}
52
53impl<T> PrimitiveDistinctCountAccumulator<T>
54where
55 T: ArrowPrimitiveType + Send,
56 T::Native: Eq + Hash,
57{
58 pub fn new(data_type: &DataType) -> Self {
59 Self {
60 values: HashSet::default(),
61 data_type: data_type.clone(),
62 }
63 }
64}
65
66impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
67where
68 T: ArrowPrimitiveType + Send + Debug,
69 T::Native: Eq + Hash,
70{
71 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
72 let arr = Arc::new(
73 PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
74 .with_data_type(self.data_type.clone()),
75 );
76 Ok(vec![
77 SingleRowListArrayBuilder::new(arr).build_list_scalar(),
78 ])
79 }
80
81 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
82 if values.is_empty() {
83 return Ok(());
84 }
85
86 let arr = as_primitive_array::<T>(&values[0])?;
87 arr.iter().for_each(|value| {
88 if let Some(value) = value {
89 self.values.insert(value);
90 }
91 });
92
93 Ok(())
94 }
95
96 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
97 if states.is_empty() {
98 return Ok(());
99 }
100 assert_eq!(
101 states.len(),
102 1,
103 "count_distinct states must be single array"
104 );
105
106 let arr = as_list_array(&states[0])?;
107 arr.iter().try_for_each(|maybe_list| {
108 if let Some(list) = maybe_list {
109 let list = as_primitive_array::<T>(&list)?;
110 self.values.extend(list.values())
111 };
112 Ok(())
113 })
114 }
115
116 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
117 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
118 }
119
120 fn size(&self) -> usize {
121 let num_elements = self.values.len();
122 let fixed_size = size_of_val(self) + size_of_val(&self.values);
123
124 estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
125 }
126}
127
128#[derive(Debug)]
129pub struct FloatDistinctCountAccumulator<T: ArrowPrimitiveType> {
130 values: GenericDistinctBuffer<T>,
131}
132
133impl<T: ArrowPrimitiveType> FloatDistinctCountAccumulator<T> {
134 pub fn new() -> Self {
135 Self {
136 values: GenericDistinctBuffer::new(T::DATA_TYPE),
137 }
138 }
139}
140
141impl<T: ArrowPrimitiveType> Default for FloatDistinctCountAccumulator<T> {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl<T: ArrowPrimitiveType + Debug> Accumulator for FloatDistinctCountAccumulator<T> {
148 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
149 self.values.state()
150 }
151
152 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
153 self.values.update_batch(values)
154 }
155
156 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
157 self.values.merge_batch(states)
158 }
159
160 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
161 Ok(ScalarValue::Int64(Some(self.values.values.len() as i64)))
162 }
163
164 fn size(&self) -> usize {
165 size_of_val(self) + self.values.size()
166 }
167}