Skip to main content

datafusion_functions_aggregate/
array_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
19
20use std::cmp::Ordering;
21use std::collections::{HashSet, VecDeque};
22use std::mem::{size_of, size_of_val, take};
23use std::sync::Arc;
24
25use arrow::array::{
26    Array, ArrayRef, AsArray, BooleanArray, ListArray, NullBufferBuilder, StructArray,
27    UInt32Array, new_empty_array,
28};
29use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
30use arrow::compute::{SortOptions, filter};
31use arrow::datatypes::{DataType, Field, FieldRef, Fields};
32
33use datafusion_common::cast::as_list_array;
34use datafusion_common::utils::{
35    SingleRowListArrayBuilder, compare_rows, get_row_at_idx, take_function_args,
36};
37use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, exec_err};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::utils::format_state_name;
40use datafusion_expr::{
41    Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature,
42    Volatility,
43};
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filter_to_nulls;
45use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
46use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
47use datafusion_functions_aggregate_common::utils::ordering_fields;
48use datafusion_macros::user_doc;
49use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
50
51make_udaf_expr_and_func!(
52    ArrayAgg,
53    array_agg,
54    expression,
55    "input values, including nulls, concatenated into an array",
56    array_agg_udaf
57);
58
59#[user_doc(
60    doc_section(label = "General Functions"),
61    description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
62This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
63    syntax_example = "array_agg(expression [ORDER BY expression])",
64    sql_example = r#"
65```sql
66> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
67+-----------------------------------------------+
68| array_agg(column_name ORDER BY other_column)  |
69+-----------------------------------------------+
70| [element1, element2, element3]                |
71+-----------------------------------------------+
72> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
73+--------------------------------------------------------+
74| array_agg(DISTINCT column_name ORDER BY column_name)  |
75+--------------------------------------------------------+
76| [element1, element2, element3]                         |
77+--------------------------------------------------------+
78```
79"#,
80    standard_argument(name = "expression",)
81)]
82#[derive(Debug, PartialEq, Eq, Hash)]
83/// ARRAY_AGG aggregate expression
84pub struct ArrayAgg {
85    signature: Signature,
86    is_input_pre_ordered: bool,
87}
88
89impl Default for ArrayAgg {
90    fn default() -> Self {
91        Self {
92            signature: Signature::any(1, Volatility::Immutable),
93            is_input_pre_ordered: false,
94        }
95    }
96}
97
98impl AggregateUDFImpl for ArrayAgg {
99    fn as_any(&self) -> &dyn std::any::Any {
100        self
101    }
102
103    fn name(&self) -> &str {
104        "array_agg"
105    }
106
107    fn signature(&self) -> &Signature {
108        &self.signature
109    }
110
111    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
112        Ok(DataType::List(Arc::new(Field::new_list_field(
113            arg_types[0].clone(),
114            true,
115        ))))
116    }
117
118    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
119        if args.is_distinct {
120            return Ok(vec![
121                Field::new_list(
122                    format_state_name(args.name, "distinct_array_agg"),
123                    // See COMMENTS.md to understand why nullable is set to true
124                    Field::new_list_field(args.input_fields[0].data_type().clone(), true),
125                    true,
126                )
127                .into(),
128            ]);
129        }
130
131        let mut fields = vec![
132            Field::new_list(
133                format_state_name(args.name, "array_agg"),
134                // See COMMENTS.md to understand why nullable is set to true
135                Field::new_list_field(args.input_fields[0].data_type().clone(), true),
136                true,
137            )
138            .into(),
139        ];
140
141        if args.ordering_fields.is_empty() {
142            return Ok(fields);
143        }
144
145        let orderings = args.ordering_fields.to_vec();
146        fields.push(
147            Field::new_list(
148                format_state_name(args.name, "array_agg_orderings"),
149                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
150                false,
151            )
152            .into(),
153        );
154
155        Ok(fields)
156    }
157
158    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
159        AggregateOrderSensitivity::SoftRequirement
160    }
161
162    fn with_beneficial_ordering(
163        self: Arc<Self>,
164        beneficial_ordering: bool,
165    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
166        Ok(Some(Arc::new(Self {
167            signature: self.signature.clone(),
168            is_input_pre_ordered: beneficial_ordering,
169        })))
170    }
171
172    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
173        let field = &acc_args.expr_fields[0];
174        let data_type = field.data_type();
175        let ignore_nulls = acc_args.ignore_nulls && field.is_nullable();
176
177        if acc_args.is_distinct {
178            // Limitation similar to Postgres. The aggregation function can only mix
179            // DISTINCT and ORDER BY if all the expressions in the ORDER BY appear
180            // also in the arguments of the function. This implies that if the
181            // aggregation function only accepts one argument, only one argument
182            // can be used in the ORDER BY, For example:
183            //
184            // ARRAY_AGG(DISTINCT col)
185            //
186            // can only be mixed with an ORDER BY if the order expression is "col".
187            //
188            // ARRAY_AGG(DISTINCT col ORDER BY col)                         <- Valid
189            // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid
190            // ARRAY_AGG(DISTINCT col ORDER BY other_col)                   <- Invalid
191            // ARRAY_AGG(DISTINCT col ORDER BY concat(col, ''))             <- Invalid
192            let sort_option = match acc_args.order_bys {
193                [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options),
194                [] => None,
195                _ => {
196                    return exec_err!(
197                        "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"
198                    );
199                }
200            };
201            return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
202                data_type,
203                sort_option,
204                ignore_nulls,
205            )?));
206        }
207
208        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
209            return Ok(Box::new(ArrayAggAccumulator::try_new(
210                data_type,
211                ignore_nulls,
212            )?));
213        };
214
215        let ordering_dtypes = ordering
216            .iter()
217            .map(|e| e.expr.data_type(acc_args.schema))
218            .collect::<Result<Vec<_>>>()?;
219
220        OrderSensitiveArrayAggAccumulator::try_new(
221            data_type,
222            &ordering_dtypes,
223            ordering,
224            self.is_input_pre_ordered,
225            acc_args.is_reversed,
226            ignore_nulls,
227        )
228        .map(|acc| Box::new(acc) as _)
229    }
230
231    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
232        datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
233    }
234
235    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
236        !args.is_distinct && args.order_bys.is_empty()
237    }
238
239    fn create_groups_accumulator(
240        &self,
241        args: AccumulatorArgs,
242    ) -> Result<Box<dyn GroupsAccumulator>> {
243        let field = &args.expr_fields[0];
244        let data_type = field.data_type().clone();
245        let ignore_nulls = args.ignore_nulls && field.is_nullable();
246        Ok(Box::new(ArrayAggGroupsAccumulator::new(
247            data_type,
248            ignore_nulls,
249        )))
250    }
251
252    fn supports_null_handling_clause(&self) -> bool {
253        true
254    }
255
256    fn documentation(&self) -> Option<&Documentation> {
257        self.doc()
258    }
259}
260
261#[derive(Debug)]
262pub struct ArrayAggAccumulator {
263    values: Vec<ArrayRef>,
264    datatype: DataType,
265    ignore_nulls: bool,
266}
267
268impl ArrayAggAccumulator {
269    /// new array_agg accumulator based on given item data type
270    pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
271        Ok(Self {
272            values: vec![],
273            datatype: datatype.clone(),
274            ignore_nulls,
275        })
276    }
277
278    /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
279    /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
280    fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
281        let offsets = list_array.value_offsets();
282        // Offsets always have at least 1 value
283        let initial_offset = offsets[0];
284        let null_count = list_array.null_count();
285
286        // If no nulls than just use the fast path
287        // This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
288        if null_count == 0 {
289            // According to Arrow specification, the first offset can be non-zero
290            let list_values = list_array.values().slice(
291                initial_offset as usize,
292                (offsets[offsets.len() - 1] - initial_offset) as usize,
293            );
294            return Some(list_values);
295        }
296
297        // If all the values are null than just return an empty values array
298        if list_array.null_count() == list_array.len() {
299            return Some(list_array.values().slice(0, 0));
300        }
301
302        // According to the Arrow spec, null values can point to non-empty lists
303        // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
304
305        // Unwrapping is safe as we just checked if there is a null value
306        let nulls = list_array.nulls().unwrap();
307
308        let mut valid_slices_iter = nulls.valid_slices();
309
310        // This is safe as we validated that there is at least 1 valid value in the array
311        let (start, end) = valid_slices_iter.next().unwrap();
312
313        let start_offset = offsets[start];
314
315        // End is exclusive, so it already point to the last offset value
316        // This is valid as the length of the array is always 1 less than the length of the offsets
317        let mut end_offset_of_last_valid_value = offsets[end];
318
319        for (start, end) in valid_slices_iter {
320            // If there is a null value that point to a non-empty list than the start offset of the valid value
321            // will be different that the end offset of the last valid value
322            if offsets[start] != end_offset_of_last_valid_value {
323                return None;
324            }
325
326            // End is exclusive, so it already point to the last offset value
327            // This is valid as the length of the array is always 1 less than the length of the offsets
328            end_offset_of_last_valid_value = offsets[end];
329        }
330
331        let consecutive_valid_values = list_array.values().slice(
332            start_offset as usize,
333            (end_offset_of_last_valid_value - start_offset) as usize,
334        );
335
336        Some(consecutive_valid_values)
337    }
338}
339
340impl Accumulator for ArrayAggAccumulator {
341    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
342        // Append value like Int64Array(1,2,3)
343        if values.is_empty() {
344            return Ok(());
345        }
346
347        assert_eq_or_internal_err!(values.len(), 1, "expects single batch");
348
349        let val = &values[0];
350        let nulls = if self.ignore_nulls {
351            val.logical_nulls()
352        } else {
353            None
354        };
355
356        let val = match nulls {
357            Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
358            Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
359            None => Arc::clone(val),
360        };
361
362        if !val.is_empty() {
363            self.values.push(val)
364        }
365
366        Ok(())
367    }
368
369    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
370        // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
371        if states.is_empty() {
372            return Ok(());
373        }
374
375        assert_eq_or_internal_err!(states.len(), 1, "expects single state");
376
377        let list_arr = as_list_array(&states[0])?;
378
379        match Self::get_optional_values_to_merge_as_is(list_arr) {
380            Some(values) => {
381                // Make sure we don't insert empty lists
382                if !values.is_empty() {
383                    self.values.push(values);
384                }
385            }
386            None => {
387                for arr in list_arr.iter().flatten() {
388                    self.values.push(arr);
389                }
390            }
391        }
392
393        Ok(())
394    }
395
396    fn state(&mut self) -> Result<Vec<ScalarValue>> {
397        Ok(vec![self.evaluate()?])
398    }
399
400    fn evaluate(&mut self) -> Result<ScalarValue> {
401        // Transform Vec<ListArr> to ListArr
402        let element_arrays: Vec<&dyn Array> =
403            self.values.iter().map(|a| a.as_ref()).collect();
404
405        if element_arrays.is_empty() {
406            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
407        }
408
409        let concated_array = arrow::compute::concat(&element_arrays)?;
410
411        Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
412    }
413
414    fn size(&self) -> usize {
415        size_of_val(self)
416            + (size_of::<ArrayRef>() * self.values.capacity())
417            + self
418                .values
419                .iter()
420                // Each ArrayRef might be just a reference to a bigger array, and many
421                // ArrayRefs here might be referencing exactly the same array, so if we
422                // were to call `arr.get_array_memory_size()`, we would be double-counting
423                // the same underlying data many times.
424                //
425                // Instead, we do an approximation by estimating how much memory each
426                // ArrayRef would occupy if its underlying data was fully owned by this
427                // accumulator.
428                //
429                // Note that this is just an estimation, but the reality is that this
430                // accumulator might not own any data.
431                .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
432                .sum::<usize>()
433            + self.datatype.size()
434            - size_of_val(&self.datatype)
435    }
436}
437
438#[derive(Debug)]
439struct ArrayAggGroupsAccumulator {
440    datatype: DataType,
441    ignore_nulls: bool,
442    /// Source arrays — input arrays (from update_batch) or list backing
443    /// arrays (from merge_batch).
444    batches: Vec<ArrayRef>,
445    /// Per-batch list of (group_idx, row_idx) pairs.
446    batch_entries: Vec<Vec<(u32, u32)>>,
447    /// Total number of groups tracked.
448    num_groups: usize,
449}
450
451impl ArrayAggGroupsAccumulator {
452    fn new(datatype: DataType, ignore_nulls: bool) -> Self {
453        Self {
454            datatype,
455            ignore_nulls,
456            batches: Vec::new(),
457            batch_entries: Vec::new(),
458            num_groups: 0,
459        }
460    }
461
462    fn clear_state(&mut self) {
463        // `size()` measures Vec capacity rather than len, so allocate new
464        // buffers instead of using `clear()`.
465        self.batches = Vec::new();
466        self.batch_entries = Vec::new();
467        self.num_groups = 0;
468    }
469
470    fn compact_retained_state(&mut self, emit_groups: usize) -> Result<()> {
471        // EmitTo::First is used to recover from memory pressure. Simply
472        // removing emitted entries in place is not enough because mixed batches
473        // would continue to pin their original Array arrays, even if only a few
474        // retained rows remain.
475        //
476        // Rebuild the retained state from scratch so fully emitted batches are
477        // dropped, mixed batches are compacted to arrays containing only the
478        // surviving rows, and retained metadata is right-sized.
479        let emit_groups = emit_groups as u32;
480        let old_batches = take(&mut self.batches);
481        let old_batch_entries = take(&mut self.batch_entries);
482
483        let mut batches = Vec::new();
484        let mut batch_entries = Vec::new();
485
486        for (batch, entries) in old_batches.into_iter().zip(old_batch_entries) {
487            let retained_len = entries.iter().filter(|(g, _)| *g >= emit_groups).count();
488
489            if retained_len == 0 {
490                continue;
491            }
492
493            if retained_len == entries.len() {
494                // Nothing was emitted from this batch, so we keep the existing
495                // array and only renumber the remaining group IDs so that they
496                // start from 0.
497                let mut retained_entries = entries;
498                for (g, _) in &mut retained_entries {
499                    *g -= emit_groups;
500                }
501                retained_entries.shrink_to_fit();
502                batches.push(batch);
503                batch_entries.push(retained_entries);
504                continue;
505            }
506
507            let mut retained_entries = Vec::with_capacity(retained_len);
508            let mut retained_rows = Vec::with_capacity(retained_len);
509
510            for (g, r) in entries {
511                if g >= emit_groups {
512                    // Compute the new `(group_idx, row_idx)` pair for a
513                    // retained row. `group_idx` is renumbered to start from
514                    // 0, and `row_idx` points into the new dense batch we are
515                    // building.
516                    retained_entries.push((g - emit_groups, retained_rows.len() as u32));
517                    retained_rows.push(r);
518                }
519            }
520
521            debug_assert_eq!(retained_entries.len(), retained_len);
522            debug_assert_eq!(retained_rows.len(), retained_len);
523
524            let batch = if retained_len == batch.len() {
525                batch
526            } else {
527                // Compact mixed batches so retained rows no longer pin the
528                // original array.
529                let retained_rows = UInt32Array::from(retained_rows);
530                arrow::compute::take(batch.as_ref(), &retained_rows, None)?
531            };
532
533            batches.push(batch);
534            batch_entries.push(retained_entries);
535        }
536
537        self.batches = batches;
538        self.batch_entries = batch_entries;
539        self.num_groups -= emit_groups as usize;
540
541        Ok(())
542    }
543}
544
545impl GroupsAccumulator for ArrayAggGroupsAccumulator {
546    /// Store a reference to the input batch, plus a `(group_idx, row_idx)` pair
547    /// for every row.
548    fn update_batch(
549        &mut self,
550        values: &[ArrayRef],
551        group_indices: &[usize],
552        opt_filter: Option<&BooleanArray>,
553        total_num_groups: usize,
554    ) -> Result<()> {
555        assert_eq!(values.len(), 1, "single argument to update_batch");
556        let input = &values[0];
557
558        self.num_groups = self.num_groups.max(total_num_groups);
559
560        let nulls = if self.ignore_nulls {
561            input.logical_nulls()
562        } else {
563            None
564        };
565
566        let mut entries = Vec::new();
567
568        for (row_idx, &group_idx) in group_indices.iter().enumerate() {
569            // Skip filtered rows
570            if let Some(filter) = opt_filter
571                && (filter.is_null(row_idx) || !filter.value(row_idx))
572            {
573                continue;
574            }
575
576            // Skip null values when ignore_nulls is set
577            if let Some(ref nulls) = nulls
578                && nulls.is_null(row_idx)
579            {
580                continue;
581            }
582
583            entries.push((group_idx as u32, row_idx as u32));
584        }
585
586        // We only need to record the batch if it was non-empty.
587        if !entries.is_empty() {
588            self.batches.push(Arc::clone(input));
589            self.batch_entries.push(entries);
590        }
591
592        Ok(())
593    }
594
595    /// Produce a `ListArray` ordered by group index: the list at
596    /// position N contains the aggregated values for group N.
597    ///
598    /// Uses a counting sort to rearrange the stored `(group, row)`
599    /// entries into group order, then calls `interleave` to gather
600    /// the values into a flat array that backs the output `ListArray`.
601    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
602        let emit_groups = match emit_to {
603            EmitTo::All => self.num_groups,
604            EmitTo::First(n) => n,
605        };
606
607        // Step 1: Count entries per group. For EmitTo::First(n), only groups
608        // 0..n are counted; the rest are retained to be emitted in the future.
609        let mut counts = vec![0u32; emit_groups];
610        for entries in &self.batch_entries {
611            for &(g, _) in entries {
612                let g = g as usize;
613                if g < emit_groups {
614                    counts[g] += 1;
615                }
616            }
617        }
618
619        // Step 2: Do a prefix sum over the counts and use it to build ListArray
620        // offsets, null buffer, and write positions for the counting sort.
621        let mut offsets = Vec::<i32>::with_capacity(emit_groups + 1);
622        offsets.push(0);
623        let mut nulls_builder = NullBufferBuilder::new(emit_groups);
624        let mut write_positions = Vec::with_capacity(emit_groups);
625        let mut cur_offset = 0u32;
626        for &count in &counts {
627            if count == 0 {
628                nulls_builder.append_null();
629            } else {
630                nulls_builder.append_non_null();
631            }
632            write_positions.push(cur_offset);
633            cur_offset += count;
634            offsets.push(cur_offset as i32);
635        }
636        let total_rows = cur_offset as usize;
637
638        // Step 3: Scatter entries into group order using the counting sort. The
639        // batch index is implicit from the outer loop position.
640        let flat_values = if total_rows == 0 {
641            new_empty_array(&self.datatype)
642        } else {
643            let mut interleave_indices = vec![(0usize, 0usize); total_rows];
644            for (batch_idx, entries) in self.batch_entries.iter().enumerate() {
645                for &(g, r) in entries {
646                    let g = g as usize;
647                    if g < emit_groups {
648                        let wp = write_positions[g] as usize;
649                        interleave_indices[wp] = (batch_idx, r as usize);
650                        write_positions[g] += 1;
651                    }
652                }
653            }
654
655            let sources: Vec<&dyn Array> =
656                self.batches.iter().map(|b| b.as_ref()).collect();
657            arrow::compute::interleave(&sources, &interleave_indices)?
658        };
659
660        // Step 4: Release state for emitted groups.
661        match emit_to {
662            EmitTo::All => self.clear_state(),
663            EmitTo::First(_) => self.compact_retained_state(emit_groups)?,
664        }
665
666        let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
667        let field = Arc::new(Field::new_list_field(self.datatype.clone(), true));
668        let result = ListArray::new(field, offsets, flat_values, nulls_builder.finish());
669
670        Ok(Arc::new(result))
671    }
672
673    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
674        Ok(vec![self.evaluate(emit_to)?])
675    }
676
677    fn merge_batch(
678        &mut self,
679        values: &[ArrayRef],
680        group_indices: &[usize],
681        _opt_filter: Option<&BooleanArray>,
682        total_num_groups: usize,
683    ) -> Result<()> {
684        assert_eq!(values.len(), 1, "one argument to merge_batch");
685        let input_list = values[0].as_list::<i32>();
686
687        self.num_groups = self.num_groups.max(total_num_groups);
688
689        // Push the ListArray's backing values array as a single batch.
690        let list_values = input_list.values();
691        let list_offsets = input_list.offsets();
692
693        let mut entries = Vec::new();
694
695        for (row_idx, &group_idx) in group_indices.iter().enumerate() {
696            if input_list.is_null(row_idx) {
697                continue;
698            }
699            let start = list_offsets[row_idx] as u32;
700            let end = list_offsets[row_idx + 1] as u32;
701            for pos in start..end {
702                entries.push((group_idx as u32, pos));
703            }
704        }
705
706        if !entries.is_empty() {
707            self.batches.push(Arc::clone(list_values));
708            self.batch_entries.push(entries);
709        }
710
711        Ok(())
712    }
713
714    fn convert_to_state(
715        &self,
716        values: &[ArrayRef],
717        opt_filter: Option<&BooleanArray>,
718    ) -> Result<Vec<ArrayRef>> {
719        assert_eq!(values.len(), 1, "one argument to convert_to_state");
720
721        let input = &values[0];
722
723        // Each row becomes a 1-element list: offsets are [0, 1, 2, ..., n].
724        let offsets = OffsetBuffer::from_repeated_length(1, input.len());
725
726        // Filtered rows become null list entries, which merge_batch will skip.
727        let filter_nulls = opt_filter.and_then(filter_to_nulls);
728
729        // With ignore_nulls, null values also become null list entries. Without
730        // ignore_nulls, null values stay as [NULL] so merge_batch retains them.
731        let nulls = if self.ignore_nulls {
732            let logical = input.logical_nulls();
733            NullBuffer::union(filter_nulls.as_ref(), logical.as_ref())
734        } else {
735            filter_nulls
736        };
737
738        let field = Arc::new(Field::new_list_field(self.datatype.clone(), true));
739        let list_array = ListArray::new(field, offsets, Arc::clone(input), nulls);
740
741        Ok(vec![Arc::new(list_array)])
742    }
743
744    fn supports_convert_to_state(&self) -> bool {
745        true
746    }
747
748    fn size(&self) -> usize {
749        self.batches
750            .iter()
751            .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
752            .sum::<usize>()
753            + self.batches.capacity() * size_of::<ArrayRef>()
754            + self
755                .batch_entries
756                .iter()
757                .map(|e| e.capacity() * size_of::<(u32, u32)>())
758                .sum::<usize>()
759            + self.batch_entries.capacity() * size_of::<Vec<(u32, u32)>>()
760    }
761}
762
763#[derive(Debug)]
764pub struct DistinctArrayAggAccumulator {
765    values: HashSet<ScalarValue>,
766    datatype: DataType,
767    sort_options: Option<SortOptions>,
768    ignore_nulls: bool,
769}
770
771impl DistinctArrayAggAccumulator {
772    pub fn try_new(
773        datatype: &DataType,
774        sort_options: Option<SortOptions>,
775        ignore_nulls: bool,
776    ) -> Result<Self> {
777        Ok(Self {
778            values: HashSet::new(),
779            datatype: datatype.clone(),
780            sort_options,
781            ignore_nulls,
782        })
783    }
784}
785
786impl Accumulator for DistinctArrayAggAccumulator {
787    fn state(&mut self) -> Result<Vec<ScalarValue>> {
788        Ok(vec![self.evaluate()?])
789    }
790
791    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
792        if values.is_empty() {
793            return Ok(());
794        }
795
796        let val = &values[0];
797        let nulls = if self.ignore_nulls {
798            val.logical_nulls()
799        } else {
800            None
801        };
802
803        let nulls = nulls.as_ref();
804        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
805            for i in 0..val.len() {
806                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
807                    self.values
808                        .insert(ScalarValue::try_from_array(val, i)?.compacted());
809                }
810            }
811        }
812
813        Ok(())
814    }
815
816    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
817        if states.is_empty() {
818            return Ok(());
819        }
820
821        assert_eq_or_internal_err!(states.len(), 1, "expects single state");
822
823        states[0]
824            .as_list::<i32>()
825            .iter()
826            .flatten()
827            .try_for_each(|val| self.update_batch(&[val]))
828    }
829
830    fn evaluate(&mut self) -> Result<ScalarValue> {
831        let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
832        if values.is_empty() {
833            return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
834        }
835
836        if let Some(opts) = self.sort_options {
837            let mut delayed_cmp_err = Ok(());
838            values.sort_by(|a, b| {
839                if a.is_null() {
840                    return match opts.nulls_first {
841                        true => Ordering::Less,
842                        false => Ordering::Greater,
843                    };
844                }
845                if b.is_null() {
846                    return match opts.nulls_first {
847                        true => Ordering::Greater,
848                        false => Ordering::Less,
849                    };
850                }
851                match opts.descending {
852                    true => b.try_cmp(a),
853                    false => a.try_cmp(b),
854                }
855                .unwrap_or_else(|err| {
856                    delayed_cmp_err = Err(err);
857                    Ordering::Equal
858                })
859            });
860            delayed_cmp_err?;
861        };
862
863        let arr = ScalarValue::new_list(&values, &self.datatype, true);
864        Ok(ScalarValue::List(arr))
865    }
866
867    fn size(&self) -> usize {
868        size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
869            - size_of_val(&self.values)
870            + self.datatype.size()
871            - size_of_val(&self.datatype)
872            - size_of_val(&self.sort_options)
873            + size_of::<Option<SortOptions>>()
874    }
875}
876
877/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi
878/// partition setting, partial aggregations are computed for every partition,
879/// and then their results are merged.
880#[derive(Debug)]
881pub(crate) struct OrderSensitiveArrayAggAccumulator {
882    /// Stores entries in the `ARRAY_AGG` result.
883    values: Vec<ScalarValue>,
884    /// Stores values of ordering requirement expressions corresponding to each
885    /// entry in `values`. This information is used when merging results from
886    /// different partitions. For detailed information how merging is done, see
887    /// [`merge_ordered_arrays`].
888    ordering_values: Vec<Vec<ScalarValue>>,
889    /// Stores datatypes of expressions inside values and ordering requirement
890    /// expressions.
891    datatypes: Vec<DataType>,
892    /// Stores the ordering requirement of the `Accumulator`.
893    ordering_req: LexOrdering,
894    /// Whether the input is known to be pre-ordered
895    is_input_pre_ordered: bool,
896    /// Whether the aggregation is running in reverse.
897    reverse: bool,
898    /// Whether the aggregation should ignore null values.
899    ignore_nulls: bool,
900}
901
902impl OrderSensitiveArrayAggAccumulator {
903    /// Create a new order-sensitive ARRAY_AGG accumulator based on the given
904    /// item data type.
905    pub fn try_new(
906        datatype: &DataType,
907        ordering_dtypes: &[DataType],
908        ordering_req: LexOrdering,
909        is_input_pre_ordered: bool,
910        reverse: bool,
911        ignore_nulls: bool,
912    ) -> Result<Self> {
913        let mut datatypes = vec![datatype.clone()];
914        datatypes.extend(ordering_dtypes.iter().cloned());
915        Ok(Self {
916            values: vec![],
917            ordering_values: vec![],
918            datatypes,
919            ordering_req,
920            is_input_pre_ordered,
921            reverse,
922            ignore_nulls,
923        })
924    }
925
926    fn sort(&mut self) {
927        let sort_options = self
928            .ordering_req
929            .iter()
930            .map(|sort_expr| sort_expr.options)
931            .collect::<Vec<_>>();
932        let mut values = take(&mut self.values)
933            .into_iter()
934            .zip(take(&mut self.ordering_values))
935            .collect::<Vec<_>>();
936        let mut delayed_cmp_err = Ok(());
937        values.sort_by(|(_, left_ordering), (_, right_ordering)| {
938            compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
939                |err| {
940                    delayed_cmp_err = Err(err);
941                    Ordering::Equal
942                },
943            )
944        });
945        (self.values, self.ordering_values) = values.into_iter().unzip();
946    }
947
948    fn evaluate_orderings(&self) -> Result<ScalarValue> {
949        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
950
951        let column_wise_ordering_values = if self.ordering_values.is_empty() {
952            fields
953                .iter()
954                .map(|f| new_empty_array(f.data_type()))
955                .collect::<Vec<_>>()
956        } else {
957            (0..fields.len())
958                .map(|i| {
959                    let column_values = self.ordering_values.iter().map(|x| x[i].clone());
960                    ScalarValue::iter_to_array(column_values)
961                })
962                .collect::<Result<_>>()?
963        };
964
965        let ordering_array = StructArray::try_new(
966            Fields::from(fields),
967            column_wise_ordering_values,
968            None,
969        )?;
970        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
971    }
972}
973
974impl Accumulator for OrderSensitiveArrayAggAccumulator {
975    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
976        if values.is_empty() {
977            return Ok(());
978        }
979
980        let val = &values[0];
981        let ord = &values[1..];
982        let nulls = if self.ignore_nulls {
983            val.logical_nulls()
984        } else {
985            None
986        };
987
988        let nulls = nulls.as_ref();
989        if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
990            for i in 0..val.len() {
991                if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
992                    self.values
993                        .push(ScalarValue::try_from_array(val, i)?.compacted());
994                    self.ordering_values.push(
995                        get_row_at_idx(ord, i)?
996                            .into_iter()
997                            .map(|v| v.compacted())
998                            .collect(),
999                    )
1000                }
1001            }
1002        }
1003
1004        Ok(())
1005    }
1006
1007    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1008        if states.is_empty() {
1009            return Ok(());
1010        }
1011
1012        // First entry in the state is the aggregation result. Second entry
1013        // stores values received for ordering requirement columns for each
1014        // aggregation value inside `ARRAY_AGG` list. For each `StructArray`
1015        // inside `ARRAY_AGG` list, we will receive an `Array` that stores values
1016        // received from its ordering requirement expression. (This information
1017        // is necessary for during merging).
1018        let [array_agg_values, agg_orderings] =
1019            take_function_args("OrderSensitiveArrayAggAccumulator::merge_batch", states)?;
1020        let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
1021            return exec_err!("Expects to receive a list array");
1022        };
1023
1024        // Stores ARRAY_AGG results coming from each partition
1025        let mut partition_values = vec![];
1026        // Stores ordering requirement expression results coming from each partition
1027        let mut partition_ordering_values = vec![];
1028
1029        // Existing values should be merged also.
1030        if !self.is_input_pre_ordered {
1031            self.sort();
1032        }
1033        partition_values.push(take(&mut self.values).into());
1034        partition_ordering_values.push(take(&mut self.ordering_values).into());
1035
1036        // Convert array to Scalars to sort them easily. Convert back to array at evaluation.
1037        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
1038        for maybe_v in array_agg_res.into_iter() {
1039            if let Some(v) = maybe_v {
1040                partition_values.push(v.into());
1041            } else {
1042                partition_values.push(vec![].into());
1043            }
1044        }
1045
1046        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
1047        for partition_ordering_rows in orderings.into_iter().flatten() {
1048            // Extract value from struct to ordering_rows for each group/partition
1049            let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
1050                    if let ScalarValue::Struct(s) = ordering_row {
1051                        let mut ordering_columns_per_row = vec![];
1052
1053                        for column in s.columns() {
1054                            let sv = ScalarValue::try_from_array(column, 0)?;
1055                            ordering_columns_per_row.push(sv);
1056                        }
1057
1058                        Ok(ordering_columns_per_row)
1059                    } else {
1060                        exec_err!(
1061                            "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
1062                            ordering_row.data_type()
1063                        )
1064                    }
1065                }).collect::<Result<VecDeque<_>>>()?;
1066
1067            partition_ordering_values.push(ordering_value);
1068        }
1069
1070        let sort_options = self
1071            .ordering_req
1072            .iter()
1073            .map(|sort_expr| sort_expr.options)
1074            .collect::<Vec<_>>();
1075
1076        (self.values, self.ordering_values) = merge_ordered_arrays(
1077            &mut partition_values,
1078            &mut partition_ordering_values,
1079            &sort_options,
1080        )?;
1081
1082        Ok(())
1083    }
1084
1085    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1086        if !self.is_input_pre_ordered {
1087            self.sort();
1088        }
1089
1090        let mut result = vec![self.evaluate()?];
1091        result.push(self.evaluate_orderings()?);
1092
1093        Ok(result)
1094    }
1095
1096    fn evaluate(&mut self) -> Result<ScalarValue> {
1097        if !self.is_input_pre_ordered {
1098            self.sort();
1099        }
1100
1101        if self.values.is_empty() {
1102            return Ok(ScalarValue::new_null_list(
1103                self.datatypes[0].clone(),
1104                true,
1105                1,
1106            ));
1107        }
1108
1109        let values = self.values.clone();
1110        let array = if self.reverse {
1111            ScalarValue::new_list_from_iter(
1112                values.into_iter().rev(),
1113                &self.datatypes[0],
1114                true,
1115            )
1116        } else {
1117            ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
1118        };
1119        Ok(ScalarValue::List(array))
1120    }
1121
1122    fn size(&self) -> usize {
1123        let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
1124            - size_of_val(&self.values);
1125
1126        // Add size of the `self.ordering_values`
1127        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
1128        for row in &self.ordering_values {
1129            total += ScalarValue::size_of_vec(row) - size_of_val(row);
1130        }
1131
1132        // Add size of the `self.datatypes`
1133        total += size_of::<DataType>() * self.datatypes.capacity();
1134        for dtype in &self.datatypes {
1135            total += dtype.size() - size_of_val(dtype);
1136        }
1137
1138        // Add size of the `self.ordering_req`
1139        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
1140        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
1141        total
1142    }
1143}
1144
1145#[cfg(test)]
1146mod tests {
1147    use super::*;
1148    use arrow::array::{ListBuilder, StringBuilder};
1149    use arrow::datatypes::{FieldRef, Schema};
1150    use datafusion_common::cast::as_generic_string_array;
1151    use datafusion_common::internal_err;
1152    use datafusion_physical_expr::PhysicalExpr;
1153    use datafusion_physical_expr::expressions::Column;
1154    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
1155    use std::sync::Arc;
1156
1157    #[test]
1158    fn no_duplicates_no_distinct() -> Result<()> {
1159        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1160
1161        acc1.update_batch(&[data(["a", "b", "c"])])?;
1162        acc2.update_batch(&[data(["d", "e", "f"])])?;
1163        acc1 = merge(acc1, acc2)?;
1164
1165        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1166
1167        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
1168
1169        Ok(())
1170    }
1171
1172    #[test]
1173    fn no_duplicates_distinct() -> Result<()> {
1174        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1175            .distinct()
1176            .build_two()?;
1177
1178        acc1.update_batch(&[data(["a", "b", "c"])])?;
1179        acc2.update_batch(&[data(["d", "e", "f"])])?;
1180        acc1 = merge(acc1, acc2)?;
1181
1182        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1183        result.sort();
1184
1185        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
1186
1187        Ok(())
1188    }
1189
1190    #[test]
1191    fn duplicates_no_distinct() -> Result<()> {
1192        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1193
1194        acc1.update_batch(&[data(["a", "b", "c"])])?;
1195        acc2.update_batch(&[data(["a", "b", "c"])])?;
1196        acc1 = merge(acc1, acc2)?;
1197
1198        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1199
1200        assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
1201
1202        Ok(())
1203    }
1204
1205    #[test]
1206    fn duplicates_distinct() -> Result<()> {
1207        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1208            .distinct()
1209            .build_two()?;
1210
1211        acc1.update_batch(&[data(["a", "b", "c"])])?;
1212        acc2.update_batch(&[data(["a", "b", "c"])])?;
1213        acc1 = merge(acc1, acc2)?;
1214
1215        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1216        result.sort();
1217
1218        assert_eq!(result, vec!["a", "b", "c"]);
1219
1220        Ok(())
1221    }
1222
1223    #[test]
1224    fn duplicates_on_second_batch_distinct() -> Result<()> {
1225        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1226            .distinct()
1227            .build_two()?;
1228
1229        acc1.update_batch(&[data(["a", "c"])])?;
1230        acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
1231        acc1 = merge(acc1, acc2)?;
1232
1233        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1234        result.sort();
1235
1236        assert_eq!(result, vec!["a", "b", "c", "d"]);
1237
1238        Ok(())
1239    }
1240
1241    #[test]
1242    fn no_duplicates_distinct_sort_asc() -> Result<()> {
1243        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1244            .distinct()
1245            .order_by_col("col", SortOptions::new(false, false))
1246            .build_two()?;
1247
1248        acc1.update_batch(&[data(["e", "b", "d"])])?;
1249        acc2.update_batch(&[data(["f", "a", "c"])])?;
1250        acc1 = merge(acc1, acc2)?;
1251
1252        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1253
1254        assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
1255
1256        Ok(())
1257    }
1258
1259    #[test]
1260    fn no_duplicates_distinct_sort_desc() -> Result<()> {
1261        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1262            .distinct()
1263            .order_by_col("col", SortOptions::new(true, false))
1264            .build_two()?;
1265
1266        acc1.update_batch(&[data(["e", "b", "d"])])?;
1267        acc2.update_batch(&[data(["f", "a", "c"])])?;
1268        acc1 = merge(acc1, acc2)?;
1269
1270        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1271
1272        assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
1273
1274        Ok(())
1275    }
1276
1277    #[test]
1278    fn duplicates_distinct_sort_asc() -> Result<()> {
1279        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1280            .distinct()
1281            .order_by_col("col", SortOptions::new(false, false))
1282            .build_two()?;
1283
1284        acc1.update_batch(&[data(["a", "c", "b"])])?;
1285        acc2.update_batch(&[data(["b", "c", "a"])])?;
1286        acc1 = merge(acc1, acc2)?;
1287
1288        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1289
1290        assert_eq!(result, vec!["a", "b", "c"]);
1291
1292        Ok(())
1293    }
1294
1295    #[test]
1296    fn duplicates_distinct_sort_desc() -> Result<()> {
1297        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1298            .distinct()
1299            .order_by_col("col", SortOptions::new(true, false))
1300            .build_two()?;
1301
1302        acc1.update_batch(&[data(["a", "c", "b"])])?;
1303        acc2.update_batch(&[data(["b", "c", "a"])])?;
1304        acc1 = merge(acc1, acc2)?;
1305
1306        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1307
1308        assert_eq!(result, vec!["c", "b", "a"]);
1309
1310        Ok(())
1311    }
1312
1313    #[test]
1314    fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
1315        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1316            .distinct()
1317            .order_by_col("col", SortOptions::new(false, true))
1318            .build_two()?;
1319
1320        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1321        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1322        acc1 = merge(acc1, acc2)?;
1323
1324        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1325
1326        assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
1327
1328        Ok(())
1329    }
1330
1331    #[test]
1332    fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
1333        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1334            .distinct()
1335            .order_by_col("col", SortOptions::new(false, false))
1336            .build_two()?;
1337
1338        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1339        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1340        acc1 = merge(acc1, acc2)?;
1341
1342        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1343
1344        assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
1345
1346        Ok(())
1347    }
1348
1349    #[test]
1350    fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
1351        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1352            .distinct()
1353            .order_by_col("col", SortOptions::new(true, true))
1354            .build_two()?;
1355
1356        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1357        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1358        acc1 = merge(acc1, acc2)?;
1359
1360        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1361
1362        assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
1363
1364        Ok(())
1365    }
1366
1367    #[test]
1368    fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
1369        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1370            .distinct()
1371            .order_by_col("col", SortOptions::new(true, false))
1372            .build_two()?;
1373
1374        acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1375        acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1376        acc1 = merge(acc1, acc2)?;
1377
1378        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1379
1380        assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
1381
1382        Ok(())
1383    }
1384
1385    #[test]
1386    fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
1387        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1388            .distinct()
1389            .build_two()?;
1390
1391        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1392        acc2.update_batch(&[data([Some("a"), None, None, None])])?;
1393        acc1 = merge(acc1, acc2)?;
1394
1395        let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1396        result.sort();
1397        assert_eq!(result, vec!["NULL", "a"]);
1398        Ok(())
1399    }
1400
1401    #[test]
1402    fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
1403        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1404            .distinct()
1405            .build_two()?;
1406
1407        acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1408        acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
1409        acc1 = merge(acc1, acc2)?;
1410
1411        let result = print_nulls(str_arr(acc1.evaluate()?)?);
1412        assert_eq!(result, vec!["NULL"]);
1413        Ok(())
1414    }
1415
1416    #[test]
1417    fn does_not_over_account_memory() -> Result<()> {
1418        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1419
1420        acc1.update_batch(&[data(["a", "c", "b"])])?;
1421        acc2.update_batch(&[data(["b", "c", "a"])])?;
1422        acc1 = merge(acc1, acc2)?;
1423
1424        assert_eq!(acc1.size(), 266);
1425
1426        Ok(())
1427    }
1428    #[test]
1429    fn does_not_over_account_memory_distinct() -> Result<()> {
1430        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1431            .distinct()
1432            .build_two()?;
1433
1434        acc1.update_batch(&[string_list_data([
1435            vec!["a", "b", "c"],
1436            vec!["d", "e", "f"],
1437        ])])?;
1438        acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
1439        acc1 = merge(acc1, acc2)?;
1440
1441        // without compaction, the size is 16660
1442        assert_eq!(acc1.size(), 1660);
1443
1444        Ok(())
1445    }
1446
1447    #[test]
1448    fn does_not_over_account_memory_ordered() -> Result<()> {
1449        let mut acc = ArrayAggAccumulatorBuilder::string()
1450            .order_by_col("col", SortOptions::new(false, false))
1451            .build()?;
1452
1453        acc.update_batch(&[string_list_data([
1454            vec!["a", "b", "c"],
1455            vec!["c", "d", "e"],
1456            vec!["b", "c", "d"],
1457        ])])?;
1458
1459        // without compaction, the size is 17112
1460        assert_eq!(acc.size(), 2224);
1461
1462        Ok(())
1463    }
1464
1465    struct ArrayAggAccumulatorBuilder {
1466        return_field: FieldRef,
1467        distinct: bool,
1468        order_bys: Vec<PhysicalSortExpr>,
1469        schema: Schema,
1470    }
1471
1472    impl ArrayAggAccumulatorBuilder {
1473        fn string() -> Self {
1474            Self::new(DataType::Utf8)
1475        }
1476
1477        fn new(data_type: DataType) -> Self {
1478            Self {
1479                return_field: Field::new("f", data_type.clone(), true).into(),
1480                distinct: false,
1481                order_bys: vec![],
1482                schema: Schema {
1483                    fields: Fields::from(vec![Field::new(
1484                        "col",
1485                        DataType::new_list(data_type, true),
1486                        true,
1487                    )]),
1488                    metadata: Default::default(),
1489                },
1490            }
1491        }
1492
1493        fn distinct(mut self) -> Self {
1494            self.distinct = true;
1495            self
1496        }
1497
1498        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
1499            let new_order = PhysicalSortExpr::new(
1500                Arc::new(
1501                    Column::new_with_schema(col, &self.schema)
1502                        .expect("column not available in schema"),
1503                ),
1504                sort_options,
1505            );
1506            self.order_bys.push(new_order);
1507            self
1508        }
1509
1510        fn build(&self) -> Result<Box<dyn Accumulator>> {
1511            let expr = Arc::new(Column::new("col", 0));
1512            let expr_field = expr.return_field(&self.schema)?;
1513            ArrayAgg::default().accumulator(AccumulatorArgs {
1514                return_field: Arc::clone(&self.return_field),
1515                schema: &self.schema,
1516                expr_fields: &[expr_field],
1517                ignore_nulls: false,
1518                order_bys: &self.order_bys,
1519                is_reversed: false,
1520                name: "",
1521                is_distinct: self.distinct,
1522                exprs: &[expr],
1523            })
1524        }
1525
1526        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
1527            Ok((self.build()?, self.build()?))
1528        }
1529    }
1530
1531    fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
1532        let ScalarValue::List(list) = value else {
1533            return internal_err!("ScalarValue was not a List");
1534        };
1535        Ok(as_generic_string_array::<i32>(list.values())?
1536            .iter()
1537            .map(|v| v.map(|v| v.to_string()))
1538            .collect())
1539    }
1540
1541    fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1542        sort.into_iter()
1543            .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
1544            .collect()
1545    }
1546
1547    fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef {
1548        let mut builder = ListBuilder::new(StringBuilder::new());
1549        for string_list in data.into_iter() {
1550            builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>());
1551        }
1552
1553        Arc::new(builder.finish())
1554    }
1555
1556    fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1557    where
1558        ScalarValue: From<T>,
1559    {
1560        let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1561        ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1562    }
1563
1564    fn merge(
1565        mut acc1: Box<dyn Accumulator>,
1566        mut acc2: Box<dyn Accumulator>,
1567    ) -> Result<Box<dyn Accumulator>> {
1568        let intermediate_state = acc2.state().and_then(|e| {
1569            e.iter()
1570                .map(|v| v.to_array())
1571                .collect::<Result<Vec<ArrayRef>>>()
1572        })?;
1573        acc1.merge_batch(&intermediate_state)?;
1574        Ok(acc1)
1575    }
1576
1577    // ---- GroupsAccumulator tests ----
1578
1579    use arrow::array::Int32Array;
1580
1581    fn list_array_to_i32_vecs(list: &ListArray) -> Vec<Option<Vec<Option<i32>>>> {
1582        (0..list.len())
1583            .map(|i| {
1584                if list.is_null(i) {
1585                    None
1586                } else {
1587                    let arr = list.value(i);
1588                    let vals: Vec<Option<i32>> = arr
1589                        .as_any()
1590                        .downcast_ref::<Int32Array>()
1591                        .unwrap()
1592                        .iter()
1593                        .collect();
1594                    Some(vals)
1595                }
1596            })
1597            .collect()
1598    }
1599
1600    fn eval_i32_lists(
1601        acc: &mut ArrayAggGroupsAccumulator,
1602        emit_to: EmitTo,
1603    ) -> Result<Vec<Option<Vec<Option<i32>>>>> {
1604        let result = acc.evaluate(emit_to)?;
1605        Ok(list_array_to_i32_vecs(result.as_list::<i32>()))
1606    }
1607
1608    #[test]
1609    fn groups_accumulator_multiple_batches() -> Result<()> {
1610        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1611
1612        // First batch
1613        let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
1614        acc.update_batch(&[values], &[0, 1, 0], None, 2)?;
1615
1616        // Second batch
1617        let values: ArrayRef = Arc::new(Int32Array::from(vec![4, 5]));
1618        acc.update_batch(&[values], &[1, 0], None, 2)?;
1619
1620        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1621        assert_eq!(vals[0], Some(vec![Some(1), Some(3), Some(5)]));
1622        assert_eq!(vals[1], Some(vec![Some(2), Some(4)]));
1623
1624        Ok(())
1625    }
1626
1627    #[test]
1628    fn groups_accumulator_emit_first() -> Result<()> {
1629        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1630
1631        let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30]));
1632        acc.update_batch(&[values], &[0, 1, 2], None, 3)?;
1633
1634        // Emit first 2 groups
1635        let vals = eval_i32_lists(&mut acc, EmitTo::First(2))?;
1636        assert_eq!(vals.len(), 2);
1637        assert_eq!(vals[0], Some(vec![Some(10)]));
1638        assert_eq!(vals[1], Some(vec![Some(20)]));
1639
1640        // Remaining group (was index 2, now shifted to 0)
1641        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1642        assert_eq!(vals.len(), 1);
1643        assert_eq!(vals[0], Some(vec![Some(30)]));
1644
1645        Ok(())
1646    }
1647
1648    #[test]
1649    fn groups_accumulator_emit_first_frees_batches() -> Result<()> {
1650        // Batch 0 has rows only for group 0; batch 1 has rows for
1651        // both groups. After emitting group 0, batch 0 should be
1652        // dropped entirely and batch 1 should be compacted to the
1653        // retained row(s).
1654        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1655
1656        let batch0: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1657        acc.update_batch(&[batch0], &[0, 0], None, 2)?;
1658
1659        let batch1: ArrayRef = Arc::new(Int32Array::from(vec![30, 40]));
1660        acc.update_batch(&[batch1], &[0, 1], None, 2)?;
1661
1662        assert_eq!(acc.batches.len(), 2);
1663        assert!(!acc.batches[0].is_empty());
1664        assert!(!acc.batches[1].is_empty());
1665
1666        // Emit group 0. Batch 0 is only referenced by group 0, so it
1667        // should be removed. Batch 1 is mixed, so it should be compacted
1668        // to contain only the retained row for group 1.
1669        let vals = eval_i32_lists(&mut acc, EmitTo::First(1))?;
1670        assert_eq!(vals[0], Some(vec![Some(10), Some(20), Some(30)]));
1671
1672        assert_eq!(acc.batches.len(), 1);
1673        let retained = acc.batches[0]
1674            .as_any()
1675            .downcast_ref::<Int32Array>()
1676            .unwrap();
1677        assert_eq!(retained.values(), &[40]);
1678        assert_eq!(acc.batch_entries, vec![vec![(0, 0)]]);
1679
1680        // Emit remaining group 1
1681        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1682        assert_eq!(vals[0], Some(vec![Some(40)]));
1683
1684        assert!(acc.batches.is_empty());
1685        assert_eq!(acc.size(), 0);
1686
1687        Ok(())
1688    }
1689
1690    #[test]
1691    fn groups_accumulator_emit_first_compacts_mixed_batches() -> Result<()> {
1692        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1693
1694        let batch: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40]));
1695        acc.update_batch(&[batch], &[0, 1, 0, 1], None, 2)?;
1696
1697        let size_before = acc.size();
1698        let vals = eval_i32_lists(&mut acc, EmitTo::First(1))?;
1699        assert_eq!(vals[0], Some(vec![Some(10), Some(30)]));
1700
1701        assert_eq!(acc.num_groups, 1);
1702        assert_eq!(acc.batches.len(), 1);
1703        let retained = acc.batches[0]
1704            .as_any()
1705            .downcast_ref::<Int32Array>()
1706            .unwrap();
1707        assert_eq!(retained.values(), &[20, 40]);
1708        assert_eq!(acc.batch_entries, vec![vec![(0, 0), (0, 1)]]);
1709        assert!(acc.size() < size_before);
1710
1711        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1712        assert_eq!(vals[0], Some(vec![Some(20), Some(40)]));
1713        assert_eq!(acc.size(), 0);
1714
1715        Ok(())
1716    }
1717
1718    #[test]
1719    fn groups_accumulator_emit_all_releases_capacity() -> Result<()> {
1720        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1721
1722        let batch: ArrayRef = Arc::new(Int32Array::from_iter_values(0..64));
1723        acc.update_batch(
1724            &[batch],
1725            &(0..64).map(|i| i % 4).collect::<Vec<_>>(),
1726            None,
1727            4,
1728        )?;
1729
1730        assert!(acc.size() > 0);
1731        let _ = eval_i32_lists(&mut acc, EmitTo::All)?;
1732
1733        assert_eq!(acc.size(), 0);
1734        assert_eq!(acc.batches.capacity(), 0);
1735        assert_eq!(acc.batch_entries.capacity(), 0);
1736
1737        Ok(())
1738    }
1739
1740    #[test]
1741    fn groups_accumulator_null_groups() -> Result<()> {
1742        // Groups that never receive values should produce null
1743        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1744
1745        let values: ArrayRef = Arc::new(Int32Array::from(vec![1]));
1746        // Only group 0 gets a value, groups 1 and 2 are empty
1747        acc.update_batch(&[values], &[0], None, 3)?;
1748
1749        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1750        assert_eq!(vals, vec![Some(vec![Some(1)]), None, None]);
1751
1752        Ok(())
1753    }
1754
1755    #[test]
1756    fn groups_accumulator_ignore_nulls() -> Result<()> {
1757        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true);
1758
1759        let values: ArrayRef =
1760            Arc::new(Int32Array::from(vec![Some(1), None, Some(3), None]));
1761        acc.update_batch(&[values], &[0, 0, 1, 1], None, 2)?;
1762
1763        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1764        // Group 0: only non-null value is 1
1765        assert_eq!(vals[0], Some(vec![Some(1)]));
1766        // Group 1: only non-null value is 3
1767        assert_eq!(vals[1], Some(vec![Some(3)]));
1768
1769        Ok(())
1770    }
1771
1772    #[test]
1773    fn groups_accumulator_opt_filter() -> Result<()> {
1774        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1775
1776        let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1777        // Use a mix of false and null to filter out rows — both should
1778        // be skipped.
1779        let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]);
1780        acc.update_batch(&[values], &[0, 0, 1, 1], Some(&filter), 2)?;
1781
1782        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1783        assert_eq!(vals[0], Some(vec![Some(1)])); // row 1 filtered (null)
1784        assert_eq!(vals[1], Some(vec![Some(3)])); // row 3 filtered (false)
1785
1786        Ok(())
1787    }
1788
1789    #[test]
1790    fn groups_accumulator_state_merge_roundtrip() -> Result<()> {
1791        // Accumulator 1: update_batch, then merge, then update_batch again.
1792        // Verifies that values appear in chronological insertion order.
1793        let mut acc1 = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1794        let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1795        acc1.update_batch(&[values], &[0, 1], None, 2)?;
1796
1797        // Accumulator 2
1798        let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1799        let values: ArrayRef = Arc::new(Int32Array::from(vec![3, 4]));
1800        acc2.update_batch(&[values], &[0, 1], None, 2)?;
1801
1802        // Merge acc2's state into acc1
1803        let state = acc2.state(EmitTo::All)?;
1804        acc1.merge_batch(&state, &[0, 1], None, 2)?;
1805
1806        // Another update_batch on acc1 after the merge
1807        let values: ArrayRef = Arc::new(Int32Array::from(vec![5, 6]));
1808        acc1.update_batch(&[values], &[0, 1], None, 2)?;
1809
1810        // Each group's values in insertion order:
1811        // group 0: update(1), merge(3), update(5) → [1, 3, 5]
1812        // group 1: update(2), merge(4), update(6) → [2, 4, 6]
1813        let vals = eval_i32_lists(&mut acc1, EmitTo::All)?;
1814        assert_eq!(vals[0], Some(vec![Some(1), Some(3), Some(5)]));
1815        assert_eq!(vals[1], Some(vec![Some(2), Some(4), Some(6)]));
1816
1817        Ok(())
1818    }
1819
1820    #[test]
1821    fn groups_accumulator_convert_to_state() -> Result<()> {
1822        let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1823
1824        let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), None, Some(30)]));
1825        let state = acc.convert_to_state(&[values], None)?;
1826
1827        assert_eq!(state.len(), 1);
1828        let vals = list_array_to_i32_vecs(state[0].as_list::<i32>());
1829        assert_eq!(
1830            vals,
1831            vec![
1832                Some(vec![Some(10)]),
1833                Some(vec![None]), // null preserved inside list, not promoted
1834                Some(vec![Some(30)]),
1835            ]
1836        );
1837
1838        Ok(())
1839    }
1840
1841    #[test]
1842    fn groups_accumulator_convert_to_state_with_filter() -> Result<()> {
1843        let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1844
1845        let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30]));
1846        let filter = BooleanArray::from(vec![true, false, true]);
1847        let state = acc.convert_to_state(&[values], Some(&filter))?;
1848
1849        let vals = list_array_to_i32_vecs(state[0].as_list::<i32>());
1850        assert_eq!(
1851            vals,
1852            vec![
1853                Some(vec![Some(10)]),
1854                None, // filtered
1855                Some(vec![Some(30)]),
1856            ]
1857        );
1858
1859        Ok(())
1860    }
1861
1862    #[test]
1863    fn groups_accumulator_convert_to_state_merge_preserves_nulls() -> Result<()> {
1864        // Verifies that null values survive the convert_to_state -> merge_batch
1865        // round-trip when ignore_nulls is false (default null handling).
1866        let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1867
1868        let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1869        let state = acc.convert_to_state(&[values], None)?;
1870
1871        // Feed state into a new accumulator via merge_batch
1872        let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1873        acc2.merge_batch(&state, &[0, 0, 1], None, 2)?;
1874
1875        // Group 0 received rows 0 ([1]) and 1 ([NULL]) → [1, NULL]
1876        let vals = eval_i32_lists(&mut acc2, EmitTo::All)?;
1877        assert_eq!(vals[0], Some(vec![Some(1), None]));
1878        // Group 1 received row 2 ([3]) → [3]
1879        assert_eq!(vals[1], Some(vec![Some(3)]));
1880
1881        Ok(())
1882    }
1883
1884    #[test]
1885    fn groups_accumulator_convert_to_state_merge_ignore_nulls() -> Result<()> {
1886        // Verifies that null values are dropped in the convert_to_state ->
1887        // merge_batch round-trip when ignore_nulls is true.
1888        let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true);
1889
1890        let values: ArrayRef =
1891            Arc::new(Int32Array::from(vec![Some(1), None, Some(3), None]));
1892        let state = acc.convert_to_state(&[values], None)?;
1893
1894        let list = state[0].as_list::<i32>();
1895        // Rows 0 and 2 are valid lists; rows 1 and 3 are null list entries
1896        assert!(!list.is_null(0));
1897        assert!(list.is_null(1));
1898        assert!(!list.is_null(2));
1899        assert!(list.is_null(3));
1900
1901        // Feed state into a new accumulator via merge_batch
1902        let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, true);
1903        acc2.merge_batch(&state, &[0, 0, 1, 1], None, 2)?;
1904
1905        // Group 0: received [1] and null (skipped) → [1]
1906        let vals = eval_i32_lists(&mut acc2, EmitTo::All)?;
1907        assert_eq!(vals[0], Some(vec![Some(1)]));
1908        // Group 1: received [3] and null (skipped) → [3]
1909        assert_eq!(vals[1], Some(vec![Some(3)]));
1910
1911        Ok(())
1912    }
1913
1914    #[test]
1915    fn groups_accumulator_all_groups_empty() -> Result<()> {
1916        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false);
1917
1918        // Create groups but don't add any values (all filtered out)
1919        let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1920        let filter = BooleanArray::from(vec![false, false]);
1921        acc.update_batch(&[values], &[0, 1], Some(&filter), 2)?;
1922
1923        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1924        assert_eq!(vals, vec![None, None]);
1925
1926        Ok(())
1927    }
1928
1929    #[test]
1930    fn groups_accumulator_ignore_nulls_all_null_group() -> Result<()> {
1931        // When ignore_nulls is true and a group receives only nulls,
1932        // it should produce a null output
1933        let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true);
1934
1935        let values: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(1), None]));
1936        acc.update_batch(&[values], &[0, 1, 0], None, 2)?;
1937
1938        let vals = eval_i32_lists(&mut acc, EmitTo::All)?;
1939        assert_eq!(vals[0], None); // group 0 got only nulls, all filtered
1940        assert_eq!(vals[1], Some(vec![Some(1)])); // group 1 got value 1
1941
1942        Ok(())
1943    }
1944}