Skip to main content

datafusion_expr_common/
accumulator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Accumulator module contains the trait definition for aggregation function's accumulators.
19
20use arrow::array::ArrayRef;
21use datafusion_common::{Result, ScalarValue, internal_err};
22use std::fmt::Debug;
23
24/// Tracks an aggregate function's state.
25///
26/// `Accumulator`s are stateful objects that implement a single group. They
27/// aggregate values from multiple rows together into a final output aggregate.
28///
29/// [`GroupsAccumulator]` is an additional more performant (but also complex) API
30/// that manages state for multiple groups at once.
31///
32/// An accumulator knows how to:
33/// * update its state from inputs via [`update_batch`]
34///
35/// * compute the final value from its internal state via [`evaluate`]
36///
37/// * retract an update to its state from given inputs via
38///   [`retract_batch`] (when used as a window aggregate [window
39///   function])
40///
41/// * convert its internal state to a vector of aggregate values via
42///   [`state`] and combine the state from multiple accumulators
43///   via [`merge_batch`], as part of efficient multi-phase grouping.
44///
45/// [`update_batch`]: Self::update_batch
46/// [`retract_batch`]: Self::retract_batch
47/// [`state`]: Self::state
48/// [`evaluate`]: Self::evaluate
49/// [`merge_batch`]: Self::merge_batch
50/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
51pub trait Accumulator: Send + Sync + Debug {
52    /// Updates the accumulator's state from its input.
53    ///
54    /// `values` contains the arguments to this aggregate function.
55    ///
56    /// For example, the `SUM` accumulator maintains a running sum,
57    /// and `update_batch` adds each of the input values to the
58    /// running sum.
59    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
60
61    /// Returns the final aggregate value.
62    ///
63    /// For example, the `SUM` accumulator maintains a running sum,
64    /// and `evaluate` will produce that running sum as its output.
65    ///
66    /// This function gets `&mut self` to allow for the accumulator to build
67    /// arrow-compatible internal state that can be returned without copying
68    /// when possible (for example distinct strings).
69    ///
70    /// ## Correctness
71    ///
72    /// This function must not consume the internal state, as it is also used in window
73    /// aggregate functions where it can be executed multiple times depending on the
74    /// current window frame. Consuming the internal state can cause the next invocation
75    /// to have incorrect results.
76    ///
77    /// - Even if this accumulator doesn't implement [`retract_batch`] it may still be used
78    ///   in window aggregate functions where the window frame is
79    ///   `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`
80    ///
81    /// It is fine to modify the state (e.g. re-order elements within internal state vec) so long
82    /// as this doesn't cause an incorrect computation on the next call of evaluate.
83    ///
84    /// [`retract_batch`]: Self::retract_batch
85    fn evaluate(&mut self) -> Result<ScalarValue>;
86
87    /// Returns the allocated size required for this accumulator, in
88    /// bytes, including `Self`.
89    ///
90    /// This value is used to calculate the memory used during
91    /// execution so DataFusion can stay within its allotted limit.
92    ///
93    /// "Allocated" means that for internal containers such as `Vec`,
94    /// the `capacity` should be used not the `len`.
95    fn size(&self) -> usize;
96
97    /// Returns the intermediate state of the accumulator, consuming the
98    /// intermediate state.
99    ///
100    /// This function should not be called twice, otherwise it will
101    /// result in potentially non-deterministic behavior.
102    ///
103    /// This function gets `&mut self` to allow for the accumulator to build
104    /// arrow-compatible internal state that can be returned without copying
105    /// when possible (for example distinct strings).
106    ///
107    /// Intermediate state is used for "multi-phase" grouping in
108    /// DataFusion, where an aggregate is computed in parallel with
109    /// multiple `Accumulator` instances, as described below:
110    ///
111    /// # Multi-Phase Grouping
112    ///
113    /// ```text
114    ///                               ▲
115    ///                               │                   evaluate() is called to
116    ///                               │                   produce the final aggregate
117    ///                               │                   value per group
118    ///                               │
119    ///                  ┌─────────────────────────┐
120    ///                  │GroupBy                  │
121    ///                  │(AggregateMode::Final)   │      state() is called for each
122    ///                  │                         │      group and the resulting
123    ///                  └─────────────────────────┘      RecordBatches passed to the
124    ///                                                   Final GroupBy via merge_batch()
125    ///                               ▲
126    ///                               │
127    ///              ┌────────────────┴───────────────┐
128    ///              │                                │
129    ///              │                                │
130    /// ┌─────────────────────────┐      ┌─────────────────────────┐
131    /// │        GroupBy          │      │        GroupBy          │
132    /// │(AggregateMode::Partial) │      │(AggregateMode::Partial) │
133    /// └─────────────────────────┘      └─────────────────────────┘
134    ///              ▲                                ▲
135    ///              │                                │    update_batch() is called for
136    ///              │                                │    each input RecordBatch
137    ///         .─────────.                      .─────────.
138    ///      ,─'           '─.                ,─'           '─.
139    ///     ;      Input      :              ;      Input      :
140    ///     :   Partition 0   ;              :   Partition 1   ;
141    ///      ╲               ╱                ╲               ╱
142    ///       '─.         ,─'                  '─.         ,─'
143    ///          `───────'                        `───────'
144    /// ```
145    ///
146    /// The partial state is serialized as `Arrays` and then combined
147    /// with other partial states from different instances of this
148    /// Accumulator (that ran on different partitions, for example).
149    ///
150    /// The state can be and often is a different type than the output
151    /// type of the [`Accumulator`] and needs different merge
152    /// operations (for example, the partial state for `COUNT` needs
153    /// to be summed together)
154    ///
155    /// Some accumulators can return multiple values for their
156    /// intermediate states. For example, the average accumulator
157    /// tracks `sum` and `n`, and this function should return a vector
158    /// of two values, sum and n.
159    ///
160    /// Note that [`ScalarValue::List`] can be used to pass multiple
161    /// values if the number of intermediate values is not known at
162    /// planning time (e.g. for `MEDIAN`)
163    ///
164    /// # Multi-phase repartitioned Grouping
165    ///
166    /// Many multi-phase grouping plans contain a Repartition operation
167    /// as well as shown below:
168    ///
169    /// ```text
170    ///                ▲                          ▲
171    ///                │                          │
172    ///                │                          │
173    ///                │                          │
174    ///                │                          │
175    ///                │                          │
176    ///    ┌───────────────────────┐  ┌───────────────────────┐       4. Each AggregateMode::Final
177    ///    │GroupBy                │  │GroupBy                │       GroupBy has an entry for its
178    ///    │(AggregateMode::Final) │  │(AggregateMode::Final) │       subset of groups (in this case
179    ///    │                       │  │                       │       that means half the entries)
180    ///    └───────────────────────┘  └───────────────────────┘
181    ///                ▲                          ▲
182    ///                │                          │
183    ///                └─────────────┬────────────┘
184    ///                              │
185    ///                              │
186    ///                              │
187    ///                 ┌─────────────────────────┐                   3. Repartitioning by hash(group
188    ///                 │       Repartition       │                   keys) ensures that each distinct
189    ///                 │         HASH(x)         │                   group key now appears in exactly
190    ///                 └─────────────────────────┘                   one partition
191    ///                              ▲
192    ///                              │
193    ///              ┌───────────────┴─────────────┐
194    ///              │                             │
195    ///              │                             │
196    /// ┌─────────────────────────┐  ┌──────────────────────────┐     2. Each AggregateMode::Partial
197    /// │        GroupBy          │  │       GroupBy            │     GroupBy has an entry for *all*
198    /// │(AggregateMode::Partial) │  │ (AggregateMode::Partial) │     the groups
199    /// └─────────────────────────┘  └──────────────────────────┘
200    ///              ▲                             ▲
201    ///              │                             │
202    ///              │                             │
203    ///         .─────────.                   .─────────.
204    ///      ,─'           '─.             ,─'           '─.
205    ///     ;      Input      :           ;      Input      :         1. Since input data is
206    ///     :   Partition 0   ;           :   Partition 1   ;         arbitrarily or RoundRobin
207    ///      ╲               ╱             ╲               ╱          distributed, each partition
208    ///       '─.         ,─'               '─.         ,─'           likely has all distinct
209    ///          `───────'                     `───────'
210    /// ```
211    ///
212    /// This structure is used so that the `AggregateMode::Partial` accumulators
213    /// reduces the cardinality of the input as soon as possible. Typically,
214    /// each partial accumulator sees all groups in the input as the group keys
215    /// are evenly distributed across the input.
216    ///
217    /// The final output is computed by repartitioning the result of
218    /// [`Self::state`] from each Partial aggregate and `hash(group keys)` so
219    /// that each distinct group key appears in exactly one of the
220    /// `AggregateMode::Final` GroupBy nodes. The outputs of the final nodes are
221    /// then unioned together to produce the overall final output.
222    ///
223    /// Here is an example that shows the distribution of groups in the
224    /// different phases
225    ///
226    /// ```text
227    ///               ┌─────┐                ┌─────┐
228    ///               │  1  │                │  3  │
229    ///               ├─────┤                ├─────┤
230    ///               │  2  │                │  4  │                After repartitioning by
231    ///               └─────┘                └─────┘                hash(group keys), each distinct
232    ///               ┌─────┐                ┌─────┐                group key now appears in exactly
233    ///               │  1  │                │  3  │                one partition
234    ///               ├─────┤                ├─────┤
235    ///               │  2  │                │  4  │
236    ///               └─────┘                └─────┘
237    ///
238    ///
239    /// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
240    ///
241    ///               ┌─────┐                ┌─────┐
242    ///               │  2  │                │  2  │
243    ///               ├─────┤                ├─────┤
244    ///               │  1  │                │  2  │
245    ///               ├─────┤                ├─────┤
246    ///               │  3  │                │  3  │
247    ///               ├─────┤                ├─────┤
248    ///               │  4  │                │  1  │
249    ///               └─────┘                └─────┘                Input data is arbitrarily or
250    ///                 ...                    ...                  RoundRobin distributed, each
251    ///               ┌─────┐                ┌─────┐                partition likely has all
252    ///               │  1  │                │  4  │                distinct group keys
253    ///               ├─────┤                ├─────┤
254    ///               │  4  │                │  3  │
255    ///               ├─────┤                ├─────┤
256    ///               │  1  │                │  1  │
257    ///               ├─────┤                ├─────┤
258    ///               │  4  │                │  3  │
259    ///               └─────┘                └─────┘
260    ///
261    ///           group values           group values
262    ///           in partition 0         in partition 1
263    /// ```
264    fn state(&mut self) -> Result<Vec<ScalarValue>>;
265
266    /// Updates the accumulator's state from an `Array` containing one
267    /// or more intermediate values.
268    ///
269    /// For some aggregates (such as `SUM`), merge_batch is the same
270    /// as `update_batch`, but for some aggregates (such as `COUNT`)
271    /// the operations differ. See [`Self::state`] for more details on how
272    /// state is used and merged.
273    ///
274    /// The `states` array passed was formed by concatenating the
275    /// results of calling [`Self::state`] on zero or more other
276    /// `Accumulator` instances.
277    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
278
279    /// Retracts (removed) an update (caused by the given inputs) to
280    /// accumulator's state.
281    ///
282    /// This is the inverse operation of [`Self::update_batch`] and is used
283    /// to incrementally calculate window aggregates where the `OVER`
284    /// clause defines a bounded window.
285    ///
286    /// # Example
287    ///
288    /// For example, given the following input partition
289    ///
290    /// ```text
291    ///                     │      current      │
292    ///                            window
293    ///                     │                   │
294    ///                ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐
295    ///     Input      │ A  │ B  │ C  │ D  │ E  │ F  │ G  │ H  │ I  │
296    ///   partition    └────┴────┴────┴────┼────┴────┴────┴────┼────┘
297    ///
298    ///                                    │         next      │
299    ///                                             window
300    /// ```
301    ///
302    /// First, [`Self::evaluate`] will be called to produce the output
303    /// for the current window.
304    ///
305    /// Then, to advance to the next window:
306    ///
307    /// First, [`Self::retract_batch`] will be called with the values
308    /// that are leaving the window, `[B, C, D]` and then
309    /// [`Self::update_batch`] will be called with the values that are
310    /// entering the window, `[F, G, H]`.
311    fn retract_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
312        // TODO add retract for all accumulators
313        internal_err!(
314            "Retract should be implemented for aggregate functions when used with custom window frame queries"
315        )
316    }
317
318    /// Does the accumulator support incrementally updating its value
319    /// by *removing* values.
320    ///
321    /// If this function returns true, [`Self::retract_batch`] will be
322    /// called for sliding window functions such as queries with an
323    /// `OVER (ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING)`
324    fn supports_retract_batch(&self) -> bool {
325        false
326    }
327}