1use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::{ArrayRef, AsArray, BooleanArray, LargeStringArray};
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29use datafusion_common::{
30 Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err,
31};
32use datafusion_expr::function::AccumulatorArgs;
33use datafusion_expr::utils::format_state_name;
34use datafusion_expr::{
35 Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature,
36 TypeSignature, Volatility,
37};
38use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
39use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
40use datafusion_macros::user_doc;
41use datafusion_physical_expr::expressions::Literal;
42
43make_udaf_expr_and_func!(
44 StringAgg,
45 string_agg,
46 expr delimiter,
47 "Concatenates the values of string expressions and places separator values between them",
48 string_agg_udaf
49);
50
51#[user_doc(
52 doc_section(label = "General Functions"),
53 description = "Concatenates the values of string expressions and places separator values between them. \
54If ordering is required, strings are concatenated in the specified order. \
55This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
56 syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
57 sql_example = r#"```sql
58> SELECT string_agg(name, ', ') AS names_list
59 FROM employee;
60+--------------------------+
61| names_list |
62+--------------------------+
63| Alice, Bob, Bob, Charlie |
64+--------------------------+
65> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
66 FROM employee;
67+--------------------------+
68| names_list |
69+--------------------------+
70| Charlie, Bob, Bob, Alice |
71+--------------------------+
72> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
73 FROM employee;
74+--------------------------+
75| names_list |
76+--------------------------+
77| Charlie, Bob, Alice |
78+--------------------------+
79```"#,
80 argument(
81 name = "expression",
82 description = "The string expression to concatenate. Can be a column or any valid string expression."
83 ),
84 argument(
85 name = "delimiter",
86 description = "A literal string used as a separator between the concatenated values."
87 )
88)]
89#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct StringAgg {
92 signature: Signature,
93 array_agg: ArrayAgg,
94}
95
96impl StringAgg {
97 pub fn new() -> Self {
99 Self {
100 signature: Signature::one_of(
101 vec![
102 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
103 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
104 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
105 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
106 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
107 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
108 TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
109 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
110 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
111 TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
112 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
113 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
114 ],
115 Volatility::Immutable,
116 ),
117 array_agg: Default::default(),
118 }
119 }
120
121 fn extract_delimiter(args: &AccumulatorArgs) -> Result<String> {
123 let Some(lit) = args.exprs[1].downcast_ref::<Literal>() else {
124 return not_impl_err!("string_agg delimiter must be a string literal");
125 };
126
127 if lit.value().is_null() {
128 return Ok(String::new());
129 }
130
131 match lit.value().try_as_str() {
132 Some(s) => Ok(s.unwrap_or("").to_string()),
133 None => {
134 not_impl_err!(
135 "string_agg not supported for delimiter \"{}\"",
136 lit.value()
137 )
138 }
139 }
140 }
141}
142
143impl Default for StringAgg {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl AggregateUDFImpl for StringAgg {
154 fn name(&self) -> &str {
155 "string_agg"
156 }
157
158 fn signature(&self) -> &Signature {
159 &self.signature
160 }
161
162 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
163 Ok(DataType::LargeUtf8)
164 }
165
166 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
167 if !args.is_distinct && args.ordering_fields.is_empty() {
168 Ok(vec![
169 Field::new(
170 format_state_name(args.name, "string_agg"),
171 DataType::LargeUtf8,
172 true,
173 )
174 .into(),
175 ])
176 } else {
177 self.array_agg.state_fields(args)
178 }
179 }
180
181 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
182 let delimiter = Self::extract_delimiter(&acc_args)?;
183
184 if !acc_args.is_distinct && acc_args.order_bys.is_empty() {
185 Ok(Box::new(SimpleStringAggAccumulator::new(&delimiter)))
186 } else {
187 let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
188 return_field: Field::new(
189 "f",
190 DataType::new_list(acc_args.return_field.data_type().clone(), true),
191 true,
192 )
193 .into(),
194 exprs: &filter_index(acc_args.exprs, 1),
195 expr_fields: &filter_index(acc_args.expr_fields, 1),
196 schema: acc_args.schema,
200 ignore_nulls: acc_args.ignore_nulls,
201 order_bys: acc_args.order_bys,
202 is_reversed: acc_args.is_reversed,
203 name: acc_args.name,
204 is_distinct: acc_args.is_distinct,
205 })?;
206
207 Ok(Box::new(StringAggAccumulator::new(
208 array_agg_acc,
209 &delimiter,
210 )))
211 }
212 }
213
214 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
215 datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
216 }
217
218 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
219 !args.is_distinct && args.order_bys.is_empty()
220 }
221
222 fn create_groups_accumulator(
223 &self,
224 args: AccumulatorArgs,
225 ) -> Result<Box<dyn GroupsAccumulator>> {
226 let delimiter = Self::extract_delimiter(&args)?;
227 Ok(Box::new(StringAggGroupsAccumulator::new(delimiter)))
228 }
229
230 fn documentation(&self) -> Option<&Documentation> {
231 self.doc()
232 }
233}
234
235#[derive(Debug)]
237pub(crate) struct StringAggAccumulator {
238 array_agg_acc: Box<dyn Accumulator>,
239 delimiter: String,
240}
241
242impl StringAggAccumulator {
243 pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
244 Self {
245 array_agg_acc,
246 delimiter: delimiter.to_string(),
247 }
248 }
249}
250
251impl Accumulator for StringAggAccumulator {
252 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
253 self.array_agg_acc.update_batch(&filter_index(values, 1))
254 }
255
256 fn evaluate(&mut self) -> Result<ScalarValue> {
257 let scalar = self.array_agg_acc.evaluate()?;
258
259 let ScalarValue::List(list) = scalar else {
260 return internal_err!(
261 "Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}",
262 scalar.data_type()
263 );
264 };
265
266 let string_arr: Vec<_> = match list.value_type() {
267 DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
268 .iter()
269 .flatten()
270 .collect(),
271 DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
272 .iter()
273 .flatten()
274 .collect(),
275 DataType::Utf8View => as_string_view_array(list.values())?
276 .iter()
277 .flatten()
278 .collect(),
279 _ => {
280 return internal_err!(
281 "Expected elements to of type Utf8 or LargeUtf8, but got {}",
282 list.value_type()
283 );
284 }
285 };
286
287 if string_arr.is_empty() {
288 return Ok(ScalarValue::LargeUtf8(None));
289 }
290
291 Ok(ScalarValue::LargeUtf8(Some(
292 string_arr.join(&self.delimiter),
293 )))
294 }
295
296 fn size(&self) -> usize {
297 size_of_val(self) - size_of_val(&self.array_agg_acc)
298 + self.array_agg_acc.size()
299 + self.delimiter.capacity()
300 }
301
302 fn state(&mut self) -> Result<Vec<ScalarValue>> {
303 self.array_agg_acc.state()
304 }
305
306 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307 self.array_agg_acc.merge_batch(values)
308 }
309}
310
311fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
312 values
313 .iter()
314 .enumerate()
315 .filter(|(i, _)| *i != index)
316 .map(|(_, v)| v)
317 .cloned()
318 .collect::<Vec<_>>()
319}
320
321#[derive(Debug)]
323struct StringAggGroupsAccumulator {
324 delimiter: String,
326 values: Vec<Option<String>>,
331 total_data_bytes: usize,
333}
334
335impl StringAggGroupsAccumulator {
336 fn new(delimiter: String) -> Self {
337 Self {
338 delimiter,
339 values: Vec::new(),
340 total_data_bytes: 0,
341 }
342 }
343
344 fn append_batch<'a>(
345 &mut self,
346 iter: impl Iterator<Item = Option<&'a str>>,
347 group_indices: &[usize],
348 ) {
349 for (opt_value, &group_idx) in iter.zip(group_indices.iter()) {
350 if let Some(value) = opt_value {
351 match &mut self.values[group_idx] {
352 Some(existing) => {
353 let added = self.delimiter.len() + value.len();
354 existing.reserve(added);
355 existing.push_str(&self.delimiter);
356 existing.push_str(value);
357 self.total_data_bytes += added;
358 }
359 slot @ None => {
360 *slot = Some(value.to_string());
361 self.total_data_bytes += value.len();
362 }
363 }
364 }
365 }
366 }
367}
368
369impl GroupsAccumulator for StringAggGroupsAccumulator {
370 fn update_batch(
371 &mut self,
372 values: &[ArrayRef],
373 group_indices: &[usize],
374 opt_filter: Option<&BooleanArray>,
375 total_num_groups: usize,
376 ) -> Result<()> {
377 self.values.resize(total_num_groups, None);
378 let array = apply_filter_as_nulls(&values[0], opt_filter)?;
379 match array.data_type() {
380 DataType::Utf8 => {
381 self.append_batch(array.as_string::<i32>().iter(), group_indices)
382 }
383 DataType::LargeUtf8 => {
384 self.append_batch(array.as_string::<i64>().iter(), group_indices)
385 }
386 DataType::Utf8View => {
387 self.append_batch(array.as_string_view().iter(), group_indices)
388 }
389 other => {
390 return internal_err!("string_agg unexpected data type: {other}");
391 }
392 }
393 Ok(())
394 }
395
396 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
397 let to_emit = emit_to.take_needed(&mut self.values);
398 let emitted_bytes: usize = to_emit
399 .iter()
400 .filter_map(|opt| opt.as_ref().map(|s| s.len()))
401 .sum();
402 self.total_data_bytes -= emitted_bytes;
403
404 let result: ArrayRef = Arc::new(LargeStringArray::from(to_emit));
405 Ok(result)
406 }
407
408 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
409 self.evaluate(emit_to).map(|arr| vec![arr])
410 }
411
412 fn merge_batch(
413 &mut self,
414 values: &[ArrayRef],
415 group_indices: &[usize],
416 opt_filter: Option<&BooleanArray>,
417 total_num_groups: usize,
418 ) -> Result<()> {
419 self.update_batch(values, group_indices, opt_filter, total_num_groups)
421 }
422
423 fn convert_to_state(
424 &self,
425 values: &[ArrayRef],
426 opt_filter: Option<&BooleanArray>,
427 ) -> Result<Vec<ArrayRef>> {
428 let input = apply_filter_as_nulls(&values[0], opt_filter)?;
429 let result = if input.data_type() == &DataType::LargeUtf8 {
430 input
431 } else {
432 arrow::compute::cast(&input, &DataType::LargeUtf8)?
433 };
434 Ok(vec![result])
435 }
436
437 fn supports_convert_to_state(&self) -> bool {
438 true
439 }
440
441 fn size(&self) -> usize {
442 self.total_data_bytes
443 + self.values.capacity() * size_of::<Option<String>>()
444 + self.delimiter.capacity()
445 + size_of_val(self)
446 }
447}
448
449#[derive(Debug)]
452pub(crate) struct SimpleStringAggAccumulator {
453 delimiter: String,
454 accumulated_string: String,
456 has_value: bool,
457}
458
459impl SimpleStringAggAccumulator {
460 pub fn new(delimiter: &str) -> Self {
461 Self {
462 delimiter: delimiter.to_string(),
463 accumulated_string: String::new(),
464 has_value: false,
465 }
466 }
467
468 #[inline]
469 fn append_strings<'a, I>(&mut self, iter: I)
470 where
471 I: Iterator<Item = Option<&'a str>>,
472 {
473 for value in iter.flatten() {
474 if self.has_value {
475 self.accumulated_string.push_str(&self.delimiter);
476 }
477
478 self.accumulated_string.push_str(value);
479 self.has_value = true;
480 }
481 }
482}
483
484impl Accumulator for SimpleStringAggAccumulator {
485 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
486 let string_arr = values.first().ok_or_else(|| {
487 internal_datafusion_err!(
488 "Planner should ensure its first arg is Utf8/Utf8View"
489 )
490 })?;
491
492 match string_arr.data_type() {
493 DataType::Utf8 => self.append_strings(string_arr.as_string::<i32>().iter()),
494 DataType::LargeUtf8 => {
495 self.append_strings(string_arr.as_string::<i64>().iter())
496 }
497 DataType::Utf8View => self.append_strings(string_arr.as_string_view().iter()),
498 other => {
499 return internal_err!(
500 "Planner should ensure string_agg first argument is Utf8-like, found {other}"
501 );
502 }
503 }
504
505 Ok(())
506 }
507
508 fn evaluate(&mut self) -> Result<ScalarValue> {
509 if self.has_value {
510 Ok(ScalarValue::LargeUtf8(Some(
511 self.accumulated_string.clone(),
512 )))
513 } else {
514 Ok(ScalarValue::LargeUtf8(None))
515 }
516 }
517
518 fn size(&self) -> usize {
519 size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
520 }
521
522 fn state(&mut self) -> Result<Vec<ScalarValue>> {
523 let result = if self.has_value {
524 ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
525 } else {
526 ScalarValue::LargeUtf8(None)
527 };
528 self.has_value = false;
529
530 Ok(vec![result])
531 }
532
533 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
534 self.update_batch(values)
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use arrow::array::LargeStringArray;
542 use arrow::compute::SortOptions;
543 use arrow::datatypes::{Fields, Schema};
544 use datafusion_physical_expr::expressions::Column;
545 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
546 use std::sync::Arc;
547
548 #[test]
549 fn no_duplicates_no_distinct() -> Result<()> {
550 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
551
552 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
553 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
554 acc1 = merge(acc1, acc2)?;
555
556 let result = some_str(acc1.evaluate()?);
557
558 assert_eq!(result, "a,b,c,d,e,f");
559
560 Ok(())
561 }
562
563 #[test]
564 fn no_duplicates_distinct() -> Result<()> {
565 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
566 .distinct()
567 .build_two()?;
568
569 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
570 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
571 acc1 = merge(acc1, acc2)?;
572
573 let result = some_str_sorted(acc1.evaluate()?, ",");
574
575 assert_eq!(result, "a,b,c,d,e,f");
576
577 Ok(())
578 }
579
580 #[test]
581 fn duplicates_no_distinct() -> Result<()> {
582 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
583
584 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
585 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
586 acc1 = merge(acc1, acc2)?;
587
588 let result = some_str(acc1.evaluate()?);
589
590 assert_eq!(result, "a,b,c,a,b,c");
591
592 Ok(())
593 }
594
595 #[test]
596 fn duplicates_distinct() -> Result<()> {
597 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
598 .distinct()
599 .build_two()?;
600
601 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
602 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
603 acc1 = merge(acc1, acc2)?;
604
605 let result = some_str_sorted(acc1.evaluate()?, ",");
606
607 assert_eq!(result, "a,b,c");
608
609 Ok(())
610 }
611
612 #[test]
613 fn no_duplicates_distinct_sort_asc() -> Result<()> {
614 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
615 .distinct()
616 .order_by_col("col", SortOptions::new(false, false))
617 .build_two()?;
618
619 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
620 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
621 acc1 = merge(acc1, acc2)?;
622
623 let result = some_str(acc1.evaluate()?);
624
625 assert_eq!(result, "a,b,c,d,e,f");
626
627 Ok(())
628 }
629
630 #[test]
631 fn no_duplicates_distinct_sort_desc() -> Result<()> {
632 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
633 .distinct()
634 .order_by_col("col", SortOptions::new(true, false))
635 .build_two()?;
636
637 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
638 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
639 acc1 = merge(acc1, acc2)?;
640
641 let result = some_str(acc1.evaluate()?);
642
643 assert_eq!(result, "f,e,d,c,b,a");
644
645 Ok(())
646 }
647
648 #[test]
649 fn duplicates_distinct_sort_asc() -> Result<()> {
650 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
651 .distinct()
652 .order_by_col("col", SortOptions::new(false, false))
653 .build_two()?;
654
655 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
656 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
657 acc1 = merge(acc1, acc2)?;
658
659 let result = some_str(acc1.evaluate()?);
660
661 assert_eq!(result, "a,b,c");
662
663 Ok(())
664 }
665
666 #[test]
667 fn duplicates_distinct_sort_desc() -> Result<()> {
668 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
669 .distinct()
670 .order_by_col("col", SortOptions::new(true, false))
671 .build_two()?;
672
673 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
674 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
675 acc1 = merge(acc1, acc2)?;
676
677 let result = some_str(acc1.evaluate()?);
678
679 assert_eq!(result, "c,b,a");
680
681 Ok(())
682 }
683
684 struct StringAggAccumulatorBuilder {
685 sep: String,
686 distinct: bool,
687 order_bys: Vec<PhysicalSortExpr>,
688 schema: Schema,
689 }
690
691 impl StringAggAccumulatorBuilder {
692 fn new(sep: &str) -> Self {
693 Self {
694 sep: sep.to_string(),
695 distinct: Default::default(),
696 order_bys: vec![],
697 schema: Schema {
698 fields: Fields::from(vec![Field::new(
699 "col",
700 DataType::LargeUtf8,
701 true,
702 )]),
703 metadata: Default::default(),
704 },
705 }
706 }
707 fn distinct(mut self) -> Self {
708 self.distinct = true;
709 self
710 }
711
712 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
713 self.order_bys.extend([PhysicalSortExpr::new(
714 Arc::new(
715 Column::new_with_schema(col, &self.schema)
716 .expect("column not available in schema"),
717 ),
718 sort_options,
719 )]);
720 self
721 }
722
723 fn build(&self) -> Result<Box<dyn Accumulator>> {
724 StringAgg::new().accumulator(AccumulatorArgs {
725 return_field: Field::new("f", DataType::LargeUtf8, true).into(),
726 schema: &self.schema,
727 expr_fields: &[
728 Field::new("col", DataType::LargeUtf8, true).into(),
729 Field::new("lit", DataType::Utf8, false).into(),
730 ],
731 ignore_nulls: false,
732 order_bys: &self.order_bys,
733 is_reversed: false,
734 name: "",
735 is_distinct: self.distinct,
736 exprs: &[
737 Arc::new(Column::new("col", 0)),
738 Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
739 ],
740 })
741 }
742
743 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
744 Ok((self.build()?, self.build()?))
745 }
746 }
747
748 fn some_str(value: ScalarValue) -> String {
749 str(value)
750 .expect("ScalarValue was not a String")
751 .expect("ScalarValue was None")
752 }
753
754 fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
755 let value = some_str(value);
756 let mut parts: Vec<&str> = value.split(sep).collect();
757 parts.sort();
758 parts.join(sep)
759 }
760
761 fn str(value: ScalarValue) -> Result<Option<String>> {
762 match value {
763 ScalarValue::LargeUtf8(v) => Ok(v),
764 _ => internal_err!(
765 "Expected ScalarValue::LargeUtf8, got {}",
766 value.data_type()
767 ),
768 }
769 }
770
771 fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
772 Arc::new(LargeStringArray::from(list.to_vec()))
773 }
774
775 fn merge(
776 mut acc1: Box<dyn Accumulator>,
777 mut acc2: Box<dyn Accumulator>,
778 ) -> Result<Box<dyn Accumulator>> {
779 let intermediate_state = acc2.state().and_then(|e| {
780 e.iter()
781 .map(|v| v.to_array())
782 .collect::<Result<Vec<ArrayRef>>>()
783 })?;
784 acc1.merge_batch(&intermediate_state)?;
785 Ok(acc1)
786 }
787
788 fn make_groups_acc(delimiter: &str) -> StringAggGroupsAccumulator {
793 StringAggGroupsAccumulator::new(delimiter.to_string())
794 }
795
796 fn evaluate_groups(
798 acc: &mut StringAggGroupsAccumulator,
799 emit_to: EmitTo,
800 ) -> Vec<Option<String>> {
801 let result = acc.evaluate(emit_to).unwrap();
802 let arr = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
803 arr.iter().map(|v| v.map(|s| s.to_string())).collect()
804 }
805
806 #[test]
807 fn groups_basic() -> Result<()> {
808 let mut acc = make_groups_acc(",");
809
810 let values: ArrayRef =
812 Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", "f"]));
813 let group_indices = vec![0, 1, 2, 0, 1, 2];
814 acc.update_batch(&[values], &group_indices, None, 3)?;
815
816 let result = evaluate_groups(&mut acc, EmitTo::All);
817 assert_eq!(
818 result,
819 vec![
820 Some("a,d".to_string()),
821 Some("b,e".to_string()),
822 Some("c,f".to_string()),
823 ]
824 );
825 Ok(())
826 }
827
828 #[test]
829 fn groups_with_nulls() -> Result<()> {
830 let mut acc = make_groups_acc("|");
831
832 let values: ArrayRef = Arc::new(LargeStringArray::from(vec![
836 Some("a"),
837 None,
838 Some("c"),
839 None,
840 Some("b"),
841 None,
842 ]));
843 let group_indices = vec![0, 1, 0, 2, 1, 2];
844 acc.update_batch(&[values], &group_indices, None, 3)?;
845
846 let result = evaluate_groups(&mut acc, EmitTo::All);
847 assert_eq!(
848 result,
849 vec![Some("a|c".to_string()), Some("b".to_string()), None,]
850 );
851 Ok(())
852 }
853
854 #[test]
855 fn groups_with_filter() -> Result<()> {
856 let mut acc = make_groups_acc(",");
857
858 let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d"]));
859 let group_indices = vec![0, 0, 1, 1];
860 let filter = BooleanArray::from(vec![true, false, false, true]);
862 acc.update_batch(&[values], &group_indices, Some(&filter), 2)?;
863
864 let result = evaluate_groups(&mut acc, EmitTo::All);
865 assert_eq!(result, vec![Some("a".to_string()), Some("d".to_string())]);
866 Ok(())
867 }
868
869 #[test]
870 fn groups_emit_first() -> Result<()> {
871 let mut acc = make_groups_acc(",");
872
873 let values: ArrayRef =
874 Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", "f"]));
875 let group_indices = vec![0, 1, 2, 0, 1, 2];
876 acc.update_batch(&[values], &group_indices, None, 3)?;
877
878 let result = evaluate_groups(&mut acc, EmitTo::First(2));
880 assert_eq!(
881 result,
882 vec![Some("a,d".to_string()), Some("b,e".to_string())]
883 );
884
885 let result = evaluate_groups(&mut acc, EmitTo::All);
887 assert_eq!(result, vec![Some("c,f".to_string())]);
888 Ok(())
889 }
890
891 #[test]
892 fn groups_merge_batch() -> Result<()> {
893 let mut acc = make_groups_acc(",");
894
895 let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"]));
897 acc.update_batch(&[values], &[0, 1], None, 2)?;
898
899 let partial_state: ArrayRef = Arc::new(LargeStringArray::from(vec!["c,d", "e"]));
901 acc.merge_batch(&[partial_state], &[0, 1], None, 2)?;
902
903 let result = evaluate_groups(&mut acc, EmitTo::All);
904 assert_eq!(
905 result,
906 vec![Some("a,c,d".to_string()), Some("b,e".to_string())]
907 );
908 Ok(())
909 }
910
911 #[test]
912 fn groups_empty_groups() -> Result<()> {
913 let mut acc = make_groups_acc(",");
914
915 let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"]));
917 acc.update_batch(&[values], &[0, 2], None, 4)?;
918
919 let result = evaluate_groups(&mut acc, EmitTo::All);
920 assert_eq!(
921 result,
922 vec![
923 Some("a".to_string()),
924 None, Some("b".to_string()),
926 None, ]
928 );
929 Ok(())
930 }
931
932 #[test]
933 fn groups_multiple_batches() -> Result<()> {
934 let mut acc = make_groups_acc("|");
935
936 let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"]));
938 acc.update_batch(&[values], &[0, 1], None, 2)?;
939
940 let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["c", "d", "e"]));
942 acc.update_batch(&[values], &[0, 1, 2], None, 3)?;
943
944 let result = evaluate_groups(&mut acc, EmitTo::All);
945 assert_eq!(
946 result,
947 vec![
948 Some("a|c".to_string()),
949 Some("b|d".to_string()),
950 Some("e".to_string()),
951 ]
952 );
953 Ok(())
954 }
955}