datafusion_functions_aggregate_common/aggregate/count_distinct/
native.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Specialized implementation of `COUNT DISTINCT` for "Native" arrays such as
19//! [`Int64Array`] and [`Float64Array`]
20//!
21//! [`Int64Array`]: arrow::array::Int64Array
22//! [`Float64Array`]: arrow::array::Float64Array
23use std::collections::HashSet;
24use std::fmt::Debug;
25use std::hash::Hash;
26use std::mem::size_of_val;
27use std::sync::Arc;
28
29use ahash::RandomState;
30use arrow::array::ArrayRef;
31use arrow::array::PrimitiveArray;
32use arrow::array::types::ArrowPrimitiveType;
33use arrow::datatypes::DataType;
34
35use datafusion_common::ScalarValue;
36use datafusion_common::cast::{as_list_array, as_primitive_array};
37use datafusion_common::utils::SingleRowListArrayBuilder;
38use datafusion_common::utils::memory::estimate_memory_size;
39use datafusion_expr_common::accumulator::Accumulator;
40
41use crate::utils::GenericDistinctBuffer;
42
43#[derive(Debug)]
44pub struct PrimitiveDistinctCountAccumulator<T>
45where
46    T: ArrowPrimitiveType + Send,
47    T::Native: Eq + Hash,
48{
49    values: HashSet<T::Native, RandomState>,
50    data_type: DataType,
51}
52
53impl<T> PrimitiveDistinctCountAccumulator<T>
54where
55    T: ArrowPrimitiveType + Send,
56    T::Native: Eq + Hash,
57{
58    pub fn new(data_type: &DataType) -> Self {
59        Self {
60            values: HashSet::default(),
61            data_type: data_type.clone(),
62        }
63    }
64}
65
66impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
67where
68    T: ArrowPrimitiveType + Send + Debug,
69    T::Native: Eq + Hash,
70{
71    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
72        let arr = Arc::new(
73            PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
74                .with_data_type(self.data_type.clone()),
75        );
76        Ok(vec![
77            SingleRowListArrayBuilder::new(arr).build_list_scalar(),
78        ])
79    }
80
81    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
82        if values.is_empty() {
83            return Ok(());
84        }
85
86        let arr = as_primitive_array::<T>(&values[0])?;
87        arr.iter().for_each(|value| {
88            if let Some(value) = value {
89                self.values.insert(value);
90            }
91        });
92
93        Ok(())
94    }
95
96    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
97        if states.is_empty() {
98            return Ok(());
99        }
100        assert_eq!(
101            states.len(),
102            1,
103            "count_distinct states must be single array"
104        );
105
106        let arr = as_list_array(&states[0])?;
107        arr.iter().try_for_each(|maybe_list| {
108            if let Some(list) = maybe_list {
109                let list = as_primitive_array::<T>(&list)?;
110                self.values.extend(list.values())
111            };
112            Ok(())
113        })
114    }
115
116    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
117        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
118    }
119
120    fn size(&self) -> usize {
121        let num_elements = self.values.len();
122        let fixed_size = size_of_val(self) + size_of_val(&self.values);
123
124        estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
125    }
126}
127
128#[derive(Debug)]
129pub struct FloatDistinctCountAccumulator<T: ArrowPrimitiveType> {
130    values: GenericDistinctBuffer<T>,
131}
132
133impl<T: ArrowPrimitiveType> FloatDistinctCountAccumulator<T> {
134    pub fn new() -> Self {
135        Self {
136            values: GenericDistinctBuffer::new(T::DATA_TYPE),
137        }
138    }
139}
140
141impl<T: ArrowPrimitiveType> Default for FloatDistinctCountAccumulator<T> {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl<T: ArrowPrimitiveType + Debug> Accumulator for FloatDistinctCountAccumulator<T> {
148    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
149        self.values.state()
150    }
151
152    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
153        self.values.update_batch(values)
154    }
155
156    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
157        self.values.merge_batch(states)
158    }
159
160    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
161        Ok(ScalarValue::Int64(Some(self.values.values.len() as i64)))
162    }
163
164    fn size(&self) -> usize {
165        size_of_val(self) + self.values.size()
166    }
167}