datafusion_functions_aggregate/
array_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
19
20use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray};
21use arrow::compute::SortOptions;
22use arrow::datatypes::{DataType, Field, Fields};
23
24use datafusion_common::cast::as_list_array;
25use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
26use datafusion_common::{exec_err, ScalarValue};
27use datafusion_common::{internal_err, Result};
28use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion_expr::utils::format_state_name;
30use datafusion_expr::{Accumulator, Signature, Volatility};
31use datafusion_expr::{AggregateUDFImpl, Documentation};
32use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
33use datafusion_functions_aggregate_common::utils::ordering_fields;
34use datafusion_macros::user_doc;
35use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
36use std::cmp::Ordering;
37use std::collections::{HashSet, VecDeque};
38use std::mem::{size_of, size_of_val};
39use std::sync::Arc;
40
41make_udaf_expr_and_func!(
42    ArrayAgg,
43    array_agg,
44    expression,
45    "input values, including nulls, concatenated into an array",
46    array_agg_udaf
47);
48
49#[user_doc(
50    doc_section(label = "General Functions"),
51    description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
52This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
53    syntax_example = "array_agg(expression [ORDER BY expression])",
54    sql_example = r#"
55```sql
56> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
57+-----------------------------------------------+
58| array_agg(column_name ORDER BY other_column)  |
59+-----------------------------------------------+
60| [element1, element2, element3]                |
61+-----------------------------------------------+
62> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
63+--------------------------------------------------------+
64| array_agg(DISTINCT column_name ORDER BY column_name)  |
65+--------------------------------------------------------+
66| [element1, element2, element3]                         |
67+--------------------------------------------------------+
68```
69"#,
70    standard_argument(name = "expression",)
71)]
72#[derive(Debug)]
73/// ARRAY_AGG aggregate expression
74pub struct ArrayAgg {
75    signature: Signature,
76}
77
78impl Default for ArrayAgg {
79    fn default() -> Self {
80        Self {
81            signature: Signature::any(1, Volatility::Immutable),
82        }
83    }
84}
85
86impl AggregateUDFImpl for ArrayAgg {
87    fn as_any(&self) -> &dyn std::any::Any {
88        self
89    }
90
91    fn name(&self) -> &str {
92        "array_agg"
93    }
94
95    fn aliases(&self) -> &[String] {
96        &[]
97    }
98
99    fn signature(&self) -> &Signature {
100        &self.signature
101    }
102
103    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
104        Ok(DataType::List(Arc::new(Field::new_list_field(
105            arg_types[0].clone(),
106            true,
107        ))))
108    }
109
110    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
111        if args.is_distinct {
112            return Ok(vec![Field::new_list(
113                format_state_name(args.name, "distinct_array_agg"),
114                // See COMMENTS.md to understand why nullable is set to true
115                Field::new_list_field(args.input_types[0].clone(), true),
116                true,
117            )]);
118        }
119
120        let mut fields = vec![Field::new_list(
121            format_state_name(args.name, "array_agg"),
122            // See COMMENTS.md to understand why nullable is set to true
123            Field::new_list_field(args.input_types[0].clone(), true),
124            true,
125        )];
126
127        if args.ordering_fields.is_empty() {
128            return Ok(fields);
129        }
130
131        let orderings = args.ordering_fields.to_vec();
132        fields.push(Field::new_list(
133            format_state_name(args.name, "array_agg_orderings"),
134            Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
135            false,
136        ));
137
138        Ok(fields)
139    }
140
141    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
142        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
143
144        if acc_args.is_distinct {
145            // Limitation similar to Postgres. The aggregation function can only mix
146            // DISTINCT and ORDER BY if all the expressions in the ORDER BY appear
147            // also in the arguments of the function. This implies that if the
148            // aggregation function only accepts one argument, only one argument
149            // can be used in the ORDER BY, For example:
150            //
151            // ARRAY_AGG(DISTINCT col)
152            //
153            // can only be mixed with an ORDER BY if the order expression is "col".
154            //
155            // ARRAY_AGG(DISTINCT col ORDER BY col)                         <- Valid
156            // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid
157            // ARRAY_AGG(DISTINCT col ORDER BY other_col)                   <- Invalid
158            // ARRAY_AGG(DISTINCT col ORDER BY concat(col, ''))             <- Invalid
159            if acc_args.ordering_req.len() > 1 {
160                return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list");
161            }
162            let mut sort_option: Option<SortOptions> = None;
163            if let Some(order) = acc_args.ordering_req.first() {
164                if !order.expr.eq(&acc_args.exprs[0]) {
165                    return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list");
166                }
167                sort_option = Some(order.options)
168            }
169            return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
170                &data_type,
171                sort_option,
172            )?));
173        }
174
175        if acc_args.ordering_req.is_empty() {
176            return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?));
177        }
178
179        let ordering_dtypes = acc_args
180            .ordering_req
181            .iter()
182            .map(|e| e.expr.data_type(acc_args.schema))
183            .collect::<Result<Vec<_>>>()?;
184
185        OrderSensitiveArrayAggAccumulator::try_new(
186            &data_type,
187            &ordering_dtypes,
188            acc_args.ordering_req.clone(),
189            acc_args.is_reversed,
190        )
191        .map(|acc| Box::new(acc) as _)
192    }
193
194    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
195        datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
196    }
197
198    fn documentation(&self) -> Option<&Documentation> {
199        self.doc()
200    }
201}
202
203#[derive(Debug)]
204pub struct ArrayAggAccumulator {
205    values: Vec<ArrayRef>,
206    datatype: DataType,
207}
208
209impl ArrayAggAccumulator {
210    /// new array_agg accumulator based on given item data type
211    pub fn try_new(datatype: &DataType) -> Result<Self> {
212        Ok(Self {
213            values: vec![],
214            datatype: datatype.clone(),
215        })
216    }
217
218    /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
219    /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
220    fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
221        let offsets = list_array.value_offsets();
222        // Offsets always have at least 1 value
223        let initial_offset = offsets[0];
224        let null_count = list_array.null_count();
225
226        // If no nulls than just use the fast path
227        // This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
228        if null_count == 0 {
229            // According to Arrow specification, the first offset can be non-zero
230            let list_values = list_array.values().slice(
231                initial_offset as usize,
232                (offsets[offsets.len() - 1] - initial_offset) as usize,
233            );
234            return Some(list_values);
235        }
236
237        // If all the values are null than just return an empty values array
238        if list_array.null_count() == list_array.len() {
239            return Some(list_array.values().slice(0, 0));
240        }
241
242        // According to the Arrow spec, null values can point to non empty lists
243        // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
244
245        // Unwrapping is safe as we just checked if there is a null value
246        let nulls = list_array.nulls().unwrap();
247
248        let mut valid_slices_iter = nulls.valid_slices();
249
250        // This is safe as we validated that that are at least 1 valid value in the array
251        let (start, end) = valid_slices_iter.next().unwrap();
252
253        let start_offset = offsets[start];
254
255        // End is exclusive, so it already point to the last offset value
256        // This is valid as the length of the array is always 1 less than the length of the offsets
257        let mut end_offset_of_last_valid_value = offsets[end];
258
259        for (start, end) in valid_slices_iter {
260            // If there is a null value that point to a non empty list than the start offset of the valid value
261            // will be different that the end offset of the last valid value
262            if offsets[start] != end_offset_of_last_valid_value {
263                return None;
264            }
265
266            // End is exclusive, so it already point to the last offset value
267            // This is valid as the length of the array is always 1 less than the length of the offsets
268            end_offset_of_last_valid_value = offsets[end];
269        }
270
271        let consecutive_valid_values = list_array.values().slice(
272            start_offset as usize,
273            (end_offset_of_last_valid_value - start_offset) as usize,
274        );
275
276        Some(consecutive_valid_values)
277    }
278}
279
280impl Accumulator for ArrayAggAccumulator {
281    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
282        // Append value like Int64Array(1,2,3)
283        if values.is_empty() {
284            return Ok(());
285        }
286
287        if values.len() != 1 {
288            return internal_err!("expects single batch");
289        }
290
291        let val = Arc::clone(&values[0]);
292        if !val.is_empty() {
293            self.values.push(val);
294        }
295        Ok(())
296    }
297
298    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
299        // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
300        if states.is_empty() {
301            return Ok(());
302        }
303
304        if states.len() != 1 {
305            return internal_err!("expects single state");
306        }
307
308        let list_arr = as_list_array(&states[0])?;
309
310        match Self::get_optional_values_to_merge_as_is(list_arr) {
311            Some(values) => {
312                // Make sure we don't insert empty lists
313                if !values.is_empty() {
314                    self.values.push(values);
315                }
316            }
317            None => {
318                for arr in list_arr.iter().flatten() {
319                    self.values.push(arr);
320                }
321            }
322        }
323
324        Ok(())
325    }
326
327    fn state(&mut self) -> Result<Vec<ScalarValue>> {
328        Ok(vec![self.evaluate()?])
329    }
330
331    fn evaluate(&mut self) -> Result<ScalarValue> {
332        // Transform Vec<ListArr> to ListArr
333        let element_arrays: Vec<&dyn Array> =
334            self.values.iter().map(|a| a.as_ref()).collect();
335
336        if element_arrays.is_empty() {
337            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
338        }
339
340        let concated_array = arrow::compute::concat(&element_arrays)?;
341
342        Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
343    }
344
345    fn size(&self) -> usize {
346        size_of_val(self)
347            + (size_of::<ArrayRef>() * self.values.capacity())
348            + self
349                .values
350                .iter()
351                .map(|arr| arr.get_array_memory_size())
352                .sum::<usize>()
353            + self.datatype.size()
354            - size_of_val(&self.datatype)
355    }
356}
357
358#[derive(Debug)]
359struct DistinctArrayAggAccumulator {
360    values: HashSet<ScalarValue>,
361    datatype: DataType,
362    sort_options: Option<SortOptions>,
363}
364
365impl DistinctArrayAggAccumulator {
366    pub fn try_new(
367        datatype: &DataType,
368        sort_options: Option<SortOptions>,
369    ) -> Result<Self> {
370        Ok(Self {
371            values: HashSet::new(),
372            datatype: datatype.clone(),
373            sort_options,
374        })
375    }
376}
377
378impl Accumulator for DistinctArrayAggAccumulator {
379    fn state(&mut self) -> Result<Vec<ScalarValue>> {
380        Ok(vec![self.evaluate()?])
381    }
382
383    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
384        if values.is_empty() {
385            return Ok(());
386        }
387
388        let array = &values[0];
389
390        for i in 0..array.len() {
391            let scalar = ScalarValue::try_from_array(&array, i)?;
392            self.values.insert(scalar);
393        }
394
395        Ok(())
396    }
397
398    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
399        if states.is_empty() {
400            return Ok(());
401        }
402
403        if states.len() != 1 {
404            return internal_err!("expects single state");
405        }
406
407        states[0]
408            .as_list::<i32>()
409            .iter()
410            .flatten()
411            .try_for_each(|val| self.update_batch(&[val]))
412    }
413
414    fn evaluate(&mut self) -> Result<ScalarValue> {
415        let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
416        if values.is_empty() {
417            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
418        }
419
420        if let Some(opts) = self.sort_options {
421            values.sort_by(|a, b| {
422                if a.is_null() {
423                    return match opts.nulls_first {
424                        true => Ordering::Less,
425                        false => Ordering::Greater,
426                    };
427                }
428                if b.is_null() {
429                    return match opts.nulls_first {
430                        true => Ordering::Greater,
431                        false => Ordering::Less,
432                    };
433                }
434                match opts.descending {
435                    true => b.partial_cmp(a).unwrap_or(Ordering::Equal),
436                    false => a.partial_cmp(b).unwrap_or(Ordering::Equal),
437                }
438            });
439        };
440
441        let arr = ScalarValue::new_list(&values, &self.datatype, true);
442        Ok(ScalarValue::List(arr))
443    }
444
445    fn size(&self) -> usize {
446        size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
447            - size_of_val(&self.values)
448            + self.datatype.size()
449            - size_of_val(&self.datatype)
450            - size_of_val(&self.sort_options)
451            + size_of::<Option<SortOptions>>()
452    }
453}
454
455/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi
456/// partition setting, partial aggregations are computed for every partition,
457/// and then their results are merged.
458#[derive(Debug)]
459pub(crate) struct OrderSensitiveArrayAggAccumulator {
460    /// Stores entries in the `ARRAY_AGG` result.
461    values: Vec<ScalarValue>,
462    /// Stores values of ordering requirement expressions corresponding to each
463    /// entry in `values`. This information is used when merging results from
464    /// different partitions. For detailed information how merging is done, see
465    /// [`merge_ordered_arrays`].
466    ordering_values: Vec<Vec<ScalarValue>>,
467    /// Stores datatypes of expressions inside values and ordering requirement
468    /// expressions.
469    datatypes: Vec<DataType>,
470    /// Stores the ordering requirement of the `Accumulator`.
471    ordering_req: LexOrdering,
472    /// Whether the aggregation is running in reverse.
473    reverse: bool,
474}
475
476impl OrderSensitiveArrayAggAccumulator {
477    /// Create a new order-sensitive ARRAY_AGG accumulator based on the given
478    /// item data type.
479    pub fn try_new(
480        datatype: &DataType,
481        ordering_dtypes: &[DataType],
482        ordering_req: LexOrdering,
483        reverse: bool,
484    ) -> Result<Self> {
485        let mut datatypes = vec![datatype.clone()];
486        datatypes.extend(ordering_dtypes.iter().cloned());
487        Ok(Self {
488            values: vec![],
489            ordering_values: vec![],
490            datatypes,
491            ordering_req,
492            reverse,
493        })
494    }
495}
496
497impl Accumulator for OrderSensitiveArrayAggAccumulator {
498    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
499        if values.is_empty() {
500            return Ok(());
501        }
502
503        let n_row = values[0].len();
504        for index in 0..n_row {
505            let row = get_row_at_idx(values, index)?;
506            self.values.push(row[0].clone());
507            self.ordering_values.push(row[1..].to_vec());
508        }
509
510        Ok(())
511    }
512
513    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
514        if states.is_empty() {
515            return Ok(());
516        }
517
518        // First entry in the state is the aggregation result. Second entry
519        // stores values received for ordering requirement columns for each
520        // aggregation value inside `ARRAY_AGG` list. For each `StructArray`
521        // inside `ARRAY_AGG` list, we will receive an `Array` that stores values
522        // received from its ordering requirement expression. (This information
523        // is necessary for during merging).
524        let [array_agg_values, agg_orderings, ..] = &states else {
525            return exec_err!("State should have two elements");
526        };
527        let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
528            return exec_err!("Expects to receive a list array");
529        };
530
531        // Stores ARRAY_AGG results coming from each partition
532        let mut partition_values = vec![];
533        // Stores ordering requirement expression results coming from each partition
534        let mut partition_ordering_values = vec![];
535
536        // Existing values should be merged also.
537        partition_values.push(self.values.clone().into());
538        partition_ordering_values.push(self.ordering_values.clone().into());
539
540        // Convert array to Scalars to sort them easily. Convert back to array at evaluation.
541        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
542        for v in array_agg_res.into_iter() {
543            partition_values.push(v.into());
544        }
545
546        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
547
548        for partition_ordering_rows in orderings.into_iter() {
549            // Extract value from struct to ordering_rows for each group/partition
550            let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
551                    if let ScalarValue::Struct(s) = ordering_row {
552                        let mut ordering_columns_per_row = vec![];
553
554                        for column in s.columns() {
555                            let sv = ScalarValue::try_from_array(column, 0)?;
556                            ordering_columns_per_row.push(sv);
557                        }
558
559                        Ok(ordering_columns_per_row)
560                    } else {
561                        exec_err!(
562                            "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
563                            ordering_row.data_type()
564                        )
565                    }
566                }).collect::<Result<VecDeque<_>>>()?;
567
568            partition_ordering_values.push(ordering_value);
569        }
570
571        let sort_options = self
572            .ordering_req
573            .iter()
574            .map(|sort_expr| sort_expr.options)
575            .collect::<Vec<_>>();
576
577        (self.values, self.ordering_values) = merge_ordered_arrays(
578            &mut partition_values,
579            &mut partition_ordering_values,
580            &sort_options,
581        )?;
582
583        Ok(())
584    }
585
586    fn state(&mut self) -> Result<Vec<ScalarValue>> {
587        let mut result = vec![self.evaluate()?];
588        result.push(self.evaluate_orderings()?);
589
590        Ok(result)
591    }
592
593    fn evaluate(&mut self) -> Result<ScalarValue> {
594        if self.values.is_empty() {
595            return Ok(ScalarValue::new_null_list(
596                self.datatypes[0].clone(),
597                true,
598                1,
599            ));
600        }
601
602        let values = self.values.clone();
603        let array = if self.reverse {
604            ScalarValue::new_list_from_iter(
605                values.into_iter().rev(),
606                &self.datatypes[0],
607                true,
608            )
609        } else {
610            ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
611        };
612        Ok(ScalarValue::List(array))
613    }
614
615    fn size(&self) -> usize {
616        let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
617            - size_of_val(&self.values);
618
619        // Add size of the `self.ordering_values`
620        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
621        for row in &self.ordering_values {
622            total += ScalarValue::size_of_vec(row) - size_of_val(row);
623        }
624
625        // Add size of the `self.datatypes`
626        total += size_of::<DataType>() * self.datatypes.capacity();
627        for dtype in &self.datatypes {
628            total += dtype.size() - size_of_val(dtype);
629        }
630
631        // Add size of the `self.ordering_req`
632        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
633        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
634        total
635    }
636}
637
638impl OrderSensitiveArrayAggAccumulator {
639    fn evaluate_orderings(&self) -> Result<ScalarValue> {
640        let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]);
641        let num_columns = fields.len();
642        let struct_field = Fields::from(fields.clone());
643
644        let mut column_wise_ordering_values = vec![];
645        for i in 0..num_columns {
646            let column_values = self
647                .ordering_values
648                .iter()
649                .map(|x| x[i].clone())
650                .collect::<Vec<_>>();
651            let array = if column_values.is_empty() {
652                new_empty_array(fields[i].data_type())
653            } else {
654                ScalarValue::iter_to_array(column_values.into_iter())?
655            };
656            column_wise_ordering_values.push(array);
657        }
658
659        let ordering_array =
660            StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
661        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
662    }
663}
664
665#[cfg(test)]
666mod tests {
667    use super::*;
668    use arrow::datatypes::{FieldRef, Schema};
669    use datafusion_common::cast::as_generic_string_array;
670    use datafusion_common::internal_err;
671    use datafusion_physical_expr::expressions::Column;
672    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
673    use std::sync::Arc;
674
675    #[test]
676    fn no_duplicates_no_distinct() -> Result<()> {
677        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
678
679        acc1.update_batch(&[data(["a", "b", "c"])])?;
680        acc2.update_batch(&[data(["d", "e", "f"])])?;
681        acc1 = merge(acc1, acc2)?;
682
683        let result = print_nulls(str_arr(acc1.evaluate()?)?);
684
685        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
686
687        Ok(())
688    }
689
690    #[test]
691    fn no_duplicates_distinct() -> Result<()> {
692        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
693            .distinct()
694            .build_two()?;
695
696        acc1.update_batch(&[data(["a", "b", "c"])])?;
697        acc2.update_batch(&[data(["d", "e", "f"])])?;
698        acc1 = merge(acc1, acc2)?;
699
700        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
701        result.sort();
702
703        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
704
705        Ok(())
706    }
707
708    #[test]
709    fn duplicates_no_distinct() -> Result<()> {
710        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
711
712        acc1.update_batch(&[data(["a", "b", "c"])])?;
713        acc2.update_batch(&[data(["a", "b", "c"])])?;
714        acc1 = merge(acc1, acc2)?;
715
716        let result = print_nulls(str_arr(acc1.evaluate()?)?);
717
718        assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
719
720        Ok(())
721    }
722
723    #[test]
724    fn duplicates_distinct() -> Result<()> {
725        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
726            .distinct()
727            .build_two()?;
728
729        acc1.update_batch(&[data(["a", "b", "c"])])?;
730        acc2.update_batch(&[data(["a", "b", "c"])])?;
731        acc1 = merge(acc1, acc2)?;
732
733        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
734        result.sort();
735
736        assert_eq!(result, vec!["a", "b", "c"]);
737
738        Ok(())
739    }
740
741    #[test]
742    fn duplicates_on_second_batch_distinct() -> Result<()> {
743        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
744            .distinct()
745            .build_two()?;
746
747        acc1.update_batch(&[data(["a", "c"])])?;
748        acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
749        acc1 = merge(acc1, acc2)?;
750
751        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
752        result.sort();
753
754        assert_eq!(result, vec!["a", "b", "c", "d"]);
755
756        Ok(())
757    }
758
759    #[test]
760    fn no_duplicates_distinct_sort_asc() -> Result<()> {
761        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
762            .distinct()
763            .order_by_col("col", SortOptions::new(false, false))
764            .build_two()?;
765
766        acc1.update_batch(&[data(["e", "b", "d"])])?;
767        acc2.update_batch(&[data(["f", "a", "c"])])?;
768        acc1 = merge(acc1, acc2)?;
769
770        let result = print_nulls(str_arr(acc1.evaluate()?)?);
771
772        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
773
774        Ok(())
775    }
776
777    #[test]
778    fn no_duplicates_distinct_sort_desc() -> Result<()> {
779        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
780            .distinct()
781            .order_by_col("col", SortOptions::new(true, false))
782            .build_two()?;
783
784        acc1.update_batch(&[data(["e", "b", "d"])])?;
785        acc2.update_batch(&[data(["f", "a", "c"])])?;
786        acc1 = merge(acc1, acc2)?;
787
788        let result = print_nulls(str_arr(acc1.evaluate()?)?);
789
790        assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
791
792        Ok(())
793    }
794
795    #[test]
796    fn duplicates_distinct_sort_asc() -> Result<()> {
797        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
798            .distinct()
799            .order_by_col("col", SortOptions::new(false, false))
800            .build_two()?;
801
802        acc1.update_batch(&[data(["a", "c", "b"])])?;
803        acc2.update_batch(&[data(["b", "c", "a"])])?;
804        acc1 = merge(acc1, acc2)?;
805
806        let result = print_nulls(str_arr(acc1.evaluate()?)?);
807
808        assert_eq!(result, vec!["a", "b", "c"]);
809
810        Ok(())
811    }
812
813    #[test]
814    fn duplicates_distinct_sort_desc() -> Result<()> {
815        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
816            .distinct()
817            .order_by_col("col", SortOptions::new(true, false))
818            .build_two()?;
819
820        acc1.update_batch(&[data(["a", "c", "b"])])?;
821        acc2.update_batch(&[data(["b", "c", "a"])])?;
822        acc1 = merge(acc1, acc2)?;
823
824        let result = print_nulls(str_arr(acc1.evaluate()?)?);
825
826        assert_eq!(result, vec!["c", "b", "a"]);
827
828        Ok(())
829    }
830
831    #[test]
832    fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
833        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
834            .distinct()
835            .order_by_col("col", SortOptions::new(false, true))
836            .build_two()?;
837
838        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
839        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
840        acc1 = merge(acc1, acc2)?;
841
842        let result = print_nulls(str_arr(acc1.evaluate()?)?);
843
844        assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
845
846        Ok(())
847    }
848
849    #[test]
850    fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
851        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
852            .distinct()
853            .order_by_col("col", SortOptions::new(false, false))
854            .build_two()?;
855
856        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
857        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
858        acc1 = merge(acc1, acc2)?;
859
860        let result = print_nulls(str_arr(acc1.evaluate()?)?);
861
862        assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
863
864        Ok(())
865    }
866
867    #[test]
868    fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
869        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
870            .distinct()
871            .order_by_col("col", SortOptions::new(true, true))
872            .build_two()?;
873
874        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
875        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
876        acc1 = merge(acc1, acc2)?;
877
878        let result = print_nulls(str_arr(acc1.evaluate()?)?);
879
880        assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
881
882        Ok(())
883    }
884
885    #[test]
886    fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
887        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
888            .distinct()
889            .order_by_col("col", SortOptions::new(true, false))
890            .build_two()?;
891
892        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
893        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
894        acc1 = merge(acc1, acc2)?;
895
896        let result = print_nulls(str_arr(acc1.evaluate()?)?);
897
898        assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
899
900        Ok(())
901    }
902
903    #[test]
904    fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
905        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
906            .distinct()
907            .build_two()?;
908
909        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
910        acc2.update_batch(&[data([Some("a"), None, None, None])])?;
911        acc1 = merge(acc1, acc2)?;
912
913        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
914        result.sort();
915        assert_eq!(result, vec!["NULL", "a"]);
916        Ok(())
917    }
918
919    #[test]
920    fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
921        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
922            .distinct()
923            .build_two()?;
924
925        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
926        acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
927        acc1 = merge(acc1, acc2)?;
928
929        let result = print_nulls(str_arr(acc1.evaluate()?)?);
930        assert_eq!(result, vec!["NULL"]);
931        Ok(())
932    }
933
934    struct ArrayAggAccumulatorBuilder {
935        data_type: DataType,
936        distinct: bool,
937        ordering: LexOrdering,
938        schema: Schema,
939    }
940
941    impl ArrayAggAccumulatorBuilder {
942        fn string() -> Self {
943            Self::new(DataType::Utf8)
944        }
945
946        fn new(data_type: DataType) -> Self {
947            Self {
948                data_type: data_type.clone(),
949                distinct: Default::default(),
950                ordering: Default::default(),
951                schema: Schema {
952                    fields: Fields::from(vec![Field::new(
953                        "col",
954                        DataType::List(FieldRef::new(Field::new(
955                            "item", data_type, true,
956                        ))),
957                        true,
958                    )]),
959                    metadata: Default::default(),
960                },
961            }
962        }
963
964        fn distinct(mut self) -> Self {
965            self.distinct = true;
966            self
967        }
968
969        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
970            self.ordering.extend([PhysicalSortExpr::new(
971                Arc::new(
972                    Column::new_with_schema(col, &self.schema)
973                        .expect("column not available in schema"),
974                ),
975                sort_options,
976            )]);
977            self
978        }
979
980        fn build(&self) -> Result<Box<dyn Accumulator>> {
981            ArrayAgg::default().accumulator(AccumulatorArgs {
982                return_type: &self.data_type,
983                schema: &self.schema,
984                ignore_nulls: false,
985                ordering_req: &self.ordering,
986                is_reversed: false,
987                name: "",
988                is_distinct: self.distinct,
989                exprs: &[Arc::new(Column::new("col", 0))],
990            })
991        }
992
993        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
994            Ok((self.build()?, self.build()?))
995        }
996    }
997
998    fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
999        let ScalarValue::List(list) = value else {
1000            return internal_err!("ScalarValue was not a List");
1001        };
1002        Ok(as_generic_string_array::<i32>(list.values())?
1003            .iter()
1004            .map(|v| v.map(|v| v.to_string()))
1005            .collect())
1006    }
1007
1008    fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1009        sort.into_iter()
1010            .map(|v| v.unwrap_or("NULL".to_string()))
1011            .collect()
1012    }
1013
1014    fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1015    where
1016        ScalarValue: From<T>,
1017    {
1018        let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1019        ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1020    }
1021
1022    fn merge(
1023        mut acc1: Box<dyn Accumulator>,
1024        mut acc2: Box<dyn Accumulator>,
1025    ) -> Result<Box<dyn Accumulator>> {
1026        let intermediate_state = acc2.state().and_then(|e| {
1027            e.iter()
1028                .map(|v| v.to_array())
1029                .collect::<Result<Vec<ArrayRef>>>()
1030        })?;
1031        acc1.merge_batch(&intermediate_state)?;
1032        Ok(acc1)
1033    }
1034}