datafusion_functions_aggregate/
array_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
19
20use std::cmp::Ordering;
21use std::collections::{HashSet, VecDeque};
22use std::mem::{size_of, size_of_val, take};
23use std::sync::Arc;
24
25use arrow::array::{
26    new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray,
27};
28use arrow::compute::{filter, SortOptions};
29use arrow::datatypes::{DataType, Field, FieldRef, Fields};
30
31use datafusion_common::cast::as_list_array;
32use datafusion_common::utils::{
33    compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder,
34};
35use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
40};
41use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
42use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
43use datafusion_functions_aggregate_common::utils::ordering_fields;
44use datafusion_macros::user_doc;
45use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
46
47make_udaf_expr_and_func!(
48    ArrayAgg,
49    array_agg,
50    expression,
51    "input values, including nulls, concatenated into an array",
52    array_agg_udaf
53);
54
55#[user_doc(
56    doc_section(label = "General Functions"),
57    description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
58This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
59    syntax_example = "array_agg(expression [ORDER BY expression])",
60    sql_example = r#"
61```sql
62> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
63+-----------------------------------------------+
64| array_agg(column_name ORDER BY other_column)  |
65+-----------------------------------------------+
66| [element1, element2, element3]                |
67+-----------------------------------------------+
68> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
69+--------------------------------------------------------+
70| array_agg(DISTINCT column_name ORDER BY column_name)  |
71+--------------------------------------------------------+
72| [element1, element2, element3]                         |
73+--------------------------------------------------------+
74```
75"#,
76    standard_argument(name = "expression",)
77)]
78#[derive(Debug, PartialEq, Eq, Hash)]
79/// ARRAY_AGG aggregate expression
80pub struct ArrayAgg {
81    signature: Signature,
82    is_input_pre_ordered: bool,
83}
84
85impl Default for ArrayAgg {
86    fn default() -> Self {
87        Self {
88            signature: Signature::any(1, Volatility::Immutable),
89            is_input_pre_ordered: false,
90        }
91    }
92}
93
94impl AggregateUDFImpl for ArrayAgg {
95    fn as_any(&self) -> &dyn std::any::Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "array_agg"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        Ok(DataType::List(Arc::new(Field::new_list_field(
109            arg_types[0].clone(),
110            true,
111        ))))
112    }
113
114    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115        if args.is_distinct {
116            return Ok(vec![Field::new_list(
117                format_state_name(args.name, "distinct_array_agg"),
118                // See COMMENTS.md to understand why nullable is set to true
119                Field::new_list_field(args.input_fields[0].data_type().clone(), true),
120                true,
121            )
122            .into()]);
123        }
124
125        let mut fields = vec![Field::new_list(
126            format_state_name(args.name, "array_agg"),
127            // See COMMENTS.md to understand why nullable is set to true
128            Field::new_list_field(args.input_fields[0].data_type().clone(), true),
129            true,
130        )
131        .into()];
132
133        if args.ordering_fields.is_empty() {
134            return Ok(fields);
135        }
136
137        let orderings = args.ordering_fields.to_vec();
138        fields.push(
139            Field::new_list(
140                format_state_name(args.name, "array_agg_orderings"),
141                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
142                false,
143            )
144            .into(),
145        );
146
147        Ok(fields)
148    }
149
150    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
151        AggregateOrderSensitivity::SoftRequirement
152    }
153
154    fn with_beneficial_ordering(
155        self: Arc<Self>,
156        beneficial_ordering: bool,
157    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
158        Ok(Some(Arc::new(Self {
159            signature: self.signature.clone(),
160            is_input_pre_ordered: beneficial_ordering,
161        })))
162    }
163
164    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
165        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
166        let ignore_nulls =
167            acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
168
169        if acc_args.is_distinct {
170            // Limitation similar to Postgres. The aggregation function can only mix
171            // DISTINCT and ORDER BY if all the expressions in the ORDER BY appear
172            // also in the arguments of the function. This implies that if the
173            // aggregation function only accepts one argument, only one argument
174            // can be used in the ORDER BY, For example:
175            //
176            // ARRAY_AGG(DISTINCT col)
177            //
178            // can only be mixed with an ORDER BY if the order expression is "col".
179            //
180            // ARRAY_AGG(DISTINCT col ORDER BY col)                         <- Valid
181            // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid
182            // ARRAY_AGG(DISTINCT col ORDER BY other_col)                   <- Invalid
183            // ARRAY_AGG(DISTINCT col ORDER BY concat(col, ''))             <- Invalid
184            let sort_option = match acc_args.order_bys {
185                [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options),
186                [] => None,
187                _ => {
188                    return exec_err!(
189                        "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"
190                    );
191                }
192            };
193            return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
194                &data_type,
195                sort_option,
196                ignore_nulls,
197            )?));
198        }
199
200        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
201            return Ok(Box::new(ArrayAggAccumulator::try_new(
202                &data_type,
203                ignore_nulls,
204            )?));
205        };
206
207        let ordering_dtypes = ordering
208            .iter()
209            .map(|e| e.expr.data_type(acc_args.schema))
210            .collect::<Result<Vec<_>>>()?;
211
212        OrderSensitiveArrayAggAccumulator::try_new(
213            &data_type,
214            &ordering_dtypes,
215            ordering,
216            self.is_input_pre_ordered,
217            acc_args.is_reversed,
218            ignore_nulls,
219        )
220        .map(|acc| Box::new(acc) as _)
221    }
222
223    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
224        datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
225    }
226
227    fn documentation(&self) -> Option<&Documentation> {
228        self.doc()
229    }
230}
231
232#[derive(Debug)]
233pub struct ArrayAggAccumulator {
234    values: Vec<ArrayRef>,
235    datatype: DataType,
236    ignore_nulls: bool,
237}
238
239impl ArrayAggAccumulator {
240    /// new array_agg accumulator based on given item data type
241    pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
242        Ok(Self {
243            values: vec![],
244            datatype: datatype.clone(),
245            ignore_nulls,
246        })
247    }
248
249    /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
250    /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
251    fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
252        let offsets = list_array.value_offsets();
253        // Offsets always have at least 1 value
254        let initial_offset = offsets[0];
255        let null_count = list_array.null_count();
256
257        // If no nulls than just use the fast path
258        // This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
259        if null_count == 0 {
260            // According to Arrow specification, the first offset can be non-zero
261            let list_values = list_array.values().slice(
262                initial_offset as usize,
263                (offsets[offsets.len() - 1] - initial_offset) as usize,
264            );
265            return Some(list_values);
266        }
267
268        // If all the values are null than just return an empty values array
269        if list_array.null_count() == list_array.len() {
270            return Some(list_array.values().slice(0, 0));
271        }
272
273        // According to the Arrow spec, null values can point to non-empty lists
274        // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
275
276        // Unwrapping is safe as we just checked if there is a null value
277        let nulls = list_array.nulls().unwrap();
278
279        let mut valid_slices_iter = nulls.valid_slices();
280
281        // This is safe as we validated that there is at least 1 valid value in the array
282        let (start, end) = valid_slices_iter.next().unwrap();
283
284        let start_offset = offsets[start];
285
286        // End is exclusive, so it already point to the last offset value
287        // This is valid as the length of the array is always 1 less than the length of the offsets
288        let mut end_offset_of_last_valid_value = offsets[end];
289
290        for (start, end) in valid_slices_iter {
291            // If there is a null value that point to a non-empty list than the start offset of the valid value
292            // will be different that the end offset of the last valid value
293            if offsets[start] != end_offset_of_last_valid_value {
294                return None;
295            }
296
297            // End is exclusive, so it already point to the last offset value
298            // This is valid as the length of the array is always 1 less than the length of the offsets
299            end_offset_of_last_valid_value = offsets[end];
300        }
301
302        let consecutive_valid_values = list_array.values().slice(
303            start_offset as usize,
304            (end_offset_of_last_valid_value - start_offset) as usize,
305        );
306
307        Some(consecutive_valid_values)
308    }
309}
310
311impl Accumulator for ArrayAggAccumulator {
312    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
313        // Append value like Int64Array(1,2,3)
314        if values.is_empty() {
315            return Ok(());
316        }
317
318        if values.len() != 1 {
319            return internal_err!("expects single batch");
320        }
321
322        let val = &values[0];
323        let nulls = if self.ignore_nulls {
324            val.logical_nulls()
325        } else {
326            None
327        };
328
329        let val = match nulls {
330            Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
331            Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
332            None => Arc::clone(val),
333        };
334
335        if !val.is_empty() {
336            self.values.push(val)
337        }
338
339        Ok(())
340    }
341
342    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
343        // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
344        if states.is_empty() {
345            return Ok(());
346        }
347
348        if states.len() != 1 {
349            return internal_err!("expects single state");
350        }
351
352        let list_arr = as_list_array(&states[0])?;
353
354        match Self::get_optional_values_to_merge_as_is(list_arr) {
355            Some(values) => {
356                // Make sure we don't insert empty lists
357                if !values.is_empty() {
358                    self.values.push(values);
359                }
360            }
361            None => {
362                for arr in list_arr.iter().flatten() {
363                    self.values.push(arr);
364                }
365            }
366        }
367
368        Ok(())
369    }
370
371    fn state(&mut self) -> Result<Vec<ScalarValue>> {
372        Ok(vec![self.evaluate()?])
373    }
374
375    fn evaluate(&mut self) -> Result<ScalarValue> {
376        // Transform Vec<ListArr> to ListArr
377        let element_arrays: Vec<&dyn Array> =
378            self.values.iter().map(|a| a.as_ref()).collect();
379
380        if element_arrays.is_empty() {
381            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
382        }
383
384        let concated_array = arrow::compute::concat(&element_arrays)?;
385
386        Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
387    }
388
389    fn size(&self) -> usize {
390        size_of_val(self)
391            + (size_of::<ArrayRef>() * self.values.capacity())
392            + self
393                .values
394                .iter()
395                // Each ArrayRef might be just a reference to a bigger array, and many
396                // ArrayRefs here might be referencing exactly the same array, so if we
397                // were to call `arr.get_array_memory_size()`, we would be double-counting
398                // the same underlying data many times.
399                //
400                // Instead, we do an approximation by estimating how much memory each
401                // ArrayRef would occupy if its underlying data was fully owned by this
402                // accumulator.
403                //
404                // Note that this is just an estimation, but the reality is that this
405                // accumulator might not own any data.
406                .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
407                .sum::<usize>()
408            + self.datatype.size()
409            - size_of_val(&self.datatype)
410    }
411}
412
413#[derive(Debug)]
414struct DistinctArrayAggAccumulator {
415    values: HashSet<ScalarValue>,
416    datatype: DataType,
417    sort_options: Option<SortOptions>,
418    ignore_nulls: bool,
419}
420
421impl DistinctArrayAggAccumulator {
422    pub fn try_new(
423        datatype: &DataType,
424        sort_options: Option<SortOptions>,
425        ignore_nulls: bool,
426    ) -> Result<Self> {
427        Ok(Self {
428            values: HashSet::new(),
429            datatype: datatype.clone(),
430            sort_options,
431            ignore_nulls,
432        })
433    }
434}
435
436impl Accumulator for DistinctArrayAggAccumulator {
437    fn state(&mut self) -> Result<Vec<ScalarValue>> {
438        Ok(vec![self.evaluate()?])
439    }
440
441    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
442        if values.is_empty() {
443            return Ok(());
444        }
445
446        let val = &values[0];
447        let nulls = if self.ignore_nulls {
448            val.logical_nulls()
449        } else {
450            None
451        };
452
453        let nulls = nulls.as_ref();
454        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
455            for i in 0..val.len() {
456                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
457                    self.values
458                        .insert(ScalarValue::try_from_array(val, i)?.compacted());
459                }
460            }
461        }
462
463        Ok(())
464    }
465
466    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
467        if states.is_empty() {
468            return Ok(());
469        }
470
471        if states.len() != 1 {
472            return internal_err!("expects single state");
473        }
474
475        states[0]
476            .as_list::<i32>()
477            .iter()
478            .flatten()
479            .try_for_each(|val| self.update_batch(&[val]))
480    }
481
482    fn evaluate(&mut self) -> Result<ScalarValue> {
483        let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
484        if values.is_empty() {
485            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
486        }
487
488        if let Some(opts) = self.sort_options {
489            let mut delayed_cmp_err = Ok(());
490            values.sort_by(|a, b| {
491                if a.is_null() {
492                    return match opts.nulls_first {
493                        true => Ordering::Less,
494                        false => Ordering::Greater,
495                    };
496                }
497                if b.is_null() {
498                    return match opts.nulls_first {
499                        true => Ordering::Greater,
500                        false => Ordering::Less,
501                    };
502                }
503                match opts.descending {
504                    true => b.try_cmp(a),
505                    false => a.try_cmp(b),
506                }
507                .unwrap_or_else(|err| {
508                    delayed_cmp_err = Err(err);
509                    Ordering::Equal
510                })
511            });
512            delayed_cmp_err?;
513        };
514
515        let arr = ScalarValue::new_list(&values, &self.datatype, true);
516        Ok(ScalarValue::List(arr))
517    }
518
519    fn size(&self) -> usize {
520        size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
521            - size_of_val(&self.values)
522            + self.datatype.size()
523            - size_of_val(&self.datatype)
524            - size_of_val(&self.sort_options)
525            + size_of::<Option<SortOptions>>()
526    }
527}
528
529/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi
530/// partition setting, partial aggregations are computed for every partition,
531/// and then their results are merged.
532#[derive(Debug)]
533pub(crate) struct OrderSensitiveArrayAggAccumulator {
534    /// Stores entries in the `ARRAY_AGG` result.
535    values: Vec<ScalarValue>,
536    /// Stores values of ordering requirement expressions corresponding to each
537    /// entry in `values`. This information is used when merging results from
538    /// different partitions. For detailed information how merging is done, see
539    /// [`merge_ordered_arrays`].
540    ordering_values: Vec<Vec<ScalarValue>>,
541    /// Stores datatypes of expressions inside values and ordering requirement
542    /// expressions.
543    datatypes: Vec<DataType>,
544    /// Stores the ordering requirement of the `Accumulator`.
545    ordering_req: LexOrdering,
546    /// Whether the input is known to be pre-ordered
547    is_input_pre_ordered: bool,
548    /// Whether the aggregation is running in reverse.
549    reverse: bool,
550    /// Whether the aggregation should ignore null values.
551    ignore_nulls: bool,
552}
553
554impl OrderSensitiveArrayAggAccumulator {
555    /// Create a new order-sensitive ARRAY_AGG accumulator based on the given
556    /// item data type.
557    pub fn try_new(
558        datatype: &DataType,
559        ordering_dtypes: &[DataType],
560        ordering_req: LexOrdering,
561        is_input_pre_ordered: bool,
562        reverse: bool,
563        ignore_nulls: bool,
564    ) -> Result<Self> {
565        let mut datatypes = vec![datatype.clone()];
566        datatypes.extend(ordering_dtypes.iter().cloned());
567        Ok(Self {
568            values: vec![],
569            ordering_values: vec![],
570            datatypes,
571            ordering_req,
572            is_input_pre_ordered,
573            reverse,
574            ignore_nulls,
575        })
576    }
577
578    fn sort(&mut self) {
579        let sort_options = self
580            .ordering_req
581            .iter()
582            .map(|sort_expr| sort_expr.options)
583            .collect::<Vec<_>>();
584        let mut values = take(&mut self.values)
585            .into_iter()
586            .zip(take(&mut self.ordering_values))
587            .collect::<Vec<_>>();
588        let mut delayed_cmp_err = Ok(());
589        values.sort_by(|(_, left_ordering), (_, right_ordering)| {
590            compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
591                |err| {
592                    delayed_cmp_err = Err(err);
593                    Ordering::Equal
594                },
595            )
596        });
597        (self.values, self.ordering_values) = values.into_iter().unzip();
598    }
599
600    fn evaluate_orderings(&self) -> Result<ScalarValue> {
601        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
602
603        let column_wise_ordering_values = if self.ordering_values.is_empty() {
604            fields
605                .iter()
606                .map(|f| new_empty_array(f.data_type()))
607                .collect::<Vec<_>>()
608        } else {
609            (0..fields.len())
610                .map(|i| {
611                    let column_values = self.ordering_values.iter().map(|x| x[i].clone());
612                    ScalarValue::iter_to_array(column_values)
613                })
614                .collect::<Result<_>>()?
615        };
616
617        let ordering_array = StructArray::try_new(
618            Fields::from(fields),
619            column_wise_ordering_values,
620            None,
621        )?;
622        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
623    }
624}
625
626impl Accumulator for OrderSensitiveArrayAggAccumulator {
627    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
628        if values.is_empty() {
629            return Ok(());
630        }
631
632        let val = &values[0];
633        let ord = &values[1..];
634        let nulls = if self.ignore_nulls {
635            val.logical_nulls()
636        } else {
637            None
638        };
639
640        let nulls = nulls.as_ref();
641        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
642            for i in 0..val.len() {
643                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
644                    self.values
645                        .push(ScalarValue::try_from_array(val, i)?.compacted());
646                    self.ordering_values.push(
647                        get_row_at_idx(ord, i)?
648                            .into_iter()
649                            .map(|v| v.compacted())
650                            .collect(),
651                    )
652                }
653            }
654        }
655
656        Ok(())
657    }
658
659    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
660        if states.is_empty() {
661            return Ok(());
662        }
663
664        // First entry in the state is the aggregation result. Second entry
665        // stores values received for ordering requirement columns for each
666        // aggregation value inside `ARRAY_AGG` list. For each `StructArray`
667        // inside `ARRAY_AGG` list, we will receive an `Array` that stores values
668        // received from its ordering requirement expression. (This information
669        // is necessary for during merging).
670        let [array_agg_values, agg_orderings] =
671            take_function_args("OrderSensitiveArrayAggAccumulator::merge_batch", states)?;
672        let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
673            return exec_err!("Expects to receive a list array");
674        };
675
676        // Stores ARRAY_AGG results coming from each partition
677        let mut partition_values = vec![];
678        // Stores ordering requirement expression results coming from each partition
679        let mut partition_ordering_values = vec![];
680
681        // Existing values should be merged also.
682        if !self.is_input_pre_ordered {
683            self.sort();
684        }
685        partition_values.push(take(&mut self.values).into());
686        partition_ordering_values.push(take(&mut self.ordering_values).into());
687
688        // Convert array to Scalars to sort them easily. Convert back to array at evaluation.
689        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
690        for v in array_agg_res.into_iter() {
691            partition_values.push(v.into());
692        }
693
694        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
695
696        for partition_ordering_rows in orderings.into_iter() {
697            // Extract value from struct to ordering_rows for each group/partition
698            let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
699                    if let ScalarValue::Struct(s) = ordering_row {
700                        let mut ordering_columns_per_row = vec![];
701
702                        for column in s.columns() {
703                            let sv = ScalarValue::try_from_array(column, 0)?;
704                            ordering_columns_per_row.push(sv);
705                        }
706
707                        Ok(ordering_columns_per_row)
708                    } else {
709                        exec_err!(
710                            "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
711                            ordering_row.data_type()
712                        )
713                    }
714                }).collect::<Result<VecDeque<_>>>()?;
715
716            partition_ordering_values.push(ordering_value);
717        }
718
719        let sort_options = self
720            .ordering_req
721            .iter()
722            .map(|sort_expr| sort_expr.options)
723            .collect::<Vec<_>>();
724
725        (self.values, self.ordering_values) = merge_ordered_arrays(
726            &mut partition_values,
727            &mut partition_ordering_values,
728            &sort_options,
729        )?;
730
731        Ok(())
732    }
733
734    fn state(&mut self) -> Result<Vec<ScalarValue>> {
735        if !self.is_input_pre_ordered {
736            self.sort();
737        }
738
739        let mut result = vec![self.evaluate()?];
740        result.push(self.evaluate_orderings()?);
741
742        Ok(result)
743    }
744
745    fn evaluate(&mut self) -> Result<ScalarValue> {
746        if !self.is_input_pre_ordered {
747            self.sort();
748        }
749
750        if self.values.is_empty() {
751            return Ok(ScalarValue::new_null_list(
752                self.datatypes[0].clone(),
753                true,
754                1,
755            ));
756        }
757
758        let values = self.values.clone();
759        let array = if self.reverse {
760            ScalarValue::new_list_from_iter(
761                values.into_iter().rev(),
762                &self.datatypes[0],
763                true,
764            )
765        } else {
766            ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
767        };
768        Ok(ScalarValue::List(array))
769    }
770
771    fn size(&self) -> usize {
772        let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
773            - size_of_val(&self.values);
774
775        // Add size of the `self.ordering_values`
776        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
777        for row in &self.ordering_values {
778            total += ScalarValue::size_of_vec(row) - size_of_val(row);
779        }
780
781        // Add size of the `self.datatypes`
782        total += size_of::<DataType>() * self.datatypes.capacity();
783        for dtype in &self.datatypes {
784            total += dtype.size() - size_of_val(dtype);
785        }
786
787        // Add size of the `self.ordering_req`
788        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
789        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
790        total
791    }
792}
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797    use arrow::array::{ListBuilder, StringBuilder};
798    use arrow::datatypes::{FieldRef, Schema};
799    use datafusion_common::cast::as_generic_string_array;
800    use datafusion_common::internal_err;
801    use datafusion_physical_expr::expressions::Column;
802    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
803    use std::sync::Arc;
804
805    #[test]
806    fn no_duplicates_no_distinct() -> Result<()> {
807        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
808
809        acc1.update_batch(&[data(["a", "b", "c"])])?;
810        acc2.update_batch(&[data(["d", "e", "f"])])?;
811        acc1 = merge(acc1, acc2)?;
812
813        let result = print_nulls(str_arr(acc1.evaluate()?)?);
814
815        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
816
817        Ok(())
818    }
819
820    #[test]
821    fn no_duplicates_distinct() -> Result<()> {
822        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
823            .distinct()
824            .build_two()?;
825
826        acc1.update_batch(&[data(["a", "b", "c"])])?;
827        acc2.update_batch(&[data(["d", "e", "f"])])?;
828        acc1 = merge(acc1, acc2)?;
829
830        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
831        result.sort();
832
833        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
834
835        Ok(())
836    }
837
838    #[test]
839    fn duplicates_no_distinct() -> Result<()> {
840        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
841
842        acc1.update_batch(&[data(["a", "b", "c"])])?;
843        acc2.update_batch(&[data(["a", "b", "c"])])?;
844        acc1 = merge(acc1, acc2)?;
845
846        let result = print_nulls(str_arr(acc1.evaluate()?)?);
847
848        assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
849
850        Ok(())
851    }
852
853    #[test]
854    fn duplicates_distinct() -> Result<()> {
855        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
856            .distinct()
857            .build_two()?;
858
859        acc1.update_batch(&[data(["a", "b", "c"])])?;
860        acc2.update_batch(&[data(["a", "b", "c"])])?;
861        acc1 = merge(acc1, acc2)?;
862
863        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
864        result.sort();
865
866        assert_eq!(result, vec!["a", "b", "c"]);
867
868        Ok(())
869    }
870
871    #[test]
872    fn duplicates_on_second_batch_distinct() -> Result<()> {
873        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
874            .distinct()
875            .build_two()?;
876
877        acc1.update_batch(&[data(["a", "c"])])?;
878        acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
879        acc1 = merge(acc1, acc2)?;
880
881        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
882        result.sort();
883
884        assert_eq!(result, vec!["a", "b", "c", "d"]);
885
886        Ok(())
887    }
888
889    #[test]
890    fn no_duplicates_distinct_sort_asc() -> Result<()> {
891        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
892            .distinct()
893            .order_by_col("col", SortOptions::new(false, false))
894            .build_two()?;
895
896        acc1.update_batch(&[data(["e", "b", "d"])])?;
897        acc2.update_batch(&[data(["f", "a", "c"])])?;
898        acc1 = merge(acc1, acc2)?;
899
900        let result = print_nulls(str_arr(acc1.evaluate()?)?);
901
902        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
903
904        Ok(())
905    }
906
907    #[test]
908    fn no_duplicates_distinct_sort_desc() -> Result<()> {
909        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
910            .distinct()
911            .order_by_col("col", SortOptions::new(true, false))
912            .build_two()?;
913
914        acc1.update_batch(&[data(["e", "b", "d"])])?;
915        acc2.update_batch(&[data(["f", "a", "c"])])?;
916        acc1 = merge(acc1, acc2)?;
917
918        let result = print_nulls(str_arr(acc1.evaluate()?)?);
919
920        assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
921
922        Ok(())
923    }
924
925    #[test]
926    fn duplicates_distinct_sort_asc() -> Result<()> {
927        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
928            .distinct()
929            .order_by_col("col", SortOptions::new(false, false))
930            .build_two()?;
931
932        acc1.update_batch(&[data(["a", "c", "b"])])?;
933        acc2.update_batch(&[data(["b", "c", "a"])])?;
934        acc1 = merge(acc1, acc2)?;
935
936        let result = print_nulls(str_arr(acc1.evaluate()?)?);
937
938        assert_eq!(result, vec!["a", "b", "c"]);
939
940        Ok(())
941    }
942
943    #[test]
944    fn duplicates_distinct_sort_desc() -> Result<()> {
945        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
946            .distinct()
947            .order_by_col("col", SortOptions::new(true, false))
948            .build_two()?;
949
950        acc1.update_batch(&[data(["a", "c", "b"])])?;
951        acc2.update_batch(&[data(["b", "c", "a"])])?;
952        acc1 = merge(acc1, acc2)?;
953
954        let result = print_nulls(str_arr(acc1.evaluate()?)?);
955
956        assert_eq!(result, vec!["c", "b", "a"]);
957
958        Ok(())
959    }
960
961    #[test]
962    fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
963        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
964            .distinct()
965            .order_by_col("col", SortOptions::new(false, true))
966            .build_two()?;
967
968        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
969        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
970        acc1 = merge(acc1, acc2)?;
971
972        let result = print_nulls(str_arr(acc1.evaluate()?)?);
973
974        assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
975
976        Ok(())
977    }
978
979    #[test]
980    fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
981        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
982            .distinct()
983            .order_by_col("col", SortOptions::new(false, false))
984            .build_two()?;
985
986        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
987        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
988        acc1 = merge(acc1, acc2)?;
989
990        let result = print_nulls(str_arr(acc1.evaluate()?)?);
991
992        assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
993
994        Ok(())
995    }
996
997    #[test]
998    fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
999        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1000            .distinct()
1001            .order_by_col("col", SortOptions::new(true, true))
1002            .build_two()?;
1003
1004        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1005        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1006        acc1 = merge(acc1, acc2)?;
1007
1008        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1009
1010        assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
1011
1012        Ok(())
1013    }
1014
1015    #[test]
1016    fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
1017        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1018            .distinct()
1019            .order_by_col("col", SortOptions::new(true, false))
1020            .build_two()?;
1021
1022        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1023        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1024        acc1 = merge(acc1, acc2)?;
1025
1026        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1027
1028        assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
1029
1030        Ok(())
1031    }
1032
1033    #[test]
1034    fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
1035        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1036            .distinct()
1037            .build_two()?;
1038
1039        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1040        acc2.update_batch(&[data([Some("a"), None, None, None])])?;
1041        acc1 = merge(acc1, acc2)?;
1042
1043        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1044        result.sort();
1045        assert_eq!(result, vec!["NULL", "a"]);
1046        Ok(())
1047    }
1048
1049    #[test]
1050    fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
1051        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1052            .distinct()
1053            .build_two()?;
1054
1055        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1056        acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
1057        acc1 = merge(acc1, acc2)?;
1058
1059        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1060        assert_eq!(result, vec!["NULL"]);
1061        Ok(())
1062    }
1063
1064    #[test]
1065    fn does_not_over_account_memory() -> Result<()> {
1066        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1067
1068        acc1.update_batch(&[data(["a", "c", "b"])])?;
1069        acc2.update_batch(&[data(["b", "c", "a"])])?;
1070        acc1 = merge(acc1, acc2)?;
1071
1072        assert_eq!(acc1.size(), 266);
1073
1074        Ok(())
1075    }
1076    #[test]
1077    fn does_not_over_account_memory_distinct() -> Result<()> {
1078        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1079            .distinct()
1080            .build_two()?;
1081
1082        acc1.update_batch(&[string_list_data([
1083            vec!["a", "b", "c"],
1084            vec!["d", "e", "f"],
1085        ])])?;
1086        acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
1087        acc1 = merge(acc1, acc2)?;
1088
1089        // without compaction, the size is 16660
1090        assert_eq!(acc1.size(), 1660);
1091
1092        Ok(())
1093    }
1094
1095    #[test]
1096    fn does_not_over_account_memory_ordered() -> Result<()> {
1097        let mut acc = ArrayAggAccumulatorBuilder::string()
1098            .order_by_col("col", SortOptions::new(false, false))
1099            .build()?;
1100
1101        acc.update_batch(&[string_list_data([
1102            vec!["a", "b", "c"],
1103            vec!["c", "d", "e"],
1104            vec!["b", "c", "d"],
1105        ])])?;
1106
1107        // without compaction, the size is 17112
1108        assert_eq!(acc.size(), 2112);
1109
1110        Ok(())
1111    }
1112
1113    struct ArrayAggAccumulatorBuilder {
1114        return_field: FieldRef,
1115        distinct: bool,
1116        order_bys: Vec<PhysicalSortExpr>,
1117        schema: Schema,
1118    }
1119
1120    impl ArrayAggAccumulatorBuilder {
1121        fn string() -> Self {
1122            Self::new(DataType::Utf8)
1123        }
1124
1125        fn new(data_type: DataType) -> Self {
1126            Self {
1127                return_field: Field::new("f", data_type.clone(), true).into(),
1128                distinct: false,
1129                order_bys: vec![],
1130                schema: Schema {
1131                    fields: Fields::from(vec![Field::new(
1132                        "col",
1133                        DataType::new_list(data_type, true),
1134                        true,
1135                    )]),
1136                    metadata: Default::default(),
1137                },
1138            }
1139        }
1140
1141        fn distinct(mut self) -> Self {
1142            self.distinct = true;
1143            self
1144        }
1145
1146        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
1147            let new_order = PhysicalSortExpr::new(
1148                Arc::new(
1149                    Column::new_with_schema(col, &self.schema)
1150                        .expect("column not available in schema"),
1151                ),
1152                sort_options,
1153            );
1154            self.order_bys.push(new_order);
1155            self
1156        }
1157
1158        fn build(&self) -> Result<Box<dyn Accumulator>> {
1159            ArrayAgg::default().accumulator(AccumulatorArgs {
1160                return_field: Arc::clone(&self.return_field),
1161                schema: &self.schema,
1162                ignore_nulls: false,
1163                order_bys: &self.order_bys,
1164                is_reversed: false,
1165                name: "",
1166                is_distinct: self.distinct,
1167                exprs: &[Arc::new(Column::new("col", 0))],
1168            })
1169        }
1170
1171        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
1172            Ok((self.build()?, self.build()?))
1173        }
1174    }
1175
1176    fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
1177        let ScalarValue::List(list) = value else {
1178            return internal_err!("ScalarValue was not a List");
1179        };
1180        Ok(as_generic_string_array::<i32>(list.values())?
1181            .iter()
1182            .map(|v| v.map(|v| v.to_string()))
1183            .collect())
1184    }
1185
1186    fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1187        sort.into_iter()
1188            .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
1189            .collect()
1190    }
1191
1192    fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef {
1193        let mut builder = ListBuilder::new(StringBuilder::new());
1194        for string_list in data.into_iter() {
1195            builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>());
1196        }
1197
1198        Arc::new(builder.finish())
1199    }
1200
1201    fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1202    where
1203        ScalarValue: From<T>,
1204    {
1205        let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1206        ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1207    }
1208
1209    fn merge(
1210        mut acc1: Box<dyn Accumulator>,
1211        mut acc2: Box<dyn Accumulator>,
1212    ) -> Result<Box<dyn Accumulator>> {
1213        let intermediate_state = acc2.state().and_then(|e| {
1214            e.iter()
1215                .map(|v| v.to_array())
1216                .collect::<Result<Vec<ArrayRef>>>()
1217        })?;
1218        acc1.merge_batch(&intermediate_state)?;
1219        Ok(acc1)
1220    }
1221}