Skip to main content

datafusion_functions_aggregate/
string_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`StringAgg`] accumulator for the `string_agg` function
19
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::{ArrayRef, AsArray, BooleanArray, LargeStringArray};
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29use datafusion_common::{
30    Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err,
31};
32use datafusion_expr::function::AccumulatorArgs;
33use datafusion_expr::utils::format_state_name;
34use datafusion_expr::{
35    Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature,
36    TypeSignature, Volatility,
37};
38use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
39use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
40use datafusion_macros::user_doc;
41use datafusion_physical_expr::expressions::Literal;
42
43make_udaf_expr_and_func!(
44    StringAgg,
45    string_agg,
46    expr delimiter,
47    "Concatenates the values of string expressions and places separator values between them",
48    string_agg_udaf
49);
50
51#[user_doc(
52    doc_section(label = "General Functions"),
53    description = "Concatenates the values of string expressions and places separator values between them. \
54If ordering is required, strings are concatenated in the specified order. \
55This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
56    syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
57    sql_example = r#"```sql
58> SELECT string_agg(name, ', ') AS names_list
59  FROM employee;
60+--------------------------+
61| names_list               |
62+--------------------------+
63| Alice, Bob, Bob, Charlie |
64+--------------------------+
65> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
66  FROM employee;
67+--------------------------+
68| names_list               |
69+--------------------------+
70| Charlie, Bob, Bob, Alice |
71+--------------------------+
72> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
73  FROM employee;
74+--------------------------+
75| names_list               |
76+--------------------------+
77| Charlie, Bob, Alice |
78+--------------------------+
79```"#,
80    argument(
81        name = "expression",
82        description = "The string expression to concatenate. Can be a column or any valid string expression."
83    ),
84    argument(
85        name = "delimiter",
86        description = "A literal string used as a separator between the concatenated values."
87    )
88)]
89/// STRING_AGG aggregate expression
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct StringAgg {
92    signature: Signature,
93    array_agg: ArrayAgg,
94}
95
96impl StringAgg {
97    /// Create a new StringAgg aggregate function
98    pub fn new() -> Self {
99        Self {
100            signature: Signature::one_of(
101                vec![
102                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
103                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
104                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
105                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
106                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
107                    TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
108                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
109                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
110                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
111                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
112                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
113                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
114                ],
115                Volatility::Immutable,
116            ),
117            array_agg: Default::default(),
118        }
119    }
120
121    /// Extract the delimiter string from the second argument expression.
122    fn extract_delimiter(args: &AccumulatorArgs) -> Result<String> {
123        let Some(lit) = args.exprs[1].downcast_ref::<Literal>() else {
124            return not_impl_err!("string_agg delimiter must be a string literal");
125        };
126
127        if lit.value().is_null() {
128            return Ok(String::new());
129        }
130
131        match lit.value().try_as_str() {
132            Some(s) => Ok(s.unwrap_or("").to_string()),
133            None => {
134                not_impl_err!(
135                    "string_agg not supported for delimiter \"{}\"",
136                    lit.value()
137                )
138            }
139        }
140    }
141}
142
143impl Default for StringAgg {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149/// Three accumulation strategies depending on query shape:
150/// - No DISTINCT / ORDER BY with GROUP BY: `StringAggGroupsAccumulator`
151/// - No DISTINCT / ORDER BY without GROUP BY: `SimpleStringAggAccumulator`
152/// - With DISTINCT or ORDER BY: `StringAggAccumulator` (delegates to `ArrayAgg`)
153impl AggregateUDFImpl for StringAgg {
154    fn name(&self) -> &str {
155        "string_agg"
156    }
157
158    fn signature(&self) -> &Signature {
159        &self.signature
160    }
161
162    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
163        Ok(DataType::LargeUtf8)
164    }
165
166    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
167        if !args.is_distinct && args.ordering_fields.is_empty() {
168            Ok(vec![
169                Field::new(
170                    format_state_name(args.name, "string_agg"),
171                    DataType::LargeUtf8,
172                    true,
173                )
174                .into(),
175            ])
176        } else {
177            self.array_agg.state_fields(args)
178        }
179    }
180
181    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
182        let delimiter = Self::extract_delimiter(&acc_args)?;
183
184        if !acc_args.is_distinct && acc_args.order_bys.is_empty() {
185            Ok(Box::new(SimpleStringAggAccumulator::new(&delimiter)))
186        } else {
187            let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
188                return_field: Field::new(
189                    "f",
190                    DataType::new_list(acc_args.return_field.data_type().clone(), true),
191                    true,
192                )
193                .into(),
194                exprs: &filter_index(acc_args.exprs, 1),
195                expr_fields: &filter_index(acc_args.expr_fields, 1),
196                // Unchanged below; we list each field explicitly in case we ever add more
197                // fields to AccumulatorArgs making it easier to see if changes are also
198                // needed here.
199                schema: acc_args.schema,
200                ignore_nulls: acc_args.ignore_nulls,
201                order_bys: acc_args.order_bys,
202                is_reversed: acc_args.is_reversed,
203                name: acc_args.name,
204                is_distinct: acc_args.is_distinct,
205            })?;
206
207            Ok(Box::new(StringAggAccumulator::new(
208                array_agg_acc,
209                &delimiter,
210            )))
211        }
212    }
213
214    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
215        datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
216    }
217
218    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
219        !args.is_distinct && args.order_bys.is_empty()
220    }
221
222    fn create_groups_accumulator(
223        &self,
224        args: AccumulatorArgs,
225    ) -> Result<Box<dyn GroupsAccumulator>> {
226        let delimiter = Self::extract_delimiter(&args)?;
227        Ok(Box::new(StringAggGroupsAccumulator::new(delimiter)))
228    }
229
230    fn documentation(&self) -> Option<&Documentation> {
231        self.doc()
232    }
233}
234
235/// StringAgg accumulator for the general case (with order or distinct specified)
236#[derive(Debug)]
237pub(crate) struct StringAggAccumulator {
238    array_agg_acc: Box<dyn Accumulator>,
239    delimiter: String,
240}
241
242impl StringAggAccumulator {
243    pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
244        Self {
245            array_agg_acc,
246            delimiter: delimiter.to_string(),
247        }
248    }
249}
250
251impl Accumulator for StringAggAccumulator {
252    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
253        self.array_agg_acc.update_batch(&filter_index(values, 1))
254    }
255
256    fn evaluate(&mut self) -> Result<ScalarValue> {
257        let scalar = self.array_agg_acc.evaluate()?;
258
259        let ScalarValue::List(list) = scalar else {
260            return internal_err!(
261                "Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}",
262                scalar.data_type()
263            );
264        };
265
266        let string_arr: Vec<_> = match list.value_type() {
267            DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
268                .iter()
269                .flatten()
270                .collect(),
271            DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
272                .iter()
273                .flatten()
274                .collect(),
275            DataType::Utf8View => as_string_view_array(list.values())?
276                .iter()
277                .flatten()
278                .collect(),
279            _ => {
280                return internal_err!(
281                    "Expected elements to of type Utf8 or LargeUtf8, but got {}",
282                    list.value_type()
283                );
284            }
285        };
286
287        if string_arr.is_empty() {
288            return Ok(ScalarValue::LargeUtf8(None));
289        }
290
291        Ok(ScalarValue::LargeUtf8(Some(
292            string_arr.join(&self.delimiter),
293        )))
294    }
295
296    fn size(&self) -> usize {
297        size_of_val(self) - size_of_val(&self.array_agg_acc)
298            + self.array_agg_acc.size()
299            + self.delimiter.capacity()
300    }
301
302    fn state(&mut self) -> Result<Vec<ScalarValue>> {
303        self.array_agg_acc.state()
304    }
305
306    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307        self.array_agg_acc.merge_batch(values)
308    }
309}
310
311fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
312    values
313        .iter()
314        .enumerate()
315        .filter(|(i, _)| *i != index)
316        .map(|(_, v)| v)
317        .cloned()
318        .collect::<Vec<_>>()
319}
320
321/// GroupsAccumulator for `string_agg` without DISTINCT or ORDER BY.
322#[derive(Debug)]
323struct StringAggGroupsAccumulator {
324    /// The delimiter placed between concatenated values.
325    delimiter: String,
326    /// Accumulated string per group. `None` means no values have been seen
327    /// (the group's output will be NULL).
328    /// A potential improvement is to avoid this String allocation
329    /// See <https://github.com/apache/datafusion/issues/21156>
330    values: Vec<Option<String>>,
331    /// Running total of string data bytes across all groups.
332    total_data_bytes: usize,
333}
334
335impl StringAggGroupsAccumulator {
336    fn new(delimiter: String) -> Self {
337        Self {
338            delimiter,
339            values: Vec::new(),
340            total_data_bytes: 0,
341        }
342    }
343
344    fn append_batch<'a>(
345        &mut self,
346        iter: impl Iterator<Item = Option<&'a str>>,
347        group_indices: &[usize],
348    ) {
349        for (opt_value, &group_idx) in iter.zip(group_indices.iter()) {
350            if let Some(value) = opt_value {
351                match &mut self.values[group_idx] {
352                    Some(existing) => {
353                        let added = self.delimiter.len() + value.len();
354                        existing.reserve(added);
355                        existing.push_str(&self.delimiter);
356                        existing.push_str(value);
357                        self.total_data_bytes += added;
358                    }
359                    slot @ None => {
360                        *slot = Some(value.to_string());
361                        self.total_data_bytes += value.len();
362                    }
363                }
364            }
365        }
366    }
367}
368
369impl GroupsAccumulator for StringAggGroupsAccumulator {
370    fn update_batch(
371        &mut self,
372        values: &[ArrayRef],
373        group_indices: &[usize],
374        opt_filter: Option<&BooleanArray>,
375        total_num_groups: usize,
376    ) -> Result<()> {
377        self.values.resize(total_num_groups, None);
378        let array = apply_filter_as_nulls(&values[0], opt_filter)?;
379        match array.data_type() {
380            DataType::Utf8 => {
381                self.append_batch(array.as_string::<i32>().iter(), group_indices)
382            }
383            DataType::LargeUtf8 => {
384                self.append_batch(array.as_string::<i64>().iter(), group_indices)
385            }
386            DataType::Utf8View => {
387                self.append_batch(array.as_string_view().iter(), group_indices)
388            }
389            other => {
390                return internal_err!("string_agg unexpected data type: {other}");
391            }
392        }
393        Ok(())
394    }
395
396    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
397        let to_emit = emit_to.take_needed(&mut self.values);
398        let emitted_bytes: usize = to_emit
399            .iter()
400            .filter_map(|opt| opt.as_ref().map(|s| s.len()))
401            .sum();
402        self.total_data_bytes -= emitted_bytes;
403
404        let result: ArrayRef = Arc::new(LargeStringArray::from(to_emit));
405        Ok(result)
406    }
407
408    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
409        self.evaluate(emit_to).map(|arr| vec![arr])
410    }
411
412    fn merge_batch(
413        &mut self,
414        values: &[ArrayRef],
415        group_indices: &[usize],
416        opt_filter: Option<&BooleanArray>,
417        total_num_groups: usize,
418    ) -> Result<()> {
419        // State is always LargeUtf8, which update_batch already handles.
420        self.update_batch(values, group_indices, opt_filter, total_num_groups)
421    }
422
423    fn convert_to_state(
424        &self,
425        values: &[ArrayRef],
426        opt_filter: Option<&BooleanArray>,
427    ) -> Result<Vec<ArrayRef>> {
428        let input = apply_filter_as_nulls(&values[0], opt_filter)?;
429        let result = if input.data_type() == &DataType::LargeUtf8 {
430            input
431        } else {
432            arrow::compute::cast(&input, &DataType::LargeUtf8)?
433        };
434        Ok(vec![result])
435    }
436
437    fn supports_convert_to_state(&self) -> bool {
438        true
439    }
440
441    fn size(&self) -> usize {
442        self.total_data_bytes
443            + self.values.capacity() * size_of::<Option<String>>()
444            + self.delimiter.capacity()
445            + size_of_val(self)
446    }
447}
448
449/// Per-row accumulator for `string_agg` without DISTINCT or ORDER BY.  Used for
450/// non-grouped aggregation; grouped queries use [`StringAggGroupsAccumulator`].
451#[derive(Debug)]
452pub(crate) struct SimpleStringAggAccumulator {
453    delimiter: String,
454    /// Updated during `update_batch()`. e.g. "foo,bar"
455    accumulated_string: String,
456    has_value: bool,
457}
458
459impl SimpleStringAggAccumulator {
460    pub fn new(delimiter: &str) -> Self {
461        Self {
462            delimiter: delimiter.to_string(),
463            accumulated_string: String::new(),
464            has_value: false,
465        }
466    }
467
468    #[inline]
469    fn append_strings<'a, I>(&mut self, iter: I)
470    where
471        I: Iterator<Item = Option<&'a str>>,
472    {
473        for value in iter.flatten() {
474            if self.has_value {
475                self.accumulated_string.push_str(&self.delimiter);
476            }
477
478            self.accumulated_string.push_str(value);
479            self.has_value = true;
480        }
481    }
482}
483
484impl Accumulator for SimpleStringAggAccumulator {
485    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
486        let string_arr = values.first().ok_or_else(|| {
487            internal_datafusion_err!(
488                "Planner should ensure its first arg is Utf8/Utf8View"
489            )
490        })?;
491
492        match string_arr.data_type() {
493            DataType::Utf8 => self.append_strings(string_arr.as_string::<i32>().iter()),
494            DataType::LargeUtf8 => {
495                self.append_strings(string_arr.as_string::<i64>().iter())
496            }
497            DataType::Utf8View => self.append_strings(string_arr.as_string_view().iter()),
498            other => {
499                return internal_err!(
500                    "Planner should ensure string_agg first argument is Utf8-like, found {other}"
501                );
502            }
503        }
504
505        Ok(())
506    }
507
508    fn evaluate(&mut self) -> Result<ScalarValue> {
509        if self.has_value {
510            Ok(ScalarValue::LargeUtf8(Some(
511                self.accumulated_string.clone(),
512            )))
513        } else {
514            Ok(ScalarValue::LargeUtf8(None))
515        }
516    }
517
518    fn size(&self) -> usize {
519        size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
520    }
521
522    fn state(&mut self) -> Result<Vec<ScalarValue>> {
523        let result = if self.has_value {
524            ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
525        } else {
526            ScalarValue::LargeUtf8(None)
527        };
528        self.has_value = false;
529
530        Ok(vec![result])
531    }
532
533    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
534        self.update_batch(values)
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use arrow::array::LargeStringArray;
542    use arrow::compute::SortOptions;
543    use arrow::datatypes::{Fields, Schema};
544    use datafusion_physical_expr::expressions::Column;
545    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
546    use std::sync::Arc;
547
548    #[test]
549    fn no_duplicates_no_distinct() -> Result<()> {
550        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
551
552        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
553        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
554        acc1 = merge(acc1, acc2)?;
555
556        let result = some_str(acc1.evaluate()?);
557
558        assert_eq!(result, "a,b,c,d,e,f");
559
560        Ok(())
561    }
562
563    #[test]
564    fn no_duplicates_distinct() -> Result<()> {
565        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
566            .distinct()
567            .build_two()?;
568
569        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
570        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
571        acc1 = merge(acc1, acc2)?;
572
573        let result = some_str_sorted(acc1.evaluate()?, ",");
574
575        assert_eq!(result, "a,b,c,d,e,f");
576
577        Ok(())
578    }
579
580    #[test]
581    fn duplicates_no_distinct() -> Result<()> {
582        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
583
584        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
585        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
586        acc1 = merge(acc1, acc2)?;
587
588        let result = some_str(acc1.evaluate()?);
589
590        assert_eq!(result, "a,b,c,a,b,c");
591
592        Ok(())
593    }
594
595    #[test]
596    fn duplicates_distinct() -> Result<()> {
597        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
598            .distinct()
599            .build_two()?;
600
601        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
602        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
603        acc1 = merge(acc1, acc2)?;
604
605        let result = some_str_sorted(acc1.evaluate()?, ",");
606
607        assert_eq!(result, "a,b,c");
608
609        Ok(())
610    }
611
612    #[test]
613    fn no_duplicates_distinct_sort_asc() -> Result<()> {
614        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
615            .distinct()
616            .order_by_col("col", SortOptions::new(false, false))
617            .build_two()?;
618
619        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
620        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
621        acc1 = merge(acc1, acc2)?;
622
623        let result = some_str(acc1.evaluate()?);
624
625        assert_eq!(result, "a,b,c,d,e,f");
626
627        Ok(())
628    }
629
630    #[test]
631    fn no_duplicates_distinct_sort_desc() -> Result<()> {
632        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
633            .distinct()
634            .order_by_col("col", SortOptions::new(true, false))
635            .build_two()?;
636
637        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
638        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
639        acc1 = merge(acc1, acc2)?;
640
641        let result = some_str(acc1.evaluate()?);
642
643        assert_eq!(result, "f,e,d,c,b,a");
644
645        Ok(())
646    }
647
648    #[test]
649    fn duplicates_distinct_sort_asc() -> Result<()> {
650        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
651            .distinct()
652            .order_by_col("col", SortOptions::new(false, false))
653            .build_two()?;
654
655        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
656        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
657        acc1 = merge(acc1, acc2)?;
658
659        let result = some_str(acc1.evaluate()?);
660
661        assert_eq!(result, "a,b,c");
662
663        Ok(())
664    }
665
666    #[test]
667    fn duplicates_distinct_sort_desc() -> Result<()> {
668        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
669            .distinct()
670            .order_by_col("col", SortOptions::new(true, false))
671            .build_two()?;
672
673        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
674        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
675        acc1 = merge(acc1, acc2)?;
676
677        let result = some_str(acc1.evaluate()?);
678
679        assert_eq!(result, "c,b,a");
680
681        Ok(())
682    }
683
684    struct StringAggAccumulatorBuilder {
685        sep: String,
686        distinct: bool,
687        order_bys: Vec<PhysicalSortExpr>,
688        schema: Schema,
689    }
690
691    impl StringAggAccumulatorBuilder {
692        fn new(sep: &str) -> Self {
693            Self {
694                sep: sep.to_string(),
695                distinct: Default::default(),
696                order_bys: vec![],
697                schema: Schema {
698                    fields: Fields::from(vec![Field::new(
699                        "col",
700                        DataType::LargeUtf8,
701                        true,
702                    )]),
703                    metadata: Default::default(),
704                },
705            }
706        }
707        fn distinct(mut self) -> Self {
708            self.distinct = true;
709            self
710        }
711
712        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
713            self.order_bys.extend([PhysicalSortExpr::new(
714                Arc::new(
715                    Column::new_with_schema(col, &self.schema)
716                        .expect("column not available in schema"),
717                ),
718                sort_options,
719            )]);
720            self
721        }
722
723        fn build(&self) -> Result<Box<dyn Accumulator>> {
724            StringAgg::new().accumulator(AccumulatorArgs {
725                return_field: Field::new("f", DataType::LargeUtf8, true).into(),
726                schema: &self.schema,
727                expr_fields: &[
728                    Field::new("col", DataType::LargeUtf8, true).into(),
729                    Field::new("lit", DataType::Utf8, false).into(),
730                ],
731                ignore_nulls: false,
732                order_bys: &self.order_bys,
733                is_reversed: false,
734                name: "",
735                is_distinct: self.distinct,
736                exprs: &[
737                    Arc::new(Column::new("col", 0)),
738                    Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
739                ],
740            })
741        }
742
743        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
744            Ok((self.build()?, self.build()?))
745        }
746    }
747
748    fn some_str(value: ScalarValue) -> String {
749        str(value)
750            .expect("ScalarValue was not a String")
751            .expect("ScalarValue was None")
752    }
753
754    fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
755        let value = some_str(value);
756        let mut parts: Vec<&str> = value.split(sep).collect();
757        parts.sort();
758        parts.join(sep)
759    }
760
761    fn str(value: ScalarValue) -> Result<Option<String>> {
762        match value {
763            ScalarValue::LargeUtf8(v) => Ok(v),
764            _ => internal_err!(
765                "Expected ScalarValue::LargeUtf8, got {}",
766                value.data_type()
767            ),
768        }
769    }
770
771    fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
772        Arc::new(LargeStringArray::from(list.to_vec()))
773    }
774
775    fn merge(
776        mut acc1: Box<dyn Accumulator>,
777        mut acc2: Box<dyn Accumulator>,
778    ) -> Result<Box<dyn Accumulator>> {
779        let intermediate_state = acc2.state().and_then(|e| {
780            e.iter()
781                .map(|v| v.to_array())
782                .collect::<Result<Vec<ArrayRef>>>()
783        })?;
784        acc1.merge_batch(&intermediate_state)?;
785        Ok(acc1)
786    }
787
788    // ---------------------------------------------------------------
789    // Tests for StringAggGroupsAccumulator
790    // ---------------------------------------------------------------
791
792    fn make_groups_acc(delimiter: &str) -> StringAggGroupsAccumulator {
793        StringAggGroupsAccumulator::new(delimiter.to_string())
794    }
795
796    /// Helper: evaluate and downcast to LargeStringArray
797    fn evaluate_groups(
798        acc: &mut StringAggGroupsAccumulator,
799        emit_to: EmitTo,
800    ) -> Vec<Option<String>> {
801        let result = acc.evaluate(emit_to).unwrap();
802        let arr = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
803        arr.iter().map(|v| v.map(|s| s.to_string())).collect()
804    }
805
806    #[test]
807    fn groups_basic() -> Result<()> {
808        let mut acc = make_groups_acc(",");
809
810        // 6 rows, 3 groups: group 0 gets "a","d"; group 1 gets "b","e"; group 2 gets "c","f"
811        let values: ArrayRef =
812            Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", "f"]));
813        let group_indices = vec![0, 1, 2, 0, 1, 2];
814        acc.update_batch(&[values], &group_indices, None, 3)?;
815
816        let result = evaluate_groups(&mut acc, EmitTo::All);
817        assert_eq!(
818            result,
819            vec![
820                Some("a,d".to_string()),
821                Some("b,e".to_string()),
822                Some("c,f".to_string()),
823            ]
824        );
825        Ok(())
826    }
827
828    #[test]
829    fn groups_with_nulls() -> Result<()> {
830        let mut acc = make_groups_acc("|");
831
832        // Group 0: "a", NULL, "c" → "a|c"
833        // Group 1: NULL, "b"     → "b"
834        // Group 2: NULL only     → NULL
835        let values: ArrayRef = Arc::new(LargeStringArray::from(vec![
836            Some("a"),
837            None,
838            Some("c"),
839            None,
840            Some("b"),
841            None,
842        ]));
843        let group_indices = vec![0, 1, 0, 2, 1, 2];
844        acc.update_batch(&[values], &group_indices, None, 3)?;
845
846        let result = evaluate_groups(&mut acc, EmitTo::All);
847        assert_eq!(
848            result,
849            vec![Some("a|c".to_string()), Some("b".to_string()), None,]
850        );
851        Ok(())
852    }
853
854    #[test]
855    fn groups_with_filter() -> Result<()> {
856        let mut acc = make_groups_acc(",");
857
858        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d"]));
859        let group_indices = vec![0, 0, 1, 1];
860        // Filter: only rows 0 and 3 are included
861        let filter = BooleanArray::from(vec![true, false, false, true]);
862        acc.update_batch(&[values], &group_indices, Some(&filter), 2)?;
863
864        let result = evaluate_groups(&mut acc, EmitTo::All);
865        assert_eq!(result, vec![Some("a".to_string()), Some("d".to_string())]);
866        Ok(())
867    }
868
869    #[test]
870    fn groups_emit_first() -> Result<()> {
871        let mut acc = make_groups_acc(",");
872
873        let values: ArrayRef =
874            Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", "f"]));
875        let group_indices = vec![0, 1, 2, 0, 1, 2];
876        acc.update_batch(&[values], &group_indices, None, 3)?;
877
878        // Emit only the first 2 groups
879        let result = evaluate_groups(&mut acc, EmitTo::First(2));
880        assert_eq!(
881            result,
882            vec![Some("a,d".to_string()), Some("b,e".to_string())]
883        );
884
885        // Group 2 (now shifted to index 0) should still be intact
886        let result = evaluate_groups(&mut acc, EmitTo::All);
887        assert_eq!(result, vec![Some("c,f".to_string())]);
888        Ok(())
889    }
890
891    #[test]
892    fn groups_merge_batch() -> Result<()> {
893        let mut acc = make_groups_acc(",");
894
895        // First batch: group 0 = "a", group 1 = "b"
896        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"]));
897        acc.update_batch(&[values], &[0, 1], None, 2)?;
898
899        // Simulate a second accumulator's state (LargeUtf8 partial strings)
900        let partial_state: ArrayRef = Arc::new(LargeStringArray::from(vec!["c,d", "e"]));
901        acc.merge_batch(&[partial_state], &[0, 1], None, 2)?;
902
903        let result = evaluate_groups(&mut acc, EmitTo::All);
904        assert_eq!(
905            result,
906            vec![Some("a,c,d".to_string()), Some("b,e".to_string())]
907        );
908        Ok(())
909    }
910
911    #[test]
912    fn groups_empty_groups() -> Result<()> {
913        let mut acc = make_groups_acc(",");
914
915        // 4 groups total, but only groups 0 and 2 receive values
916        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"]));
917        acc.update_batch(&[values], &[0, 2], None, 4)?;
918
919        let result = evaluate_groups(&mut acc, EmitTo::All);
920        assert_eq!(
921            result,
922            vec![
923                Some("a".to_string()),
924                None, // group 1: never received a value
925                Some("b".to_string()),
926                None, // group 3: never received a value
927            ]
928        );
929        Ok(())
930    }
931
932    #[test]
933    fn groups_multiple_batches() -> Result<()> {
934        let mut acc = make_groups_acc("|");
935
936        // Batch 1: 2 groups
937        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"]));
938        acc.update_batch(&[values], &[0, 1], None, 2)?;
939
940        // Batch 2: same groups, plus a new group
941        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["c", "d", "e"]));
942        acc.update_batch(&[values], &[0, 1, 2], None, 3)?;
943
944        let result = evaluate_groups(&mut acc, EmitTo::All);
945        assert_eq!(
946            result,
947            vec![
948                Some("a|c".to_string()),
949                Some("b|d".to_string()),
950                Some("e".to_string()),
951            ]
952        );
953        Ok(())
954    }
955}