datafusion_functions_aggregate/
array_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
19
20use std::cmp::Ordering;
21use std::collections::{HashSet, VecDeque};
22use std::mem::{size_of, size_of_val, take};
23use std::sync::Arc;
24
25use arrow::array::{
26    Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, new_empty_array,
27};
28use arrow::compute::{SortOptions, filter};
29use arrow::datatypes::{DataType, Field, FieldRef, Fields};
30
31use datafusion_common::cast::as_list_array;
32use datafusion_common::utils::{
33    SingleRowListArrayBuilder, compare_rows, get_row_at_idx, take_function_args,
34};
35use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, exec_err};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
40};
41use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
42use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
43use datafusion_functions_aggregate_common::utils::ordering_fields;
44use datafusion_macros::user_doc;
45use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
46
47make_udaf_expr_and_func!(
48    ArrayAgg,
49    array_agg,
50    expression,
51    "input values, including nulls, concatenated into an array",
52    array_agg_udaf
53);
54
55#[user_doc(
56    doc_section(label = "General Functions"),
57    description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
58This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
59    syntax_example = "array_agg(expression [ORDER BY expression])",
60    sql_example = r#"
61```sql
62> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
63+-----------------------------------------------+
64| array_agg(column_name ORDER BY other_column)  |
65+-----------------------------------------------+
66| [element1, element2, element3]                |
67+-----------------------------------------------+
68> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
69+--------------------------------------------------------+
70| array_agg(DISTINCT column_name ORDER BY column_name)  |
71+--------------------------------------------------------+
72| [element1, element2, element3]                         |
73+--------------------------------------------------------+
74```
75"#,
76    standard_argument(name = "expression",)
77)]
78#[derive(Debug, PartialEq, Eq, Hash)]
79/// ARRAY_AGG aggregate expression
80pub struct ArrayAgg {
81    signature: Signature,
82    is_input_pre_ordered: bool,
83}
84
85impl Default for ArrayAgg {
86    fn default() -> Self {
87        Self {
88            signature: Signature::any(1, Volatility::Immutable),
89            is_input_pre_ordered: false,
90        }
91    }
92}
93
94impl AggregateUDFImpl for ArrayAgg {
95    fn as_any(&self) -> &dyn std::any::Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "array_agg"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        Ok(DataType::List(Arc::new(Field::new_list_field(
109            arg_types[0].clone(),
110            true,
111        ))))
112    }
113
114    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115        if args.is_distinct {
116            return Ok(vec![
117                Field::new_list(
118                    format_state_name(args.name, "distinct_array_agg"),
119                    // See COMMENTS.md to understand why nullable is set to true
120                    Field::new_list_field(args.input_fields[0].data_type().clone(), true),
121                    true,
122                )
123                .into(),
124            ]);
125        }
126
127        let mut fields = vec![
128            Field::new_list(
129                format_state_name(args.name, "array_agg"),
130                // See COMMENTS.md to understand why nullable is set to true
131                Field::new_list_field(args.input_fields[0].data_type().clone(), true),
132                true,
133            )
134            .into(),
135        ];
136
137        if args.ordering_fields.is_empty() {
138            return Ok(fields);
139        }
140
141        let orderings = args.ordering_fields.to_vec();
142        fields.push(
143            Field::new_list(
144                format_state_name(args.name, "array_agg_orderings"),
145                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
146                false,
147            )
148            .into(),
149        );
150
151        Ok(fields)
152    }
153
154    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
155        AggregateOrderSensitivity::SoftRequirement
156    }
157
158    fn with_beneficial_ordering(
159        self: Arc<Self>,
160        beneficial_ordering: bool,
161    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
162        Ok(Some(Arc::new(Self {
163            signature: self.signature.clone(),
164            is_input_pre_ordered: beneficial_ordering,
165        })))
166    }
167
168    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
169        let field = &acc_args.expr_fields[0];
170        let data_type = field.data_type();
171        let ignore_nulls = acc_args.ignore_nulls && field.is_nullable();
172
173        if acc_args.is_distinct {
174            // Limitation similar to Postgres. The aggregation function can only mix
175            // DISTINCT and ORDER BY if all the expressions in the ORDER BY appear
176            // also in the arguments of the function. This implies that if the
177            // aggregation function only accepts one argument, only one argument
178            // can be used in the ORDER BY, For example:
179            //
180            // ARRAY_AGG(DISTINCT col)
181            //
182            // can only be mixed with an ORDER BY if the order expression is "col".
183            //
184            // ARRAY_AGG(DISTINCT col ORDER BY col)                         <- Valid
185            // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid
186            // ARRAY_AGG(DISTINCT col ORDER BY other_col)                   <- Invalid
187            // ARRAY_AGG(DISTINCT col ORDER BY concat(col, ''))             <- Invalid
188            let sort_option = match acc_args.order_bys {
189                [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options),
190                [] => None,
191                _ => {
192                    return exec_err!(
193                        "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"
194                    );
195                }
196            };
197            return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
198                data_type,
199                sort_option,
200                ignore_nulls,
201            )?));
202        }
203
204        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
205            return Ok(Box::new(ArrayAggAccumulator::try_new(
206                data_type,
207                ignore_nulls,
208            )?));
209        };
210
211        let ordering_dtypes = ordering
212            .iter()
213            .map(|e| e.expr.data_type(acc_args.schema))
214            .collect::<Result<Vec<_>>>()?;
215
216        OrderSensitiveArrayAggAccumulator::try_new(
217            data_type,
218            &ordering_dtypes,
219            ordering,
220            self.is_input_pre_ordered,
221            acc_args.is_reversed,
222            ignore_nulls,
223        )
224        .map(|acc| Box::new(acc) as _)
225    }
226
227    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
228        datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
229    }
230
231    fn supports_null_handling_clause(&self) -> bool {
232        true
233    }
234
235    fn documentation(&self) -> Option<&Documentation> {
236        self.doc()
237    }
238}
239
240#[derive(Debug)]
241pub struct ArrayAggAccumulator {
242    values: Vec<ArrayRef>,
243    datatype: DataType,
244    ignore_nulls: bool,
245}
246
247impl ArrayAggAccumulator {
248    /// new array_agg accumulator based on given item data type
249    pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
250        Ok(Self {
251            values: vec![],
252            datatype: datatype.clone(),
253            ignore_nulls,
254        })
255    }
256
257    /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
258    /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
259    fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
260        let offsets = list_array.value_offsets();
261        // Offsets always have at least 1 value
262        let initial_offset = offsets[0];
263        let null_count = list_array.null_count();
264
265        // If no nulls than just use the fast path
266        // This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
267        if null_count == 0 {
268            // According to Arrow specification, the first offset can be non-zero
269            let list_values = list_array.values().slice(
270                initial_offset as usize,
271                (offsets[offsets.len() - 1] - initial_offset) as usize,
272            );
273            return Some(list_values);
274        }
275
276        // If all the values are null than just return an empty values array
277        if list_array.null_count() == list_array.len() {
278            return Some(list_array.values().slice(0, 0));
279        }
280
281        // According to the Arrow spec, null values can point to non-empty lists
282        // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
283
284        // Unwrapping is safe as we just checked if there is a null value
285        let nulls = list_array.nulls().unwrap();
286
287        let mut valid_slices_iter = nulls.valid_slices();
288
289        // This is safe as we validated that there is at least 1 valid value in the array
290        let (start, end) = valid_slices_iter.next().unwrap();
291
292        let start_offset = offsets[start];
293
294        // End is exclusive, so it already point to the last offset value
295        // This is valid as the length of the array is always 1 less than the length of the offsets
296        let mut end_offset_of_last_valid_value = offsets[end];
297
298        for (start, end) in valid_slices_iter {
299            // If there is a null value that point to a non-empty list than the start offset of the valid value
300            // will be different that the end offset of the last valid value
301            if offsets[start] != end_offset_of_last_valid_value {
302                return None;
303            }
304
305            // End is exclusive, so it already point to the last offset value
306            // This is valid as the length of the array is always 1 less than the length of the offsets
307            end_offset_of_last_valid_value = offsets[end];
308        }
309
310        let consecutive_valid_values = list_array.values().slice(
311            start_offset as usize,
312            (end_offset_of_last_valid_value - start_offset) as usize,
313        );
314
315        Some(consecutive_valid_values)
316    }
317}
318
319impl Accumulator for ArrayAggAccumulator {
320    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
321        // Append value like Int64Array(1,2,3)
322        if values.is_empty() {
323            return Ok(());
324        }
325
326        assert_eq_or_internal_err!(values.len(), 1, "expects single batch");
327
328        let val = &values[0];
329        let nulls = if self.ignore_nulls {
330            val.logical_nulls()
331        } else {
332            None
333        };
334
335        let val = match nulls {
336            Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
337            Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
338            None => Arc::clone(val),
339        };
340
341        if !val.is_empty() {
342            self.values.push(val)
343        }
344
345        Ok(())
346    }
347
348    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
349        // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
350        if states.is_empty() {
351            return Ok(());
352        }
353
354        assert_eq_or_internal_err!(states.len(), 1, "expects single state");
355
356        let list_arr = as_list_array(&states[0])?;
357
358        match Self::get_optional_values_to_merge_as_is(list_arr) {
359            Some(values) => {
360                // Make sure we don't insert empty lists
361                if !values.is_empty() {
362                    self.values.push(values);
363                }
364            }
365            None => {
366                for arr in list_arr.iter().flatten() {
367                    self.values.push(arr);
368                }
369            }
370        }
371
372        Ok(())
373    }
374
375    fn state(&mut self) -> Result<Vec<ScalarValue>> {
376        Ok(vec![self.evaluate()?])
377    }
378
379    fn evaluate(&mut self) -> Result<ScalarValue> {
380        // Transform Vec<ListArr> to ListArr
381        let element_arrays: Vec<&dyn Array> =
382            self.values.iter().map(|a| a.as_ref()).collect();
383
384        if element_arrays.is_empty() {
385            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
386        }
387
388        let concated_array = arrow::compute::concat(&element_arrays)?;
389
390        Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
391    }
392
393    fn size(&self) -> usize {
394        size_of_val(self)
395            + (size_of::<ArrayRef>() * self.values.capacity())
396            + self
397                .values
398                .iter()
399                // Each ArrayRef might be just a reference to a bigger array, and many
400                // ArrayRefs here might be referencing exactly the same array, so if we
401                // were to call `arr.get_array_memory_size()`, we would be double-counting
402                // the same underlying data many times.
403                //
404                // Instead, we do an approximation by estimating how much memory each
405                // ArrayRef would occupy if its underlying data was fully owned by this
406                // accumulator.
407                //
408                // Note that this is just an estimation, but the reality is that this
409                // accumulator might not own any data.
410                .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
411                .sum::<usize>()
412            + self.datatype.size()
413            - size_of_val(&self.datatype)
414    }
415}
416
417#[derive(Debug)]
418struct DistinctArrayAggAccumulator {
419    values: HashSet<ScalarValue>,
420    datatype: DataType,
421    sort_options: Option<SortOptions>,
422    ignore_nulls: bool,
423}
424
425impl DistinctArrayAggAccumulator {
426    pub fn try_new(
427        datatype: &DataType,
428        sort_options: Option<SortOptions>,
429        ignore_nulls: bool,
430    ) -> Result<Self> {
431        Ok(Self {
432            values: HashSet::new(),
433            datatype: datatype.clone(),
434            sort_options,
435            ignore_nulls,
436        })
437    }
438}
439
440impl Accumulator for DistinctArrayAggAccumulator {
441    fn state(&mut self) -> Result<Vec<ScalarValue>> {
442        Ok(vec![self.evaluate()?])
443    }
444
445    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
446        if values.is_empty() {
447            return Ok(());
448        }
449
450        let val = &values[0];
451        let nulls = if self.ignore_nulls {
452            val.logical_nulls()
453        } else {
454            None
455        };
456
457        let nulls = nulls.as_ref();
458        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
459            for i in 0..val.len() {
460                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
461                    self.values
462                        .insert(ScalarValue::try_from_array(val, i)?.compacted());
463                }
464            }
465        }
466
467        Ok(())
468    }
469
470    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
471        if states.is_empty() {
472            return Ok(());
473        }
474
475        assert_eq_or_internal_err!(states.len(), 1, "expects single state");
476
477        states[0]
478            .as_list::<i32>()
479            .iter()
480            .flatten()
481            .try_for_each(|val| self.update_batch(&[val]))
482    }
483
484    fn evaluate(&mut self) -> Result<ScalarValue> {
485        let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
486        if values.is_empty() {
487            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
488        }
489
490        if let Some(opts) = self.sort_options {
491            let mut delayed_cmp_err = Ok(());
492            values.sort_by(|a, b| {
493                if a.is_null() {
494                    return match opts.nulls_first {
495                        true => Ordering::Less,
496                        false => Ordering::Greater,
497                    };
498                }
499                if b.is_null() {
500                    return match opts.nulls_first {
501                        true => Ordering::Greater,
502                        false => Ordering::Less,
503                    };
504                }
505                match opts.descending {
506                    true => b.try_cmp(a),
507                    false => a.try_cmp(b),
508                }
509                .unwrap_or_else(|err| {
510                    delayed_cmp_err = Err(err);
511                    Ordering::Equal
512                })
513            });
514            delayed_cmp_err?;
515        };
516
517        let arr = ScalarValue::new_list(&values, &self.datatype, true);
518        Ok(ScalarValue::List(arr))
519    }
520
521    fn size(&self) -> usize {
522        size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
523            - size_of_val(&self.values)
524            + self.datatype.size()
525            - size_of_val(&self.datatype)
526            - size_of_val(&self.sort_options)
527            + size_of::<Option<SortOptions>>()
528    }
529}
530
531/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi
532/// partition setting, partial aggregations are computed for every partition,
533/// and then their results are merged.
534#[derive(Debug)]
535pub(crate) struct OrderSensitiveArrayAggAccumulator {
536    /// Stores entries in the `ARRAY_AGG` result.
537    values: Vec<ScalarValue>,
538    /// Stores values of ordering requirement expressions corresponding to each
539    /// entry in `values`. This information is used when merging results from
540    /// different partitions. For detailed information how merging is done, see
541    /// [`merge_ordered_arrays`].
542    ordering_values: Vec<Vec<ScalarValue>>,
543    /// Stores datatypes of expressions inside values and ordering requirement
544    /// expressions.
545    datatypes: Vec<DataType>,
546    /// Stores the ordering requirement of the `Accumulator`.
547    ordering_req: LexOrdering,
548    /// Whether the input is known to be pre-ordered
549    is_input_pre_ordered: bool,
550    /// Whether the aggregation is running in reverse.
551    reverse: bool,
552    /// Whether the aggregation should ignore null values.
553    ignore_nulls: bool,
554}
555
556impl OrderSensitiveArrayAggAccumulator {
557    /// Create a new order-sensitive ARRAY_AGG accumulator based on the given
558    /// item data type.
559    pub fn try_new(
560        datatype: &DataType,
561        ordering_dtypes: &[DataType],
562        ordering_req: LexOrdering,
563        is_input_pre_ordered: bool,
564        reverse: bool,
565        ignore_nulls: bool,
566    ) -> Result<Self> {
567        let mut datatypes = vec![datatype.clone()];
568        datatypes.extend(ordering_dtypes.iter().cloned());
569        Ok(Self {
570            values: vec![],
571            ordering_values: vec![],
572            datatypes,
573            ordering_req,
574            is_input_pre_ordered,
575            reverse,
576            ignore_nulls,
577        })
578    }
579
580    fn sort(&mut self) {
581        let sort_options = self
582            .ordering_req
583            .iter()
584            .map(|sort_expr| sort_expr.options)
585            .collect::<Vec<_>>();
586        let mut values = take(&mut self.values)
587            .into_iter()
588            .zip(take(&mut self.ordering_values))
589            .collect::<Vec<_>>();
590        let mut delayed_cmp_err = Ok(());
591        values.sort_by(|(_, left_ordering), (_, right_ordering)| {
592            compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
593                |err| {
594                    delayed_cmp_err = Err(err);
595                    Ordering::Equal
596                },
597            )
598        });
599        (self.values, self.ordering_values) = values.into_iter().unzip();
600    }
601
602    fn evaluate_orderings(&self) -> Result<ScalarValue> {
603        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
604
605        let column_wise_ordering_values = if self.ordering_values.is_empty() {
606            fields
607                .iter()
608                .map(|f| new_empty_array(f.data_type()))
609                .collect::<Vec<_>>()
610        } else {
611            (0..fields.len())
612                .map(|i| {
613                    let column_values = self.ordering_values.iter().map(|x| x[i].clone());
614                    ScalarValue::iter_to_array(column_values)
615                })
616                .collect::<Result<_>>()?
617        };
618
619        let ordering_array = StructArray::try_new(
620            Fields::from(fields),
621            column_wise_ordering_values,
622            None,
623        )?;
624        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
625    }
626}
627
628impl Accumulator for OrderSensitiveArrayAggAccumulator {
629    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
630        if values.is_empty() {
631            return Ok(());
632        }
633
634        let val = &values[0];
635        let ord = &values[1..];
636        let nulls = if self.ignore_nulls {
637            val.logical_nulls()
638        } else {
639            None
640        };
641
642        let nulls = nulls.as_ref();
643        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
644            for i in 0..val.len() {
645                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
646                    self.values
647                        .push(ScalarValue::try_from_array(val, i)?.compacted());
648                    self.ordering_values.push(
649                        get_row_at_idx(ord, i)?
650                            .into_iter()
651                            .map(|v| v.compacted())
652                            .collect(),
653                    )
654                }
655            }
656        }
657
658        Ok(())
659    }
660
661    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
662        if states.is_empty() {
663            return Ok(());
664        }
665
666        // First entry in the state is the aggregation result. Second entry
667        // stores values received for ordering requirement columns for each
668        // aggregation value inside `ARRAY_AGG` list. For each `StructArray`
669        // inside `ARRAY_AGG` list, we will receive an `Array` that stores values
670        // received from its ordering requirement expression. (This information
671        // is necessary for during merging).
672        let [array_agg_values, agg_orderings] =
673            take_function_args("OrderSensitiveArrayAggAccumulator::merge_batch", states)?;
674        let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
675            return exec_err!("Expects to receive a list array");
676        };
677
678        // Stores ARRAY_AGG results coming from each partition
679        let mut partition_values = vec![];
680        // Stores ordering requirement expression results coming from each partition
681        let mut partition_ordering_values = vec![];
682
683        // Existing values should be merged also.
684        if !self.is_input_pre_ordered {
685            self.sort();
686        }
687        partition_values.push(take(&mut self.values).into());
688        partition_ordering_values.push(take(&mut self.ordering_values).into());
689
690        // Convert array to Scalars to sort them easily. Convert back to array at evaluation.
691        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
692        for maybe_v in array_agg_res.into_iter() {
693            if let Some(v) = maybe_v {
694                partition_values.push(v.into());
695            } else {
696                partition_values.push(vec![].into());
697            }
698        }
699
700        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
701        for partition_ordering_rows in orderings.into_iter().flatten() {
702            // Extract value from struct to ordering_rows for each group/partition
703            let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
704                    if let ScalarValue::Struct(s) = ordering_row {
705                        let mut ordering_columns_per_row = vec![];
706
707                        for column in s.columns() {
708                            let sv = ScalarValue::try_from_array(column, 0)?;
709                            ordering_columns_per_row.push(sv);
710                        }
711
712                        Ok(ordering_columns_per_row)
713                    } else {
714                        exec_err!(
715                            "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
716                            ordering_row.data_type()
717                        )
718                    }
719                }).collect::<Result<VecDeque<_>>>()?;
720
721            partition_ordering_values.push(ordering_value);
722        }
723
724        let sort_options = self
725            .ordering_req
726            .iter()
727            .map(|sort_expr| sort_expr.options)
728            .collect::<Vec<_>>();
729
730        (self.values, self.ordering_values) = merge_ordered_arrays(
731            &mut partition_values,
732            &mut partition_ordering_values,
733            &sort_options,
734        )?;
735
736        Ok(())
737    }
738
739    fn state(&mut self) -> Result<Vec<ScalarValue>> {
740        if !self.is_input_pre_ordered {
741            self.sort();
742        }
743
744        let mut result = vec![self.evaluate()?];
745        result.push(self.evaluate_orderings()?);
746
747        Ok(result)
748    }
749
750    fn evaluate(&mut self) -> Result<ScalarValue> {
751        if !self.is_input_pre_ordered {
752            self.sort();
753        }
754
755        if self.values.is_empty() {
756            return Ok(ScalarValue::new_null_list(
757                self.datatypes[0].clone(),
758                true,
759                1,
760            ));
761        }
762
763        let values = self.values.clone();
764        let array = if self.reverse {
765            ScalarValue::new_list_from_iter(
766                values.into_iter().rev(),
767                &self.datatypes[0],
768                true,
769            )
770        } else {
771            ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
772        };
773        Ok(ScalarValue::List(array))
774    }
775
776    fn size(&self) -> usize {
777        let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
778            - size_of_val(&self.values);
779
780        // Add size of the `self.ordering_values`
781        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
782        for row in &self.ordering_values {
783            total += ScalarValue::size_of_vec(row) - size_of_val(row);
784        }
785
786        // Add size of the `self.datatypes`
787        total += size_of::<DataType>() * self.datatypes.capacity();
788        for dtype in &self.datatypes {
789            total += dtype.size() - size_of_val(dtype);
790        }
791
792        // Add size of the `self.ordering_req`
793        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
794        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
795        total
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802    use arrow::array::{ListBuilder, StringBuilder};
803    use arrow::datatypes::{FieldRef, Schema};
804    use datafusion_common::cast::as_generic_string_array;
805    use datafusion_common::internal_err;
806    use datafusion_physical_expr::PhysicalExpr;
807    use datafusion_physical_expr::expressions::Column;
808    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
809    use std::sync::Arc;
810
811    #[test]
812    fn no_duplicates_no_distinct() -> Result<()> {
813        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
814
815        acc1.update_batch(&[data(["a", "b", "c"])])?;
816        acc2.update_batch(&[data(["d", "e", "f"])])?;
817        acc1 = merge(acc1, acc2)?;
818
819        let result = print_nulls(str_arr(acc1.evaluate()?)?);
820
821        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
822
823        Ok(())
824    }
825
826    #[test]
827    fn no_duplicates_distinct() -> Result<()> {
828        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
829            .distinct()
830            .build_two()?;
831
832        acc1.update_batch(&[data(["a", "b", "c"])])?;
833        acc2.update_batch(&[data(["d", "e", "f"])])?;
834        acc1 = merge(acc1, acc2)?;
835
836        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
837        result.sort();
838
839        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
840
841        Ok(())
842    }
843
844    #[test]
845    fn duplicates_no_distinct() -> Result<()> {
846        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
847
848        acc1.update_batch(&[data(["a", "b", "c"])])?;
849        acc2.update_batch(&[data(["a", "b", "c"])])?;
850        acc1 = merge(acc1, acc2)?;
851
852        let result = print_nulls(str_arr(acc1.evaluate()?)?);
853
854        assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
855
856        Ok(())
857    }
858
859    #[test]
860    fn duplicates_distinct() -> Result<()> {
861        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
862            .distinct()
863            .build_two()?;
864
865        acc1.update_batch(&[data(["a", "b", "c"])])?;
866        acc2.update_batch(&[data(["a", "b", "c"])])?;
867        acc1 = merge(acc1, acc2)?;
868
869        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
870        result.sort();
871
872        assert_eq!(result, vec!["a", "b", "c"]);
873
874        Ok(())
875    }
876
877    #[test]
878    fn duplicates_on_second_batch_distinct() -> Result<()> {
879        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
880            .distinct()
881            .build_two()?;
882
883        acc1.update_batch(&[data(["a", "c"])])?;
884        acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
885        acc1 = merge(acc1, acc2)?;
886
887        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
888        result.sort();
889
890        assert_eq!(result, vec!["a", "b", "c", "d"]);
891
892        Ok(())
893    }
894
895    #[test]
896    fn no_duplicates_distinct_sort_asc() -> Result<()> {
897        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
898            .distinct()
899            .order_by_col("col", SortOptions::new(false, false))
900            .build_two()?;
901
902        acc1.update_batch(&[data(["e", "b", "d"])])?;
903        acc2.update_batch(&[data(["f", "a", "c"])])?;
904        acc1 = merge(acc1, acc2)?;
905
906        let result = print_nulls(str_arr(acc1.evaluate()?)?);
907
908        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
909
910        Ok(())
911    }
912
913    #[test]
914    fn no_duplicates_distinct_sort_desc() -> Result<()> {
915        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
916            .distinct()
917            .order_by_col("col", SortOptions::new(true, false))
918            .build_two()?;
919
920        acc1.update_batch(&[data(["e", "b", "d"])])?;
921        acc2.update_batch(&[data(["f", "a", "c"])])?;
922        acc1 = merge(acc1, acc2)?;
923
924        let result = print_nulls(str_arr(acc1.evaluate()?)?);
925
926        assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
927
928        Ok(())
929    }
930
931    #[test]
932    fn duplicates_distinct_sort_asc() -> Result<()> {
933        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
934            .distinct()
935            .order_by_col("col", SortOptions::new(false, false))
936            .build_two()?;
937
938        acc1.update_batch(&[data(["a", "c", "b"])])?;
939        acc2.update_batch(&[data(["b", "c", "a"])])?;
940        acc1 = merge(acc1, acc2)?;
941
942        let result = print_nulls(str_arr(acc1.evaluate()?)?);
943
944        assert_eq!(result, vec!["a", "b", "c"]);
945
946        Ok(())
947    }
948
949    #[test]
950    fn duplicates_distinct_sort_desc() -> Result<()> {
951        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
952            .distinct()
953            .order_by_col("col", SortOptions::new(true, false))
954            .build_two()?;
955
956        acc1.update_batch(&[data(["a", "c", "b"])])?;
957        acc2.update_batch(&[data(["b", "c", "a"])])?;
958        acc1 = merge(acc1, acc2)?;
959
960        let result = print_nulls(str_arr(acc1.evaluate()?)?);
961
962        assert_eq!(result, vec!["c", "b", "a"]);
963
964        Ok(())
965    }
966
967    #[test]
968    fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
969        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
970            .distinct()
971            .order_by_col("col", SortOptions::new(false, true))
972            .build_two()?;
973
974        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
975        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
976        acc1 = merge(acc1, acc2)?;
977
978        let result = print_nulls(str_arr(acc1.evaluate()?)?);
979
980        assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
981
982        Ok(())
983    }
984
985    #[test]
986    fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
987        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
988            .distinct()
989            .order_by_col("col", SortOptions::new(false, false))
990            .build_two()?;
991
992        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
993        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
994        acc1 = merge(acc1, acc2)?;
995
996        let result = print_nulls(str_arr(acc1.evaluate()?)?);
997
998        assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
999
1000        Ok(())
1001    }
1002
1003    #[test]
1004    fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
1005        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1006            .distinct()
1007            .order_by_col("col", SortOptions::new(true, true))
1008            .build_two()?;
1009
1010        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1011        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1012        acc1 = merge(acc1, acc2)?;
1013
1014        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1015
1016        assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
1017
1018        Ok(())
1019    }
1020
1021    #[test]
1022    fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
1023        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1024            .distinct()
1025            .order_by_col("col", SortOptions::new(true, false))
1026            .build_two()?;
1027
1028        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1029        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1030        acc1 = merge(acc1, acc2)?;
1031
1032        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1033
1034        assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
1035
1036        Ok(())
1037    }
1038
1039    #[test]
1040    fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
1041        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1042            .distinct()
1043            .build_two()?;
1044
1045        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1046        acc2.update_batch(&[data([Some("a"), None, None, None])])?;
1047        acc1 = merge(acc1, acc2)?;
1048
1049        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1050        result.sort();
1051        assert_eq!(result, vec!["NULL", "a"]);
1052        Ok(())
1053    }
1054
1055    #[test]
1056    fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
1057        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1058            .distinct()
1059            .build_two()?;
1060
1061        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1062        acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
1063        acc1 = merge(acc1, acc2)?;
1064
1065        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1066        assert_eq!(result, vec!["NULL"]);
1067        Ok(())
1068    }
1069
1070    #[test]
1071    fn does_not_over_account_memory() -> Result<()> {
1072        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1073
1074        acc1.update_batch(&[data(["a", "c", "b"])])?;
1075        acc2.update_batch(&[data(["b", "c", "a"])])?;
1076        acc1 = merge(acc1, acc2)?;
1077
1078        assert_eq!(acc1.size(), 266);
1079
1080        Ok(())
1081    }
1082    #[test]
1083    fn does_not_over_account_memory_distinct() -> Result<()> {
1084        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1085            .distinct()
1086            .build_two()?;
1087
1088        acc1.update_batch(&[string_list_data([
1089            vec!["a", "b", "c"],
1090            vec!["d", "e", "f"],
1091        ])])?;
1092        acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
1093        acc1 = merge(acc1, acc2)?;
1094
1095        // without compaction, the size is 16660
1096        assert_eq!(acc1.size(), 1660);
1097
1098        Ok(())
1099    }
1100
1101    #[test]
1102    fn does_not_over_account_memory_ordered() -> Result<()> {
1103        let mut acc = ArrayAggAccumulatorBuilder::string()
1104            .order_by_col("col", SortOptions::new(false, false))
1105            .build()?;
1106
1107        acc.update_batch(&[string_list_data([
1108            vec!["a", "b", "c"],
1109            vec!["c", "d", "e"],
1110            vec!["b", "c", "d"],
1111        ])])?;
1112
1113        // without compaction, the size is 17112
1114        assert_eq!(acc.size(), 2224);
1115
1116        Ok(())
1117    }
1118
1119    struct ArrayAggAccumulatorBuilder {
1120        return_field: FieldRef,
1121        distinct: bool,
1122        order_bys: Vec<PhysicalSortExpr>,
1123        schema: Schema,
1124    }
1125
1126    impl ArrayAggAccumulatorBuilder {
1127        fn string() -> Self {
1128            Self::new(DataType::Utf8)
1129        }
1130
1131        fn new(data_type: DataType) -> Self {
1132            Self {
1133                return_field: Field::new("f", data_type.clone(), true).into(),
1134                distinct: false,
1135                order_bys: vec![],
1136                schema: Schema {
1137                    fields: Fields::from(vec![Field::new(
1138                        "col",
1139                        DataType::new_list(data_type, true),
1140                        true,
1141                    )]),
1142                    metadata: Default::default(),
1143                },
1144            }
1145        }
1146
1147        fn distinct(mut self) -> Self {
1148            self.distinct = true;
1149            self
1150        }
1151
1152        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
1153            let new_order = PhysicalSortExpr::new(
1154                Arc::new(
1155                    Column::new_with_schema(col, &self.schema)
1156                        .expect("column not available in schema"),
1157                ),
1158                sort_options,
1159            );
1160            self.order_bys.push(new_order);
1161            self
1162        }
1163
1164        fn build(&self) -> Result<Box<dyn Accumulator>> {
1165            let expr = Arc::new(Column::new("col", 0));
1166            let expr_field = expr.return_field(&self.schema)?;
1167            ArrayAgg::default().accumulator(AccumulatorArgs {
1168                return_field: Arc::clone(&self.return_field),
1169                schema: &self.schema,
1170                expr_fields: &[expr_field],
1171                ignore_nulls: false,
1172                order_bys: &self.order_bys,
1173                is_reversed: false,
1174                name: "",
1175                is_distinct: self.distinct,
1176                exprs: &[expr],
1177            })
1178        }
1179
1180        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
1181            Ok((self.build()?, self.build()?))
1182        }
1183    }
1184
1185    fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
1186        let ScalarValue::List(list) = value else {
1187            return internal_err!("ScalarValue was not a List");
1188        };
1189        Ok(as_generic_string_array::<i32>(list.values())?
1190            .iter()
1191            .map(|v| v.map(|v| v.to_string()))
1192            .collect())
1193    }
1194
1195    fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1196        sort.into_iter()
1197            .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
1198            .collect()
1199    }
1200
1201    fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef {
1202        let mut builder = ListBuilder::new(StringBuilder::new());
1203        for string_list in data.into_iter() {
1204            builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>());
1205        }
1206
1207        Arc::new(builder.finish())
1208    }
1209
1210    fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1211    where
1212        ScalarValue: From<T>,
1213    {
1214        let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1215        ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1216    }
1217
1218    fn merge(
1219        mut acc1: Box<dyn Accumulator>,
1220        mut acc2: Box<dyn Accumulator>,
1221    ) -> Result<Box<dyn Accumulator>> {
1222        let intermediate_state = acc2.state().and_then(|e| {
1223            e.iter()
1224                .map(|v| v.to_array())
1225                .collect::<Result<Vec<ArrayRef>>>()
1226        })?;
1227        acc1.merge_batch(&intermediate_state)?;
1228        Ok(acc1)
1229    }
1230}