datafusion_expr_common/groups_accumulator.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Vectorized [`GroupsAccumulator`]
19
20use arrow::array::{ArrayRef, BooleanArray};
21use datafusion_common::{Result, not_impl_err};
22
23/// Describes how many rows should be emitted during grouping.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum EmitTo {
26 /// Emit all groups
27 All,
28 /// Emit only the first `n` groups and shift all existing group
29 /// indexes down by `n`.
30 ///
31 /// For example, if `n=10`, group_index `0, 1, ... 9` are emitted
32 /// and group indexes `10, 11, 12, ...` become `0, 1, 2, ...`.
33 First(usize),
34}
35
36impl EmitTo {
37 /// Removes the number of rows from `v` required to emit the right
38 /// number of rows, returning a `Vec` with elements taken, and the
39 /// remaining values in `v`.
40 ///
41 /// This avoids copying if Self::All
42 pub fn take_needed<T>(&self, v: &mut Vec<T>) -> Vec<T> {
43 match self {
44 Self::All => {
45 // Take the entire vector, leave new (empty) vector
46 std::mem::take(v)
47 }
48 Self::First(n) => {
49 // get end n+1,.. values into t
50 let mut t = v.split_off(*n);
51 // leave n+1,.. in v
52 std::mem::swap(v, &mut t);
53 t
54 }
55 }
56 }
57}
58
59/// `GroupsAccumulator` implements a single aggregate (e.g. AVG) and
60/// stores the state for *all* groups internally.
61///
62/// Logically, a [`GroupsAccumulator`] stores a mapping from each group index to
63/// the state of the aggregate for that group. For example an implementation for
64/// `min` might look like
65///
66/// ```text
67/// ┌─────┐
68/// │ 0 │───────────▶ 100
69/// ├─────┤
70/// │ 1 │───────────▶ 200
71/// └─────┘
72/// ... ...
73/// ┌─────┐
74/// │ N-2 │───────────▶ 50
75/// ├─────┤
76/// │ N-1 │───────────▶ 200
77/// └─────┘
78///
79///
80/// Logical group Current Min
81/// number value for that
82/// group
83/// ```
84///
85/// # Notes on Implementing `GroupsAccumulator`
86///
87/// All aggregates must first implement the simpler [`Accumulator`] trait, which
88/// handles state for a single group. Implementing `GroupsAccumulator` is
89/// optional and is harder to implement than `Accumulator`, but can be much
90/// faster for queries with many group values. See the [Aggregating Millions of
91/// Groups Fast blog] for more background.
92/// For more background, please also see the [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog]
93///
94/// [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog]: https://datafusion.apache.org/blog/2023/08/05/datafusion_fast_grouping
95///
96/// [`NullState`] can help keep the state for groups that have not seen any
97/// values and produce the correct output for those groups.
98///
99/// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html
100///
101/// # Details
102/// Each group is assigned a `group_index` by the hash table and each
103/// accumulator manages the specific state, one per `group_index`.
104///
105/// `group_index`es are contiguous (there aren't gaps), and thus it is
106/// expected that each `GroupsAccumulator` will use something like `Vec<..>`
107/// to store the group states.
108///
109/// [`Accumulator`]: crate::accumulator::Accumulator
110/// [Aggregating Millions of Groups Fast blog]: https://arrow.apache.org/blog/2023/08/05/datafusion_fast_grouping/
111pub trait GroupsAccumulator: Send {
112 /// Updates the accumulator's state from its arguments, encoded as
113 /// a vector of [`ArrayRef`]s.
114 ///
115 /// * `values`: the input arguments to the accumulator
116 ///
117 /// * `group_indices`: The group indices to which each row in `values` belongs.
118 ///
119 /// * `opt_filter`: if present, only update aggregate state using
120 /// `values[i]` if `opt_filter[i]` is true
121 ///
122 /// * `total_num_groups`: the number of groups (the largest
123 /// group_index is thus `total_num_groups - 1`).
124 ///
125 /// Note that subsequent calls to update_batch may have larger
126 /// total_num_groups as new groups are seen.
127 ///
128 /// See [`NullState`] to help keep the state for groups that have not seen any
129 /// values and produce the correct output for those groups.
130 ///
131 /// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html
132 fn update_batch(
133 &mut self,
134 values: &[ArrayRef],
135 group_indices: &[usize],
136 opt_filter: Option<&BooleanArray>,
137 total_num_groups: usize,
138 ) -> Result<()>;
139
140 /// Returns the final aggregate value for each group as a single
141 /// `RecordBatch`, resetting the internal state.
142 ///
143 /// The rows returned *must* be in group_index order: The value
144 /// for group_index 0, followed by 1, etc. Any group_index that
145 /// did not have values, should be null.
146 ///
147 /// For example, a `SUM` accumulator maintains a running sum for
148 /// each group, and `evaluate` will produce that running sum as
149 /// its output for all groups, in group_index order
150 ///
151 /// If `emit_to` is [`EmitTo::All`], the accumulator should
152 /// return all groups and release / reset its internal state
153 /// equivalent to when it was first created.
154 ///
155 /// If `emit_to` is [`EmitTo::First`], only the first `n` groups
156 /// should be emitted and the state for those first groups
157 /// removed. State for the remaining groups must be retained for
158 /// future use. The group_indices on subsequent calls to
159 /// `update_batch` or `merge_batch` will be shifted down by
160 /// `n`. See [`EmitTo::First`] for more details.
161 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef>;
162
163 /// Returns the intermediate aggregate state for this accumulator,
164 /// used for multi-phase grouping, resetting its internal state.
165 ///
166 /// See [`Accumulator::state`] for more information on multi-phase
167 /// aggregation.
168 ///
169 /// For example, `AVG` might return two arrays: `SUM` and `COUNT`
170 /// but the `MIN` aggregate would just return a single array.
171 ///
172 /// Note more sophisticated internal state can be passed as
173 /// single `StructArray` rather than multiple arrays.
174 ///
175 /// See [`Self::evaluate`] for details on the required output
176 /// order and `emit_to`.
177 ///
178 /// [`Accumulator::state`]: crate::accumulator::Accumulator::state
179 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
180
181 /// Merges intermediate state (the output from [`Self::state`])
182 /// into this accumulator's current state.
183 ///
184 /// For some aggregates (such as `SUM`), `merge_batch` is the same
185 /// as `update_batch`, but for some aggregates (such as `COUNT`,
186 /// where the partial counts must be summed) the operations
187 /// differ. See [`Self::state`] for more details on how state is
188 /// used and merged.
189 ///
190 /// * `values`: arrays produced from previously calling `state` on other accumulators.
191 ///
192 /// Other arguments are the same as for [`Self::update_batch`].
193 fn merge_batch(
194 &mut self,
195 values: &[ArrayRef],
196 group_indices: &[usize],
197 opt_filter: Option<&BooleanArray>,
198 total_num_groups: usize,
199 ) -> Result<()>;
200
201 /// Converts an input batch directly to the intermediate aggregate state.
202 ///
203 /// This is the equivalent of treating each input row as its own group. It
204 /// is invoked when the Partial phase of a multi-phase aggregation is not
205 /// reducing the cardinality enough to warrant spending more effort on
206 /// pre-aggregation (see `Background` section below), and switches to
207 /// passing intermediate state directly on to the next aggregation phase.
208 ///
209 /// Examples:
210 /// * `COUNT`: an array of 1s for each row in the input batch.
211 /// * `SUM/MIN/MAX`: the input values themselves.
212 ///
213 /// # Arguments
214 /// * `values`: the input arguments to the accumulator
215 /// * `opt_filter`: if present, any row where `opt_filter[i]` is false should be ignored
216 ///
217 /// # Background
218 ///
219 /// In a multi-phase aggregation (see [`Accumulator::state`]), the initial
220 /// Partial phase reduces the cardinality of the input data as soon as
221 /// possible in the plan.
222 ///
223 /// This strategy is very effective for queries with a small number of
224 /// groups, as most of the data is aggregated immediately and only a small
225 /// amount of data must be repartitioned (see [`Accumulator::state`] for
226 /// background)
227 ///
228 /// However, for queries with a large number of groups, the Partial phase
229 /// often does not reduce the cardinality enough to warrant the memory and
230 /// CPU cost of actually performing the aggregation. For such cases, the
231 /// HashAggregate operator will dynamically switch to passing intermediate
232 /// state directly to the next aggregation phase with minimal processing
233 /// using this method.
234 ///
235 /// [`Accumulator::state`]: crate::accumulator::Accumulator::state
236 fn convert_to_state(
237 &self,
238 _values: &[ArrayRef],
239 _opt_filter: Option<&BooleanArray>,
240 ) -> Result<Vec<ArrayRef>> {
241 not_impl_err!("Input batch conversion to state not implemented")
242 }
243
244 /// Returns `true` if [`Self::convert_to_state`] is implemented to support
245 /// intermediate aggregate state conversion.
246 fn supports_convert_to_state(&self) -> bool {
247 false
248 }
249
250 /// Amount of memory used to store the state of this accumulator,
251 /// in bytes.
252 ///
253 /// This function is called once per batch, so it should be `O(n)` to
254 /// compute, not `O(num_groups)`
255 fn size(&self) -> usize;
256}