datafusion_functions_aggregate/
array_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
19
20use std::cmp::Ordering;
21use std::collections::{HashSet, VecDeque};
22use std::mem::{size_of, size_of_val, take};
23use std::sync::Arc;
24
25use arrow::array::{
26    new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray,
27};
28use arrow::compute::{filter, SortOptions};
29use arrow::datatypes::{DataType, Field, FieldRef, Fields};
30
31use datafusion_common::cast::as_list_array;
32use datafusion_common::utils::{
33    compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder,
34};
35use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
40};
41use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
42use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
43use datafusion_functions_aggregate_common::utils::ordering_fields;
44use datafusion_macros::user_doc;
45use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
46
47make_udaf_expr_and_func!(
48    ArrayAgg,
49    array_agg,
50    expression,
51    "input values, including nulls, concatenated into an array",
52    array_agg_udaf
53);
54
55#[user_doc(
56    doc_section(label = "General Functions"),
57    description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
58This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
59    syntax_example = "array_agg(expression [ORDER BY expression])",
60    sql_example = r#"
61```sql
62> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
63+-----------------------------------------------+
64| array_agg(column_name ORDER BY other_column)  |
65+-----------------------------------------------+
66| [element1, element2, element3]                |
67+-----------------------------------------------+
68> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
69+--------------------------------------------------------+
70| array_agg(DISTINCT column_name ORDER BY column_name)  |
71+--------------------------------------------------------+
72| [element1, element2, element3]                         |
73+--------------------------------------------------------+
74```
75"#,
76    standard_argument(name = "expression",)
77)]
78#[derive(Debug, PartialEq, Eq, Hash)]
79/// ARRAY_AGG aggregate expression
80pub struct ArrayAgg {
81    signature: Signature,
82    is_input_pre_ordered: bool,
83}
84
85impl Default for ArrayAgg {
86    fn default() -> Self {
87        Self {
88            signature: Signature::any(1, Volatility::Immutable),
89            is_input_pre_ordered: false,
90        }
91    }
92}
93
94impl AggregateUDFImpl for ArrayAgg {
95    fn as_any(&self) -> &dyn std::any::Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "array_agg"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        Ok(DataType::List(Arc::new(Field::new_list_field(
109            arg_types[0].clone(),
110            true,
111        ))))
112    }
113
114    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115        if args.is_distinct {
116            return Ok(vec![Field::new_list(
117                format_state_name(args.name, "distinct_array_agg"),
118                // See COMMENTS.md to understand why nullable is set to true
119                Field::new_list_field(args.input_fields[0].data_type().clone(), true),
120                true,
121            )
122            .into()]);
123        }
124
125        let mut fields = vec![Field::new_list(
126            format_state_name(args.name, "array_agg"),
127            // See COMMENTS.md to understand why nullable is set to true
128            Field::new_list_field(args.input_fields[0].data_type().clone(), true),
129            true,
130        )
131        .into()];
132
133        if args.ordering_fields.is_empty() {
134            return Ok(fields);
135        }
136
137        let orderings = args.ordering_fields.to_vec();
138        fields.push(
139            Field::new_list(
140                format_state_name(args.name, "array_agg_orderings"),
141                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
142                false,
143            )
144            .into(),
145        );
146
147        Ok(fields)
148    }
149
150    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
151        AggregateOrderSensitivity::SoftRequirement
152    }
153
154    fn with_beneficial_ordering(
155        self: Arc<Self>,
156        beneficial_ordering: bool,
157    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
158        Ok(Some(Arc::new(Self {
159            signature: self.signature.clone(),
160            is_input_pre_ordered: beneficial_ordering,
161        })))
162    }
163
164    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
165        let field = &acc_args.expr_fields[0];
166        let data_type = field.data_type();
167        let ignore_nulls = acc_args.ignore_nulls && field.is_nullable();
168
169        if acc_args.is_distinct {
170            // Limitation similar to Postgres. The aggregation function can only mix
171            // DISTINCT and ORDER BY if all the expressions in the ORDER BY appear
172            // also in the arguments of the function. This implies that if the
173            // aggregation function only accepts one argument, only one argument
174            // can be used in the ORDER BY, For example:
175            //
176            // ARRAY_AGG(DISTINCT col)
177            //
178            // can only be mixed with an ORDER BY if the order expression is "col".
179            //
180            // ARRAY_AGG(DISTINCT col ORDER BY col)                         <- Valid
181            // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid
182            // ARRAY_AGG(DISTINCT col ORDER BY other_col)                   <- Invalid
183            // ARRAY_AGG(DISTINCT col ORDER BY concat(col, ''))             <- Invalid
184            let sort_option = match acc_args.order_bys {
185                [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options),
186                [] => None,
187                _ => {
188                    return exec_err!(
189                        "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"
190                    );
191                }
192            };
193            return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
194                data_type,
195                sort_option,
196                ignore_nulls,
197            )?));
198        }
199
200        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
201            return Ok(Box::new(ArrayAggAccumulator::try_new(
202                data_type,
203                ignore_nulls,
204            )?));
205        };
206
207        let ordering_dtypes = ordering
208            .iter()
209            .map(|e| e.expr.data_type(acc_args.schema))
210            .collect::<Result<Vec<_>>>()?;
211
212        OrderSensitiveArrayAggAccumulator::try_new(
213            data_type,
214            &ordering_dtypes,
215            ordering,
216            self.is_input_pre_ordered,
217            acc_args.is_reversed,
218            ignore_nulls,
219        )
220        .map(|acc| Box::new(acc) as _)
221    }
222
223    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
224        datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
225    }
226
227    fn documentation(&self) -> Option<&Documentation> {
228        self.doc()
229    }
230}
231
232#[derive(Debug)]
233pub struct ArrayAggAccumulator {
234    values: Vec<ArrayRef>,
235    datatype: DataType,
236    ignore_nulls: bool,
237}
238
239impl ArrayAggAccumulator {
240    /// new array_agg accumulator based on given item data type
241    pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
242        Ok(Self {
243            values: vec![],
244            datatype: datatype.clone(),
245            ignore_nulls,
246        })
247    }
248
249    /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
250    /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
251    fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
252        let offsets = list_array.value_offsets();
253        // Offsets always have at least 1 value
254        let initial_offset = offsets[0];
255        let null_count = list_array.null_count();
256
257        // If no nulls than just use the fast path
258        // This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
259        if null_count == 0 {
260            // According to Arrow specification, the first offset can be non-zero
261            let list_values = list_array.values().slice(
262                initial_offset as usize,
263                (offsets[offsets.len() - 1] - initial_offset) as usize,
264            );
265            return Some(list_values);
266        }
267
268        // If all the values are null than just return an empty values array
269        if list_array.null_count() == list_array.len() {
270            return Some(list_array.values().slice(0, 0));
271        }
272
273        // According to the Arrow spec, null values can point to non-empty lists
274        // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
275
276        // Unwrapping is safe as we just checked if there is a null value
277        let nulls = list_array.nulls().unwrap();
278
279        let mut valid_slices_iter = nulls.valid_slices();
280
281        // This is safe as we validated that there is at least 1 valid value in the array
282        let (start, end) = valid_slices_iter.next().unwrap();
283
284        let start_offset = offsets[start];
285
286        // End is exclusive, so it already point to the last offset value
287        // This is valid as the length of the array is always 1 less than the length of the offsets
288        let mut end_offset_of_last_valid_value = offsets[end];
289
290        for (start, end) in valid_slices_iter {
291            // If there is a null value that point to a non-empty list than the start offset of the valid value
292            // will be different that the end offset of the last valid value
293            if offsets[start] != end_offset_of_last_valid_value {
294                return None;
295            }
296
297            // End is exclusive, so it already point to the last offset value
298            // This is valid as the length of the array is always 1 less than the length of the offsets
299            end_offset_of_last_valid_value = offsets[end];
300        }
301
302        let consecutive_valid_values = list_array.values().slice(
303            start_offset as usize,
304            (end_offset_of_last_valid_value - start_offset) as usize,
305        );
306
307        Some(consecutive_valid_values)
308    }
309}
310
311impl Accumulator for ArrayAggAccumulator {
312    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
313        // Append value like Int64Array(1,2,3)
314        if values.is_empty() {
315            return Ok(());
316        }
317
318        if values.len() != 1 {
319            return internal_err!("expects single batch");
320        }
321
322        let val = &values[0];
323        let nulls = if self.ignore_nulls {
324            val.logical_nulls()
325        } else {
326            None
327        };
328
329        let val = match nulls {
330            Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
331            Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
332            None => Arc::clone(val),
333        };
334
335        if !val.is_empty() {
336            self.values.push(val)
337        }
338
339        Ok(())
340    }
341
342    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
343        // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
344        if states.is_empty() {
345            return Ok(());
346        }
347
348        if states.len() != 1 {
349            return internal_err!("expects single state");
350        }
351
352        let list_arr = as_list_array(&states[0])?;
353
354        match Self::get_optional_values_to_merge_as_is(list_arr) {
355            Some(values) => {
356                // Make sure we don't insert empty lists
357                if !values.is_empty() {
358                    self.values.push(values);
359                }
360            }
361            None => {
362                for arr in list_arr.iter().flatten() {
363                    self.values.push(arr);
364                }
365            }
366        }
367
368        Ok(())
369    }
370
371    fn state(&mut self) -> Result<Vec<ScalarValue>> {
372        Ok(vec![self.evaluate()?])
373    }
374
375    fn evaluate(&mut self) -> Result<ScalarValue> {
376        // Transform Vec<ListArr> to ListArr
377        let element_arrays: Vec<&dyn Array> =
378            self.values.iter().map(|a| a.as_ref()).collect();
379
380        if element_arrays.is_empty() {
381            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
382        }
383
384        let concated_array = arrow::compute::concat(&element_arrays)?;
385
386        Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
387    }
388
389    fn size(&self) -> usize {
390        size_of_val(self)
391            + (size_of::<ArrayRef>() * self.values.capacity())
392            + self
393                .values
394                .iter()
395                // Each ArrayRef might be just a reference to a bigger array, and many
396                // ArrayRefs here might be referencing exactly the same array, so if we
397                // were to call `arr.get_array_memory_size()`, we would be double-counting
398                // the same underlying data many times.
399                //
400                // Instead, we do an approximation by estimating how much memory each
401                // ArrayRef would occupy if its underlying data was fully owned by this
402                // accumulator.
403                //
404                // Note that this is just an estimation, but the reality is that this
405                // accumulator might not own any data.
406                .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
407                .sum::<usize>()
408            + self.datatype.size()
409            - size_of_val(&self.datatype)
410    }
411}
412
413#[derive(Debug)]
414struct DistinctArrayAggAccumulator {
415    values: HashSet<ScalarValue>,
416    datatype: DataType,
417    sort_options: Option<SortOptions>,
418    ignore_nulls: bool,
419}
420
421impl DistinctArrayAggAccumulator {
422    pub fn try_new(
423        datatype: &DataType,
424        sort_options: Option<SortOptions>,
425        ignore_nulls: bool,
426    ) -> Result<Self> {
427        Ok(Self {
428            values: HashSet::new(),
429            datatype: datatype.clone(),
430            sort_options,
431            ignore_nulls,
432        })
433    }
434}
435
436impl Accumulator for DistinctArrayAggAccumulator {
437    fn state(&mut self) -> Result<Vec<ScalarValue>> {
438        Ok(vec![self.evaluate()?])
439    }
440
441    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
442        if values.is_empty() {
443            return Ok(());
444        }
445
446        let val = &values[0];
447        let nulls = if self.ignore_nulls {
448            val.logical_nulls()
449        } else {
450            None
451        };
452
453        let nulls = nulls.as_ref();
454        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
455            for i in 0..val.len() {
456                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
457                    self.values
458                        .insert(ScalarValue::try_from_array(val, i)?.compacted());
459                }
460            }
461        }
462
463        Ok(())
464    }
465
466    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
467        if states.is_empty() {
468            return Ok(());
469        }
470
471        if states.len() != 1 {
472            return internal_err!("expects single state");
473        }
474
475        states[0]
476            .as_list::<i32>()
477            .iter()
478            .flatten()
479            .try_for_each(|val| self.update_batch(&[val]))
480    }
481
482    fn evaluate(&mut self) -> Result<ScalarValue> {
483        let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
484        if values.is_empty() {
485            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
486        }
487
488        if let Some(opts) = self.sort_options {
489            let mut delayed_cmp_err = Ok(());
490            values.sort_by(|a, b| {
491                if a.is_null() {
492                    return match opts.nulls_first {
493                        true => Ordering::Less,
494                        false => Ordering::Greater,
495                    };
496                }
497                if b.is_null() {
498                    return match opts.nulls_first {
499                        true => Ordering::Greater,
500                        false => Ordering::Less,
501                    };
502                }
503                match opts.descending {
504                    true => b.try_cmp(a),
505                    false => a.try_cmp(b),
506                }
507                .unwrap_or_else(|err| {
508                    delayed_cmp_err = Err(err);
509                    Ordering::Equal
510                })
511            });
512            delayed_cmp_err?;
513        };
514
515        let arr = ScalarValue::new_list(&values, &self.datatype, true);
516        Ok(ScalarValue::List(arr))
517    }
518
519    fn size(&self) -> usize {
520        size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
521            - size_of_val(&self.values)
522            + self.datatype.size()
523            - size_of_val(&self.datatype)
524            - size_of_val(&self.sort_options)
525            + size_of::<Option<SortOptions>>()
526    }
527}
528
529/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi
530/// partition setting, partial aggregations are computed for every partition,
531/// and then their results are merged.
532#[derive(Debug)]
533pub(crate) struct OrderSensitiveArrayAggAccumulator {
534    /// Stores entries in the `ARRAY_AGG` result.
535    values: Vec<ScalarValue>,
536    /// Stores values of ordering requirement expressions corresponding to each
537    /// entry in `values`. This information is used when merging results from
538    /// different partitions. For detailed information how merging is done, see
539    /// [`merge_ordered_arrays`].
540    ordering_values: Vec<Vec<ScalarValue>>,
541    /// Stores datatypes of expressions inside values and ordering requirement
542    /// expressions.
543    datatypes: Vec<DataType>,
544    /// Stores the ordering requirement of the `Accumulator`.
545    ordering_req: LexOrdering,
546    /// Whether the input is known to be pre-ordered
547    is_input_pre_ordered: bool,
548    /// Whether the aggregation is running in reverse.
549    reverse: bool,
550    /// Whether the aggregation should ignore null values.
551    ignore_nulls: bool,
552}
553
554impl OrderSensitiveArrayAggAccumulator {
555    /// Create a new order-sensitive ARRAY_AGG accumulator based on the given
556    /// item data type.
557    pub fn try_new(
558        datatype: &DataType,
559        ordering_dtypes: &[DataType],
560        ordering_req: LexOrdering,
561        is_input_pre_ordered: bool,
562        reverse: bool,
563        ignore_nulls: bool,
564    ) -> Result<Self> {
565        let mut datatypes = vec![datatype.clone()];
566        datatypes.extend(ordering_dtypes.iter().cloned());
567        Ok(Self {
568            values: vec![],
569            ordering_values: vec![],
570            datatypes,
571            ordering_req,
572            is_input_pre_ordered,
573            reverse,
574            ignore_nulls,
575        })
576    }
577
578    fn sort(&mut self) {
579        let sort_options = self
580            .ordering_req
581            .iter()
582            .map(|sort_expr| sort_expr.options)
583            .collect::<Vec<_>>();
584        let mut values = take(&mut self.values)
585            .into_iter()
586            .zip(take(&mut self.ordering_values))
587            .collect::<Vec<_>>();
588        let mut delayed_cmp_err = Ok(());
589        values.sort_by(|(_, left_ordering), (_, right_ordering)| {
590            compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
591                |err| {
592                    delayed_cmp_err = Err(err);
593                    Ordering::Equal
594                },
595            )
596        });
597        (self.values, self.ordering_values) = values.into_iter().unzip();
598    }
599
600    fn evaluate_orderings(&self) -> Result<ScalarValue> {
601        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
602
603        let column_wise_ordering_values = if self.ordering_values.is_empty() {
604            fields
605                .iter()
606                .map(|f| new_empty_array(f.data_type()))
607                .collect::<Vec<_>>()
608        } else {
609            (0..fields.len())
610                .map(|i| {
611                    let column_values = self.ordering_values.iter().map(|x| x[i].clone());
612                    ScalarValue::iter_to_array(column_values)
613                })
614                .collect::<Result<_>>()?
615        };
616
617        let ordering_array = StructArray::try_new(
618            Fields::from(fields),
619            column_wise_ordering_values,
620            None,
621        )?;
622        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
623    }
624}
625
626impl Accumulator for OrderSensitiveArrayAggAccumulator {
627    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
628        if values.is_empty() {
629            return Ok(());
630        }
631
632        let val = &values[0];
633        let ord = &values[1..];
634        let nulls = if self.ignore_nulls {
635            val.logical_nulls()
636        } else {
637            None
638        };
639
640        let nulls = nulls.as_ref();
641        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
642            for i in 0..val.len() {
643                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
644                    self.values
645                        .push(ScalarValue::try_from_array(val, i)?.compacted());
646                    self.ordering_values.push(
647                        get_row_at_idx(ord, i)?
648                            .into_iter()
649                            .map(|v| v.compacted())
650                            .collect(),
651                    )
652                }
653            }
654        }
655
656        Ok(())
657    }
658
659    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
660        if states.is_empty() {
661            return Ok(());
662        }
663
664        // First entry in the state is the aggregation result. Second entry
665        // stores values received for ordering requirement columns for each
666        // aggregation value inside `ARRAY_AGG` list. For each `StructArray`
667        // inside `ARRAY_AGG` list, we will receive an `Array` that stores values
668        // received from its ordering requirement expression. (This information
669        // is necessary for during merging).
670        let [array_agg_values, agg_orderings] =
671            take_function_args("OrderSensitiveArrayAggAccumulator::merge_batch", states)?;
672        let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
673            return exec_err!("Expects to receive a list array");
674        };
675
676        // Stores ARRAY_AGG results coming from each partition
677        let mut partition_values = vec![];
678        // Stores ordering requirement expression results coming from each partition
679        let mut partition_ordering_values = vec![];
680
681        // Existing values should be merged also.
682        if !self.is_input_pre_ordered {
683            self.sort();
684        }
685        partition_values.push(take(&mut self.values).into());
686        partition_ordering_values.push(take(&mut self.ordering_values).into());
687
688        // Convert array to Scalars to sort them easily. Convert back to array at evaluation.
689        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
690        for maybe_v in array_agg_res.into_iter() {
691            if let Some(v) = maybe_v {
692                partition_values.push(v.into());
693            } else {
694                partition_values.push(vec![].into());
695            }
696        }
697
698        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
699        for partition_ordering_rows in orderings.into_iter().flatten() {
700            // Extract value from struct to ordering_rows for each group/partition
701            let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
702                    if let ScalarValue::Struct(s) = ordering_row {
703                        let mut ordering_columns_per_row = vec![];
704
705                        for column in s.columns() {
706                            let sv = ScalarValue::try_from_array(column, 0)?;
707                            ordering_columns_per_row.push(sv);
708                        }
709
710                        Ok(ordering_columns_per_row)
711                    } else {
712                        exec_err!(
713                            "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
714                            ordering_row.data_type()
715                        )
716                    }
717                }).collect::<Result<VecDeque<_>>>()?;
718
719            partition_ordering_values.push(ordering_value);
720        }
721
722        let sort_options = self
723            .ordering_req
724            .iter()
725            .map(|sort_expr| sort_expr.options)
726            .collect::<Vec<_>>();
727
728        (self.values, self.ordering_values) = merge_ordered_arrays(
729            &mut partition_values,
730            &mut partition_ordering_values,
731            &sort_options,
732        )?;
733
734        Ok(())
735    }
736
737    fn state(&mut self) -> Result<Vec<ScalarValue>> {
738        if !self.is_input_pre_ordered {
739            self.sort();
740        }
741
742        let mut result = vec![self.evaluate()?];
743        result.push(self.evaluate_orderings()?);
744
745        Ok(result)
746    }
747
748    fn evaluate(&mut self) -> Result<ScalarValue> {
749        if !self.is_input_pre_ordered {
750            self.sort();
751        }
752
753        if self.values.is_empty() {
754            return Ok(ScalarValue::new_null_list(
755                self.datatypes[0].clone(),
756                true,
757                1,
758            ));
759        }
760
761        let values = self.values.clone();
762        let array = if self.reverse {
763            ScalarValue::new_list_from_iter(
764                values.into_iter().rev(),
765                &self.datatypes[0],
766                true,
767            )
768        } else {
769            ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
770        };
771        Ok(ScalarValue::List(array))
772    }
773
774    fn size(&self) -> usize {
775        let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
776            - size_of_val(&self.values);
777
778        // Add size of the `self.ordering_values`
779        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
780        for row in &self.ordering_values {
781            total += ScalarValue::size_of_vec(row) - size_of_val(row);
782        }
783
784        // Add size of the `self.datatypes`
785        total += size_of::<DataType>() * self.datatypes.capacity();
786        for dtype in &self.datatypes {
787            total += dtype.size() - size_of_val(dtype);
788        }
789
790        // Add size of the `self.ordering_req`
791        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
792        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
793        total
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800    use arrow::array::{ListBuilder, StringBuilder};
801    use arrow::datatypes::{FieldRef, Schema};
802    use datafusion_common::cast::as_generic_string_array;
803    use datafusion_common::internal_err;
804    use datafusion_physical_expr::expressions::Column;
805    use datafusion_physical_expr::PhysicalExpr;
806    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
807    use std::sync::Arc;
808
809    #[test]
810    fn no_duplicates_no_distinct() -> Result<()> {
811        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
812
813        acc1.update_batch(&[data(["a", "b", "c"])])?;
814        acc2.update_batch(&[data(["d", "e", "f"])])?;
815        acc1 = merge(acc1, acc2)?;
816
817        let result = print_nulls(str_arr(acc1.evaluate()?)?);
818
819        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
820
821        Ok(())
822    }
823
824    #[test]
825    fn no_duplicates_distinct() -> Result<()> {
826        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
827            .distinct()
828            .build_two()?;
829
830        acc1.update_batch(&[data(["a", "b", "c"])])?;
831        acc2.update_batch(&[data(["d", "e", "f"])])?;
832        acc1 = merge(acc1, acc2)?;
833
834        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
835        result.sort();
836
837        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
838
839        Ok(())
840    }
841
842    #[test]
843    fn duplicates_no_distinct() -> Result<()> {
844        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
845
846        acc1.update_batch(&[data(["a", "b", "c"])])?;
847        acc2.update_batch(&[data(["a", "b", "c"])])?;
848        acc1 = merge(acc1, acc2)?;
849
850        let result = print_nulls(str_arr(acc1.evaluate()?)?);
851
852        assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
853
854        Ok(())
855    }
856
857    #[test]
858    fn duplicates_distinct() -> Result<()> {
859        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
860            .distinct()
861            .build_two()?;
862
863        acc1.update_batch(&[data(["a", "b", "c"])])?;
864        acc2.update_batch(&[data(["a", "b", "c"])])?;
865        acc1 = merge(acc1, acc2)?;
866
867        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
868        result.sort();
869
870        assert_eq!(result, vec!["a", "b", "c"]);
871
872        Ok(())
873    }
874
875    #[test]
876    fn duplicates_on_second_batch_distinct() -> Result<()> {
877        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
878            .distinct()
879            .build_two()?;
880
881        acc1.update_batch(&[data(["a", "c"])])?;
882        acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
883        acc1 = merge(acc1, acc2)?;
884
885        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
886        result.sort();
887
888        assert_eq!(result, vec!["a", "b", "c", "d"]);
889
890        Ok(())
891    }
892
893    #[test]
894    fn no_duplicates_distinct_sort_asc() -> Result<()> {
895        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
896            .distinct()
897            .order_by_col("col", SortOptions::new(false, false))
898            .build_two()?;
899
900        acc1.update_batch(&[data(["e", "b", "d"])])?;
901        acc2.update_batch(&[data(["f", "a", "c"])])?;
902        acc1 = merge(acc1, acc2)?;
903
904        let result = print_nulls(str_arr(acc1.evaluate()?)?);
905
906        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
907
908        Ok(())
909    }
910
911    #[test]
912    fn no_duplicates_distinct_sort_desc() -> Result<()> {
913        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
914            .distinct()
915            .order_by_col("col", SortOptions::new(true, false))
916            .build_two()?;
917
918        acc1.update_batch(&[data(["e", "b", "d"])])?;
919        acc2.update_batch(&[data(["f", "a", "c"])])?;
920        acc1 = merge(acc1, acc2)?;
921
922        let result = print_nulls(str_arr(acc1.evaluate()?)?);
923
924        assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
925
926        Ok(())
927    }
928
929    #[test]
930    fn duplicates_distinct_sort_asc() -> Result<()> {
931        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
932            .distinct()
933            .order_by_col("col", SortOptions::new(false, false))
934            .build_two()?;
935
936        acc1.update_batch(&[data(["a", "c", "b"])])?;
937        acc2.update_batch(&[data(["b", "c", "a"])])?;
938        acc1 = merge(acc1, acc2)?;
939
940        let result = print_nulls(str_arr(acc1.evaluate()?)?);
941
942        assert_eq!(result, vec!["a", "b", "c"]);
943
944        Ok(())
945    }
946
947    #[test]
948    fn duplicates_distinct_sort_desc() -> Result<()> {
949        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
950            .distinct()
951            .order_by_col("col", SortOptions::new(true, false))
952            .build_two()?;
953
954        acc1.update_batch(&[data(["a", "c", "b"])])?;
955        acc2.update_batch(&[data(["b", "c", "a"])])?;
956        acc1 = merge(acc1, acc2)?;
957
958        let result = print_nulls(str_arr(acc1.evaluate()?)?);
959
960        assert_eq!(result, vec!["c", "b", "a"]);
961
962        Ok(())
963    }
964
965    #[test]
966    fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
967        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
968            .distinct()
969            .order_by_col("col", SortOptions::new(false, true))
970            .build_two()?;
971
972        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
973        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
974        acc1 = merge(acc1, acc2)?;
975
976        let result = print_nulls(str_arr(acc1.evaluate()?)?);
977
978        assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
979
980        Ok(())
981    }
982
983    #[test]
984    fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
985        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
986            .distinct()
987            .order_by_col("col", SortOptions::new(false, false))
988            .build_two()?;
989
990        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
991        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
992        acc1 = merge(acc1, acc2)?;
993
994        let result = print_nulls(str_arr(acc1.evaluate()?)?);
995
996        assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
997
998        Ok(())
999    }
1000
1001    #[test]
1002    fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
1003        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1004            .distinct()
1005            .order_by_col("col", SortOptions::new(true, true))
1006            .build_two()?;
1007
1008        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1009        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1010        acc1 = merge(acc1, acc2)?;
1011
1012        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1013
1014        assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
1015
1016        Ok(())
1017    }
1018
1019    #[test]
1020    fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
1021        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1022            .distinct()
1023            .order_by_col("col", SortOptions::new(true, false))
1024            .build_two()?;
1025
1026        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1027        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1028        acc1 = merge(acc1, acc2)?;
1029
1030        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1031
1032        assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
1033
1034        Ok(())
1035    }
1036
1037    #[test]
1038    fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
1039        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1040            .distinct()
1041            .build_two()?;
1042
1043        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1044        acc2.update_batch(&[data([Some("a"), None, None, None])])?;
1045        acc1 = merge(acc1, acc2)?;
1046
1047        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1048        result.sort();
1049        assert_eq!(result, vec!["NULL", "a"]);
1050        Ok(())
1051    }
1052
1053    #[test]
1054    fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
1055        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1056            .distinct()
1057            .build_two()?;
1058
1059        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1060        acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
1061        acc1 = merge(acc1, acc2)?;
1062
1063        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1064        assert_eq!(result, vec!["NULL"]);
1065        Ok(())
1066    }
1067
1068    #[test]
1069    fn does_not_over_account_memory() -> Result<()> {
1070        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1071
1072        acc1.update_batch(&[data(["a", "c", "b"])])?;
1073        acc2.update_batch(&[data(["b", "c", "a"])])?;
1074        acc1 = merge(acc1, acc2)?;
1075
1076        assert_eq!(acc1.size(), 266);
1077
1078        Ok(())
1079    }
1080    #[test]
1081    fn does_not_over_account_memory_distinct() -> Result<()> {
1082        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1083            .distinct()
1084            .build_two()?;
1085
1086        acc1.update_batch(&[string_list_data([
1087            vec!["a", "b", "c"],
1088            vec!["d", "e", "f"],
1089        ])])?;
1090        acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
1091        acc1 = merge(acc1, acc2)?;
1092
1093        // without compaction, the size is 16660
1094        assert_eq!(acc1.size(), 1660);
1095
1096        Ok(())
1097    }
1098
1099    #[test]
1100    fn does_not_over_account_memory_ordered() -> Result<()> {
1101        let mut acc = ArrayAggAccumulatorBuilder::string()
1102            .order_by_col("col", SortOptions::new(false, false))
1103            .build()?;
1104
1105        acc.update_batch(&[string_list_data([
1106            vec!["a", "b", "c"],
1107            vec!["c", "d", "e"],
1108            vec!["b", "c", "d"],
1109        ])])?;
1110
1111        // without compaction, the size is 17112
1112        assert_eq!(acc.size(), 2184);
1113
1114        Ok(())
1115    }
1116
1117    struct ArrayAggAccumulatorBuilder {
1118        return_field: FieldRef,
1119        distinct: bool,
1120        order_bys: Vec<PhysicalSortExpr>,
1121        schema: Schema,
1122    }
1123
1124    impl ArrayAggAccumulatorBuilder {
1125        fn string() -> Self {
1126            Self::new(DataType::Utf8)
1127        }
1128
1129        fn new(data_type: DataType) -> Self {
1130            Self {
1131                return_field: Field::new("f", data_type.clone(), true).into(),
1132                distinct: false,
1133                order_bys: vec![],
1134                schema: Schema {
1135                    fields: Fields::from(vec![Field::new(
1136                        "col",
1137                        DataType::new_list(data_type, true),
1138                        true,
1139                    )]),
1140                    metadata: Default::default(),
1141                },
1142            }
1143        }
1144
1145        fn distinct(mut self) -> Self {
1146            self.distinct = true;
1147            self
1148        }
1149
1150        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
1151            let new_order = PhysicalSortExpr::new(
1152                Arc::new(
1153                    Column::new_with_schema(col, &self.schema)
1154                        .expect("column not available in schema"),
1155                ),
1156                sort_options,
1157            );
1158            self.order_bys.push(new_order);
1159            self
1160        }
1161
1162        fn build(&self) -> Result<Box<dyn Accumulator>> {
1163            let expr = Arc::new(Column::new("col", 0));
1164            let expr_field = expr.return_field(&self.schema)?;
1165            ArrayAgg::default().accumulator(AccumulatorArgs {
1166                return_field: Arc::clone(&self.return_field),
1167                schema: &self.schema,
1168                expr_fields: &[expr_field],
1169                ignore_nulls: false,
1170                order_bys: &self.order_bys,
1171                is_reversed: false,
1172                name: "",
1173                is_distinct: self.distinct,
1174                exprs: &[expr],
1175            })
1176        }
1177
1178        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
1179            Ok((self.build()?, self.build()?))
1180        }
1181    }
1182
1183    fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
1184        let ScalarValue::List(list) = value else {
1185            return internal_err!("ScalarValue was not a List");
1186        };
1187        Ok(as_generic_string_array::<i32>(list.values())?
1188            .iter()
1189            .map(|v| v.map(|v| v.to_string()))
1190            .collect())
1191    }
1192
1193    fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1194        sort.into_iter()
1195            .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
1196            .collect()
1197    }
1198
1199    fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef {
1200        let mut builder = ListBuilder::new(StringBuilder::new());
1201        for string_list in data.into_iter() {
1202            builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>());
1203        }
1204
1205        Arc::new(builder.finish())
1206    }
1207
1208    fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1209    where
1210        ScalarValue: From<T>,
1211    {
1212        let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1213        ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1214    }
1215
1216    fn merge(
1217        mut acc1: Box<dyn Accumulator>,
1218        mut acc2: Box<dyn Accumulator>,
1219    ) -> Result<Box<dyn Accumulator>> {
1220        let intermediate_state = acc2.state().and_then(|e| {
1221            e.iter()
1222                .map(|v| v.to_array())
1223                .collect::<Result<Vec<ArrayRef>>>()
1224        })?;
1225        acc1.merge_batch(&intermediate_state)?;
1226        Ok(acc1)
1227    }
1228}