datafusion_functions_aggregate/
array_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
19
20use std::cmp::Ordering;
21use std::collections::{HashSet, VecDeque};
22use std::mem::{size_of, size_of_val};
23use std::sync::Arc;
24
25use arrow::array::{
26    make_array, new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray,
27    StructArray,
28};
29use arrow::compute::{filter, SortOptions};
30use arrow::datatypes::{DataType, Field, FieldRef, Fields};
31
32use datafusion_common::cast::as_list_array;
33use datafusion_common::scalar::copy_array_data;
34use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
35use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
40};
41use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
42use datafusion_functions_aggregate_common::utils::ordering_fields;
43use datafusion_macros::user_doc;
44use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
45
46make_udaf_expr_and_func!(
47    ArrayAgg,
48    array_agg,
49    expression,
50    "input values, including nulls, concatenated into an array",
51    array_agg_udaf
52);
53
54#[user_doc(
55    doc_section(label = "General Functions"),
56    description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
57This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
58    syntax_example = "array_agg(expression [ORDER BY expression])",
59    sql_example = r#"
60```sql
61> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
62+-----------------------------------------------+
63| array_agg(column_name ORDER BY other_column)  |
64+-----------------------------------------------+
65| [element1, element2, element3]                |
66+-----------------------------------------------+
67> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
68+--------------------------------------------------------+
69| array_agg(DISTINCT column_name ORDER BY column_name)  |
70+--------------------------------------------------------+
71| [element1, element2, element3]                         |
72+--------------------------------------------------------+
73```
74"#,
75    standard_argument(name = "expression",)
76)]
77#[derive(Debug)]
78/// ARRAY_AGG aggregate expression
79pub struct ArrayAgg {
80    signature: Signature,
81}
82
83impl Default for ArrayAgg {
84    fn default() -> Self {
85        Self {
86            signature: Signature::any(1, Volatility::Immutable),
87        }
88    }
89}
90
91impl AggregateUDFImpl for ArrayAgg {
92    fn as_any(&self) -> &dyn std::any::Any {
93        self
94    }
95
96    fn name(&self) -> &str {
97        "array_agg"
98    }
99
100    fn signature(&self) -> &Signature {
101        &self.signature
102    }
103
104    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
105        Ok(DataType::List(Arc::new(Field::new_list_field(
106            arg_types[0].clone(),
107            true,
108        ))))
109    }
110
111    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
112        if args.is_distinct {
113            return Ok(vec![Field::new_list(
114                format_state_name(args.name, "distinct_array_agg"),
115                // See COMMENTS.md to understand why nullable is set to true
116                Field::new_list_field(args.input_fields[0].data_type().clone(), true),
117                true,
118            )
119            .into()]);
120        }
121
122        let mut fields = vec![Field::new_list(
123            format_state_name(args.name, "array_agg"),
124            // See COMMENTS.md to understand why nullable is set to true
125            Field::new_list_field(args.input_fields[0].data_type().clone(), true),
126            true,
127        )
128        .into()];
129
130        if args.ordering_fields.is_empty() {
131            return Ok(fields);
132        }
133
134        let orderings = args.ordering_fields.to_vec();
135        fields.push(
136            Field::new_list(
137                format_state_name(args.name, "array_agg_orderings"),
138                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
139                false,
140            )
141            .into(),
142        );
143
144        Ok(fields)
145    }
146
147    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
148        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
149        let ignore_nulls =
150            acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
151
152        if acc_args.is_distinct {
153            // Limitation similar to Postgres. The aggregation function can only mix
154            // DISTINCT and ORDER BY if all the expressions in the ORDER BY appear
155            // also in the arguments of the function. This implies that if the
156            // aggregation function only accepts one argument, only one argument
157            // can be used in the ORDER BY, For example:
158            //
159            // ARRAY_AGG(DISTINCT col)
160            //
161            // can only be mixed with an ORDER BY if the order expression is "col".
162            //
163            // ARRAY_AGG(DISTINCT col ORDER BY col)                         <- Valid
164            // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid
165            // ARRAY_AGG(DISTINCT col ORDER BY other_col)                   <- Invalid
166            // ARRAY_AGG(DISTINCT col ORDER BY concat(col, ''))             <- Invalid
167            let sort_option = match acc_args.order_bys {
168                [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options),
169                [] => None,
170                _ => {
171                    return exec_err!(
172                        "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"
173                    );
174                }
175            };
176            return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
177                &data_type,
178                sort_option,
179                ignore_nulls,
180            )?));
181        }
182
183        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
184            return Ok(Box::new(ArrayAggAccumulator::try_new(
185                &data_type,
186                ignore_nulls,
187            )?));
188        };
189
190        let ordering_dtypes = ordering
191            .iter()
192            .map(|e| e.expr.data_type(acc_args.schema))
193            .collect::<Result<Vec<_>>>()?;
194
195        OrderSensitiveArrayAggAccumulator::try_new(
196            &data_type,
197            &ordering_dtypes,
198            ordering,
199            acc_args.is_reversed,
200            ignore_nulls,
201        )
202        .map(|acc| Box::new(acc) as _)
203    }
204
205    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
206        datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
207    }
208
209    fn documentation(&self) -> Option<&Documentation> {
210        self.doc()
211    }
212}
213
214#[derive(Debug)]
215pub struct ArrayAggAccumulator {
216    values: Vec<ArrayRef>,
217    datatype: DataType,
218    ignore_nulls: bool,
219}
220
221impl ArrayAggAccumulator {
222    /// new array_agg accumulator based on given item data type
223    pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
224        Ok(Self {
225            values: vec![],
226            datatype: datatype.clone(),
227            ignore_nulls,
228        })
229    }
230
231    /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
232    /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
233    fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
234        let offsets = list_array.value_offsets();
235        // Offsets always have at least 1 value
236        let initial_offset = offsets[0];
237        let null_count = list_array.null_count();
238
239        // If no nulls than just use the fast path
240        // This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
241        if null_count == 0 {
242            // According to Arrow specification, the first offset can be non-zero
243            let list_values = list_array.values().slice(
244                initial_offset as usize,
245                (offsets[offsets.len() - 1] - initial_offset) as usize,
246            );
247            return Some(list_values);
248        }
249
250        // If all the values are null than just return an empty values array
251        if list_array.null_count() == list_array.len() {
252            return Some(list_array.values().slice(0, 0));
253        }
254
255        // According to the Arrow spec, null values can point to non-empty lists
256        // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
257
258        // Unwrapping is safe as we just checked if there is a null value
259        let nulls = list_array.nulls().unwrap();
260
261        let mut valid_slices_iter = nulls.valid_slices();
262
263        // This is safe as we validated that there is at least 1 valid value in the array
264        let (start, end) = valid_slices_iter.next().unwrap();
265
266        let start_offset = offsets[start];
267
268        // End is exclusive, so it already point to the last offset value
269        // This is valid as the length of the array is always 1 less than the length of the offsets
270        let mut end_offset_of_last_valid_value = offsets[end];
271
272        for (start, end) in valid_slices_iter {
273            // If there is a null value that point to a non-empty list than the start offset of the valid value
274            // will be different that the end offset of the last valid value
275            if offsets[start] != end_offset_of_last_valid_value {
276                return None;
277            }
278
279            // End is exclusive, so it already point to the last offset value
280            // This is valid as the length of the array is always 1 less than the length of the offsets
281            end_offset_of_last_valid_value = offsets[end];
282        }
283
284        let consecutive_valid_values = list_array.values().slice(
285            start_offset as usize,
286            (end_offset_of_last_valid_value - start_offset) as usize,
287        );
288
289        Some(consecutive_valid_values)
290    }
291}
292
293impl Accumulator for ArrayAggAccumulator {
294    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
295        // Append value like Int64Array(1,2,3)
296        if values.is_empty() {
297            return Ok(());
298        }
299
300        if values.len() != 1 {
301            return internal_err!("expects single batch");
302        }
303
304        let val = &values[0];
305        let nulls = if self.ignore_nulls {
306            val.logical_nulls()
307        } else {
308            None
309        };
310
311        let val = match nulls {
312            Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
313            Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
314            None => Arc::clone(val),
315        };
316
317        if !val.is_empty() {
318            // The ArrayRef might be holding a reference to its original input buffer, so
319            // storing it here directly copied/compacted avoids over accounting memory
320            // not used here.
321            self.values
322                .push(make_array(copy_array_data(&val.to_data())));
323        }
324
325        Ok(())
326    }
327
328    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
329        // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
330        if states.is_empty() {
331            return Ok(());
332        }
333
334        if states.len() != 1 {
335            return internal_err!("expects single state");
336        }
337
338        let list_arr = as_list_array(&states[0])?;
339
340        match Self::get_optional_values_to_merge_as_is(list_arr) {
341            Some(values) => {
342                // Make sure we don't insert empty lists
343                if !values.is_empty() {
344                    self.values.push(values);
345                }
346            }
347            None => {
348                for arr in list_arr.iter().flatten() {
349                    self.values.push(arr);
350                }
351            }
352        }
353
354        Ok(())
355    }
356
357    fn state(&mut self) -> Result<Vec<ScalarValue>> {
358        Ok(vec![self.evaluate()?])
359    }
360
361    fn evaluate(&mut self) -> Result<ScalarValue> {
362        // Transform Vec<ListArr> to ListArr
363        let element_arrays: Vec<&dyn Array> =
364            self.values.iter().map(|a| a.as_ref()).collect();
365
366        if element_arrays.is_empty() {
367            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
368        }
369
370        let concated_array = arrow::compute::concat(&element_arrays)?;
371
372        Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
373    }
374
375    fn size(&self) -> usize {
376        size_of_val(self)
377            + (size_of::<ArrayRef>() * self.values.capacity())
378            + self
379                .values
380                .iter()
381                .map(|arr| arr.get_array_memory_size())
382                .sum::<usize>()
383            + self.datatype.size()
384            - size_of_val(&self.datatype)
385    }
386}
387
388#[derive(Debug)]
389struct DistinctArrayAggAccumulator {
390    values: HashSet<ScalarValue>,
391    datatype: DataType,
392    sort_options: Option<SortOptions>,
393    ignore_nulls: bool,
394}
395
396impl DistinctArrayAggAccumulator {
397    pub fn try_new(
398        datatype: &DataType,
399        sort_options: Option<SortOptions>,
400        ignore_nulls: bool,
401    ) -> Result<Self> {
402        Ok(Self {
403            values: HashSet::new(),
404            datatype: datatype.clone(),
405            sort_options,
406            ignore_nulls,
407        })
408    }
409}
410
411impl Accumulator for DistinctArrayAggAccumulator {
412    fn state(&mut self) -> Result<Vec<ScalarValue>> {
413        Ok(vec![self.evaluate()?])
414    }
415
416    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
417        if values.is_empty() {
418            return Ok(());
419        }
420
421        let val = &values[0];
422        let nulls = if self.ignore_nulls {
423            val.logical_nulls()
424        } else {
425            None
426        };
427
428        let nulls = nulls.as_ref();
429        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
430            for i in 0..val.len() {
431                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
432                    self.values
433                        .insert(ScalarValue::try_from_array(val, i)?.compacted());
434                }
435            }
436        }
437
438        Ok(())
439    }
440
441    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
442        if states.is_empty() {
443            return Ok(());
444        }
445
446        if states.len() != 1 {
447            return internal_err!("expects single state");
448        }
449
450        states[0]
451            .as_list::<i32>()
452            .iter()
453            .flatten()
454            .try_for_each(|val| self.update_batch(&[val]))
455    }
456
457    fn evaluate(&mut self) -> Result<ScalarValue> {
458        let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
459        if values.is_empty() {
460            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
461        }
462
463        if let Some(opts) = self.sort_options {
464            let mut delayed_cmp_err = Ok(());
465            values.sort_by(|a, b| {
466                if a.is_null() {
467                    return match opts.nulls_first {
468                        true => Ordering::Less,
469                        false => Ordering::Greater,
470                    };
471                }
472                if b.is_null() {
473                    return match opts.nulls_first {
474                        true => Ordering::Greater,
475                        false => Ordering::Less,
476                    };
477                }
478                match opts.descending {
479                    true => b.try_cmp(a),
480                    false => a.try_cmp(b),
481                }
482                .unwrap_or_else(|err| {
483                    delayed_cmp_err = Err(err);
484                    Ordering::Equal
485                })
486            });
487            delayed_cmp_err?;
488        };
489
490        let arr = ScalarValue::new_list(&values, &self.datatype, true);
491        Ok(ScalarValue::List(arr))
492    }
493
494    fn size(&self) -> usize {
495        size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
496            - size_of_val(&self.values)
497            + self.datatype.size()
498            - size_of_val(&self.datatype)
499            - size_of_val(&self.sort_options)
500            + size_of::<Option<SortOptions>>()
501    }
502}
503
504/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi
505/// partition setting, partial aggregations are computed for every partition,
506/// and then their results are merged.
507#[derive(Debug)]
508pub(crate) struct OrderSensitiveArrayAggAccumulator {
509    /// Stores entries in the `ARRAY_AGG` result.
510    values: Vec<ScalarValue>,
511    /// Stores values of ordering requirement expressions corresponding to each
512    /// entry in `values`. This information is used when merging results from
513    /// different partitions. For detailed information how merging is done, see
514    /// [`merge_ordered_arrays`].
515    ordering_values: Vec<Vec<ScalarValue>>,
516    /// Stores datatypes of expressions inside values and ordering requirement
517    /// expressions.
518    datatypes: Vec<DataType>,
519    /// Stores the ordering requirement of the `Accumulator`.
520    ordering_req: LexOrdering,
521    /// Whether the aggregation is running in reverse.
522    reverse: bool,
523    /// Whether the aggregation should ignore null values.
524    ignore_nulls: bool,
525}
526
527impl OrderSensitiveArrayAggAccumulator {
528    /// Create a new order-sensitive ARRAY_AGG accumulator based on the given
529    /// item data type.
530    pub fn try_new(
531        datatype: &DataType,
532        ordering_dtypes: &[DataType],
533        ordering_req: LexOrdering,
534        reverse: bool,
535        ignore_nulls: bool,
536    ) -> Result<Self> {
537        let mut datatypes = vec![datatype.clone()];
538        datatypes.extend(ordering_dtypes.iter().cloned());
539        Ok(Self {
540            values: vec![],
541            ordering_values: vec![],
542            datatypes,
543            ordering_req,
544            reverse,
545            ignore_nulls,
546        })
547    }
548
549    fn evaluate_orderings(&self) -> Result<ScalarValue> {
550        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
551
552        let column_wise_ordering_values = if self.ordering_values.is_empty() {
553            fields
554                .iter()
555                .map(|f| new_empty_array(f.data_type()))
556                .collect::<Vec<_>>()
557        } else {
558            (0..fields.len())
559                .map(|i| {
560                    let column_values = self.ordering_values.iter().map(|x| x[i].clone());
561                    ScalarValue::iter_to_array(column_values)
562                })
563                .collect::<Result<_>>()?
564        };
565
566        let ordering_array = StructArray::try_new(
567            Fields::from(fields),
568            column_wise_ordering_values,
569            None,
570        )?;
571        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
572    }
573}
574
575impl Accumulator for OrderSensitiveArrayAggAccumulator {
576    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
577        if values.is_empty() {
578            return Ok(());
579        }
580
581        let val = &values[0];
582        let ord = &values[1..];
583        let nulls = if self.ignore_nulls {
584            val.logical_nulls()
585        } else {
586            None
587        };
588
589        let nulls = nulls.as_ref();
590        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
591            for i in 0..val.len() {
592                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
593                    self.values
594                        .push(ScalarValue::try_from_array(val, i)?.compacted());
595                    self.ordering_values.push(
596                        get_row_at_idx(ord, i)?
597                            .into_iter()
598                            .map(|v| v.compacted())
599                            .collect(),
600                    )
601                }
602            }
603        }
604
605        Ok(())
606    }
607
608    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
609        if states.is_empty() {
610            return Ok(());
611        }
612
613        // First entry in the state is the aggregation result. Second entry
614        // stores values received for ordering requirement columns for each
615        // aggregation value inside `ARRAY_AGG` list. For each `StructArray`
616        // inside `ARRAY_AGG` list, we will receive an `Array` that stores values
617        // received from its ordering requirement expression. (This information
618        // is necessary for during merging).
619        let [array_agg_values, agg_orderings, ..] = &states else {
620            return exec_err!("State should have two elements");
621        };
622        let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
623            return exec_err!("Expects to receive a list array");
624        };
625
626        // Stores ARRAY_AGG results coming from each partition
627        let mut partition_values = vec![];
628        // Stores ordering requirement expression results coming from each partition
629        let mut partition_ordering_values = vec![];
630
631        // Existing values should be merged also.
632        partition_values.push(self.values.clone().into());
633        partition_ordering_values.push(self.ordering_values.clone().into());
634
635        // Convert array to Scalars to sort them easily. Convert back to array at evaluation.
636        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
637        for v in array_agg_res.into_iter() {
638            partition_values.push(v.into());
639        }
640
641        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
642
643        for partition_ordering_rows in orderings.into_iter() {
644            // Extract value from struct to ordering_rows for each group/partition
645            let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
646                    if let ScalarValue::Struct(s) = ordering_row {
647                        let mut ordering_columns_per_row = vec![];
648
649                        for column in s.columns() {
650                            let sv = ScalarValue::try_from_array(column, 0)?;
651                            ordering_columns_per_row.push(sv);
652                        }
653
654                        Ok(ordering_columns_per_row)
655                    } else {
656                        exec_err!(
657                            "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
658                            ordering_row.data_type()
659                        )
660                    }
661                }).collect::<Result<VecDeque<_>>>()?;
662
663            partition_ordering_values.push(ordering_value);
664        }
665
666        let sort_options = self
667            .ordering_req
668            .iter()
669            .map(|sort_expr| sort_expr.options)
670            .collect::<Vec<_>>();
671
672        (self.values, self.ordering_values) = merge_ordered_arrays(
673            &mut partition_values,
674            &mut partition_ordering_values,
675            &sort_options,
676        )?;
677
678        Ok(())
679    }
680
681    fn state(&mut self) -> Result<Vec<ScalarValue>> {
682        let mut result = vec![self.evaluate()?];
683        result.push(self.evaluate_orderings()?);
684
685        Ok(result)
686    }
687
688    fn evaluate(&mut self) -> Result<ScalarValue> {
689        if self.values.is_empty() {
690            return Ok(ScalarValue::new_null_list(
691                self.datatypes[0].clone(),
692                true,
693                1,
694            ));
695        }
696
697        let values = self.values.clone();
698        let array = if self.reverse {
699            ScalarValue::new_list_from_iter(
700                values.into_iter().rev(),
701                &self.datatypes[0],
702                true,
703            )
704        } else {
705            ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
706        };
707        Ok(ScalarValue::List(array))
708    }
709
710    fn size(&self) -> usize {
711        let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
712            - size_of_val(&self.values);
713
714        // Add size of the `self.ordering_values`
715        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
716        for row in &self.ordering_values {
717            total += ScalarValue::size_of_vec(row) - size_of_val(row);
718        }
719
720        // Add size of the `self.datatypes`
721        total += size_of::<DataType>() * self.datatypes.capacity();
722        for dtype in &self.datatypes {
723            total += dtype.size() - size_of_val(dtype);
724        }
725
726        // Add size of the `self.ordering_req`
727        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
728        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
729        total
730    }
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736    use arrow::array::{ListBuilder, StringBuilder};
737    use arrow::datatypes::{FieldRef, Schema};
738    use datafusion_common::cast::as_generic_string_array;
739    use datafusion_common::internal_err;
740    use datafusion_physical_expr::expressions::Column;
741    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
742    use std::sync::Arc;
743
744    #[test]
745    fn no_duplicates_no_distinct() -> Result<()> {
746        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
747
748        acc1.update_batch(&[data(["a", "b", "c"])])?;
749        acc2.update_batch(&[data(["d", "e", "f"])])?;
750        acc1 = merge(acc1, acc2)?;
751
752        let result = print_nulls(str_arr(acc1.evaluate()?)?);
753
754        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
755
756        Ok(())
757    }
758
759    #[test]
760    fn no_duplicates_distinct() -> Result<()> {
761        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
762            .distinct()
763            .build_two()?;
764
765        acc1.update_batch(&[data(["a", "b", "c"])])?;
766        acc2.update_batch(&[data(["d", "e", "f"])])?;
767        acc1 = merge(acc1, acc2)?;
768
769        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
770        result.sort();
771
772        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
773
774        Ok(())
775    }
776
777    #[test]
778    fn duplicates_no_distinct() -> Result<()> {
779        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
780
781        acc1.update_batch(&[data(["a", "b", "c"])])?;
782        acc2.update_batch(&[data(["a", "b", "c"])])?;
783        acc1 = merge(acc1, acc2)?;
784
785        let result = print_nulls(str_arr(acc1.evaluate()?)?);
786
787        assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
788
789        Ok(())
790    }
791
792    #[test]
793    fn duplicates_distinct() -> Result<()> {
794        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
795            .distinct()
796            .build_two()?;
797
798        acc1.update_batch(&[data(["a", "b", "c"])])?;
799        acc2.update_batch(&[data(["a", "b", "c"])])?;
800        acc1 = merge(acc1, acc2)?;
801
802        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
803        result.sort();
804
805        assert_eq!(result, vec!["a", "b", "c"]);
806
807        Ok(())
808    }
809
810    #[test]
811    fn duplicates_on_second_batch_distinct() -> Result<()> {
812        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
813            .distinct()
814            .build_two()?;
815
816        acc1.update_batch(&[data(["a", "c"])])?;
817        acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
818        acc1 = merge(acc1, acc2)?;
819
820        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
821        result.sort();
822
823        assert_eq!(result, vec!["a", "b", "c", "d"]);
824
825        Ok(())
826    }
827
828    #[test]
829    fn no_duplicates_distinct_sort_asc() -> Result<()> {
830        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
831            .distinct()
832            .order_by_col("col", SortOptions::new(false, false))
833            .build_two()?;
834
835        acc1.update_batch(&[data(["e", "b", "d"])])?;
836        acc2.update_batch(&[data(["f", "a", "c"])])?;
837        acc1 = merge(acc1, acc2)?;
838
839        let result = print_nulls(str_arr(acc1.evaluate()?)?);
840
841        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
842
843        Ok(())
844    }
845
846    #[test]
847    fn no_duplicates_distinct_sort_desc() -> Result<()> {
848        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
849            .distinct()
850            .order_by_col("col", SortOptions::new(true, false))
851            .build_two()?;
852
853        acc1.update_batch(&[data(["e", "b", "d"])])?;
854        acc2.update_batch(&[data(["f", "a", "c"])])?;
855        acc1 = merge(acc1, acc2)?;
856
857        let result = print_nulls(str_arr(acc1.evaluate()?)?);
858
859        assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
860
861        Ok(())
862    }
863
864    #[test]
865    fn duplicates_distinct_sort_asc() -> Result<()> {
866        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
867            .distinct()
868            .order_by_col("col", SortOptions::new(false, false))
869            .build_two()?;
870
871        acc1.update_batch(&[data(["a", "c", "b"])])?;
872        acc2.update_batch(&[data(["b", "c", "a"])])?;
873        acc1 = merge(acc1, acc2)?;
874
875        let result = print_nulls(str_arr(acc1.evaluate()?)?);
876
877        assert_eq!(result, vec!["a", "b", "c"]);
878
879        Ok(())
880    }
881
882    #[test]
883    fn duplicates_distinct_sort_desc() -> Result<()> {
884        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
885            .distinct()
886            .order_by_col("col", SortOptions::new(true, false))
887            .build_two()?;
888
889        acc1.update_batch(&[data(["a", "c", "b"])])?;
890        acc2.update_batch(&[data(["b", "c", "a"])])?;
891        acc1 = merge(acc1, acc2)?;
892
893        let result = print_nulls(str_arr(acc1.evaluate()?)?);
894
895        assert_eq!(result, vec!["c", "b", "a"]);
896
897        Ok(())
898    }
899
900    #[test]
901    fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
902        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
903            .distinct()
904            .order_by_col("col", SortOptions::new(false, true))
905            .build_two()?;
906
907        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
908        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
909        acc1 = merge(acc1, acc2)?;
910
911        let result = print_nulls(str_arr(acc1.evaluate()?)?);
912
913        assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
914
915        Ok(())
916    }
917
918    #[test]
919    fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
920        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
921            .distinct()
922            .order_by_col("col", SortOptions::new(false, false))
923            .build_two()?;
924
925        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
926        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
927        acc1 = merge(acc1, acc2)?;
928
929        let result = print_nulls(str_arr(acc1.evaluate()?)?);
930
931        assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
932
933        Ok(())
934    }
935
936    #[test]
937    fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
938        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
939            .distinct()
940            .order_by_col("col", SortOptions::new(true, true))
941            .build_two()?;
942
943        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
944        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
945        acc1 = merge(acc1, acc2)?;
946
947        let result = print_nulls(str_arr(acc1.evaluate()?)?);
948
949        assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
950
951        Ok(())
952    }
953
954    #[test]
955    fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
956        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
957            .distinct()
958            .order_by_col("col", SortOptions::new(true, false))
959            .build_two()?;
960
961        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
962        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
963        acc1 = merge(acc1, acc2)?;
964
965        let result = print_nulls(str_arr(acc1.evaluate()?)?);
966
967        assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
968
969        Ok(())
970    }
971
972    #[test]
973    fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
974        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
975            .distinct()
976            .build_two()?;
977
978        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
979        acc2.update_batch(&[data([Some("a"), None, None, None])])?;
980        acc1 = merge(acc1, acc2)?;
981
982        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
983        result.sort();
984        assert_eq!(result, vec!["NULL", "a"]);
985        Ok(())
986    }
987
988    #[test]
989    fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
990        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
991            .distinct()
992            .build_two()?;
993
994        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
995        acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
996        acc1 = merge(acc1, acc2)?;
997
998        let result = print_nulls(str_arr(acc1.evaluate()?)?);
999        assert_eq!(result, vec!["NULL"]);
1000        Ok(())
1001    }
1002
1003    #[test]
1004    fn does_not_over_account_memory() -> Result<()> {
1005        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1006
1007        acc1.update_batch(&[data(["a", "c", "b"])])?;
1008        acc2.update_batch(&[data(["b", "c", "a"])])?;
1009        acc1 = merge(acc1, acc2)?;
1010
1011        // without compaction, the size is 2652.
1012        assert_eq!(acc1.size(), 732);
1013
1014        Ok(())
1015    }
1016    #[test]
1017    fn does_not_over_account_memory_distinct() -> Result<()> {
1018        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1019            .distinct()
1020            .build_two()?;
1021
1022        acc1.update_batch(&[string_list_data([
1023            vec!["a", "b", "c"],
1024            vec!["d", "e", "f"],
1025        ])])?;
1026        acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
1027        acc1 = merge(acc1, acc2)?;
1028
1029        // without compaction, the size is 16660
1030        assert_eq!(acc1.size(), 1660);
1031
1032        Ok(())
1033    }
1034
1035    #[test]
1036    fn does_not_over_account_memory_ordered() -> Result<()> {
1037        let mut acc = ArrayAggAccumulatorBuilder::string()
1038            .order_by_col("col", SortOptions::new(false, false))
1039            .build()?;
1040
1041        acc.update_batch(&[string_list_data([
1042            vec!["a", "b", "c"],
1043            vec!["c", "d", "e"],
1044            vec!["b", "c", "d"],
1045        ])])?;
1046
1047        // without compaction, the size is 17112
1048        assert_eq!(acc.size(), 2112);
1049
1050        Ok(())
1051    }
1052
1053    struct ArrayAggAccumulatorBuilder {
1054        return_field: FieldRef,
1055        distinct: bool,
1056        order_bys: Vec<PhysicalSortExpr>,
1057        schema: Schema,
1058    }
1059
1060    impl ArrayAggAccumulatorBuilder {
1061        fn string() -> Self {
1062            Self::new(DataType::Utf8)
1063        }
1064
1065        fn new(data_type: DataType) -> Self {
1066            Self {
1067                return_field: Field::new("f", data_type.clone(), true).into(),
1068                distinct: false,
1069                order_bys: vec![],
1070                schema: Schema {
1071                    fields: Fields::from(vec![Field::new(
1072                        "col",
1073                        DataType::new_list(data_type, true),
1074                        true,
1075                    )]),
1076                    metadata: Default::default(),
1077                },
1078            }
1079        }
1080
1081        fn distinct(mut self) -> Self {
1082            self.distinct = true;
1083            self
1084        }
1085
1086        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
1087            let new_order = PhysicalSortExpr::new(
1088                Arc::new(
1089                    Column::new_with_schema(col, &self.schema)
1090                        .expect("column not available in schema"),
1091                ),
1092                sort_options,
1093            );
1094            self.order_bys.push(new_order);
1095            self
1096        }
1097
1098        fn build(&self) -> Result<Box<dyn Accumulator>> {
1099            ArrayAgg::default().accumulator(AccumulatorArgs {
1100                return_field: Arc::clone(&self.return_field),
1101                schema: &self.schema,
1102                ignore_nulls: false,
1103                order_bys: &self.order_bys,
1104                is_reversed: false,
1105                name: "",
1106                is_distinct: self.distinct,
1107                exprs: &[Arc::new(Column::new("col", 0))],
1108            })
1109        }
1110
1111        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
1112            Ok((self.build()?, self.build()?))
1113        }
1114    }
1115
1116    fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
1117        let ScalarValue::List(list) = value else {
1118            return internal_err!("ScalarValue was not a List");
1119        };
1120        Ok(as_generic_string_array::<i32>(list.values())?
1121            .iter()
1122            .map(|v| v.map(|v| v.to_string()))
1123            .collect())
1124    }
1125
1126    fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1127        sort.into_iter()
1128            .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
1129            .collect()
1130    }
1131
1132    fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef {
1133        let mut builder = ListBuilder::new(StringBuilder::new());
1134        for string_list in data.into_iter() {
1135            builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>());
1136        }
1137
1138        Arc::new(builder.finish())
1139    }
1140
1141    fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1142    where
1143        ScalarValue: From<T>,
1144    {
1145        let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1146        ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1147    }
1148
1149    fn merge(
1150        mut acc1: Box<dyn Accumulator>,
1151        mut acc2: Box<dyn Accumulator>,
1152    ) -> Result<Box<dyn Accumulator>> {
1153        let intermediate_state = acc2.state().and_then(|e| {
1154            e.iter()
1155                .map(|v| v.to_array())
1156                .collect::<Result<Vec<ArrayRef>>>()
1157        })?;
1158        acc1.merge_batch(&intermediate_state)?;
1159        Ok(acc1)
1160    }
1161}