datafusion_functions_aggregate_common/aggregate/count_distinct/
native.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Specialized implementation of `COUNT DISTINCT` for "Native" arrays such as
19//! [`Int64Array`] and [`Float64Array`]
20//!
21//! [`Int64Array`]: arrow::array::Int64Array
22//! [`Float64Array`]: arrow::array::Float64Array
23use std::collections::HashSet;
24use std::fmt::Debug;
25use std::hash::Hash;
26use std::mem::size_of_val;
27use std::sync::Arc;
28
29use ahash::RandomState;
30use arrow::array::types::ArrowPrimitiveType;
31use arrow::array::ArrayRef;
32use arrow::array::PrimitiveArray;
33use arrow::datatypes::DataType;
34
35use datafusion_common::cast::{as_list_array, as_primitive_array};
36use datafusion_common::utils::memory::estimate_memory_size;
37use datafusion_common::utils::SingleRowListArrayBuilder;
38use datafusion_common::ScalarValue;
39use datafusion_expr_common::accumulator::Accumulator;
40
41use crate::utils::Hashable;
42
43#[derive(Debug)]
44pub struct PrimitiveDistinctCountAccumulator<T>
45where
46    T: ArrowPrimitiveType + Send,
47    T::Native: Eq + Hash,
48{
49    values: HashSet<T::Native, RandomState>,
50    data_type: DataType,
51}
52
53impl<T> PrimitiveDistinctCountAccumulator<T>
54where
55    T: ArrowPrimitiveType + Send,
56    T::Native: Eq + Hash,
57{
58    pub fn new(data_type: &DataType) -> Self {
59        Self {
60            values: HashSet::default(),
61            data_type: data_type.clone(),
62        }
63    }
64}
65
66impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
67where
68    T: ArrowPrimitiveType + Send + Debug,
69    T::Native: Eq + Hash,
70{
71    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
72        let arr = Arc::new(
73            PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
74                .with_data_type(self.data_type.clone()),
75        );
76        Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
77    }
78
79    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
80        if values.is_empty() {
81            return Ok(());
82        }
83
84        let arr = as_primitive_array::<T>(&values[0])?;
85        arr.iter().for_each(|value| {
86            if let Some(value) = value {
87                self.values.insert(value);
88            }
89        });
90
91        Ok(())
92    }
93
94    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
95        if states.is_empty() {
96            return Ok(());
97        }
98        assert_eq!(
99            states.len(),
100            1,
101            "count_distinct states must be single array"
102        );
103
104        let arr = as_list_array(&states[0])?;
105        arr.iter().try_for_each(|maybe_list| {
106            if let Some(list) = maybe_list {
107                let list = as_primitive_array::<T>(&list)?;
108                self.values.extend(list.values())
109            };
110            Ok(())
111        })
112    }
113
114    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
115        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
116    }
117
118    fn size(&self) -> usize {
119        let num_elements = self.values.len();
120        let fixed_size = size_of_val(self) + size_of_val(&self.values);
121
122        estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
123    }
124}
125
126#[derive(Debug)]
127pub struct FloatDistinctCountAccumulator<T>
128where
129    T: ArrowPrimitiveType + Send,
130{
131    values: HashSet<Hashable<T::Native>, RandomState>,
132}
133
134impl<T> FloatDistinctCountAccumulator<T>
135where
136    T: ArrowPrimitiveType + Send,
137{
138    pub fn new() -> Self {
139        Self {
140            values: HashSet::default(),
141        }
142    }
143}
144
145impl<T> Default for FloatDistinctCountAccumulator<T>
146where
147    T: ArrowPrimitiveType + Send,
148{
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl<T> Accumulator for FloatDistinctCountAccumulator<T>
155where
156    T: ArrowPrimitiveType + Send + Debug,
157{
158    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
159        let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
160            self.values.iter().map(|v| v.0),
161        )) as ArrayRef;
162        Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
163    }
164
165    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
166        if values.is_empty() {
167            return Ok(());
168        }
169
170        let arr = as_primitive_array::<T>(&values[0])?;
171        arr.iter().for_each(|value| {
172            if let Some(value) = value {
173                self.values.insert(Hashable(value));
174            }
175        });
176
177        Ok(())
178    }
179
180    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
181        if states.is_empty() {
182            return Ok(());
183        }
184        assert_eq!(
185            states.len(),
186            1,
187            "count_distinct states must be single array"
188        );
189
190        let arr = as_list_array(&states[0])?;
191        arr.iter().try_for_each(|maybe_list| {
192            if let Some(list) = maybe_list {
193                let list = as_primitive_array::<T>(&list)?;
194                self.values
195                    .extend(list.values().iter().map(|v| Hashable(*v)));
196            };
197            Ok(())
198        })
199    }
200
201    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
202        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
203    }
204
205    fn size(&self) -> usize {
206        let num_elements = self.values.len();
207        let fixed_size = size_of_val(self) + size_of_val(&self.values);
208
209        estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
210    }
211}