datafusion_functions_aggregate_common/aggregate/count_distinct/
groups.rs1use arrow::array::{
19 ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, PrimitiveArray,
20};
21use arrow::buffer::{OffsetBuffer, ScalarBuffer};
22use arrow::datatypes::{ArrowPrimitiveType, Field};
23use datafusion_common::HashSet;
24use datafusion_common::hash_utils::RandomState;
25use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
26use std::hash::Hash;
27use std::mem::size_of;
28use std::sync::Arc;
29
30use crate::aggregate::groups_accumulator::accumulate::accumulate;
31
32pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
33where
34 T::Native: Eq + Hash,
35{
36 seen: HashSet<(usize, T::Native), RandomState>,
37 counts: Vec<i64>,
38}
39
40impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
41where
42 T::Native: Eq + Hash,
43{
44 pub fn new() -> Self {
45 Self {
46 seen: HashSet::default(),
47 counts: Vec::new(),
48 }
49 }
50}
51
52impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
53where
54 T::Native: Eq + Hash,
55{
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl<T: ArrowPrimitiveType + Send + std::fmt::Debug> GroupsAccumulator
62 for PrimitiveDistinctCountGroupsAccumulator<T>
63where
64 T::Native: Eq + Hash,
65{
66 fn update_batch(
67 &mut self,
68 values: &[ArrayRef],
69 group_indices: &[usize],
70 opt_filter: Option<&BooleanArray>,
71 total_num_groups: usize,
72 ) -> datafusion_common::Result<()> {
73 debug_assert_eq!(values.len(), 1);
74 self.counts.resize(total_num_groups, 0);
75 let arr = values[0].as_primitive::<T>();
76 accumulate(group_indices, arr, opt_filter, |group_idx, value| {
77 if self.seen.insert((group_idx, value)) {
78 self.counts[group_idx] += 1;
79 }
80 });
81 Ok(())
82 }
83
84 fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
85 let counts = emit_to.take_needed(&mut self.counts);
86
87 match emit_to {
88 EmitTo::All => {
89 self.seen.clear();
90 }
91 EmitTo::First(n) => {
92 let mut remaining = HashSet::default();
93 for (group_idx, value) in self.seen.drain() {
94 if group_idx >= n {
95 remaining.insert((group_idx - n, value));
96 }
97 }
98 self.seen = remaining;
99 }
100 }
101
102 Ok(Arc::new(Int64Array::from(counts)))
103 }
104
105 fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
106 let num_emitted = match emit_to {
107 EmitTo::All => self.counts.len(),
108 EmitTo::First(n) => n,
109 };
110
111 let mut offsets = Vec::with_capacity(num_emitted + 1);
113 offsets.push(0i32);
114 let mut total = 0i32;
115 for &c in &self.counts[..num_emitted] {
116 total += c as i32;
117 offsets.push(total);
118 }
119
120 let mut all_values = vec![T::Native::default(); total as usize];
121 let mut cursors: Vec<i32> = offsets[..num_emitted].to_vec();
122
123 if matches!(emit_to, EmitTo::All) {
124 for (group_idx, value) in self.seen.drain() {
125 let pos = cursors[group_idx] as usize;
126 all_values[pos] = value;
127 cursors[group_idx] += 1;
128 }
129 self.counts.clear();
130 } else {
131 let mut remaining = HashSet::default();
132 for (group_idx, value) in self.seen.drain() {
133 if group_idx < num_emitted {
134 let pos = cursors[group_idx] as usize;
135 all_values[pos] = value;
136 cursors[group_idx] += 1;
137 } else {
138 remaining.insert((group_idx - num_emitted, value));
139 }
140 }
141 self.seen = remaining;
142 let _ = emit_to.take_needed(&mut self.counts);
143 }
144
145 let values_array = Arc::new(PrimitiveArray::<T>::new(
146 ScalarBuffer::from(all_values),
147 None,
148 ));
149 let list_array = ListArray::new(
150 Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
151 OffsetBuffer::new(offsets.into()),
152 values_array,
153 None,
154 );
155
156 Ok(vec![Arc::new(list_array)])
157 }
158
159 fn merge_batch(
160 &mut self,
161 values: &[ArrayRef],
162 group_indices: &[usize],
163 _opt_filter: Option<&BooleanArray>,
164 total_num_groups: usize,
165 ) -> datafusion_common::Result<()> {
166 debug_assert_eq!(values.len(), 1);
167 self.counts.resize(total_num_groups, 0);
168 let list_array = values[0].as_list::<i32>();
169 let inner = list_array.values().as_primitive::<T>();
170 let inner_values = inner.values();
171 let offsets = list_array.offsets();
172
173 for (row_idx, &group_idx) in group_indices.iter().enumerate() {
174 let start = offsets[row_idx] as usize;
175 let end = offsets[row_idx + 1] as usize;
176 for &value in &inner_values[start..end] {
177 if self.seen.insert((group_idx, value)) {
178 self.counts[group_idx] += 1;
179 }
180 }
181 }
182
183 Ok(())
184 }
185
186 fn size(&self) -> usize {
187 size_of::<Self>()
188 + self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
189 + self.counts.capacity() * size_of::<i64>()
190 }
191}