datafusion_expr_common/
accumulator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Accumulator module contains the trait definition for aggregation function's accumulators.
19
20use arrow::array::ArrayRef;
21use datafusion_common::{internal_err, Result, ScalarValue};
22use std::fmt::Debug;
23
24/// Tracks an aggregate function's state.
25///
26/// `Accumulator`s are stateful objects that implement a single group. They
27/// aggregate values from multiple rows together into a final output aggregate.
28///
29/// [`GroupsAccumulator]` is an additional more performant (but also complex) API
30/// that manages state for multiple groups at once.
31///
32/// An accumulator knows how to:
33/// * update its state from inputs via [`update_batch`]
34///
35/// * compute the final value from its internal state via [`evaluate`]
36///
37/// * retract an update to its state from given inputs via
38///   [`retract_batch`] (when used as a window aggregate [window
39///   function])
40///
41/// * convert its internal state to a vector of aggregate values via
42///   [`state`] and combine the state from multiple accumulators
43///   via [`merge_batch`], as part of efficient multi-phase grouping.
44///
45/// [`GroupsAccumulator`]: crate::GroupsAccumulator
46/// [`update_batch`]: Self::update_batch
47/// [`retract_batch`]: Self::retract_batch
48/// [`state`]: Self::state
49/// [`evaluate`]: Self::evaluate
50/// [`merge_batch`]: Self::merge_batch
51/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
52pub trait Accumulator: Send + Sync + Debug {
53    /// Updates the accumulator's state from its input.
54    ///
55    /// `values` contains the arguments to this aggregate function.
56    ///
57    /// For example, the `SUM` accumulator maintains a running sum,
58    /// and `update_batch` adds each of the input values to the
59    /// running sum.
60    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
61
62    /// Returns the final aggregate value, consuming the internal state.
63    ///
64    /// For example, the `SUM` accumulator maintains a running sum,
65    /// and `evaluate` will produce that running sum as its output.
66    ///
67    /// This function should not be called twice, otherwise it will
68    /// result in potentially non-deterministic behavior.
69    ///
70    /// This function gets `&mut self` to allow for the accumulator to build
71    /// arrow-compatible internal state that can be returned without copying
72    /// when possible (for example distinct strings)
73    fn evaluate(&mut self) -> Result<ScalarValue>;
74
75    /// Returns the allocated size required for this accumulator, in
76    /// bytes, including `Self`.
77    ///
78    /// This value is used to calculate the memory used during
79    /// execution so DataFusion can stay within its allotted limit.
80    ///
81    /// "Allocated" means that for internal containers such as `Vec`,
82    /// the `capacity` should be used not the `len`.
83    fn size(&self) -> usize;
84
85    /// Returns the intermediate state of the accumulator, consuming the
86    /// intermediate state.
87    ///
88    /// This function should not be called twice, otherwise it will
89    /// result in potentially non-deterministic behavior.
90    ///
91    /// This function gets `&mut self` to allow for the accumulator to build
92    /// arrow-compatible internal state that can be returned without copying
93    /// when possible (for example distinct strings).
94    ///
95    /// Intermediate state is used for "multi-phase" grouping in
96    /// DataFusion, where an aggregate is computed in parallel with
97    /// multiple `Accumulator` instances, as described below:
98    ///
99    /// # Multi-Phase Grouping
100    ///
101    /// ```text
102    ///                               ▲
103    ///                               │                   evaluate() is called to
104    ///                               │                   produce the final aggregate
105    ///                               │                   value per group
106    ///                               │
107    ///                  ┌─────────────────────────┐
108    ///                  │GroupBy                  │
109    ///                  │(AggregateMode::Final)   │      state() is called for each
110    ///                  │                         │      group and the resulting
111    ///                  └─────────────────────────┘      RecordBatches passed to the
112    ///                                                   Final GroupBy via merge_batch()
113    ///                               ▲
114    ///                               │
115    ///              ┌────────────────┴───────────────┐
116    ///              │                                │
117    ///              │                                │
118    /// ┌─────────────────────────┐      ┌─────────────────────────┐
119    /// │        GroupBy          │      │        GroupBy          │
120    /// │(AggregateMode::Partial) │      │(AggregateMode::Partial) │
121    /// └─────────────────────────┘      └─────────────────────────┘
122    ///              ▲                                ▲
123    ///              │                                │    update_batch() is called for
124    ///              │                                │    each input RecordBatch
125    ///         .─────────.                      .─────────.
126    ///      ,─'           '─.                ,─'           '─.
127    ///     ;      Input      :              ;      Input      :
128    ///     :   Partition 0   ;              :   Partition 1   ;
129    ///      ╲               ╱                ╲               ╱
130    ///       '─.         ,─'                  '─.         ,─'
131    ///          `───────'                        `───────'
132    /// ```
133    ///
134    /// The partial state is serialized as `Arrays` and then combined
135    /// with other partial states from different instances of this
136    /// Accumulator (that ran on different partitions, for example).
137    ///
138    /// The state can be and often is a different type than the output
139    /// type of the [`Accumulator`] and needs different merge
140    /// operations (for example, the partial state for `COUNT` needs
141    /// to be summed together)
142    ///
143    /// Some accumulators can return multiple values for their
144    /// intermediate states. For example, the average accumulator
145    /// tracks `sum` and `n`, and this function should return a vector
146    /// of two values, sum and n.
147    ///
148    /// Note that [`ScalarValue::List`] can be used to pass multiple
149    /// values if the number of intermediate values is not known at
150    /// planning time (e.g. for `MEDIAN`)
151    ///
152    /// # Multi-phase repartitioned Grouping
153    ///
154    /// Many multi-phase grouping plans contain a Repartition operation
155    /// as well as shown below:
156    ///
157    /// ```text
158    ///                ▲                          ▲
159    ///                │                          │
160    ///                │                          │
161    ///                │                          │
162    ///                │                          │
163    ///                │                          │
164    ///    ┌───────────────────────┐  ┌───────────────────────┐       4. Each AggregateMode::Final
165    ///    │GroupBy                │  │GroupBy                │       GroupBy has an entry for its
166    ///    │(AggregateMode::Final) │  │(AggregateMode::Final) │       subset of groups (in this case
167    ///    │                       │  │                       │       that means half the entries)
168    ///    └───────────────────────┘  └───────────────────────┘
169    ///                ▲                          ▲
170    ///                │                          │
171    ///                └─────────────┬────────────┘
172    ///                              │
173    ///                              │
174    ///                              │
175    ///                 ┌─────────────────────────┐                   3. Repartitioning by hash(group
176    ///                 │       Repartition       │                   keys) ensures that each distinct
177    ///                 │         HASH(x)         │                   group key now appears in exactly
178    ///                 └─────────────────────────┘                   one partition
179    ///                              ▲
180    ///                              │
181    ///              ┌───────────────┴─────────────┐
182    ///              │                             │
183    ///              │                             │
184    /// ┌─────────────────────────┐  ┌──────────────────────────┐     2. Each AggregateMode::Partial
185    /// │        GroupBy          │  │       GroupBy            │     GroupBy has an entry for *all*
186    /// │(AggregateMode::Partial) │  │ (AggregateMode::Partial) │     the groups
187    /// └─────────────────────────┘  └──────────────────────────┘
188    ///              ▲                             ▲
189    ///              │                             │
190    ///              │                             │
191    ///         .─────────.                   .─────────.
192    ///      ,─'           '─.             ,─'           '─.
193    ///     ;      Input      :           ;      Input      :         1. Since input data is
194    ///     :   Partition 0   ;           :   Partition 1   ;         arbitrarily or RoundRobin
195    ///      ╲               ╱             ╲               ╱          distributed, each partition
196    ///       '─.         ,─'               '─.         ,─'           likely has all distinct
197    ///          `───────'                     `───────'
198    /// ```
199    ///
200    /// This structure is used so that the `AggregateMode::Partial` accumulators
201    /// reduces the cardinality of the input as soon as possible. Typically,
202    /// each partial accumulator sees all groups in the input as the group keys
203    /// are evenly distributed across the input.
204    ///
205    /// The final output is computed by repartitioning the result of
206    /// [`Self::state`] from each Partial aggregate and `hash(group keys)` so
207    /// that each distinct group key appears in exactly one of the
208    /// `AggregateMode::Final` GroupBy nodes. The outputs of the final nodes are
209    /// then unioned together to produce the overall final output.
210    ///
211    /// Here is an example that shows the distribution of groups in the
212    /// different phases
213    ///
214    /// ```text
215    ///               ┌─────┐                ┌─────┐
216    ///               │  1  │                │  3  │
217    ///               ├─────┤                ├─────┤
218    ///               │  2  │                │  4  │                After repartitioning by
219    ///               └─────┘                └─────┘                hash(group keys), each distinct
220    ///               ┌─────┐                ┌─────┐                group key now appears in exactly
221    ///               │  1  │                │  3  │                one partition
222    ///               ├─────┤                ├─────┤
223    ///               │  2  │                │  4  │
224    ///               └─────┘                └─────┘
225    ///
226    ///
227    /// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
228    ///
229    ///               ┌─────┐                ┌─────┐
230    ///               │  2  │                │  2  │
231    ///               ├─────┤                ├─────┤
232    ///               │  1  │                │  2  │
233    ///               ├─────┤                ├─────┤
234    ///               │  3  │                │  3  │
235    ///               ├─────┤                ├─────┤
236    ///               │  4  │                │  1  │
237    ///               └─────┘                └─────┘                Input data is arbitrarily or
238    ///                 ...                    ...                  RoundRobin distributed, each
239    ///               ┌─────┐                ┌─────┐                partition likely has all
240    ///               │  1  │                │  4  │                distinct group keys
241    ///               ├─────┤                ├─────┤
242    ///               │  4  │                │  3  │
243    ///               ├─────┤                ├─────┤
244    ///               │  1  │                │  1  │
245    ///               ├─────┤                ├─────┤
246    ///               │  4  │                │  3  │
247    ///               └─────┘                └─────┘
248    ///
249    ///           group values           group values
250    ///           in partition 0         in partition 1
251    /// ```
252    fn state(&mut self) -> Result<Vec<ScalarValue>>;
253
254    /// Updates the accumulator's state from an `Array` containing one
255    /// or more intermediate values.
256    ///
257    /// For some aggregates (such as `SUM`), merge_batch is the same
258    /// as `update_batch`, but for some aggregates (such as `COUNT`)
259    /// the operations differ. See [`Self::state`] for more details on how
260    /// state is used and merged.
261    ///
262    /// The `states` array passed was formed by concatenating the
263    /// results of calling [`Self::state`] on zero or more other
264    /// `Accumulator` instances.
265    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
266
267    /// Retracts (removed) an update (caused by the given inputs) to
268    /// accumulator's state.
269    ///
270    /// This is the inverse operation of [`Self::update_batch`] and is used
271    /// to incrementally calculate window aggregates where the `OVER`
272    /// clause defines a bounded window.
273    ///
274    /// # Example
275    ///
276    /// For example, given the following input partition
277    ///
278    /// ```text
279    ///                     │      current      │
280    ///                            window
281    ///                     │                   │
282    ///                ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐
283    ///     Input      │ A  │ B  │ C  │ D  │ E  │ F  │ G  │ H  │ I  │
284    ///   partition    └────┴────┴────┴────┼────┴────┴────┴────┼────┘
285    ///
286    ///                                    │         next      │
287    ///                                             window
288    /// ```
289    ///
290    /// First, [`Self::evaluate`] will be called to produce the output
291    /// for the current window.
292    ///
293    /// Then, to advance to the next window:
294    ///
295    /// First, [`Self::retract_batch`] will be called with the values
296    /// that are leaving the window, `[B, C, D]` and then
297    /// [`Self::update_batch`] will be called with the values that are
298    /// entering the window, `[F, G, H]`.
299    fn retract_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
300        // TODO add retract for all accumulators
301        internal_err!(
302            "Retract should be implemented for aggregate functions when used with custom window frame queries"
303        )
304    }
305
306    /// Does the accumulator support incrementally updating its value
307    /// by *removing* values.
308    ///
309    /// If this function returns true, [`Self::retract_batch`] will be
310    /// called for sliding window functions such as queries with an
311    /// `OVER (ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING)`
312    fn supports_retract_batch(&self) -> bool {
313        false
314    }
315}