datafusion_functions_aggregate/
string_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`StringAgg`] accumulator for the `string_agg` function
19
20use std::any::Any;
21use std::hash::Hash;
22use std::mem::size_of_val;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{
29    as_generic_string_array, as_string_array, as_string_view_array,
30};
31use datafusion_common::{
32    internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue,
33};
34use datafusion_expr::function::AccumulatorArgs;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
38};
39use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
40use datafusion_macros::user_doc;
41use datafusion_physical_expr::expressions::Literal;
42
43make_udaf_expr_and_func!(
44    StringAgg,
45    string_agg,
46    expr delimiter,
47    "Concatenates the values of string expressions and places separator values between them",
48    string_agg_udaf
49);
50
51#[user_doc(
52    doc_section(label = "General Functions"),
53    description = "Concatenates the values of string expressions and places separator values between them. \
54If ordering is required, strings are concatenated in the specified order. \
55This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
56    syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
57    sql_example = r#"```sql
58> SELECT string_agg(name, ', ') AS names_list
59  FROM employee;
60+--------------------------+
61| names_list               |
62+--------------------------+
63| Alice, Bob, Bob, Charlie |
64+--------------------------+
65> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
66  FROM employee;
67+--------------------------+
68| names_list               |
69+--------------------------+
70| Charlie, Bob, Bob, Alice |
71+--------------------------+
72> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
73  FROM employee;
74+--------------------------+
75| names_list               |
76+--------------------------+
77| Charlie, Bob, Alice |
78+--------------------------+
79```"#,
80    argument(
81        name = "expression",
82        description = "The string expression to concatenate. Can be a column or any valid string expression."
83    ),
84    argument(
85        name = "delimiter",
86        description = "A literal string used as a separator between the concatenated values."
87    )
88)]
89/// STRING_AGG aggregate expression
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct StringAgg {
92    signature: Signature,
93    array_agg: ArrayAgg,
94}
95
96impl StringAgg {
97    /// Create a new StringAgg aggregate function
98    pub fn new() -> Self {
99        Self {
100            signature: Signature::one_of(
101                vec![
102                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
103                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
104                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
105                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
106                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
107                    TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
108                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
109                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
110                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
111                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
112                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
113                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
114                ],
115                Volatility::Immutable,
116            ),
117            array_agg: Default::default(),
118        }
119    }
120}
121
122impl Default for StringAgg {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128/// If there is no `distinct` and `order by` required by the `string_agg` call, a
129/// more efficient accumulator `SimpleStringAggAccumulator` will be used.
130impl AggregateUDFImpl for StringAgg {
131    fn as_any(&self) -> &dyn Any {
132        self
133    }
134
135    fn name(&self) -> &str {
136        "string_agg"
137    }
138
139    fn signature(&self) -> &Signature {
140        &self.signature
141    }
142
143    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
144        Ok(DataType::LargeUtf8)
145    }
146
147    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
148        // See comments in `impl AggregateUDFImpl ...` for more detail
149        let no_order_no_distinct =
150            (args.ordering_fields.is_empty()) && (!args.is_distinct);
151        if no_order_no_distinct {
152            // Case `SimpleStringAggAccumulator`
153            Ok(vec![Field::new(
154                format_state_name(args.name, "string_agg"),
155                DataType::LargeUtf8,
156                true,
157            )
158            .into()])
159        } else {
160            // Case `StringAggAccumulator`
161            self.array_agg.state_fields(args)
162        }
163    }
164
165    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
166        let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
167            return not_impl_err!(
168                "The second argument of the string_agg function must be a string literal"
169            );
170        };
171
172        let delimiter = if lit.value().is_null() {
173            // If the second argument (the delimiter that joins strings) is NULL, join
174            // on an empty string. (e.g. [a, b, c] => "abc").
175            ""
176        } else if let Some(lit_string) = lit.value().try_as_str() {
177            lit_string.unwrap_or("")
178        } else {
179            return not_impl_err!(
180                "StringAgg not supported for delimiter \"{}\"",
181                lit.value()
182            );
183        };
184
185        // See comments in `impl AggregateUDFImpl ...` for more detail
186        let no_order_no_distinct =
187            acc_args.order_bys.is_empty() && (!acc_args.is_distinct);
188
189        if no_order_no_distinct {
190            // simple case (more efficient)
191            Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
192        } else {
193            // general case
194            let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
195                return_field: Field::new(
196                    "f",
197                    DataType::new_list(acc_args.return_field.data_type().clone(), true),
198                    true,
199                )
200                .into(),
201                exprs: &filter_index(acc_args.exprs, 1),
202                expr_fields: &filter_index(acc_args.expr_fields, 1),
203                // Unchanged below; we list each field explicitly in case we ever add more
204                // fields to AccumulatorArgs making it easier to see if changes are also
205                // needed here.
206                schema: acc_args.schema,
207                ignore_nulls: acc_args.ignore_nulls,
208                order_bys: acc_args.order_bys,
209                is_reversed: acc_args.is_reversed,
210                name: acc_args.name,
211                is_distinct: acc_args.is_distinct,
212            })?;
213
214            Ok(Box::new(StringAggAccumulator::new(
215                array_agg_acc,
216                delimiter,
217            )))
218        }
219    }
220
221    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
222        datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
223    }
224
225    fn documentation(&self) -> Option<&Documentation> {
226        self.doc()
227    }
228}
229
230/// StringAgg accumulator for the general case (with order or distinct specified)
231#[derive(Debug)]
232pub(crate) struct StringAggAccumulator {
233    array_agg_acc: Box<dyn Accumulator>,
234    delimiter: String,
235}
236
237impl StringAggAccumulator {
238    pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
239        Self {
240            array_agg_acc,
241            delimiter: delimiter.to_string(),
242        }
243    }
244}
245
246impl Accumulator for StringAggAccumulator {
247    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
248        self.array_agg_acc.update_batch(&filter_index(values, 1))
249    }
250
251    fn evaluate(&mut self) -> Result<ScalarValue> {
252        let scalar = self.array_agg_acc.evaluate()?;
253
254        let ScalarValue::List(list) = scalar else {
255            return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
256        };
257
258        let string_arr: Vec<_> = match list.value_type() {
259            DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
260                .iter()
261                .flatten()
262                .collect(),
263            DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
264                .iter()
265                .flatten()
266                .collect(),
267            DataType::Utf8View => as_string_view_array(list.values())?
268                .iter()
269                .flatten()
270                .collect(),
271            _ => {
272                return internal_err!(
273                    "Expected elements to of type Utf8 or LargeUtf8, but got {}",
274                    list.value_type()
275                )
276            }
277        };
278
279        if string_arr.is_empty() {
280            return Ok(ScalarValue::LargeUtf8(None));
281        }
282
283        Ok(ScalarValue::LargeUtf8(Some(
284            string_arr.join(&self.delimiter),
285        )))
286    }
287
288    fn size(&self) -> usize {
289        size_of_val(self) - size_of_val(&self.array_agg_acc)
290            + self.array_agg_acc.size()
291            + self.delimiter.capacity()
292    }
293
294    fn state(&mut self) -> Result<Vec<ScalarValue>> {
295        self.array_agg_acc.state()
296    }
297
298    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
299        self.array_agg_acc.merge_batch(values)
300    }
301}
302
303fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
304    values
305        .iter()
306        .enumerate()
307        .filter(|(i, _)| *i != index)
308        .map(|(_, v)| v)
309        .cloned()
310        .collect::<Vec<_>>()
311}
312
313/// StringAgg accumulator for the simple case (no order or distinct specified)
314/// This accumulator is more efficient than `StringAggAccumulator`
315/// because it accumulates the string directly,
316/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`.
317#[derive(Debug)]
318pub(crate) struct SimpleStringAggAccumulator {
319    delimiter: String,
320    /// Updated during `update_batch()`. e.g. "foo,bar"
321    accumulated_string: String,
322    has_value: bool,
323}
324
325impl SimpleStringAggAccumulator {
326    pub fn new(delimiter: &str) -> Self {
327        Self {
328            delimiter: delimiter.to_string(),
329            accumulated_string: "".to_string(),
330            has_value: false,
331        }
332    }
333
334    #[inline]
335    fn append_strings<'a, I>(&mut self, iter: I)
336    where
337        I: Iterator<Item = Option<&'a str>>,
338    {
339        for value in iter.flatten() {
340            if self.has_value {
341                self.accumulated_string.push_str(&self.delimiter);
342            }
343
344            self.accumulated_string.push_str(value);
345            self.has_value = true;
346        }
347    }
348}
349
350impl Accumulator for SimpleStringAggAccumulator {
351    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
352        let string_arr = values.first().ok_or_else(|| {
353            internal_datafusion_err!(
354                "Planner should ensure its first arg is Utf8/Utf8View"
355            )
356        })?;
357
358        match string_arr.data_type() {
359            DataType::Utf8 => {
360                let array = as_string_array(string_arr)?;
361                self.append_strings(array.iter());
362            }
363            DataType::LargeUtf8 => {
364                let array = as_generic_string_array::<i64>(string_arr)?;
365                self.append_strings(array.iter());
366            }
367            DataType::Utf8View => {
368                let array = as_string_view_array(string_arr)?;
369                self.append_strings(array.iter());
370            }
371            other => {
372                return internal_err!(
373                    "Planner should ensure string_agg first argument is Utf8-like, found {other}"
374                );
375            }
376        }
377
378        Ok(())
379    }
380
381    fn evaluate(&mut self) -> Result<ScalarValue> {
382        let result = if self.has_value {
383            ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
384        } else {
385            ScalarValue::LargeUtf8(None)
386        };
387
388        self.has_value = false;
389        Ok(result)
390    }
391
392    fn size(&self) -> usize {
393        size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
394    }
395
396    fn state(&mut self) -> Result<Vec<ScalarValue>> {
397        let result = if self.has_value {
398            ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
399        } else {
400            ScalarValue::LargeUtf8(None)
401        };
402        self.has_value = false;
403
404        Ok(vec![result])
405    }
406
407    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
408        self.update_batch(values)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use arrow::array::LargeStringArray;
416    use arrow::compute::SortOptions;
417    use arrow::datatypes::{Fields, Schema};
418    use datafusion_common::internal_err;
419    use datafusion_physical_expr::expressions::Column;
420    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
421    use std::sync::Arc;
422
423    #[test]
424    fn no_duplicates_no_distinct() -> Result<()> {
425        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
426
427        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
428        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
429        acc1 = merge(acc1, acc2)?;
430
431        let result = some_str(acc1.evaluate()?);
432
433        assert_eq!(result, "a,b,c,d,e,f");
434
435        Ok(())
436    }
437
438    #[test]
439    fn no_duplicates_distinct() -> Result<()> {
440        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
441            .distinct()
442            .build_two()?;
443
444        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
445        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
446        acc1 = merge(acc1, acc2)?;
447
448        let result = some_str_sorted(acc1.evaluate()?, ",");
449
450        assert_eq!(result, "a,b,c,d,e,f");
451
452        Ok(())
453    }
454
455    #[test]
456    fn duplicates_no_distinct() -> Result<()> {
457        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
458
459        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
460        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
461        acc1 = merge(acc1, acc2)?;
462
463        let result = some_str(acc1.evaluate()?);
464
465        assert_eq!(result, "a,b,c,a,b,c");
466
467        Ok(())
468    }
469
470    #[test]
471    fn duplicates_distinct() -> Result<()> {
472        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
473            .distinct()
474            .build_two()?;
475
476        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
477        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
478        acc1 = merge(acc1, acc2)?;
479
480        let result = some_str_sorted(acc1.evaluate()?, ",");
481
482        assert_eq!(result, "a,b,c");
483
484        Ok(())
485    }
486
487    #[test]
488    fn no_duplicates_distinct_sort_asc() -> Result<()> {
489        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
490            .distinct()
491            .order_by_col("col", SortOptions::new(false, false))
492            .build_two()?;
493
494        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
495        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
496        acc1 = merge(acc1, acc2)?;
497
498        let result = some_str(acc1.evaluate()?);
499
500        assert_eq!(result, "a,b,c,d,e,f");
501
502        Ok(())
503    }
504
505    #[test]
506    fn no_duplicates_distinct_sort_desc() -> Result<()> {
507        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
508            .distinct()
509            .order_by_col("col", SortOptions::new(true, false))
510            .build_two()?;
511
512        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
513        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
514        acc1 = merge(acc1, acc2)?;
515
516        let result = some_str(acc1.evaluate()?);
517
518        assert_eq!(result, "f,e,d,c,b,a");
519
520        Ok(())
521    }
522
523    #[test]
524    fn duplicates_distinct_sort_asc() -> Result<()> {
525        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
526            .distinct()
527            .order_by_col("col", SortOptions::new(false, false))
528            .build_two()?;
529
530        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
531        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
532        acc1 = merge(acc1, acc2)?;
533
534        let result = some_str(acc1.evaluate()?);
535
536        assert_eq!(result, "a,b,c");
537
538        Ok(())
539    }
540
541    #[test]
542    fn duplicates_distinct_sort_desc() -> Result<()> {
543        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
544            .distinct()
545            .order_by_col("col", SortOptions::new(true, false))
546            .build_two()?;
547
548        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
549        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
550        acc1 = merge(acc1, acc2)?;
551
552        let result = some_str(acc1.evaluate()?);
553
554        assert_eq!(result, "c,b,a");
555
556        Ok(())
557    }
558
559    struct StringAggAccumulatorBuilder {
560        sep: String,
561        distinct: bool,
562        order_bys: Vec<PhysicalSortExpr>,
563        schema: Schema,
564    }
565
566    impl StringAggAccumulatorBuilder {
567        fn new(sep: &str) -> Self {
568            Self {
569                sep: sep.to_string(),
570                distinct: Default::default(),
571                order_bys: vec![],
572                schema: Schema {
573                    fields: Fields::from(vec![Field::new(
574                        "col",
575                        DataType::LargeUtf8,
576                        true,
577                    )]),
578                    metadata: Default::default(),
579                },
580            }
581        }
582        fn distinct(mut self) -> Self {
583            self.distinct = true;
584            self
585        }
586
587        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
588            self.order_bys.extend([PhysicalSortExpr::new(
589                Arc::new(
590                    Column::new_with_schema(col, &self.schema)
591                        .expect("column not available in schema"),
592                ),
593                sort_options,
594            )]);
595            self
596        }
597
598        fn build(&self) -> Result<Box<dyn Accumulator>> {
599            StringAgg::new().accumulator(AccumulatorArgs {
600                return_field: Field::new("f", DataType::LargeUtf8, true).into(),
601                schema: &self.schema,
602                expr_fields: &[
603                    Field::new("col", DataType::LargeUtf8, true).into(),
604                    Field::new("lit", DataType::Utf8, false).into(),
605                ],
606                ignore_nulls: false,
607                order_bys: &self.order_bys,
608                is_reversed: false,
609                name: "",
610                is_distinct: self.distinct,
611                exprs: &[
612                    Arc::new(Column::new("col", 0)),
613                    Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
614                ],
615            })
616        }
617
618        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
619            Ok((self.build()?, self.build()?))
620        }
621    }
622
623    fn some_str(value: ScalarValue) -> String {
624        str(value)
625            .expect("ScalarValue was not a String")
626            .expect("ScalarValue was None")
627    }
628
629    fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
630        let value = some_str(value);
631        let mut parts: Vec<&str> = value.split(sep).collect();
632        parts.sort();
633        parts.join(sep)
634    }
635
636    fn str(value: ScalarValue) -> Result<Option<String>> {
637        match value {
638            ScalarValue::LargeUtf8(v) => Ok(v),
639            _ => internal_err!(
640                "Expected ScalarValue::LargeUtf8, got {}",
641                value.data_type()
642            ),
643        }
644    }
645
646    fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
647        Arc::new(LargeStringArray::from(list.to_vec()))
648    }
649
650    fn merge(
651        mut acc1: Box<dyn Accumulator>,
652        mut acc2: Box<dyn Accumulator>,
653    ) -> Result<Box<dyn Accumulator>> {
654        let intermediate_state = acc2.state().and_then(|e| {
655            e.iter()
656                .map(|v| v.to_array())
657                .collect::<Result<Vec<ArrayRef>>>()
658        })?;
659        acc1.merge_batch(&intermediate_state)?;
660        Ok(acc1)
661    }
662}