datafusion_functions_aggregate/
string_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`StringAgg`] accumulator for the `string_agg` function
19
20use std::any::Any;
21use std::hash::Hash;
22use std::mem::size_of_val;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{
29    as_generic_string_array, as_string_array, as_string_view_array,
30};
31use datafusion_common::{
32    Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err,
33};
34use datafusion_expr::function::AccumulatorArgs;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
38};
39use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
40use datafusion_macros::user_doc;
41use datafusion_physical_expr::expressions::Literal;
42
43make_udaf_expr_and_func!(
44    StringAgg,
45    string_agg,
46    expr delimiter,
47    "Concatenates the values of string expressions and places separator values between them",
48    string_agg_udaf
49);
50
51#[user_doc(
52    doc_section(label = "General Functions"),
53    description = "Concatenates the values of string expressions and places separator values between them. \
54If ordering is required, strings are concatenated in the specified order. \
55This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
56    syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
57    sql_example = r#"```sql
58> SELECT string_agg(name, ', ') AS names_list
59  FROM employee;
60+--------------------------+
61| names_list               |
62+--------------------------+
63| Alice, Bob, Bob, Charlie |
64+--------------------------+
65> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
66  FROM employee;
67+--------------------------+
68| names_list               |
69+--------------------------+
70| Charlie, Bob, Bob, Alice |
71+--------------------------+
72> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
73  FROM employee;
74+--------------------------+
75| names_list               |
76+--------------------------+
77| Charlie, Bob, Alice |
78+--------------------------+
79```"#,
80    argument(
81        name = "expression",
82        description = "The string expression to concatenate. Can be a column or any valid string expression."
83    ),
84    argument(
85        name = "delimiter",
86        description = "A literal string used as a separator between the concatenated values."
87    )
88)]
89/// STRING_AGG aggregate expression
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct StringAgg {
92    signature: Signature,
93    array_agg: ArrayAgg,
94}
95
96impl StringAgg {
97    /// Create a new StringAgg aggregate function
98    pub fn new() -> Self {
99        Self {
100            signature: Signature::one_of(
101                vec![
102                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
103                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
104                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
105                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
106                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
107                    TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
108                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
109                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
110                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
111                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
112                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
113                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
114                ],
115                Volatility::Immutable,
116            ),
117            array_agg: Default::default(),
118        }
119    }
120}
121
122impl Default for StringAgg {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128/// If there is no `distinct` and `order by` required by the `string_agg` call, a
129/// more efficient accumulator `SimpleStringAggAccumulator` will be used.
130impl AggregateUDFImpl for StringAgg {
131    fn as_any(&self) -> &dyn Any {
132        self
133    }
134
135    fn name(&self) -> &str {
136        "string_agg"
137    }
138
139    fn signature(&self) -> &Signature {
140        &self.signature
141    }
142
143    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
144        Ok(DataType::LargeUtf8)
145    }
146
147    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
148        // See comments in `impl AggregateUDFImpl ...` for more detail
149        let no_order_no_distinct =
150            (args.ordering_fields.is_empty()) && (!args.is_distinct);
151        if no_order_no_distinct {
152            // Case `SimpleStringAggAccumulator`
153            Ok(vec![
154                Field::new(
155                    format_state_name(args.name, "string_agg"),
156                    DataType::LargeUtf8,
157                    true,
158                )
159                .into(),
160            ])
161        } else {
162            // Case `StringAggAccumulator`
163            self.array_agg.state_fields(args)
164        }
165    }
166
167    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
168        let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
169            return not_impl_err!(
170                "The second argument of the string_agg function must be a string literal"
171            );
172        };
173
174        let delimiter = if lit.value().is_null() {
175            // If the second argument (the delimiter that joins strings) is NULL, join
176            // on an empty string. (e.g. [a, b, c] => "abc").
177            ""
178        } else if let Some(lit_string) = lit.value().try_as_str() {
179            lit_string.unwrap_or("")
180        } else {
181            return not_impl_err!(
182                "StringAgg not supported for delimiter \"{}\"",
183                lit.value()
184            );
185        };
186
187        // See comments in `impl AggregateUDFImpl ...` for more detail
188        let no_order_no_distinct =
189            acc_args.order_bys.is_empty() && (!acc_args.is_distinct);
190
191        if no_order_no_distinct {
192            // simple case (more efficient)
193            Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
194        } else {
195            // general case
196            let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
197                return_field: Field::new(
198                    "f",
199                    DataType::new_list(acc_args.return_field.data_type().clone(), true),
200                    true,
201                )
202                .into(),
203                exprs: &filter_index(acc_args.exprs, 1),
204                expr_fields: &filter_index(acc_args.expr_fields, 1),
205                // Unchanged below; we list each field explicitly in case we ever add more
206                // fields to AccumulatorArgs making it easier to see if changes are also
207                // needed here.
208                schema: acc_args.schema,
209                ignore_nulls: acc_args.ignore_nulls,
210                order_bys: acc_args.order_bys,
211                is_reversed: acc_args.is_reversed,
212                name: acc_args.name,
213                is_distinct: acc_args.is_distinct,
214            })?;
215
216            Ok(Box::new(StringAggAccumulator::new(
217                array_agg_acc,
218                delimiter,
219            )))
220        }
221    }
222
223    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
224        datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
225    }
226
227    fn documentation(&self) -> Option<&Documentation> {
228        self.doc()
229    }
230}
231
232/// StringAgg accumulator for the general case (with order or distinct specified)
233#[derive(Debug)]
234pub(crate) struct StringAggAccumulator {
235    array_agg_acc: Box<dyn Accumulator>,
236    delimiter: String,
237}
238
239impl StringAggAccumulator {
240    pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
241        Self {
242            array_agg_acc,
243            delimiter: delimiter.to_string(),
244        }
245    }
246}
247
248impl Accumulator for StringAggAccumulator {
249    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
250        self.array_agg_acc.update_batch(&filter_index(values, 1))
251    }
252
253    fn evaluate(&mut self) -> Result<ScalarValue> {
254        let scalar = self.array_agg_acc.evaluate()?;
255
256        let ScalarValue::List(list) = scalar else {
257            return internal_err!(
258                "Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}",
259                scalar.data_type()
260            );
261        };
262
263        let string_arr: Vec<_> = match list.value_type() {
264            DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
265                .iter()
266                .flatten()
267                .collect(),
268            DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
269                .iter()
270                .flatten()
271                .collect(),
272            DataType::Utf8View => as_string_view_array(list.values())?
273                .iter()
274                .flatten()
275                .collect(),
276            _ => {
277                return internal_err!(
278                    "Expected elements to of type Utf8 or LargeUtf8, but got {}",
279                    list.value_type()
280                );
281            }
282        };
283
284        if string_arr.is_empty() {
285            return Ok(ScalarValue::LargeUtf8(None));
286        }
287
288        Ok(ScalarValue::LargeUtf8(Some(
289            string_arr.join(&self.delimiter),
290        )))
291    }
292
293    fn size(&self) -> usize {
294        size_of_val(self) - size_of_val(&self.array_agg_acc)
295            + self.array_agg_acc.size()
296            + self.delimiter.capacity()
297    }
298
299    fn state(&mut self) -> Result<Vec<ScalarValue>> {
300        self.array_agg_acc.state()
301    }
302
303    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
304        self.array_agg_acc.merge_batch(values)
305    }
306}
307
308fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
309    values
310        .iter()
311        .enumerate()
312        .filter(|(i, _)| *i != index)
313        .map(|(_, v)| v)
314        .cloned()
315        .collect::<Vec<_>>()
316}
317
318/// StringAgg accumulator for the simple case (no order or distinct specified)
319/// This accumulator is more efficient than `StringAggAccumulator`
320/// because it accumulates the string directly,
321/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`.
322#[derive(Debug)]
323pub(crate) struct SimpleStringAggAccumulator {
324    delimiter: String,
325    /// Updated during `update_batch()`. e.g. "foo,bar"
326    accumulated_string: String,
327    has_value: bool,
328}
329
330impl SimpleStringAggAccumulator {
331    pub fn new(delimiter: &str) -> Self {
332        Self {
333            delimiter: delimiter.to_string(),
334            accumulated_string: "".to_string(),
335            has_value: false,
336        }
337    }
338
339    #[inline]
340    fn append_strings<'a, I>(&mut self, iter: I)
341    where
342        I: Iterator<Item = Option<&'a str>>,
343    {
344        for value in iter.flatten() {
345            if self.has_value {
346                self.accumulated_string.push_str(&self.delimiter);
347            }
348
349            self.accumulated_string.push_str(value);
350            self.has_value = true;
351        }
352    }
353}
354
355impl Accumulator for SimpleStringAggAccumulator {
356    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
357        let string_arr = values.first().ok_or_else(|| {
358            internal_datafusion_err!(
359                "Planner should ensure its first arg is Utf8/Utf8View"
360            )
361        })?;
362
363        match string_arr.data_type() {
364            DataType::Utf8 => {
365                let array = as_string_array(string_arr)?;
366                self.append_strings(array.iter());
367            }
368            DataType::LargeUtf8 => {
369                let array = as_generic_string_array::<i64>(string_arr)?;
370                self.append_strings(array.iter());
371            }
372            DataType::Utf8View => {
373                let array = as_string_view_array(string_arr)?;
374                self.append_strings(array.iter());
375            }
376            other => {
377                return internal_err!(
378                    "Planner should ensure string_agg first argument is Utf8-like, found {other}"
379                );
380            }
381        }
382
383        Ok(())
384    }
385
386    fn evaluate(&mut self) -> Result<ScalarValue> {
387        let result = if self.has_value {
388            ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
389        } else {
390            ScalarValue::LargeUtf8(None)
391        };
392
393        self.has_value = false;
394        Ok(result)
395    }
396
397    fn size(&self) -> usize {
398        size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
399    }
400
401    fn state(&mut self) -> Result<Vec<ScalarValue>> {
402        let result = if self.has_value {
403            ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
404        } else {
405            ScalarValue::LargeUtf8(None)
406        };
407        self.has_value = false;
408
409        Ok(vec![result])
410    }
411
412    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
413        self.update_batch(values)
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use arrow::array::LargeStringArray;
421    use arrow::compute::SortOptions;
422    use arrow::datatypes::{Fields, Schema};
423    use datafusion_common::internal_err;
424    use datafusion_physical_expr::expressions::Column;
425    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
426    use std::sync::Arc;
427
428    #[test]
429    fn no_duplicates_no_distinct() -> Result<()> {
430        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
431
432        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
433        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
434        acc1 = merge(acc1, acc2)?;
435
436        let result = some_str(acc1.evaluate()?);
437
438        assert_eq!(result, "a,b,c,d,e,f");
439
440        Ok(())
441    }
442
443    #[test]
444    fn no_duplicates_distinct() -> Result<()> {
445        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
446            .distinct()
447            .build_two()?;
448
449        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
450        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
451        acc1 = merge(acc1, acc2)?;
452
453        let result = some_str_sorted(acc1.evaluate()?, ",");
454
455        assert_eq!(result, "a,b,c,d,e,f");
456
457        Ok(())
458    }
459
460    #[test]
461    fn duplicates_no_distinct() -> Result<()> {
462        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
463
464        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
465        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
466        acc1 = merge(acc1, acc2)?;
467
468        let result = some_str(acc1.evaluate()?);
469
470        assert_eq!(result, "a,b,c,a,b,c");
471
472        Ok(())
473    }
474
475    #[test]
476    fn duplicates_distinct() -> Result<()> {
477        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
478            .distinct()
479            .build_two()?;
480
481        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
482        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
483        acc1 = merge(acc1, acc2)?;
484
485        let result = some_str_sorted(acc1.evaluate()?, ",");
486
487        assert_eq!(result, "a,b,c");
488
489        Ok(())
490    }
491
492    #[test]
493    fn no_duplicates_distinct_sort_asc() -> Result<()> {
494        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
495            .distinct()
496            .order_by_col("col", SortOptions::new(false, false))
497            .build_two()?;
498
499        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
500        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
501        acc1 = merge(acc1, acc2)?;
502
503        let result = some_str(acc1.evaluate()?);
504
505        assert_eq!(result, "a,b,c,d,e,f");
506
507        Ok(())
508    }
509
510    #[test]
511    fn no_duplicates_distinct_sort_desc() -> Result<()> {
512        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
513            .distinct()
514            .order_by_col("col", SortOptions::new(true, false))
515            .build_two()?;
516
517        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
518        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
519        acc1 = merge(acc1, acc2)?;
520
521        let result = some_str(acc1.evaluate()?);
522
523        assert_eq!(result, "f,e,d,c,b,a");
524
525        Ok(())
526    }
527
528    #[test]
529    fn duplicates_distinct_sort_asc() -> Result<()> {
530        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
531            .distinct()
532            .order_by_col("col", SortOptions::new(false, false))
533            .build_two()?;
534
535        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
536        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
537        acc1 = merge(acc1, acc2)?;
538
539        let result = some_str(acc1.evaluate()?);
540
541        assert_eq!(result, "a,b,c");
542
543        Ok(())
544    }
545
546    #[test]
547    fn duplicates_distinct_sort_desc() -> Result<()> {
548        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
549            .distinct()
550            .order_by_col("col", SortOptions::new(true, false))
551            .build_two()?;
552
553        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
554        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
555        acc1 = merge(acc1, acc2)?;
556
557        let result = some_str(acc1.evaluate()?);
558
559        assert_eq!(result, "c,b,a");
560
561        Ok(())
562    }
563
564    struct StringAggAccumulatorBuilder {
565        sep: String,
566        distinct: bool,
567        order_bys: Vec<PhysicalSortExpr>,
568        schema: Schema,
569    }
570
571    impl StringAggAccumulatorBuilder {
572        fn new(sep: &str) -> Self {
573            Self {
574                sep: sep.to_string(),
575                distinct: Default::default(),
576                order_bys: vec![],
577                schema: Schema {
578                    fields: Fields::from(vec![Field::new(
579                        "col",
580                        DataType::LargeUtf8,
581                        true,
582                    )]),
583                    metadata: Default::default(),
584                },
585            }
586        }
587        fn distinct(mut self) -> Self {
588            self.distinct = true;
589            self
590        }
591
592        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
593            self.order_bys.extend([PhysicalSortExpr::new(
594                Arc::new(
595                    Column::new_with_schema(col, &self.schema)
596                        .expect("column not available in schema"),
597                ),
598                sort_options,
599            )]);
600            self
601        }
602
603        fn build(&self) -> Result<Box<dyn Accumulator>> {
604            StringAgg::new().accumulator(AccumulatorArgs {
605                return_field: Field::new("f", DataType::LargeUtf8, true).into(),
606                schema: &self.schema,
607                expr_fields: &[
608                    Field::new("col", DataType::LargeUtf8, true).into(),
609                    Field::new("lit", DataType::Utf8, false).into(),
610                ],
611                ignore_nulls: false,
612                order_bys: &self.order_bys,
613                is_reversed: false,
614                name: "",
615                is_distinct: self.distinct,
616                exprs: &[
617                    Arc::new(Column::new("col", 0)),
618                    Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
619                ],
620            })
621        }
622
623        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
624            Ok((self.build()?, self.build()?))
625        }
626    }
627
628    fn some_str(value: ScalarValue) -> String {
629        str(value)
630            .expect("ScalarValue was not a String")
631            .expect("ScalarValue was None")
632    }
633
634    fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
635        let value = some_str(value);
636        let mut parts: Vec<&str> = value.split(sep).collect();
637        parts.sort();
638        parts.join(sep)
639    }
640
641    fn str(value: ScalarValue) -> Result<Option<String>> {
642        match value {
643            ScalarValue::LargeUtf8(v) => Ok(v),
644            _ => internal_err!(
645                "Expected ScalarValue::LargeUtf8, got {}",
646                value.data_type()
647            ),
648        }
649    }
650
651    fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
652        Arc::new(LargeStringArray::from(list.to_vec()))
653    }
654
655    fn merge(
656        mut acc1: Box<dyn Accumulator>,
657        mut acc2: Box<dyn Accumulator>,
658    ) -> Result<Box<dyn Accumulator>> {
659        let intermediate_state = acc2.state().and_then(|e| {
660            e.iter()
661                .map(|v| v.to_array())
662                .collect::<Result<Vec<ArrayRef>>>()
663        })?;
664        acc1.merge_batch(&intermediate_state)?;
665        Ok(acc1)
666    }
667}