datafusion_functions_aggregate_common/aggregate/sum_distinct/
numeric.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the accumulator for `SUM DISTINCT` for primitive numeric types
19
20use std::collections::HashSet;
21use std::fmt::Debug;
22use std::mem::{size_of, size_of_val};
23
24use ahash::RandomState;
25use arrow::array::Array;
26use arrow::array::ArrayRef;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::ArrowPrimitiveType;
29use arrow::array::AsArray;
30use arrow::datatypes::ArrowNativeType;
31use arrow::datatypes::DataType;
32
33use datafusion_common::Result;
34use datafusion_common::ScalarValue;
35use datafusion_expr_common::accumulator::Accumulator;
36
37use crate::utils::Hashable;
38
39/// Accumulator for computing SUM(DISTINCT expr)
40pub struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
41    values: HashSet<Hashable<T::Native>, RandomState>,
42    data_type: DataType,
43}
44
45impl<T: ArrowPrimitiveType> Debug for DistinctSumAccumulator<T> {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "DistinctSumAccumulator({})", self.data_type)
48    }
49}
50
51impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
52    pub fn new(data_type: &DataType) -> Self {
53        Self {
54            values: HashSet::default(),
55            data_type: data_type.clone(),
56        }
57    }
58
59    pub fn distinct_count(&self) -> usize {
60        self.values.len()
61    }
62}
63
64impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
65    fn state(&mut self) -> Result<Vec<ScalarValue>> {
66        // 1. Stores aggregate state in `ScalarValue::List`
67        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
68        let state_out = {
69            let distinct_values = self
70                .values
71                .iter()
72                .map(|value| {
73                    ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
74                })
75                .collect::<Result<Vec<_>>>()?;
76
77            vec![ScalarValue::List(ScalarValue::new_list_nullable(
78                &distinct_values,
79                &self.data_type,
80            ))]
81        };
82        Ok(state_out)
83    }
84
85    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
86        if values.is_empty() {
87            return Ok(());
88        }
89
90        let array = values[0].as_primitive::<T>();
91        match array.nulls().filter(|x| x.null_count() > 0) {
92            Some(n) => {
93                for idx in n.valid_indices() {
94                    self.values.insert(Hashable(array.value(idx)));
95                }
96            }
97            None => array.values().iter().for_each(|x| {
98                self.values.insert(Hashable(*x));
99            }),
100        }
101        Ok(())
102    }
103
104    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
105        for x in states[0].as_list::<i32>().iter().flatten() {
106            self.update_batch(&[x])?
107        }
108        Ok(())
109    }
110
111    fn evaluate(&mut self) -> Result<ScalarValue> {
112        let mut acc = T::Native::usize_as(0);
113        for distinct_value in self.values.iter() {
114            acc = acc.add_wrapping(distinct_value.0)
115        }
116        let v = (!self.values.is_empty()).then_some(acc);
117        ScalarValue::new_primitive::<T>(v, &self.data_type)
118    }
119
120    fn size(&self) -> usize {
121        size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
122    }
123}