1use std::cmp::Ordering;
21use std::collections::{HashSet, VecDeque};
22use std::mem::{size_of, size_of_val, take};
23use std::sync::Arc;
24
25use arrow::array::{
26 new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray,
27};
28use arrow::compute::{filter, SortOptions};
29use arrow::datatypes::{DataType, Field, FieldRef, Fields};
30
31use datafusion_common::cast::as_list_array;
32use datafusion_common::utils::{
33 compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder,
34};
35use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
40};
41use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
42use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
43use datafusion_functions_aggregate_common::utils::ordering_fields;
44use datafusion_macros::user_doc;
45use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
46
47make_udaf_expr_and_func!(
48 ArrayAgg,
49 array_agg,
50 expression,
51 "input values, including nulls, concatenated into an array",
52 array_agg_udaf
53);
54
55#[user_doc(
56 doc_section(label = "General Functions"),
57 description = r#"Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.
58This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the argument expression."#,
59 syntax_example = "array_agg(expression [ORDER BY expression])",
60 sql_example = r#"
61```sql
62> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
63+-----------------------------------------------+
64| array_agg(column_name ORDER BY other_column) |
65+-----------------------------------------------+
66| [element1, element2, element3] |
67+-----------------------------------------------+
68> SELECT array_agg(DISTINCT column_name ORDER BY column_name) FROM table_name;
69+--------------------------------------------------------+
70| array_agg(DISTINCT column_name ORDER BY column_name) |
71+--------------------------------------------------------+
72| [element1, element2, element3] |
73+--------------------------------------------------------+
74```
75"#,
76 standard_argument(name = "expression",)
77)]
78#[derive(Debug, PartialEq, Eq, Hash)]
79pub struct ArrayAgg {
81 signature: Signature,
82 is_input_pre_ordered: bool,
83}
84
85impl Default for ArrayAgg {
86 fn default() -> Self {
87 Self {
88 signature: Signature::any(1, Volatility::Immutable),
89 is_input_pre_ordered: false,
90 }
91 }
92}
93
94impl AggregateUDFImpl for ArrayAgg {
95 fn as_any(&self) -> &dyn std::any::Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "array_agg"
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108 Ok(DataType::List(Arc::new(Field::new_list_field(
109 arg_types[0].clone(),
110 true,
111 ))))
112 }
113
114 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115 if args.is_distinct {
116 return Ok(vec![Field::new_list(
117 format_state_name(args.name, "distinct_array_agg"),
118 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
120 true,
121 )
122 .into()]);
123 }
124
125 let mut fields = vec![Field::new_list(
126 format_state_name(args.name, "array_agg"),
127 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
129 true,
130 )
131 .into()];
132
133 if args.ordering_fields.is_empty() {
134 return Ok(fields);
135 }
136
137 let orderings = args.ordering_fields.to_vec();
138 fields.push(
139 Field::new_list(
140 format_state_name(args.name, "array_agg_orderings"),
141 Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
142 false,
143 )
144 .into(),
145 );
146
147 Ok(fields)
148 }
149
150 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
151 AggregateOrderSensitivity::SoftRequirement
152 }
153
154 fn with_beneficial_ordering(
155 self: Arc<Self>,
156 beneficial_ordering: bool,
157 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
158 Ok(Some(Arc::new(Self {
159 signature: self.signature.clone(),
160 is_input_pre_ordered: beneficial_ordering,
161 })))
162 }
163
164 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
165 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
166 let ignore_nulls =
167 acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
168
169 if acc_args.is_distinct {
170 let sort_option = match acc_args.order_bys {
185 [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options),
186 [] => None,
187 _ => {
188 return exec_err!(
189 "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"
190 );
191 }
192 };
193 return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
194 &data_type,
195 sort_option,
196 ignore_nulls,
197 )?));
198 }
199
200 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
201 return Ok(Box::new(ArrayAggAccumulator::try_new(
202 &data_type,
203 ignore_nulls,
204 )?));
205 };
206
207 let ordering_dtypes = ordering
208 .iter()
209 .map(|e| e.expr.data_type(acc_args.schema))
210 .collect::<Result<Vec<_>>>()?;
211
212 OrderSensitiveArrayAggAccumulator::try_new(
213 &data_type,
214 &ordering_dtypes,
215 ordering,
216 self.is_input_pre_ordered,
217 acc_args.is_reversed,
218 ignore_nulls,
219 )
220 .map(|acc| Box::new(acc) as _)
221 }
222
223 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
224 datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
225 }
226
227 fn documentation(&self) -> Option<&Documentation> {
228 self.doc()
229 }
230}
231
232#[derive(Debug)]
233pub struct ArrayAggAccumulator {
234 values: Vec<ArrayRef>,
235 datatype: DataType,
236 ignore_nulls: bool,
237}
238
239impl ArrayAggAccumulator {
240 pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
242 Ok(Self {
243 values: vec![],
244 datatype: datatype.clone(),
245 ignore_nulls,
246 })
247 }
248
249 fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
252 let offsets = list_array.value_offsets();
253 let initial_offset = offsets[0];
255 let null_count = list_array.null_count();
256
257 if null_count == 0 {
260 let list_values = list_array.values().slice(
262 initial_offset as usize,
263 (offsets[offsets.len() - 1] - initial_offset) as usize,
264 );
265 return Some(list_values);
266 }
267
268 if list_array.null_count() == list_array.len() {
270 return Some(list_array.values().slice(0, 0));
271 }
272
273 let nulls = list_array.nulls().unwrap();
278
279 let mut valid_slices_iter = nulls.valid_slices();
280
281 let (start, end) = valid_slices_iter.next().unwrap();
283
284 let start_offset = offsets[start];
285
286 let mut end_offset_of_last_valid_value = offsets[end];
289
290 for (start, end) in valid_slices_iter {
291 if offsets[start] != end_offset_of_last_valid_value {
294 return None;
295 }
296
297 end_offset_of_last_valid_value = offsets[end];
300 }
301
302 let consecutive_valid_values = list_array.values().slice(
303 start_offset as usize,
304 (end_offset_of_last_valid_value - start_offset) as usize,
305 );
306
307 Some(consecutive_valid_values)
308 }
309}
310
311impl Accumulator for ArrayAggAccumulator {
312 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
313 if values.is_empty() {
315 return Ok(());
316 }
317
318 if values.len() != 1 {
319 return internal_err!("expects single batch");
320 }
321
322 let val = &values[0];
323 let nulls = if self.ignore_nulls {
324 val.logical_nulls()
325 } else {
326 None
327 };
328
329 let val = match nulls {
330 Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
331 Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
332 None => Arc::clone(val),
333 };
334
335 if !val.is_empty() {
336 self.values.push(val)
337 }
338
339 Ok(())
340 }
341
342 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
343 if states.is_empty() {
345 return Ok(());
346 }
347
348 if states.len() != 1 {
349 return internal_err!("expects single state");
350 }
351
352 let list_arr = as_list_array(&states[0])?;
353
354 match Self::get_optional_values_to_merge_as_is(list_arr) {
355 Some(values) => {
356 if !values.is_empty() {
358 self.values.push(values);
359 }
360 }
361 None => {
362 for arr in list_arr.iter().flatten() {
363 self.values.push(arr);
364 }
365 }
366 }
367
368 Ok(())
369 }
370
371 fn state(&mut self) -> Result<Vec<ScalarValue>> {
372 Ok(vec![self.evaluate()?])
373 }
374
375 fn evaluate(&mut self) -> Result<ScalarValue> {
376 let element_arrays: Vec<&dyn Array> =
378 self.values.iter().map(|a| a.as_ref()).collect();
379
380 if element_arrays.is_empty() {
381 return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
382 }
383
384 let concated_array = arrow::compute::concat(&element_arrays)?;
385
386 Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
387 }
388
389 fn size(&self) -> usize {
390 size_of_val(self)
391 + (size_of::<ArrayRef>() * self.values.capacity())
392 + self
393 .values
394 .iter()
395 .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
407 .sum::<usize>()
408 + self.datatype.size()
409 - size_of_val(&self.datatype)
410 }
411}
412
413#[derive(Debug)]
414struct DistinctArrayAggAccumulator {
415 values: HashSet<ScalarValue>,
416 datatype: DataType,
417 sort_options: Option<SortOptions>,
418 ignore_nulls: bool,
419}
420
421impl DistinctArrayAggAccumulator {
422 pub fn try_new(
423 datatype: &DataType,
424 sort_options: Option<SortOptions>,
425 ignore_nulls: bool,
426 ) -> Result<Self> {
427 Ok(Self {
428 values: HashSet::new(),
429 datatype: datatype.clone(),
430 sort_options,
431 ignore_nulls,
432 })
433 }
434}
435
436impl Accumulator for DistinctArrayAggAccumulator {
437 fn state(&mut self) -> Result<Vec<ScalarValue>> {
438 Ok(vec![self.evaluate()?])
439 }
440
441 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
442 if values.is_empty() {
443 return Ok(());
444 }
445
446 let val = &values[0];
447 let nulls = if self.ignore_nulls {
448 val.logical_nulls()
449 } else {
450 None
451 };
452
453 let nulls = nulls.as_ref();
454 if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
455 for i in 0..val.len() {
456 if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
457 self.values
458 .insert(ScalarValue::try_from_array(val, i)?.compacted());
459 }
460 }
461 }
462
463 Ok(())
464 }
465
466 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
467 if states.is_empty() {
468 return Ok(());
469 }
470
471 if states.len() != 1 {
472 return internal_err!("expects single state");
473 }
474
475 states[0]
476 .as_list::<i32>()
477 .iter()
478 .flatten()
479 .try_for_each(|val| self.update_batch(&[val]))
480 }
481
482 fn evaluate(&mut self) -> Result<ScalarValue> {
483 let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
484 if values.is_empty() {
485 return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
486 }
487
488 if let Some(opts) = self.sort_options {
489 let mut delayed_cmp_err = Ok(());
490 values.sort_by(|a, b| {
491 if a.is_null() {
492 return match opts.nulls_first {
493 true => Ordering::Less,
494 false => Ordering::Greater,
495 };
496 }
497 if b.is_null() {
498 return match opts.nulls_first {
499 true => Ordering::Greater,
500 false => Ordering::Less,
501 };
502 }
503 match opts.descending {
504 true => b.try_cmp(a),
505 false => a.try_cmp(b),
506 }
507 .unwrap_or_else(|err| {
508 delayed_cmp_err = Err(err);
509 Ordering::Equal
510 })
511 });
512 delayed_cmp_err?;
513 };
514
515 let arr = ScalarValue::new_list(&values, &self.datatype, true);
516 Ok(ScalarValue::List(arr))
517 }
518
519 fn size(&self) -> usize {
520 size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
521 - size_of_val(&self.values)
522 + self.datatype.size()
523 - size_of_val(&self.datatype)
524 - size_of_val(&self.sort_options)
525 + size_of::<Option<SortOptions>>()
526 }
527}
528
529#[derive(Debug)]
533pub(crate) struct OrderSensitiveArrayAggAccumulator {
534 values: Vec<ScalarValue>,
536 ordering_values: Vec<Vec<ScalarValue>>,
541 datatypes: Vec<DataType>,
544 ordering_req: LexOrdering,
546 is_input_pre_ordered: bool,
548 reverse: bool,
550 ignore_nulls: bool,
552}
553
554impl OrderSensitiveArrayAggAccumulator {
555 pub fn try_new(
558 datatype: &DataType,
559 ordering_dtypes: &[DataType],
560 ordering_req: LexOrdering,
561 is_input_pre_ordered: bool,
562 reverse: bool,
563 ignore_nulls: bool,
564 ) -> Result<Self> {
565 let mut datatypes = vec![datatype.clone()];
566 datatypes.extend(ordering_dtypes.iter().cloned());
567 Ok(Self {
568 values: vec![],
569 ordering_values: vec![],
570 datatypes,
571 ordering_req,
572 is_input_pre_ordered,
573 reverse,
574 ignore_nulls,
575 })
576 }
577
578 fn sort(&mut self) {
579 let sort_options = self
580 .ordering_req
581 .iter()
582 .map(|sort_expr| sort_expr.options)
583 .collect::<Vec<_>>();
584 let mut values = take(&mut self.values)
585 .into_iter()
586 .zip(take(&mut self.ordering_values))
587 .collect::<Vec<_>>();
588 let mut delayed_cmp_err = Ok(());
589 values.sort_by(|(_, left_ordering), (_, right_ordering)| {
590 compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
591 |err| {
592 delayed_cmp_err = Err(err);
593 Ordering::Equal
594 },
595 )
596 });
597 (self.values, self.ordering_values) = values.into_iter().unzip();
598 }
599
600 fn evaluate_orderings(&self) -> Result<ScalarValue> {
601 let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
602
603 let column_wise_ordering_values = if self.ordering_values.is_empty() {
604 fields
605 .iter()
606 .map(|f| new_empty_array(f.data_type()))
607 .collect::<Vec<_>>()
608 } else {
609 (0..fields.len())
610 .map(|i| {
611 let column_values = self.ordering_values.iter().map(|x| x[i].clone());
612 ScalarValue::iter_to_array(column_values)
613 })
614 .collect::<Result<_>>()?
615 };
616
617 let ordering_array = StructArray::try_new(
618 Fields::from(fields),
619 column_wise_ordering_values,
620 None,
621 )?;
622 Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
623 }
624}
625
626impl Accumulator for OrderSensitiveArrayAggAccumulator {
627 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
628 if values.is_empty() {
629 return Ok(());
630 }
631
632 let val = &values[0];
633 let ord = &values[1..];
634 let nulls = if self.ignore_nulls {
635 val.logical_nulls()
636 } else {
637 None
638 };
639
640 let nulls = nulls.as_ref();
641 if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
642 for i in 0..val.len() {
643 if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
644 self.values
645 .push(ScalarValue::try_from_array(val, i)?.compacted());
646 self.ordering_values.push(
647 get_row_at_idx(ord, i)?
648 .into_iter()
649 .map(|v| v.compacted())
650 .collect(),
651 )
652 }
653 }
654 }
655
656 Ok(())
657 }
658
659 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
660 if states.is_empty() {
661 return Ok(());
662 }
663
664 let [array_agg_values, agg_orderings] =
671 take_function_args("OrderSensitiveArrayAggAccumulator::merge_batch", states)?;
672 let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
673 return exec_err!("Expects to receive a list array");
674 };
675
676 let mut partition_values = vec![];
678 let mut partition_ordering_values = vec![];
680
681 if !self.is_input_pre_ordered {
683 self.sort();
684 }
685 partition_values.push(take(&mut self.values).into());
686 partition_ordering_values.push(take(&mut self.ordering_values).into());
687
688 let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
690 for v in array_agg_res.into_iter() {
691 partition_values.push(v.into());
692 }
693
694 let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
695
696 for partition_ordering_rows in orderings.into_iter() {
697 let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
699 if let ScalarValue::Struct(s) = ordering_row {
700 let mut ordering_columns_per_row = vec![];
701
702 for column in s.columns() {
703 let sv = ScalarValue::try_from_array(column, 0)?;
704 ordering_columns_per_row.push(sv);
705 }
706
707 Ok(ordering_columns_per_row)
708 } else {
709 exec_err!(
710 "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
711 ordering_row.data_type()
712 )
713 }
714 }).collect::<Result<VecDeque<_>>>()?;
715
716 partition_ordering_values.push(ordering_value);
717 }
718
719 let sort_options = self
720 .ordering_req
721 .iter()
722 .map(|sort_expr| sort_expr.options)
723 .collect::<Vec<_>>();
724
725 (self.values, self.ordering_values) = merge_ordered_arrays(
726 &mut partition_values,
727 &mut partition_ordering_values,
728 &sort_options,
729 )?;
730
731 Ok(())
732 }
733
734 fn state(&mut self) -> Result<Vec<ScalarValue>> {
735 if !self.is_input_pre_ordered {
736 self.sort();
737 }
738
739 let mut result = vec![self.evaluate()?];
740 result.push(self.evaluate_orderings()?);
741
742 Ok(result)
743 }
744
745 fn evaluate(&mut self) -> Result<ScalarValue> {
746 if !self.is_input_pre_ordered {
747 self.sort();
748 }
749
750 if self.values.is_empty() {
751 return Ok(ScalarValue::new_null_list(
752 self.datatypes[0].clone(),
753 true,
754 1,
755 ));
756 }
757
758 let values = self.values.clone();
759 let array = if self.reverse {
760 ScalarValue::new_list_from_iter(
761 values.into_iter().rev(),
762 &self.datatypes[0],
763 true,
764 )
765 } else {
766 ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
767 };
768 Ok(ScalarValue::List(array))
769 }
770
771 fn size(&self) -> usize {
772 let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
773 - size_of_val(&self.values);
774
775 total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
777 for row in &self.ordering_values {
778 total += ScalarValue::size_of_vec(row) - size_of_val(row);
779 }
780
781 total += size_of::<DataType>() * self.datatypes.capacity();
783 for dtype in &self.datatypes {
784 total += dtype.size() - size_of_val(dtype);
785 }
786
787 total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
789 total
791 }
792}
793
794#[cfg(test)]
795mod tests {
796 use super::*;
797 use arrow::array::{ListBuilder, StringBuilder};
798 use arrow::datatypes::{FieldRef, Schema};
799 use datafusion_common::cast::as_generic_string_array;
800 use datafusion_common::internal_err;
801 use datafusion_physical_expr::expressions::Column;
802 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
803 use std::sync::Arc;
804
805 #[test]
806 fn no_duplicates_no_distinct() -> Result<()> {
807 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
808
809 acc1.update_batch(&[data(["a", "b", "c"])])?;
810 acc2.update_batch(&[data(["d", "e", "f"])])?;
811 acc1 = merge(acc1, acc2)?;
812
813 let result = print_nulls(str_arr(acc1.evaluate()?)?);
814
815 assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
816
817 Ok(())
818 }
819
820 #[test]
821 fn no_duplicates_distinct() -> Result<()> {
822 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
823 .distinct()
824 .build_two()?;
825
826 acc1.update_batch(&[data(["a", "b", "c"])])?;
827 acc2.update_batch(&[data(["d", "e", "f"])])?;
828 acc1 = merge(acc1, acc2)?;
829
830 let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
831 result.sort();
832
833 assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
834
835 Ok(())
836 }
837
838 #[test]
839 fn duplicates_no_distinct() -> Result<()> {
840 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
841
842 acc1.update_batch(&[data(["a", "b", "c"])])?;
843 acc2.update_batch(&[data(["a", "b", "c"])])?;
844 acc1 = merge(acc1, acc2)?;
845
846 let result = print_nulls(str_arr(acc1.evaluate()?)?);
847
848 assert_eq!(result, vec!["a", "b", "c", "a", "b", "c"]);
849
850 Ok(())
851 }
852
853 #[test]
854 fn duplicates_distinct() -> Result<()> {
855 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
856 .distinct()
857 .build_two()?;
858
859 acc1.update_batch(&[data(["a", "b", "c"])])?;
860 acc2.update_batch(&[data(["a", "b", "c"])])?;
861 acc1 = merge(acc1, acc2)?;
862
863 let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
864 result.sort();
865
866 assert_eq!(result, vec!["a", "b", "c"]);
867
868 Ok(())
869 }
870
871 #[test]
872 fn duplicates_on_second_batch_distinct() -> Result<()> {
873 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
874 .distinct()
875 .build_two()?;
876
877 acc1.update_batch(&[data(["a", "c"])])?;
878 acc2.update_batch(&[data(["d", "a", "b", "c"])])?;
879 acc1 = merge(acc1, acc2)?;
880
881 let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
882 result.sort();
883
884 assert_eq!(result, vec!["a", "b", "c", "d"]);
885
886 Ok(())
887 }
888
889 #[test]
890 fn no_duplicates_distinct_sort_asc() -> Result<()> {
891 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
892 .distinct()
893 .order_by_col("col", SortOptions::new(false, false))
894 .build_two()?;
895
896 acc1.update_batch(&[data(["e", "b", "d"])])?;
897 acc2.update_batch(&[data(["f", "a", "c"])])?;
898 acc1 = merge(acc1, acc2)?;
899
900 let result = print_nulls(str_arr(acc1.evaluate()?)?);
901
902 assert_eq!(result, vec!["a", "b", "c", "d", "e", "f"]);
903
904 Ok(())
905 }
906
907 #[test]
908 fn no_duplicates_distinct_sort_desc() -> Result<()> {
909 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
910 .distinct()
911 .order_by_col("col", SortOptions::new(true, false))
912 .build_two()?;
913
914 acc1.update_batch(&[data(["e", "b", "d"])])?;
915 acc2.update_batch(&[data(["f", "a", "c"])])?;
916 acc1 = merge(acc1, acc2)?;
917
918 let result = print_nulls(str_arr(acc1.evaluate()?)?);
919
920 assert_eq!(result, vec!["f", "e", "d", "c", "b", "a"]);
921
922 Ok(())
923 }
924
925 #[test]
926 fn duplicates_distinct_sort_asc() -> Result<()> {
927 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
928 .distinct()
929 .order_by_col("col", SortOptions::new(false, false))
930 .build_two()?;
931
932 acc1.update_batch(&[data(["a", "c", "b"])])?;
933 acc2.update_batch(&[data(["b", "c", "a"])])?;
934 acc1 = merge(acc1, acc2)?;
935
936 let result = print_nulls(str_arr(acc1.evaluate()?)?);
937
938 assert_eq!(result, vec!["a", "b", "c"]);
939
940 Ok(())
941 }
942
943 #[test]
944 fn duplicates_distinct_sort_desc() -> Result<()> {
945 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
946 .distinct()
947 .order_by_col("col", SortOptions::new(true, false))
948 .build_two()?;
949
950 acc1.update_batch(&[data(["a", "c", "b"])])?;
951 acc2.update_batch(&[data(["b", "c", "a"])])?;
952 acc1 = merge(acc1, acc2)?;
953
954 let result = print_nulls(str_arr(acc1.evaluate()?)?);
955
956 assert_eq!(result, vec!["c", "b", "a"]);
957
958 Ok(())
959 }
960
961 #[test]
962 fn no_duplicates_distinct_sort_asc_nulls_first() -> Result<()> {
963 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
964 .distinct()
965 .order_by_col("col", SortOptions::new(false, true))
966 .build_two()?;
967
968 acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
969 acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
970 acc1 = merge(acc1, acc2)?;
971
972 let result = print_nulls(str_arr(acc1.evaluate()?)?);
973
974 assert_eq!(result, vec!["NULL", "a", "b", "e", "f"]);
975
976 Ok(())
977 }
978
979 #[test]
980 fn no_duplicates_distinct_sort_asc_nulls_last() -> Result<()> {
981 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
982 .distinct()
983 .order_by_col("col", SortOptions::new(false, false))
984 .build_two()?;
985
986 acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
987 acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
988 acc1 = merge(acc1, acc2)?;
989
990 let result = print_nulls(str_arr(acc1.evaluate()?)?);
991
992 assert_eq!(result, vec!["a", "b", "e", "f", "NULL"]);
993
994 Ok(())
995 }
996
997 #[test]
998 fn no_duplicates_distinct_sort_desc_nulls_first() -> Result<()> {
999 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1000 .distinct()
1001 .order_by_col("col", SortOptions::new(true, true))
1002 .build_two()?;
1003
1004 acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1005 acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1006 acc1 = merge(acc1, acc2)?;
1007
1008 let result = print_nulls(str_arr(acc1.evaluate()?)?);
1009
1010 assert_eq!(result, vec!["NULL", "f", "e", "b", "a"]);
1011
1012 Ok(())
1013 }
1014
1015 #[test]
1016 fn no_duplicates_distinct_sort_desc_nulls_last() -> Result<()> {
1017 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1018 .distinct()
1019 .order_by_col("col", SortOptions::new(true, false))
1020 .build_two()?;
1021
1022 acc1.update_batch(&[data([Some("e"), Some("b"), None])])?;
1023 acc2.update_batch(&[data([Some("f"), Some("a"), None])])?;
1024 acc1 = merge(acc1, acc2)?;
1025
1026 let result = print_nulls(str_arr(acc1.evaluate()?)?);
1027
1028 assert_eq!(result, vec!["f", "e", "b", "a", "NULL"]);
1029
1030 Ok(())
1031 }
1032
1033 #[test]
1034 fn all_nulls_on_first_batch_with_distinct() -> Result<()> {
1035 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1036 .distinct()
1037 .build_two()?;
1038
1039 acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1040 acc2.update_batch(&[data([Some("a"), None, None, None])])?;
1041 acc1 = merge(acc1, acc2)?;
1042
1043 let mut result = print_nulls(str_arr(acc1.evaluate()?)?);
1044 result.sort();
1045 assert_eq!(result, vec!["NULL", "a"]);
1046 Ok(())
1047 }
1048
1049 #[test]
1050 fn all_nulls_on_both_batches_with_distinct() -> Result<()> {
1051 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1052 .distinct()
1053 .build_two()?;
1054
1055 acc1.update_batch(&[data::<Option<&str>, 3>([None, None, None])])?;
1056 acc2.update_batch(&[data::<Option<&str>, 4>([None, None, None, None])])?;
1057 acc1 = merge(acc1, acc2)?;
1058
1059 let result = print_nulls(str_arr(acc1.evaluate()?)?);
1060 assert_eq!(result, vec!["NULL"]);
1061 Ok(())
1062 }
1063
1064 #[test]
1065 fn does_not_over_account_memory() -> Result<()> {
1066 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?;
1067
1068 acc1.update_batch(&[data(["a", "c", "b"])])?;
1069 acc2.update_batch(&[data(["b", "c", "a"])])?;
1070 acc1 = merge(acc1, acc2)?;
1071
1072 assert_eq!(acc1.size(), 266);
1073
1074 Ok(())
1075 }
1076 #[test]
1077 fn does_not_over_account_memory_distinct() -> Result<()> {
1078 let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string()
1079 .distinct()
1080 .build_two()?;
1081
1082 acc1.update_batch(&[string_list_data([
1083 vec!["a", "b", "c"],
1084 vec!["d", "e", "f"],
1085 ])])?;
1086 acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
1087 acc1 = merge(acc1, acc2)?;
1088
1089 assert_eq!(acc1.size(), 1660);
1091
1092 Ok(())
1093 }
1094
1095 #[test]
1096 fn does_not_over_account_memory_ordered() -> Result<()> {
1097 let mut acc = ArrayAggAccumulatorBuilder::string()
1098 .order_by_col("col", SortOptions::new(false, false))
1099 .build()?;
1100
1101 acc.update_batch(&[string_list_data([
1102 vec!["a", "b", "c"],
1103 vec!["c", "d", "e"],
1104 vec!["b", "c", "d"],
1105 ])])?;
1106
1107 assert_eq!(acc.size(), 2112);
1109
1110 Ok(())
1111 }
1112
1113 struct ArrayAggAccumulatorBuilder {
1114 return_field: FieldRef,
1115 distinct: bool,
1116 order_bys: Vec<PhysicalSortExpr>,
1117 schema: Schema,
1118 }
1119
1120 impl ArrayAggAccumulatorBuilder {
1121 fn string() -> Self {
1122 Self::new(DataType::Utf8)
1123 }
1124
1125 fn new(data_type: DataType) -> Self {
1126 Self {
1127 return_field: Field::new("f", data_type.clone(), true).into(),
1128 distinct: false,
1129 order_bys: vec![],
1130 schema: Schema {
1131 fields: Fields::from(vec![Field::new(
1132 "col",
1133 DataType::new_list(data_type, true),
1134 true,
1135 )]),
1136 metadata: Default::default(),
1137 },
1138 }
1139 }
1140
1141 fn distinct(mut self) -> Self {
1142 self.distinct = true;
1143 self
1144 }
1145
1146 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
1147 let new_order = PhysicalSortExpr::new(
1148 Arc::new(
1149 Column::new_with_schema(col, &self.schema)
1150 .expect("column not available in schema"),
1151 ),
1152 sort_options,
1153 );
1154 self.order_bys.push(new_order);
1155 self
1156 }
1157
1158 fn build(&self) -> Result<Box<dyn Accumulator>> {
1159 ArrayAgg::default().accumulator(AccumulatorArgs {
1160 return_field: Arc::clone(&self.return_field),
1161 schema: &self.schema,
1162 ignore_nulls: false,
1163 order_bys: &self.order_bys,
1164 is_reversed: false,
1165 name: "",
1166 is_distinct: self.distinct,
1167 exprs: &[Arc::new(Column::new("col", 0))],
1168 })
1169 }
1170
1171 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
1172 Ok((self.build()?, self.build()?))
1173 }
1174 }
1175
1176 fn str_arr(value: ScalarValue) -> Result<Vec<Option<String>>> {
1177 let ScalarValue::List(list) = value else {
1178 return internal_err!("ScalarValue was not a List");
1179 };
1180 Ok(as_generic_string_array::<i32>(list.values())?
1181 .iter()
1182 .map(|v| v.map(|v| v.to_string()))
1183 .collect())
1184 }
1185
1186 fn print_nulls(sort: Vec<Option<String>>) -> Vec<String> {
1187 sort.into_iter()
1188 .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
1189 .collect()
1190 }
1191
1192 fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef {
1193 let mut builder = ListBuilder::new(StringBuilder::new());
1194 for string_list in data.into_iter() {
1195 builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>());
1196 }
1197
1198 Arc::new(builder.finish())
1199 }
1200
1201 fn data<T, const N: usize>(list: [T; N]) -> ArrayRef
1202 where
1203 ScalarValue: From<T>,
1204 {
1205 let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect();
1206 ScalarValue::iter_to_array(values).expect("Cannot convert to array")
1207 }
1208
1209 fn merge(
1210 mut acc1: Box<dyn Accumulator>,
1211 mut acc2: Box<dyn Accumulator>,
1212 ) -> Result<Box<dyn Accumulator>> {
1213 let intermediate_state = acc2.state().and_then(|e| {
1214 e.iter()
1215 .map(|v| v.to_array())
1216 .collect::<Result<Vec<ArrayRef>>>()
1217 })?;
1218 acc1.merge_batch(&intermediate_state)?;
1219 Ok(acc1)
1220 }
1221}