Skip to main content

datafusion_functions_aggregate_common/aggregate/count_distinct/
groups.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, PrimitiveArray,
20};
21use arrow::buffer::{OffsetBuffer, ScalarBuffer};
22use arrow::datatypes::{ArrowPrimitiveType, Field};
23use datafusion_common::HashSet;
24use datafusion_common::hash_utils::RandomState;
25use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
26use std::hash::Hash;
27use std::mem::size_of;
28use std::sync::Arc;
29
30use crate::aggregate::groups_accumulator::accumulate::accumulate;
31
32pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
33where
34    T::Native: Eq + Hash,
35{
36    seen: HashSet<(usize, T::Native), RandomState>,
37    counts: Vec<i64>,
38}
39
40impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
41where
42    T::Native: Eq + Hash,
43{
44    pub fn new() -> Self {
45        Self {
46            seen: HashSet::default(),
47            counts: Vec::new(),
48        }
49    }
50}
51
52impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
53where
54    T::Native: Eq + Hash,
55{
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl<T: ArrowPrimitiveType + Send + std::fmt::Debug> GroupsAccumulator
62    for PrimitiveDistinctCountGroupsAccumulator<T>
63where
64    T::Native: Eq + Hash,
65{
66    fn update_batch(
67        &mut self,
68        values: &[ArrayRef],
69        group_indices: &[usize],
70        opt_filter: Option<&BooleanArray>,
71        total_num_groups: usize,
72    ) -> datafusion_common::Result<()> {
73        debug_assert_eq!(values.len(), 1);
74        self.counts.resize(total_num_groups, 0);
75        let arr = values[0].as_primitive::<T>();
76        accumulate(group_indices, arr, opt_filter, |group_idx, value| {
77            if self.seen.insert((group_idx, value)) {
78                self.counts[group_idx] += 1;
79            }
80        });
81        Ok(())
82    }
83
84    fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
85        let counts = emit_to.take_needed(&mut self.counts);
86
87        match emit_to {
88            EmitTo::All => {
89                self.seen.clear();
90            }
91            EmitTo::First(n) => {
92                let mut remaining = HashSet::default();
93                for (group_idx, value) in self.seen.drain() {
94                    if group_idx >= n {
95                        remaining.insert((group_idx - n, value));
96                    }
97                }
98                self.seen = remaining;
99            }
100        }
101
102        Ok(Arc::new(Int64Array::from(counts)))
103    }
104
105    fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
106        let num_emitted = match emit_to {
107            EmitTo::All => self.counts.len(),
108            EmitTo::First(n) => n,
109        };
110
111        // Prefix-sum counts[..num_emitted] into offsets
112        let mut offsets = Vec::with_capacity(num_emitted + 1);
113        offsets.push(0i32);
114        let mut total = 0i32;
115        for &c in &self.counts[..num_emitted] {
116            total += c as i32;
117            offsets.push(total);
118        }
119
120        let mut all_values = vec![T::Native::default(); total as usize];
121        let mut cursors: Vec<i32> = offsets[..num_emitted].to_vec();
122
123        if matches!(emit_to, EmitTo::All) {
124            for (group_idx, value) in self.seen.drain() {
125                let pos = cursors[group_idx] as usize;
126                all_values[pos] = value;
127                cursors[group_idx] += 1;
128            }
129            self.counts.clear();
130        } else {
131            let mut remaining = HashSet::default();
132            for (group_idx, value) in self.seen.drain() {
133                if group_idx < num_emitted {
134                    let pos = cursors[group_idx] as usize;
135                    all_values[pos] = value;
136                    cursors[group_idx] += 1;
137                } else {
138                    remaining.insert((group_idx - num_emitted, value));
139                }
140            }
141            self.seen = remaining;
142            let _ = emit_to.take_needed(&mut self.counts);
143        }
144
145        let values_array = Arc::new(PrimitiveArray::<T>::new(
146            ScalarBuffer::from(all_values),
147            None,
148        ));
149        let list_array = ListArray::new(
150            Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
151            OffsetBuffer::new(offsets.into()),
152            values_array,
153            None,
154        );
155
156        Ok(vec![Arc::new(list_array)])
157    }
158
159    fn merge_batch(
160        &mut self,
161        values: &[ArrayRef],
162        group_indices: &[usize],
163        _opt_filter: Option<&BooleanArray>,
164        total_num_groups: usize,
165    ) -> datafusion_common::Result<()> {
166        debug_assert_eq!(values.len(), 1);
167        self.counts.resize(total_num_groups, 0);
168        let list_array = values[0].as_list::<i32>();
169        let inner = list_array.values().as_primitive::<T>();
170        let inner_values = inner.values();
171        let offsets = list_array.offsets();
172
173        for (row_idx, &group_idx) in group_indices.iter().enumerate() {
174            let start = offsets[row_idx] as usize;
175            let end = offsets[row_idx + 1] as usize;
176            for &value in &inner_values[start..end] {
177                if self.seen.insert((group_idx, value)) {
178                    self.counts[group_idx] += 1;
179                }
180            }
181        }
182
183        Ok(())
184    }
185
186    fn size(&self) -> usize {
187        size_of::<Self>()
188            + self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
189            + self.counts.capacity() * size_of::<i64>()
190    }
191}