1use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray};
21use arrow::datatypes::{DataType, Field, Fields};
22
23use datafusion_common::cast::as_list_array;
24use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
25use datafusion_common::{exec_err, ScalarValue};
26use datafusion_common::{internal_err, Result};
27use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
28use datafusion_expr::utils::format_state_name;
29use datafusion_expr::{Accumulator, Signature, Volatility};
30use datafusion_expr::{AggregateUDFImpl, Documentation};
31use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
32use datafusion_functions_aggregate_common::utils::ordering_fields;
33use datafusion_macros::user_doc;
34use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
35use std::collections::{HashSet, VecDeque};
36use std::mem::{size_of, size_of_val};
37use std::sync::Arc;
38
39make_udaf_expr_and_func!(
40 ArrayAgg,
41 array_agg,
42 expression,
43 "input values, including nulls, concatenated into an array",
44 array_agg_udaf
45);
46
47#[user_doc(
48 doc_section(label = "General Functions"),
49 description = "Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.",
50 syntax_example = "array_agg(expression [ORDER BY expression])",
51 sql_example = r#"```sql
52> SELECT array_agg(column_name ORDER BY other_column) FROM table_name;
53+-----------------------------------------------+
54| array_agg(column_name ORDER BY other_column) |
55+-----------------------------------------------+
56| [element1, element2, element3] |
57+-----------------------------------------------+
58```"#,
59 standard_argument(name = "expression",)
60)]
61#[derive(Debug)]
62pub struct ArrayAgg {
64 signature: Signature,
65}
66
67impl Default for ArrayAgg {
68 fn default() -> Self {
69 Self {
70 signature: Signature::any(1, Volatility::Immutable),
71 }
72 }
73}
74
75impl AggregateUDFImpl for ArrayAgg {
76 fn as_any(&self) -> &dyn std::any::Any {
77 self
78 }
79
80 fn name(&self) -> &str {
81 "array_agg"
82 }
83
84 fn aliases(&self) -> &[String] {
85 &[]
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93 Ok(DataType::List(Arc::new(Field::new_list_field(
94 arg_types[0].clone(),
95 true,
96 ))))
97 }
98
99 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
100 if args.is_distinct {
101 return Ok(vec![Field::new_list(
102 format_state_name(args.name, "distinct_array_agg"),
103 Field::new_list_field(args.input_types[0].clone(), true),
105 true,
106 )]);
107 }
108
109 let mut fields = vec![Field::new_list(
110 format_state_name(args.name, "array_agg"),
111 Field::new_list_field(args.input_types[0].clone(), true),
113 true,
114 )];
115
116 if args.ordering_fields.is_empty() {
117 return Ok(fields);
118 }
119
120 let orderings = args.ordering_fields.to_vec();
121 fields.push(Field::new_list(
122 format_state_name(args.name, "array_agg_orderings"),
123 Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
124 false,
125 ));
126
127 Ok(fields)
128 }
129
130 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
131 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
132
133 if acc_args.is_distinct {
134 return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?));
135 }
136
137 if acc_args.ordering_req.is_empty() {
138 return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?));
139 }
140
141 let ordering_dtypes = acc_args
142 .ordering_req
143 .iter()
144 .map(|e| e.expr.data_type(acc_args.schema))
145 .collect::<Result<Vec<_>>>()?;
146
147 OrderSensitiveArrayAggAccumulator::try_new(
148 &data_type,
149 &ordering_dtypes,
150 acc_args.ordering_req.clone(),
151 acc_args.is_reversed,
152 )
153 .map(|acc| Box::new(acc) as _)
154 }
155
156 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
157 datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf())
158 }
159
160 fn documentation(&self) -> Option<&Documentation> {
161 self.doc()
162 }
163}
164
165#[derive(Debug)]
166pub struct ArrayAggAccumulator {
167 values: Vec<ArrayRef>,
168 datatype: DataType,
169}
170
171impl ArrayAggAccumulator {
172 pub fn try_new(datatype: &DataType) -> Result<Self> {
174 Ok(Self {
175 values: vec![],
176 datatype: datatype.clone(),
177 })
178 }
179
180 fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
183 let offsets = list_array.value_offsets();
184 let initial_offset = offsets[0];
186 let null_count = list_array.null_count();
187
188 if null_count == 0 {
191 let list_values = list_array.values().slice(
193 initial_offset as usize,
194 (offsets[offsets.len() - 1] - initial_offset) as usize,
195 );
196 return Some(list_values);
197 }
198
199 if list_array.null_count() == list_array.len() {
201 return Some(list_array.values().slice(0, 0));
202 }
203
204 let nulls = list_array.nulls().unwrap();
209
210 let mut valid_slices_iter = nulls.valid_slices();
211
212 let (start, end) = valid_slices_iter.next().unwrap();
214
215 let start_offset = offsets[start];
216
217 let mut end_offset_of_last_valid_value = offsets[end];
220
221 for (start, end) in valid_slices_iter {
222 if offsets[start] != end_offset_of_last_valid_value {
225 return None;
226 }
227
228 end_offset_of_last_valid_value = offsets[end];
231 }
232
233 let consecutive_valid_values = list_array.values().slice(
234 start_offset as usize,
235 (end_offset_of_last_valid_value - start_offset) as usize,
236 );
237
238 Some(consecutive_valid_values)
239 }
240}
241
242impl Accumulator for ArrayAggAccumulator {
243 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
244 if values.is_empty() {
246 return Ok(());
247 }
248
249 if values.len() != 1 {
250 return internal_err!("expects single batch");
251 }
252
253 let val = Arc::clone(&values[0]);
254 if val.len() > 0 {
255 self.values.push(val);
256 }
257 Ok(())
258 }
259
260 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
261 if states.is_empty() {
263 return Ok(());
264 }
265
266 if states.len() != 1 {
267 return internal_err!("expects single state");
268 }
269
270 let list_arr = as_list_array(&states[0])?;
271
272 match Self::get_optional_values_to_merge_as_is(list_arr) {
273 Some(values) => {
274 if values.len() > 0 {
276 self.values.push(values);
277 }
278 }
279 None => {
280 for arr in list_arr.iter().flatten() {
281 self.values.push(arr);
282 }
283 }
284 }
285
286 Ok(())
287 }
288
289 fn state(&mut self) -> Result<Vec<ScalarValue>> {
290 Ok(vec![self.evaluate()?])
291 }
292
293 fn evaluate(&mut self) -> Result<ScalarValue> {
294 let element_arrays: Vec<&dyn Array> =
296 self.values.iter().map(|a| a.as_ref()).collect();
297
298 if element_arrays.is_empty() {
299 return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
300 }
301
302 let concated_array = arrow::compute::concat(&element_arrays)?;
303
304 Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar())
305 }
306
307 fn size(&self) -> usize {
308 size_of_val(self)
309 + (size_of::<ArrayRef>() * self.values.capacity())
310 + self
311 .values
312 .iter()
313 .map(|arr| arr.get_array_memory_size())
314 .sum::<usize>()
315 + self.datatype.size()
316 - size_of_val(&self.datatype)
317 }
318}
319
320#[derive(Debug)]
321struct DistinctArrayAggAccumulator {
322 values: HashSet<ScalarValue>,
323 datatype: DataType,
324}
325
326impl DistinctArrayAggAccumulator {
327 pub fn try_new(datatype: &DataType) -> Result<Self> {
328 Ok(Self {
329 values: HashSet::new(),
330 datatype: datatype.clone(),
331 })
332 }
333}
334
335impl Accumulator for DistinctArrayAggAccumulator {
336 fn state(&mut self) -> Result<Vec<ScalarValue>> {
337 Ok(vec![self.evaluate()?])
338 }
339
340 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
341 if values.len() != 1 {
342 return internal_err!("expects single batch");
343 }
344
345 let array = &values[0];
346
347 for i in 0..array.len() {
348 let scalar = ScalarValue::try_from_array(&array, i)?;
349 self.values.insert(scalar);
350 }
351
352 Ok(())
353 }
354
355 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
356 if states.is_empty() {
357 return Ok(());
358 }
359
360 if states.len() != 1 {
361 return internal_err!("expects single state");
362 }
363
364 states[0]
365 .as_list::<i32>()
366 .iter()
367 .flatten()
368 .try_for_each(|val| self.update_batch(&[val]))
369 }
370
371 fn evaluate(&mut self) -> Result<ScalarValue> {
372 let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
373 if values.is_empty() {
374 return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
375 }
376 let arr = ScalarValue::new_list(&values, &self.datatype, true);
377 Ok(ScalarValue::List(arr))
378 }
379
380 fn size(&self) -> usize {
381 size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
382 - size_of_val(&self.values)
383 + self.datatype.size()
384 - size_of_val(&self.datatype)
385 }
386}
387
388#[derive(Debug)]
392pub(crate) struct OrderSensitiveArrayAggAccumulator {
393 values: Vec<ScalarValue>,
395 ordering_values: Vec<Vec<ScalarValue>>,
400 datatypes: Vec<DataType>,
403 ordering_req: LexOrdering,
405 reverse: bool,
407}
408
409impl OrderSensitiveArrayAggAccumulator {
410 pub fn try_new(
413 datatype: &DataType,
414 ordering_dtypes: &[DataType],
415 ordering_req: LexOrdering,
416 reverse: bool,
417 ) -> Result<Self> {
418 let mut datatypes = vec![datatype.clone()];
419 datatypes.extend(ordering_dtypes.iter().cloned());
420 Ok(Self {
421 values: vec![],
422 ordering_values: vec![],
423 datatypes,
424 ordering_req,
425 reverse,
426 })
427 }
428}
429
430impl Accumulator for OrderSensitiveArrayAggAccumulator {
431 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
432 if values.is_empty() {
433 return Ok(());
434 }
435
436 let n_row = values[0].len();
437 for index in 0..n_row {
438 let row = get_row_at_idx(values, index)?;
439 self.values.push(row[0].clone());
440 self.ordering_values.push(row[1..].to_vec());
441 }
442
443 Ok(())
444 }
445
446 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
447 if states.is_empty() {
448 return Ok(());
449 }
450
451 let [array_agg_values, agg_orderings, ..] = &states else {
458 return exec_err!("State should have two elements");
459 };
460 let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
461 return exec_err!("Expects to receive a list array");
462 };
463
464 let mut partition_values = vec![];
466 let mut partition_ordering_values = vec![];
468
469 partition_values.push(self.values.clone().into());
471 partition_ordering_values.push(self.ordering_values.clone().into());
472
473 let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
475 for v in array_agg_res.into_iter() {
476 partition_values.push(v.into());
477 }
478
479 let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
480
481 for partition_ordering_rows in orderings.into_iter() {
482 let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
484 if let ScalarValue::Struct(s) = ordering_row {
485 let mut ordering_columns_per_row = vec![];
486
487 for column in s.columns() {
488 let sv = ScalarValue::try_from_array(column, 0)?;
489 ordering_columns_per_row.push(sv);
490 }
491
492 Ok(ordering_columns_per_row)
493 } else {
494 exec_err!(
495 "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}",
496 ordering_row.data_type()
497 )
498 }
499 }).collect::<Result<VecDeque<_>>>()?;
500
501 partition_ordering_values.push(ordering_value);
502 }
503
504 let sort_options = self
505 .ordering_req
506 .iter()
507 .map(|sort_expr| sort_expr.options)
508 .collect::<Vec<_>>();
509
510 (self.values, self.ordering_values) = merge_ordered_arrays(
511 &mut partition_values,
512 &mut partition_ordering_values,
513 &sort_options,
514 )?;
515
516 Ok(())
517 }
518
519 fn state(&mut self) -> Result<Vec<ScalarValue>> {
520 let mut result = vec![self.evaluate()?];
521 result.push(self.evaluate_orderings()?);
522
523 Ok(result)
524 }
525
526 fn evaluate(&mut self) -> Result<ScalarValue> {
527 if self.values.is_empty() {
528 return Ok(ScalarValue::new_null_list(
529 self.datatypes[0].clone(),
530 true,
531 1,
532 ));
533 }
534
535 let values = self.values.clone();
536 let array = if self.reverse {
537 ScalarValue::new_list_from_iter(
538 values.into_iter().rev(),
539 &self.datatypes[0],
540 true,
541 )
542 } else {
543 ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
544 };
545 Ok(ScalarValue::List(array))
546 }
547
548 fn size(&self) -> usize {
549 let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values)
550 - size_of_val(&self.values);
551
552 total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
554 for row in &self.ordering_values {
555 total += ScalarValue::size_of_vec(row) - size_of_val(row);
556 }
557
558 total += size_of::<DataType>() * self.datatypes.capacity();
560 for dtype in &self.datatypes {
561 total += dtype.size() - size_of_val(dtype);
562 }
563
564 total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
566 total
568 }
569}
570
571impl OrderSensitiveArrayAggAccumulator {
572 fn evaluate_orderings(&self) -> Result<ScalarValue> {
573 let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]);
574 let num_columns = fields.len();
575 let struct_field = Fields::from(fields.clone());
576
577 let mut column_wise_ordering_values = vec![];
578 for i in 0..num_columns {
579 let column_values = self
580 .ordering_values
581 .iter()
582 .map(|x| x[i].clone())
583 .collect::<Vec<_>>();
584 let array = if column_values.is_empty() {
585 new_empty_array(fields[i].data_type())
586 } else {
587 ScalarValue::iter_to_array(column_values.into_iter())?
588 };
589 column_wise_ordering_values.push(array);
590 }
591
592 let ordering_array =
593 StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
594 Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 use std::collections::VecDeque;
603 use std::sync::Arc;
604
605 use arrow::array::Int64Array;
606 use arrow::compute::SortOptions;
607
608 use datafusion_common::utils::get_row_at_idx;
609 use datafusion_common::{Result, ScalarValue};
610
611 #[test]
612 fn test_merge_asc() -> Result<()> {
613 let lhs_arrays: Vec<ArrayRef> = vec![
614 Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])),
615 Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])),
616 ];
617 let n_row = lhs_arrays[0].len();
618 let lhs_orderings = (0..n_row)
619 .map(|idx| get_row_at_idx(&lhs_arrays, idx))
620 .collect::<Result<VecDeque<_>>>()?;
621
622 let rhs_arrays: Vec<ArrayRef> = vec![
623 Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])),
624 Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])),
625 ];
626 let n_row = rhs_arrays[0].len();
627 let rhs_orderings = (0..n_row)
628 .map(|idx| get_row_at_idx(&rhs_arrays, idx))
629 .collect::<Result<VecDeque<_>>>()?;
630 let sort_options = vec![
631 SortOptions {
632 descending: false,
633 nulls_first: false,
634 },
635 SortOptions {
636 descending: false,
637 nulls_first: false,
638 },
639 ];
640
641 let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef;
642 let lhs_vals = (0..lhs_vals_arr.len())
643 .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx))
644 .collect::<Result<VecDeque<_>>>()?;
645
646 let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef;
647 let rhs_vals = (0..rhs_vals_arr.len())
648 .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx))
649 .collect::<Result<VecDeque<_>>>()?;
650 let expected =
651 Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef;
652 let expected_ts = vec![
653 Arc::new(Int64Array::from(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2])) as ArrayRef,
654 Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef,
655 ];
656
657 let (merged_vals, merged_ts) = merge_ordered_arrays(
658 &mut [lhs_vals, rhs_vals],
659 &mut [lhs_orderings, rhs_orderings],
660 &sort_options,
661 )?;
662 let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?;
663 let merged_ts = (0..merged_ts[0].len())
664 .map(|col_idx| {
665 ScalarValue::iter_to_array(
666 (0..merged_ts.len())
667 .map(|row_idx| merged_ts[row_idx][col_idx].clone()),
668 )
669 })
670 .collect::<Result<Vec<_>>>()?;
671
672 assert_eq!(&merged_vals, &expected);
673 assert_eq!(&merged_ts, &expected_ts);
674
675 Ok(())
676 }
677
678 #[test]
679 fn test_merge_desc() -> Result<()> {
680 let lhs_arrays: Vec<ArrayRef> = vec![
681 Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])),
682 Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])),
683 ];
684 let n_row = lhs_arrays[0].len();
685 let lhs_orderings = (0..n_row)
686 .map(|idx| get_row_at_idx(&lhs_arrays, idx))
687 .collect::<Result<VecDeque<_>>>()?;
688
689 let rhs_arrays: Vec<ArrayRef> = vec![
690 Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])),
691 Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])),
692 ];
693 let n_row = rhs_arrays[0].len();
694 let rhs_orderings = (0..n_row)
695 .map(|idx| get_row_at_idx(&rhs_arrays, idx))
696 .collect::<Result<VecDeque<_>>>()?;
697 let sort_options = vec![
698 SortOptions {
699 descending: true,
700 nulls_first: false,
701 },
702 SortOptions {
703 descending: true,
704 nulls_first: false,
705 },
706 ];
707
708 let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef;
710 let lhs_vals = (0..lhs_vals_arr.len())
711 .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx))
712 .collect::<Result<VecDeque<_>>>()?;
713
714 let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef;
715 let rhs_vals = (0..rhs_vals_arr.len())
716 .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx))
717 .collect::<Result<VecDeque<_>>>()?;
718 let expected =
719 Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef;
720 let expected_ts = vec![
721 Arc::new(Int64Array::from(vec![2, 2, 1, 1, 1, 1, 0, 0, 0, 0])) as ArrayRef,
722 Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef,
723 ];
724 let (merged_vals, merged_ts) = merge_ordered_arrays(
725 &mut [lhs_vals, rhs_vals],
726 &mut [lhs_orderings, rhs_orderings],
727 &sort_options,
728 )?;
729 let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?;
730 let merged_ts = (0..merged_ts[0].len())
731 .map(|col_idx| {
732 ScalarValue::iter_to_array(
733 (0..merged_ts.len())
734 .map(|row_idx| merged_ts[row_idx][col_idx].clone()),
735 )
736 })
737 .collect::<Result<Vec<_>>>()?;
738
739 assert_eq!(&merged_vals, &expected);
740 assert_eq!(&merged_ts, &expected_ts);
741 Ok(())
742 }
743}