datafusion_functions_aggregate_common/aggregate/groups_accumulator/
prim_op.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::mem::size_of;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
22use arrow::buffer::NullBuffer;
23use arrow::compute;
24use arrow::datatypes::ArrowPrimitiveType;
25use arrow::datatypes::DataType;
26use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
27use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
28
29use super::accumulate::NullState;
30
31/// An accumulator that implements a single operation over
32/// [`ArrowPrimitiveType`] where the accumulated state is the same as
33/// the input type (such as `Sum`)
34///
35/// F: The function to apply to two elements. The first argument is
36/// the existing value and should be updated with the second value
37/// (e.g. [`BitAndAssign`] style).
38///
39/// [`BitAndAssign`]: std::ops::BitAndAssign
40#[derive(Debug)]
41pub struct PrimitiveGroupsAccumulator<T, F>
42where
43    T: ArrowPrimitiveType + Send,
44    F: Fn(&mut T::Native, T::Native) + Send + Sync,
45{
46    /// values per group, stored as the native type
47    values: Vec<T::Native>,
48
49    /// The output type (needed for Decimal precision and scale)
50    data_type: DataType,
51
52    /// The starting value for new groups
53    starting_value: T::Native,
54
55    /// Track nulls in the input / filters
56    null_state: NullState,
57
58    /// Function that computes the primitive result
59    prim_fn: F,
60}
61
62impl<T, F> PrimitiveGroupsAccumulator<T, F>
63where
64    T: ArrowPrimitiveType + Send,
65    F: Fn(&mut T::Native, T::Native) + Send + Sync,
66{
67    pub fn new(data_type: &DataType, prim_fn: F) -> Self {
68        Self {
69            values: vec![],
70            data_type: data_type.clone(),
71            null_state: NullState::new(),
72            starting_value: T::default_value(),
73            prim_fn,
74        }
75    }
76
77    /// Set the starting values for new groups
78    pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
79        self.starting_value = starting_value;
80        self
81    }
82}
83
84impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
85where
86    T: ArrowPrimitiveType + Send,
87    F: Fn(&mut T::Native, T::Native) + Send + Sync,
88{
89    fn update_batch(
90        &mut self,
91        values: &[ArrayRef],
92        group_indices: &[usize],
93        opt_filter: Option<&BooleanArray>,
94        total_num_groups: usize,
95    ) -> Result<()> {
96        assert_eq!(values.len(), 1, "single argument to update_batch");
97        let values = values[0].as_primitive::<T>();
98
99        // update values
100        self.values.resize(total_num_groups, self.starting_value);
101
102        // NullState dispatches / handles tracking nulls and groups that saw no values
103        self.null_state.accumulate(
104            group_indices,
105            values,
106            opt_filter,
107            total_num_groups,
108            |group_index, new_value| {
109                let value = &mut self.values[group_index];
110                (self.prim_fn)(value, new_value);
111            },
112        );
113
114        Ok(())
115    }
116
117    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
118        let values = emit_to.take_needed(&mut self.values);
119        let nulls = self.null_state.build(emit_to);
120        let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy
121            .with_data_type(self.data_type.clone());
122        Ok(Arc::new(values))
123    }
124
125    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
126        self.evaluate(emit_to).map(|arr| vec![arr])
127    }
128
129    fn merge_batch(
130        &mut self,
131        values: &[ArrayRef],
132        group_indices: &[usize],
133        opt_filter: Option<&BooleanArray>,
134        total_num_groups: usize,
135    ) -> Result<()> {
136        // update / merge are the same
137        self.update_batch(values, group_indices, opt_filter, total_num_groups)
138    }
139
140    /// Converts an input batch directly to a state batch
141    ///
142    /// The state is:
143    /// - self.prim_fn for all non null, non filtered values
144    /// - null otherwise
145    fn convert_to_state(
146        &self,
147        values: &[ArrayRef],
148        opt_filter: Option<&BooleanArray>,
149    ) -> Result<Vec<ArrayRef>> {
150        let values = values[0].as_primitive::<T>().clone();
151
152        // Initializing state with starting values
153        let initial_state =
154            PrimitiveArray::<T>::from_value(self.starting_value, values.len());
155
156        // Recalculating values in case there is filter
157        let values = match opt_filter {
158            None => values,
159            Some(filter) => {
160                let (filter_values, filter_nulls) = filter.clone().into_parts();
161                // Calculating filter mask as a result of bitand of filter, and converting it to null buffer
162                let filter_bool = match filter_nulls {
163                    Some(filter_nulls) => filter_nulls.inner() & &filter_values,
164                    None => filter_values,
165                };
166                let filter_nulls = NullBuffer::from(filter_bool);
167
168                // Rebuilding input values with a new nulls mask, which is equal to
169                // the union of original nulls and filter mask
170                let (dt, values_buf, original_nulls) = values.into_parts();
171                let nulls_buf =
172                    NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls));
173                PrimitiveArray::<T>::new(values_buf, nulls_buf).with_data_type(dt)
174            }
175        };
176
177        let state_values = compute::binary_mut(initial_state, &values, |mut x, y| {
178            (self.prim_fn)(&mut x, y);
179            x
180        });
181        let state_values = state_values
182            .map_err(|_| {
183                internal_datafusion_err!(
184                    "initial_values underlying buffer must not be shared"
185                )
186            })?
187            .map_err(DataFusionError::from)?
188            .with_data_type(self.data_type.clone());
189
190        Ok(vec![Arc::new(state_values)])
191    }
192
193    fn supports_convert_to_state(&self) -> bool {
194        true
195    }
196
197    fn size(&self) -> usize {
198        self.values.capacity() * size_of::<T::Native>() + self.null_state.size()
199    }
200}