datafusion_functions_aggregate/
array_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
19
20use arrow::array::{
21    new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray,
22};
23use arrow::compute::{filter, SortOptions};
24use arrow::datatypes::{DataType, Field, FieldRef, Fields};
25
26use datafusion_common::cast::as_list_array;
27use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
28use datafusion_common::{exec_err, ScalarValue};
29use datafusion_common::{internal_err, Result};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::utils::format_state_name;
32use datafusion_expr::{Accumulator, Signature, Volatility};
33use datafusion_expr::{AggregateUDFImpl, Documentation};
34use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
35use datafusion_functions_aggregate_common::utils::ordering_fields;
36use datafusion_macros::user_doc;
37use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
38use std::cmp::Ordering;
39use std::collections::{HashSet, VecDeque};
40use std::mem::{size_of, size_of_val};
41use std::sync::Arc;
42
43make_udaf_expr_and_func!(
44    ArrayAgg,
45    array_agg,
46    expression,
47    "input values, including nulls, concatenated into an array",
48    array_agg_udaf
49);
50
51#[user_doc(
52    doc_section(label = "General Functions"),
53    description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
54This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
55    syntax_example = "array_agg(expression [ORDER BY expression])",
56    sql_example = r#"
57```sql
58> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
59+-----------------------------------------------+
60| array_agg(column_name ORDER BY other_column)  |
61+-----------------------------------------------+
62| [element1, element2, element3]                |
63+-----------------------------------------------+
64> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
65+--------------------------------------------------------+
66| array_agg(DISTINCT column_name ORDER BY column_name)  |
67+--------------------------------------------------------+
68| [element1, element2, element3]                         |
69+--------------------------------------------------------+
70```
71"#,
72    standard_argument(name = "expression",)
73)]
74#[derive(Debug)]
75/// ARRAY_AGG aggregate expression
76pub struct ArrayAgg {
77    signature: Signature,
78}
79
80impl Default for ArrayAgg {
81    fn default() -> Self {
82        Self {
83            signature: Signature::any(1, Volatility::Immutable),
84        }
85    }
86}
87
88impl AggregateUDFImpl for ArrayAgg {
89    fn as_any(&self) -> &dyn std::any::Any {
90        self
91    }
92
93    fn name(&self) -> &str {
94        "array_agg"
95    }
96
97    fn aliases(&self) -> &[String] {
98        &[]
99    }
100
101    fn signature(&self) -> &Signature {
102        &self.signature
103    }
104
105    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
106        Ok(DataType::List(Arc::new(Field::new_list_field(
107            arg_types[0].clone(),
108            true,
109        ))))
110    }
111
112    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
113        if args.is_distinct {
114            return Ok(vec![Field::new_list(
115                format_state_name(args.name, "distinct_array_agg"),
116                // See COMMENTS.md to understand why nullable is set to true
117                Field::new_list_field(args.input_fields[0].data_type().clone(), true),
118                true,
119            )
120            .into()]);
121        }
122
123        let mut fields = vec![Field::new_list(
124            format_state_name(args.name, "array_agg"),
125            // See COMMENTS.md to understand why nullable is set to true
126            Field::new_list_field(args.input_fields[0].data_type().clone(), true),
127            true,
128        )
129        .into()];
130
131        if args.ordering_fields.is_empty() {
132            return Ok(fields);
133        }
134
135        let orderings = args.ordering_fields.to_vec();
136        fields.push(
137            Field::new_list(
138                format_state_name(args.name, "array_agg_orderings"),
139                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
140                false,
141            )
142            .into(),
143        );
144
145        Ok(fields)
146    }
147
148    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
149        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
150        let ignore_nulls =
151            acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
152
153        if acc_args.is_distinct {
154            // Limitation similar to Postgres. The aggregation function can only mix
155            // DISTINCT and ORDER BY if all the expressions in the ORDER BY appear
156            // also in the arguments of the function. This implies that if the
157            // aggregation function only accepts one argument, only one argument
158            // can be used in the ORDER BY, For example:
159            //
160            // ARRAY_AGG(DISTINCT col)
161            //
162            // can only be mixed with an ORDER BY if the order expression is "col".
163            //
164            // ARRAY_AGG(DISTINCT col ORDER BY col)                         <- Valid
165            // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid
166            // ARRAY_AGG(DISTINCT col ORDER BY other_col)                   <- Invalid
167            // ARRAY_AGG(DISTINCT col ORDER BY concat(col, ''))             <- Invalid
168            if acc_args.ordering_req.len() > 1 {
169                return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list");
170            }
171            let mut sort_option: Option<SortOptions> = None;
172            if let Some(order) = acc_args.ordering_req.first() {
173                if !order.expr.eq(&acc_args.exprs[0]) {
174                    return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list");
175                }
176                sort_option = Some(order.options)
177            }
178
179            return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
180                &data_type,
181                sort_option,
182                ignore_nulls,
183            )?));
184        }
185
186        if acc_args.ordering_req.is_empty() {
187            return Ok(Box::new(ArrayAggAccumulator::try_new(
188                &data_type,
189                ignore_nulls,
190            )?));
191        }
192
193        let ordering_dtypes = acc_args
194            .ordering_req
195            .iter()
196            .map(|e| e.expr.data_type(acc_args.schema))
197            .collect::<Result<Vec<_>>>()?;
198
199        OrderSensitiveArrayAggAccumulator::try_new(
200            &data_type,
201            &ordering_dtypes,
202            acc_args.ordering_req.clone(),
203            acc_args.is_reversed,
204            ignore_nulls,
205        )
206        .map(|acc| Box::new(acc) as _)
207    }
208
209    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
210        datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
211    }
212
213    fn documentation(&self) -> Option<&Documentation> {
214        self.doc()
215    }
216}
217
218#[derive(Debug)]
219pub struct ArrayAggAccumulator {
220    values: Vec<ArrayRef>,
221    datatype: DataType,
222    ignore_nulls: bool,
223}
224
225impl ArrayAggAccumulator {
226    /// new array_agg accumulator based on given item data type
227    pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
228        Ok(Self {
229            values: vec![],
230            datatype: datatype.clone(),
231            ignore_nulls,
232        })
233    }
234
235    /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
236    /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
237    fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
238        let offsets = list_array.value_offsets();
239        // Offsets always have at least 1 value
240        let initial_offset = offsets[0];
241        let null_count = list_array.null_count();
242
243        // If no nulls than just use the fast path
244        // This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
245        if null_count == 0 {
246            // According to Arrow specification, the first offset can be non-zero
247            let list_values = list_array.values().slice(
248                initial_offset as usize,
249                (offsets[offsets.len() - 1] - initial_offset) as usize,
250            );
251            return Some(list_values);
252        }
253
254        // If all the values are null than just return an empty values array
255        if list_array.null_count() == list_array.len() {
256            return Some(list_array.values().slice(0, 0));
257        }
258
259        // According to the Arrow spec, null values can point to non-empty lists
260        // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
261
262        // Unwrapping is safe as we just checked if there is a null value
263        let nulls = list_array.nulls().unwrap();
264
265        let mut valid_slices_iter = nulls.valid_slices();
266
267        // This is safe as we validated that there is at least 1 valid value in the array
268        let (start, end) = valid_slices_iter.next().unwrap();
269
270        let start_offset = offsets[start];
271
272        // End is exclusive, so it already point to the last offset value
273        // This is valid as the length of the array is always 1 less than the length of the offsets
274        let mut end_offset_of_last_valid_value = offsets[end];
275
276        for (start, end) in valid_slices_iter {
277            // If there is a null value that point to a non-empty list than the start offset of the valid value
278            // will be different that the end offset of the last valid value
279            if offsets[start] != end_offset_of_last_valid_value {
280                return None;
281            }
282
283            // End is exclusive, so it already point to the last offset value
284            // This is valid as the length of the array is always 1 less than the length of the offsets
285            end_offset_of_last_valid_value = offsets[end];
286        }
287
288        let consecutive_valid_values = list_array.values().slice(
289            start_offset as usize,
290            (end_offset_of_last_valid_value - start_offset) as usize,
291        );
292
293        Some(consecutive_valid_values)
294    }
295}
296
297impl Accumulator for ArrayAggAccumulator {
298    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
299        // Append value like Int64Array(1,2,3)
300        if values.is_empty() {
301            return Ok(());
302        }
303
304        if values.len() != 1 {
305            return internal_err!("expects single batch");
306        }
307
308        let val = &values[0];
309        let nulls = if self.ignore_nulls {
310            val.logical_nulls()
311        } else {
312            None
313        };
314
315        let val = match nulls {
316            Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
317            Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
318            None => Arc::clone(val),
319        };
320
321        if !val.is_empty() {
322            self.values.push(val);
323        }
324
325        Ok(())
326    }
327
328    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
329        // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
330        if states.is_empty() {
331            return Ok(());
332        }
333
334        if states.len() != 1 {
335            return internal_err!("expects single state");
336        }
337
338        let list_arr = as_list_array(&states[0])?;
339
340        match Self::get_optional_values_to_merge_as_is(list_arr) {
341            Some(values) => {
342                // Make sure we don't insert empty lists
343                if !values.is_empty() {
344                    self.values.push(values);
345                }
346            }
347            None => {
348                for arr in list_arr.iter().flatten() {
349                    self.values.push(arr);
350                }
351            }
352        }
353
354        Ok(())
355    }
356
357    fn state(&mut self) -> Result<Vec<ScalarValue>> {
358        Ok(vec![self.evaluate()?])
359    }
360
361    fn evaluate(&mut self) -> Result<ScalarValue> {
362        // Transform Vec<ListArr> to ListArr
363        let element_arrays: Vec<&dyn Array> =
364            self.values.iter().map(|a| a.as_ref()).collect();
365
366        if element_arrays.is_empty() {
367            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
368        }
369
370        let concated_array = arrow::compute::concat(&element_arrays)?;
371
372        Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
373    }
374
375    fn size(&self) -> usize {
376        size_of_val(self)
377            + (size_of::<ArrayRef>() * self.values.capacity())
378            + self
379                .values
380                .iter()
381                .map(|arr| arr.get_array_memory_size())
382                .sum::<usize>()
383            + self.datatype.size()
384            - size_of_val(&self.datatype)
385    }
386}
387
388#[derive(Debug)]
389struct DistinctArrayAggAccumulator {
390    values: HashSet<ScalarValue>,
391    datatype: DataType,
392    sort_options: Option<SortOptions>,
393    ignore_nulls: bool,
394}
395
396impl DistinctArrayAggAccumulator {
397    pub fn try_new(
398        datatype: &DataType,
399        sort_options: Option<SortOptions>,
400        ignore_nulls: bool,
401    ) -> Result<Self> {
402        Ok(Self {
403            values: HashSet::new(),
404            datatype: datatype.clone(),
405            sort_options,
406            ignore_nulls,
407        })
408    }
409}
410
411impl Accumulator for DistinctArrayAggAccumulator {
412    fn state(&mut self) -> Result<Vec<ScalarValue>> {
413        Ok(vec![self.evaluate()?])
414    }
415
416    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
417        if values.is_empty() {
418            return Ok(());
419        }
420
421        let val = &values[0];
422        let nulls = if self.ignore_nulls {
423            val.logical_nulls()
424        } else {
425            None
426        };
427
428        let nulls = nulls.as_ref();
429        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
430            for i in 0..val.len() {
431                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
432                    self.values.insert(ScalarValue::try_from_array(val, i)?);
433                }
434            }
435        }
436
437        Ok(())
438    }
439
440    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
441        if states.is_empty() {
442            return Ok(());
443        }
444
445        if states.len() != 1 {
446            return internal_err!("expects single state");
447        }
448
449        states[0]
450            .as_list::<i32>()
451            .iter()
452            .flatten()
453            .try_for_each(|val| self.update_batch(&[val]))
454    }
455
456    fn evaluate(&mut self) -> Result<ScalarValue> {
457        let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
458        if values.is_empty() {
459            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
460        }
461
462        if let Some(opts) = self.sort_options {
463            values.sort_by(|a, b| {
464                if a.is_null() {
465                    return match opts.nulls_first {
466                        true => Ordering::Less,
467                        false => Ordering::Greater,
468                    };
469                }
470                if b.is_null() {
471                    return match opts.nulls_first {
472                        true => Ordering::Greater,
473                        false => Ordering::Less,
474                    };
475                }
476                match opts.descending {
477                    true => b.partial_cmp(a).unwrap_or(Ordering::Equal),
478                    false => a.partial_cmp(b).unwrap_or(Ordering::Equal),
479                }
480            });
481        };
482
483        let arr = ScalarValue::new_list(&values, &self.datatype, true);
484        Ok(ScalarValue::List(arr))
485    }
486
487    fn size(&self) -> usize {
488        size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
489            - size_of_val(&self.values)
490            + self.datatype.size()
491            - size_of_val(&self.datatype)
492            - size_of_val(&self.sort_options)
493            + size_of::<Option<SortOptions>>()
494    }
495}
496
497/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi
498/// partition setting, partial aggregations are computed for every partition,
499/// and then their results are merged.
500#[derive(Debug)]
501pub(crate) struct OrderSensitiveArrayAggAccumulator {
502    /// Stores entries in the `ARRAY_AGG` result.
503    values: Vec<ScalarValue>,
504    /// Stores values of ordering requirement expressions corresponding to each
505    /// entry in `values`. This information is used when merging results from
506    /// different partitions. For detailed information how merging is done, see
507    /// [`merge_ordered_arrays`].
508    ordering_values: Vec<Vec<ScalarValue>>,
509    /// Stores datatypes of expressions inside values and ordering requirement
510    /// expressions.
511    datatypes: Vec<DataType>,
512    /// Stores the ordering requirement of the `Accumulator`.
513    ordering_req: LexOrdering,
514    /// Whether the aggregation is running in reverse.
515    reverse: bool,
516    /// Whether the aggregation should ignore null values.
517    ignore_nulls: bool,
518}
519
520impl OrderSensitiveArrayAggAccumulator {
521    /// Create a new order-sensitive ARRAY_AGG accumulator based on the given
522    /// item data type.
523    pub fn try_new(
524        datatype: &DataType,
525        ordering_dtypes: &[DataType],
526        ordering_req: LexOrdering,
527        reverse: bool,
528        ignore_nulls: bool,
529    ) -> Result<Self> {
530        let mut datatypes = vec![datatype.clone()];
531        datatypes.extend(ordering_dtypes.iter().cloned());
532        Ok(Self {
533            values: vec![],
534            ordering_values: vec![],
535            datatypes,
536            ordering_req,
537            reverse,
538            ignore_nulls,
539        })
540    }
541}
542
543impl Accumulator for OrderSensitiveArrayAggAccumulator {
544    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
545        if values.is_empty() {
546            return Ok(());
547        }
548
549        let val = &values[0];
550        let ord = &values[1..];
551        let nulls = if self.ignore_nulls {
552            val.logical_nulls()
553        } else {
554            None
555        };
556
557        let nulls = nulls.as_ref();
558        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
559            for i in 0..val.len() {
560                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
561                    self.values.push(ScalarValue::try_from_array(val, i)?);
562                    self.ordering_values.push(get_row_at_idx(ord, i)?)
563                }
564            }
565        }
566
567        Ok(())
568    }
569
570    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
571        if states.is_empty() {
572            return Ok(());
573        }
574
575        // First entry in the state is the aggregation result. Second entry
576        // stores values received for ordering requirement columns for each
577        // aggregation value inside `ARRAY_AGG` list. For each `StructArray`
578        // inside `ARRAY_AGG` list, we will receive an `Array` that stores values
579        // received from its ordering requirement expression. (This information
580        // is necessary for during merging).
581        let [array_agg_values, agg_orderings, ..] = &states else {
582            return exec_err!("State should have two elements");
583        };
584        let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
585            return exec_err!("Expects to receive a list array");
586        };
587
588        // Stores ARRAY_AGG results coming from each partition
589        let mut partition_values = vec![];
590        // Stores ordering requirement expression results coming from each partition
591        let mut partition_ordering_values = vec![];
592
593        // Existing values should be merged also.
594        partition_values.push(self.values.clone().into());
595        partition_ordering_values.push(self.ordering_values.clone().into());
596
597        // Convert array to Scalars to sort them easily. Convert back to array at evaluation.
598        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
599        for v in array_agg_res.into_iter() {
600            partition_values.push(v.into());
601        }
602
603        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
604
605        for partition_ordering_rows in orderings.into_iter() {
606            // Extract value from struct to ordering_rows for each group/partition
607            let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
608                    if let ScalarValue::Struct(s) = ordering_row {
609                        let mut ordering_columns_per_row = vec![];
610
611                        for column in s.columns() {
612                            let sv = ScalarValue::try_from_array(column, 0)?;
613                            ordering_columns_per_row.push(sv);
614                        }
615
616                        Ok(ordering_columns_per_row)
617                    } else {
618                        exec_err!(
619                            "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
620                            ordering_row.data_type()
621                        )
622                    }
623                }).collect::<Result<VecDeque<_>>>()?;
624
625            partition_ordering_values.push(ordering_value);
626        }
627
628        let sort_options = self
629            .ordering_req
630            .iter()
631            .map(|sort_expr| sort_expr.options)
632            .collect::<Vec<_>>();
633
634        (self.values, self.ordering_values) = merge_ordered_arrays(
635            &mut partition_values,
636            &mut partition_ordering_values,
637            &sort_options,
638        )?;
639
640        Ok(())
641    }
642
643    fn state(&mut self) -> Result<Vec<ScalarValue>> {
644        let mut result = vec![self.evaluate()?];
645        result.push(self.evaluate_orderings()?);
646
647        Ok(result)
648    }
649
650    fn evaluate(&mut self) -> Result<ScalarValue> {
651        if self.values.is_empty() {
652            return Ok(ScalarValue::new_null_list(
653                self.datatypes[0].clone(),
654                true,
655                1,
656            ));
657        }
658
659        let values = self.values.clone();
660        let array = if self.reverse {
661            ScalarValue::new_list_from_iter(
662                values.into_iter().rev(),
663                &self.datatypes[0],
664                true,
665            )
666        } else {
667            ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
668        };
669        Ok(ScalarValue::List(array))
670    }
671
672    fn size(&self) -> usize {
673        let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
674            - size_of_val(&self.values);
675
676        // Add size of the `self.ordering_values`
677        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
678        for row in &self.ordering_values {
679            total += ScalarValue::size_of_vec(row) - size_of_val(row);
680        }
681
682        // Add size of the `self.datatypes`
683        total += size_of::<DataType>() * self.datatypes.capacity();
684        for dtype in &self.datatypes {
685            total += dtype.size() - size_of_val(dtype);
686        }
687
688        // Add size of the `self.ordering_req`
689        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
690        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
691        total
692    }
693}
694
695impl OrderSensitiveArrayAggAccumulator {
696    fn evaluate_orderings(&self) -> Result<ScalarValue> {
697        let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]);
698        let num_columns = fields.len();
699
700        let mut column_wise_ordering_values = vec![];
701        for i in 0..num_columns {
702            let column_values = self
703                .ordering_values
704                .iter()
705                .map(|x| x[i].clone())
706                .collect::<Vec<_>>();
707            let array = if column_values.is_empty() {
708                new_empty_array(fields[i].data_type())
709            } else {
710                ScalarValue::iter_to_array(column_values.into_iter())?
711            };
712            column_wise_ordering_values.push(array);
713        }
714
715        let struct_field = Fields::from(fields);
716        let ordering_array =
717            StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
718        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
719    }
720}
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725    use arrow::datatypes::{FieldRef, Schema};
726    use datafusion_common::cast::as_generic_string_array;
727    use datafusion_common::internal_err;
728    use datafusion_physical_expr::expressions::Column;
729    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
730    use std::sync::Arc;
731
732    #[test]
733    fn no_duplicates_no_distinct() -> Result<()> {
734        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
735
736        acc1.update_batch(&[data(["a", "b", "c"])])?;
737        acc2.update_batch(&[data(["d", "e", "f"])])?;
738        acc1 = merge(acc1, acc2)?;
739
740        let result = print_nulls(str_arr(acc1.evaluate()?)?);
741
742        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
743
744        Ok(())
745    }
746
747    #[test]
748    fn no_duplicates_distinct() -> Result<()> {
749        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
750            .distinct()
751            .build_two()?;
752
753        acc1.update_batch(&[data(["a", "b", "c"])])?;
754        acc2.update_batch(&[data(["d", "e", "f"])])?;
755        acc1 = merge(acc1, acc2)?;
756
757        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
758        result.sort();
759
760        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
761
762        Ok(())
763    }
764
765    #[test]
766    fn duplicates_no_distinct() -> Result<()> {
767        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
768
769        acc1.update_batch(&[data(["a", "b", "c"])])?;
770        acc2.update_batch(&[data(["a", "b", "c"])])?;
771        acc1 = merge(acc1, acc2)?;
772
773        let result = print_nulls(str_arr(acc1.evaluate()?)?);
774
775        assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
776
777        Ok(())
778    }
779
780    #[test]
781    fn duplicates_distinct() -> Result<()> {
782        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
783            .distinct()
784            .build_two()?;
785
786        acc1.update_batch(&[data(["a", "b", "c"])])?;
787        acc2.update_batch(&[data(["a", "b", "c"])])?;
788        acc1 = merge(acc1, acc2)?;
789
790        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
791        result.sort();
792
793        assert_eq!(result, vec!["a", "b", "c"]);
794
795        Ok(())
796    }
797
798    #[test]
799    fn duplicates_on_second_batch_distinct() -> Result<()> {
800        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
801            .distinct()
802            .build_two()?;
803
804        acc1.update_batch(&[data(["a", "c"])])?;
805        acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
806        acc1 = merge(acc1, acc2)?;
807
808        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
809        result.sort();
810
811        assert_eq!(result, vec!["a", "b", "c", "d"]);
812
813        Ok(())
814    }
815
816    #[test]
817    fn no_duplicates_distinct_sort_asc() -> Result<()> {
818        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
819            .distinct()
820            .order_by_col("col", SortOptions::new(false, false))
821            .build_two()?;
822
823        acc1.update_batch(&[data(["e", "b", "d"])])?;
824        acc2.update_batch(&[data(["f", "a", "c"])])?;
825        acc1 = merge(acc1, acc2)?;
826
827        let result = print_nulls(str_arr(acc1.evaluate()?)?);
828
829        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
830
831        Ok(())
832    }
833
834    #[test]
835    fn no_duplicates_distinct_sort_desc() -> Result<()> {
836        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
837            .distinct()
838            .order_by_col("col", SortOptions::new(true, false))
839            .build_two()?;
840
841        acc1.update_batch(&[data(["e", "b", "d"])])?;
842        acc2.update_batch(&[data(["f", "a", "c"])])?;
843        acc1 = merge(acc1, acc2)?;
844
845        let result = print_nulls(str_arr(acc1.evaluate()?)?);
846
847        assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
848
849        Ok(())
850    }
851
852    #[test]
853    fn duplicates_distinct_sort_asc() -> Result<()> {
854        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
855            .distinct()
856            .order_by_col("col", SortOptions::new(false, false))
857            .build_two()?;
858
859        acc1.update_batch(&[data(["a", "c", "b"])])?;
860        acc2.update_batch(&[data(["b", "c", "a"])])?;
861        acc1 = merge(acc1, acc2)?;
862
863        let result = print_nulls(str_arr(acc1.evaluate()?)?);
864
865        assert_eq!(result, vec!["a", "b", "c"]);
866
867        Ok(())
868    }
869
870    #[test]
871    fn duplicates_distinct_sort_desc() -> Result<()> {
872        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
873            .distinct()
874            .order_by_col("col", SortOptions::new(true, false))
875            .build_two()?;
876
877        acc1.update_batch(&[data(["a", "c", "b"])])?;
878        acc2.update_batch(&[data(["b", "c", "a"])])?;
879        acc1 = merge(acc1, acc2)?;
880
881        let result = print_nulls(str_arr(acc1.evaluate()?)?);
882
883        assert_eq!(result, vec!["c", "b", "a"]);
884
885        Ok(())
886    }
887
888    #[test]
889    fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
890        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
891            .distinct()
892            .order_by_col("col", SortOptions::new(false, true))
893            .build_two()?;
894
895        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
896        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
897        acc1 = merge(acc1, acc2)?;
898
899        let result = print_nulls(str_arr(acc1.evaluate()?)?);
900
901        assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
902
903        Ok(())
904    }
905
906    #[test]
907    fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
908        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
909            .distinct()
910            .order_by_col("col", SortOptions::new(false, false))
911            .build_two()?;
912
913        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
914        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
915        acc1 = merge(acc1, acc2)?;
916
917        let result = print_nulls(str_arr(acc1.evaluate()?)?);
918
919        assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
920
921        Ok(())
922    }
923
924    #[test]
925    fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
926        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
927            .distinct()
928            .order_by_col("col", SortOptions::new(true, true))
929            .build_two()?;
930
931        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
932        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
933        acc1 = merge(acc1, acc2)?;
934
935        let result = print_nulls(str_arr(acc1.evaluate()?)?);
936
937        assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
938
939        Ok(())
940    }
941
942    #[test]
943    fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
944        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
945            .distinct()
946            .order_by_col("col", SortOptions::new(true, false))
947            .build_two()?;
948
949        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
950        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
951        acc1 = merge(acc1, acc2)?;
952
953        let result = print_nulls(str_arr(acc1.evaluate()?)?);
954
955        assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
956
957        Ok(())
958    }
959
960    #[test]
961    fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
962        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
963            .distinct()
964            .build_two()?;
965
966        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
967        acc2.update_batch(&[data([Some("a"), None, None, None])])?;
968        acc1 = merge(acc1, acc2)?;
969
970        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
971        result.sort();
972        assert_eq!(result, vec!["NULL", "a"]);
973        Ok(())
974    }
975
976    #[test]
977    fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
978        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
979            .distinct()
980            .build_two()?;
981
982        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
983        acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
984        acc1 = merge(acc1, acc2)?;
985
986        let result = print_nulls(str_arr(acc1.evaluate()?)?);
987        assert_eq!(result, vec!["NULL"]);
988        Ok(())
989    }
990
991    struct ArrayAggAccumulatorBuilder {
992        return_field: FieldRef,
993        distinct: bool,
994        ordering: LexOrdering,
995        schema: Schema,
996    }
997
998    impl ArrayAggAccumulatorBuilder {
999        fn string() -> Self {
1000            Self::new(DataType::Utf8)
1001        }
1002
1003        fn new(data_type: DataType) -> Self {
1004            Self {
1005                return_field: Field::new("f", data_type.clone(), true).into(),
1006                distinct: false,
1007                ordering: Default::default(),
1008                schema: Schema {
1009                    fields: Fields::from(vec![Field::new(
1010                        "col",
1011                        DataType::new_list(data_type, true),
1012                        true,
1013                    )]),
1014                    metadata: Default::default(),
1015                },
1016            }
1017        }
1018
1019        fn distinct(mut self) -> Self {
1020            self.distinct = true;
1021            self
1022        }
1023
1024        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
1025            self.ordering.extend([PhysicalSortExpr::new(
1026                Arc::new(
1027                    Column::new_with_schema(col, &self.schema)
1028                        .expect("column not available in schema"),
1029                ),
1030                sort_options,
1031            )]);
1032            self
1033        }
1034
1035        fn build(&self) -> Result<Box<dyn Accumulator>> {
1036            ArrayAgg::default().accumulator(AccumulatorArgs {
1037                return_field: Arc::clone(&self.return_field),
1038                schema: &self.schema,
1039                ignore_nulls: false,
1040                ordering_req: &self.ordering,
1041                is_reversed: false,
1042                name: "",
1043                is_distinct: self.distinct,
1044                exprs: &[Arc::new(Column::new("col", 0))],
1045            })
1046        }
1047
1048        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
1049            Ok((self.build()?, self.build()?))
1050        }
1051    }
1052
1053    fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
1054        let ScalarValue::List(list) = value else {
1055            return internal_err!("ScalarValue was not a List");
1056        };
1057        Ok(as_generic_string_array::<i32>(list.values())?
1058            .iter()
1059            .map(|v| v.map(|v| v.to_string()))
1060            .collect())
1061    }
1062
1063    fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1064        sort.into_iter()
1065            .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
1066            .collect()
1067    }
1068
1069    fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1070    where
1071        ScalarValue: From<T>,
1072    {
1073        let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1074        ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1075    }
1076
1077    fn merge(
1078        mut acc1: Box<dyn Accumulator>,
1079        mut acc2: Box<dyn Accumulator>,
1080    ) -> Result<Box<dyn Accumulator>> {
1081        let intermediate_state = acc2.state().and_then(|e| {
1082            e.iter()
1083                .map(|v| v.to_array())
1084                .collect::<Result<Vec<ArrayRef>>>()
1085        })?;
1086        acc1.merge_batch(&intermediate_state)?;
1087        Ok(acc1)
1088    }
1089}