Skip to main content

datafusion_functions_aggregate/
approx_distinct.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use crate::hyperloglog::{HLL_HASH_STATE, HyperLogLog};
21use arrow::array::{Array, BinaryArray, StringViewArray};
22use arrow::array::{
23    GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
24};
25use arrow::datatypes::{
26    ArrowPrimitiveType, Date32Type, Date64Type, FieldRef, Int32Type, Int64Type,
27    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
28    TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
29    TimestampNanosecondType, TimestampSecondType, UInt32Type, UInt64Type,
30};
31use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
32use datafusion_common::ScalarValue;
33use datafusion_common::{
34    DataFusionError, Result, downcast_value, internal_datafusion_err, internal_err,
35    not_impl_err,
36};
37use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
38use datafusion_expr::utils::format_state_name;
39use datafusion_expr::{
40    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
41};
42use datafusion_functions_aggregate_common::aggregate::count_distinct::{
43    Bitmap65536DistinctCountAccumulator, Bitmap65536DistinctCountAccumulatorI16,
44    BoolArray256DistinctCountAccumulator, BoolArray256DistinctCountAccumulatorI8,
45};
46use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
47use datafusion_macros::user_doc;
48use std::fmt::{Debug, Formatter};
49use std::hash::{BuildHasher, Hash};
50use std::marker::PhantomData;
51
52make_udaf_expr_and_func!(
53    ApproxDistinct,
54    approx_distinct,
55    expression,
56    "approximate number of distinct input values",
57    approx_distinct_udaf
58);
59
60impl<T: Hash + ?Sized> From<&HyperLogLog<T>> for ScalarValue {
61    fn from(v: &HyperLogLog<T>) -> ScalarValue {
62        let values = v.as_ref().to_vec();
63        ScalarValue::Binary(Some(values))
64    }
65}
66
67impl<T: Hash + ?Sized> TryFrom<&[u8]> for HyperLogLog<T> {
68    type Error = DataFusionError;
69    fn try_from(v: &[u8]) -> Result<HyperLogLog<T>> {
70        let arr: [u8; 16384] = v.try_into().map_err(|_| {
71            internal_datafusion_err!("Impossibly got invalid binary array from states")
72        })?;
73        Ok(HyperLogLog::<T>::new_with_registers(arr))
74    }
75}
76
77impl<T: Hash + ?Sized> TryFrom<&ScalarValue> for HyperLogLog<T> {
78    type Error = DataFusionError;
79    fn try_from(v: &ScalarValue) -> Result<HyperLogLog<T>> {
80        if let ScalarValue::Binary(Some(slice)) = v {
81            slice.as_slice().try_into()
82        } else {
83            internal_err!(
84                "Impossibly got invalid scalar value while converting to HyperLogLog"
85            )
86        }
87    }
88}
89
90#[derive(Debug)]
91struct ApproxDistinctBitmapWrapper<A: Accumulator> {
92    inner: A,
93}
94
95impl<A: Accumulator> Accumulator for ApproxDistinctBitmapWrapper<A> {
96    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
97        self.inner.update_batch(values)
98    }
99
100    fn evaluate(&mut self) -> Result<ScalarValue> {
101        match self.inner.evaluate()? {
102            ScalarValue::Int64(Some(v)) => Ok(ScalarValue::UInt64(Some(v as u64))),
103            other => internal_err!("unexpected: {other}"),
104        }
105    }
106
107    fn size(&self) -> usize {
108        self.inner.size()
109    }
110
111    fn state(&mut self) -> Result<Vec<ScalarValue>> {
112        self.inner.state()
113    }
114
115    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
116        self.inner.merge_batch(states)
117    }
118}
119
120#[derive(Debug)]
121struct NumericHLLAccumulator<T>
122where
123    T: ArrowPrimitiveType,
124    T::Native: Hash,
125{
126    hll: HyperLogLog<T::Native>,
127}
128
129impl<T> NumericHLLAccumulator<T>
130where
131    T: ArrowPrimitiveType,
132    T::Native: Hash,
133{
134    pub fn new() -> Self {
135        Self {
136            hll: HyperLogLog::new(),
137        }
138    }
139}
140
141#[derive(Debug)]
142struct StringHLLAccumulator<T>
143where
144    T: OffsetSizeTrait,
145{
146    hll: HyperLogLog<str>,
147    phantom_data: PhantomData<T>,
148}
149
150impl<T> StringHLLAccumulator<T>
151where
152    T: OffsetSizeTrait,
153{
154    pub fn new() -> Self {
155        Self {
156            hll: HyperLogLog::new(),
157            phantom_data: PhantomData,
158        }
159    }
160}
161
162#[derive(Debug)]
163struct StringViewHLLAccumulator {
164    hll: HyperLogLog<str>,
165}
166
167impl StringViewHLLAccumulator {
168    pub fn new() -> Self {
169        Self {
170            hll: HyperLogLog::new(),
171        }
172    }
173}
174
175#[derive(Debug)]
176struct BinaryHLLAccumulator<T>
177where
178    T: OffsetSizeTrait,
179{
180    hll: HyperLogLog<[u8]>,
181    phantom_data: PhantomData<T>,
182}
183
184impl<T> BinaryHLLAccumulator<T>
185where
186    T: OffsetSizeTrait,
187{
188    pub fn new() -> Self {
189        Self {
190            hll: HyperLogLog::new(),
191            phantom_data: PhantomData,
192        }
193    }
194}
195
196macro_rules! default_accumulator_impl {
197    () => {
198        fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
199            assert_eq!(1, states.len(), "expect only 1 element in the states");
200            let binary_array = downcast_value!(states[0], BinaryArray);
201            for v in binary_array.iter() {
202                let v = v.ok_or_else(|| {
203                    internal_datafusion_err!(
204                        "Impossibly got empty binary array from states"
205                    )
206                })?;
207                let other = v.try_into()?;
208                self.hll.merge(&other);
209            }
210            Ok(())
211        }
212
213        fn state(&mut self) -> Result<Vec<ScalarValue>> {
214            let value = ScalarValue::from(&self.hll);
215            Ok(vec![value])
216        }
217
218        fn evaluate(&mut self) -> Result<ScalarValue> {
219            Ok(ScalarValue::UInt64(Some(self.hll.count() as u64)))
220        }
221
222        fn size(&self) -> usize {
223            // HLL has static size
224            std::mem::size_of_val(self)
225        }
226    };
227}
228
229impl<T> Accumulator for BinaryHLLAccumulator<T>
230where
231    T: OffsetSizeTrait,
232{
233    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
234        let array: &GenericBinaryArray<T> =
235            downcast_value!(values[0], GenericBinaryArray, T);
236        // flatten because we would skip nulls
237        self.hll.extend(array.into_iter().flatten());
238        Ok(())
239    }
240
241    default_accumulator_impl!();
242}
243
244impl Accumulator for StringViewHLLAccumulator {
245    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
246        let array: &StringViewArray = downcast_value!(values[0], StringViewArray);
247
248        // When all strings are stored inline in the StringView (≤ 12 bytes),
249        // hash the raw u128 view directly instead of materializing a &str.
250        if array.data_buffers().is_empty() {
251            for (i, &view) in array.views().iter().enumerate() {
252                if !array.is_null(i) {
253                    self.hll.add_hashed(HLL_HASH_STATE.hash_one(view));
254                }
255            }
256        } else {
257            self.hll.extend(array.iter().flatten());
258        }
259
260        Ok(())
261    }
262
263    default_accumulator_impl!();
264}
265
266impl<T> Accumulator for StringHLLAccumulator<T>
267where
268    T: OffsetSizeTrait,
269{
270    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
271        let array: &GenericStringArray<T> =
272            downcast_value!(values[0], GenericStringArray, T);
273        // flatten because we would skip nulls
274        self.hll.extend(array.into_iter().flatten());
275        Ok(())
276    }
277
278    default_accumulator_impl!();
279}
280
281impl<T> Accumulator for NumericHLLAccumulator<T>
282where
283    T: ArrowPrimitiveType + Debug,
284    T::Native: Hash,
285{
286    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
287        let array: &PrimitiveArray<T> = downcast_value!(values[0], PrimitiveArray, T);
288        // flatten because we would skip nulls
289        self.hll.extend(array.into_iter().flatten());
290        Ok(())
291    }
292
293    default_accumulator_impl!();
294}
295
296impl Debug for ApproxDistinct {
297    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298        f.debug_struct("ApproxDistinct")
299            .field("name", &self.name())
300            .field("signature", &self.signature)
301            .finish()
302    }
303}
304
305impl Default for ApproxDistinct {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311#[user_doc(
312    doc_section(label = "Approximate Functions"),
313    description = "Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm.",
314    syntax_example = "approx_distinct(expression)",
315    sql_example = r#"```sql
316> SELECT approx_distinct(column_name) FROM table_name;
317+-----------------------------------+
318| approx_distinct(column_name)      |
319+-----------------------------------+
320| 42                                |
321+-----------------------------------+
322```"#,
323    standard_argument(name = "expression",)
324)]
325#[derive(PartialEq, Eq, Hash)]
326pub struct ApproxDistinct {
327    signature: Signature,
328}
329
330impl ApproxDistinct {
331    pub fn new() -> Self {
332        Self {
333            signature: Signature::any(1, Volatility::Immutable),
334        }
335    }
336}
337
338#[cold]
339fn get_small_int_approx_accumulator(
340    data_type: &DataType,
341) -> Result<Box<dyn Accumulator>> {
342    match data_type {
343        DataType::UInt8 => Ok(Box::new(ApproxDistinctBitmapWrapper {
344            inner: BoolArray256DistinctCountAccumulator::new(),
345        })),
346        DataType::Int8 => Ok(Box::new(ApproxDistinctBitmapWrapper {
347            inner: BoolArray256DistinctCountAccumulatorI8::new(),
348        })),
349        DataType::UInt16 => Ok(Box::new(ApproxDistinctBitmapWrapper {
350            inner: Bitmap65536DistinctCountAccumulator::new(),
351        })),
352        DataType::Int16 => Ok(Box::new(ApproxDistinctBitmapWrapper {
353            inner: Bitmap65536DistinctCountAccumulatorI16::new(),
354        })),
355        _ => internal_err!("unsupported small int type: {}", data_type),
356    }
357}
358
359#[cold]
360fn get_small_int_state_field(name: &str, data_type: &DataType) -> Result<Vec<FieldRef>> {
361    Ok(vec![
362        Field::new_list(
363            format_state_name(name, "approx_distinct"),
364            Field::new_list_field(data_type.clone(), true),
365            false,
366        )
367        .into(),
368    ])
369}
370
371impl AggregateUDFImpl for ApproxDistinct {
372    fn name(&self) -> &str {
373        "approx_distinct"
374    }
375
376    fn signature(&self) -> &Signature {
377        &self.signature
378    }
379
380    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
381        Ok(DataType::UInt64)
382    }
383
384    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
385        let data_type = args.input_fields[0].data_type();
386        match data_type {
387            DataType::Null => Ok(vec![
388                Field::new(
389                    format_state_name(args.name, self.name()),
390                    DataType::Null,
391                    true,
392                )
393                .into(),
394            ]),
395            DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => {
396                get_small_int_state_field(args.name, data_type)
397            }
398            _ => Ok(vec![
399                Field::new(
400                    format_state_name(args.name, "hll_registers"),
401                    DataType::Binary,
402                    false,
403                )
404                .into(),
405            ]),
406        }
407    }
408
409    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
410        let data_type = acc_args.expr_fields[0].data_type();
411
412        let accumulator: Box<dyn Accumulator> = match data_type {
413            DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => {
414                return get_small_int_approx_accumulator(data_type);
415            }
416            DataType::UInt32 => Box::new(NumericHLLAccumulator::<UInt32Type>::new()),
417            DataType::UInt64 => Box::new(NumericHLLAccumulator::<UInt64Type>::new()),
418            DataType::Int32 => Box::new(NumericHLLAccumulator::<Int32Type>::new()),
419            DataType::Int64 => Box::new(NumericHLLAccumulator::<Int64Type>::new()),
420            DataType::Date32 => Box::new(NumericHLLAccumulator::<Date32Type>::new()),
421            DataType::Date64 => Box::new(NumericHLLAccumulator::<Date64Type>::new()),
422            DataType::Time32(TimeUnit::Second) => {
423                Box::new(NumericHLLAccumulator::<Time32SecondType>::new())
424            }
425            DataType::Time32(TimeUnit::Millisecond) => {
426                Box::new(NumericHLLAccumulator::<Time32MillisecondType>::new())
427            }
428            DataType::Time64(TimeUnit::Microsecond) => {
429                Box::new(NumericHLLAccumulator::<Time64MicrosecondType>::new())
430            }
431            DataType::Time64(TimeUnit::Nanosecond) => {
432                Box::new(NumericHLLAccumulator::<Time64NanosecondType>::new())
433            }
434            DataType::Timestamp(TimeUnit::Second, _) => {
435                Box::new(NumericHLLAccumulator::<TimestampSecondType>::new())
436            }
437            DataType::Timestamp(TimeUnit::Millisecond, _) => {
438                Box::new(NumericHLLAccumulator::<TimestampMillisecondType>::new())
439            }
440            DataType::Timestamp(TimeUnit::Microsecond, _) => {
441                Box::new(NumericHLLAccumulator::<TimestampMicrosecondType>::new())
442            }
443            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
444                Box::new(NumericHLLAccumulator::<TimestampNanosecondType>::new())
445            }
446            DataType::Utf8 => Box::new(StringHLLAccumulator::<i32>::new()),
447            DataType::LargeUtf8 => Box::new(StringHLLAccumulator::<i64>::new()),
448            DataType::Utf8View => Box::new(StringViewHLLAccumulator::new()),
449            DataType::Binary => Box::new(BinaryHLLAccumulator::<i32>::new()),
450            DataType::LargeBinary => Box::new(BinaryHLLAccumulator::<i64>::new()),
451            DataType::Null => {
452                Box::new(NoopAccumulator::new(ScalarValue::UInt64(Some(0))))
453            }
454            other => {
455                return not_impl_err!(
456                    "Support for 'approx_distinct' for data type {other} is not implemented"
457                );
458            }
459        };
460        Ok(accumulator)
461    }
462
463    fn documentation(&self) -> Option<&Documentation> {
464        self.doc()
465    }
466}