datafusion_functions_aggregate_common/aggregate/groups_accumulator/
prim_op.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::mem::size_of;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
22use arrow::buffer::NullBuffer;
23use arrow::compute;
24use arrow::datatypes::ArrowPrimitiveType;
25use arrow::datatypes::DataType;
26use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
27use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
28
29use super::accumulate::NullState;
30
31/// An accumulator that implements a single operation over
32/// [`ArrowPrimitiveType`] where the accumulated state is the same as
33/// the input type (such as `Sum`)
34///
35/// F: The function to apply to two elements. The first argument is
36/// the existing value and should be updated with the second value
37/// (e.g. [`BitAndAssign`] style).
38///
39/// [`BitAndAssign`]: std::ops::BitAndAssign
40#[derive(Debug)]
41pub struct PrimitiveGroupsAccumulator<T, F>
42where
43    T: ArrowPrimitiveType + Send,
44    F: Fn(&mut T::Native, T::Native) + Send + Sync,
45{
46    /// values per group, stored as the native type
47    values: Vec<T::Native>,
48
49    /// The output type (needed for Decimal precision and scale)
50    data_type: DataType,
51
52    /// The starting value for new groups
53    starting_value: T::Native,
54
55    /// Track nulls in the input / filters
56    null_state: NullState,
57
58    /// Function that computes the primitive result
59    prim_fn: F,
60}
61
62impl<T, F> PrimitiveGroupsAccumulator<T, F>
63where
64    T: ArrowPrimitiveType + Send,
65    F: Fn(&mut T::Native, T::Native) + Send + Sync,
66{
67    pub fn new(data_type: &DataType, prim_fn: F) -> Self {
68        Self {
69            values: vec![],
70            data_type: data_type.clone(),
71            null_state: NullState::new(),
72            starting_value: T::default_value(),
73            prim_fn,
74        }
75    }
76
77    /// Set the starting values for new groups
78    pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
79        self.starting_value = starting_value;
80        self
81    }
82}
83
84impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
85where
86    T: ArrowPrimitiveType + Send,
87    F: Fn(&mut T::Native, T::Native) + Send + Sync,
88{
89    fn update_batch(
90        &mut self,
91        values: &[ArrayRef],
92        group_indices: &[usize],
93        opt_filter: Option<&BooleanArray>,
94        total_num_groups: usize,
95    ) -> Result<()> {
96        assert_eq!(values.len(), 1, "single argument to update_batch");
97        let values = values[0].as_primitive::<T>();
98
99        // update values
100        self.values.resize(total_num_groups, self.starting_value);
101
102        // NullState dispatches / handles tracking nulls and groups that saw no values
103        self.null_state.accumulate(
104            group_indices,
105            values,
106            opt_filter,
107            total_num_groups,
108            |group_index, new_value| {
109                let value = &mut self.values[group_index];
110                (self.prim_fn)(value, new_value);
111            },
112        );
113
114        Ok(())
115    }
116
117    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
118        let values = emit_to.take_needed(&mut self.values);
119        let nulls = self.null_state.build(emit_to);
120        let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy
121            .with_data_type(self.data_type.clone());
122        Ok(Arc::new(values))
123    }
124
125    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
126        self.evaluate(emit_to).map(|arr| vec![arr])
127    }
128
129    fn merge_batch(
130        &mut self,
131        values: &[ArrayRef],
132        group_indices: &[usize],
133        opt_filter: Option<&BooleanArray>,
134        total_num_groups: usize,
135    ) -> Result<()> {
136        // update / merge are the same
137        self.update_batch(values, group_indices, opt_filter, total_num_groups)
138    }
139
140    /// Converts an input batch directly to a state batch
141    ///
142    /// The state is:
143    /// - self.prim_fn for all non null, non filtered values
144    /// - null otherwise
145    ///
146    fn convert_to_state(
147        &self,
148        values: &[ArrayRef],
149        opt_filter: Option<&BooleanArray>,
150    ) -> Result<Vec<ArrayRef>> {
151        let values = values[0].as_primitive::<T>().clone();
152
153        // Initializing state with starting values
154        let initial_state =
155            PrimitiveArray::<T>::from_value(self.starting_value, values.len());
156
157        // Recalculating values in case there is filter
158        let values = match opt_filter {
159            None => values,
160            Some(filter) => {
161                let (filter_values, filter_nulls) = filter.clone().into_parts();
162                // Calculating filter mask as a result of bitand of filter, and converting it to null buffer
163                let filter_bool = match filter_nulls {
164                    Some(filter_nulls) => filter_nulls.inner() & &filter_values,
165                    None => filter_values,
166                };
167                let filter_nulls = NullBuffer::from(filter_bool);
168
169                // Rebuilding input values with a new nulls mask, which is equal to
170                // the union of original nulls and filter mask
171                let (dt, values_buf, original_nulls) = values.into_parts();
172                let nulls_buf =
173                    NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls));
174                PrimitiveArray::<T>::new(values_buf, nulls_buf).with_data_type(dt)
175            }
176        };
177
178        let state_values = compute::binary_mut(initial_state, &values, |mut x, y| {
179            (self.prim_fn)(&mut x, y);
180            x
181        });
182        let state_values = state_values
183            .map_err(|_| {
184                internal_datafusion_err!(
185                    "initial_values underlying buffer must not be shared"
186                )
187            })?
188            .map_err(DataFusionError::from)?
189            .with_data_type(self.data_type.clone());
190
191        Ok(vec![Arc::new(state_values)])
192    }
193
194    fn supports_convert_to_state(&self) -> bool {
195        true
196    }
197
198    fn size(&self) -> usize {
199        self.values.capacity() * size_of::<T::Native>() + self.null_state.size()
200    }
201}