datafusion_functions_aggregate/
string_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`StringAgg`] accumulator for the `string_agg` function
19
20use std::any::Any;
21use std::hash::{DefaultHasher, Hash, Hasher};
22use std::mem::size_of_val;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
30use datafusion_expr::function::AccumulatorArgs;
31use datafusion_expr::{
32    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
33};
34use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
35use datafusion_macros::user_doc;
36use datafusion_physical_expr::expressions::Literal;
37
38make_udaf_expr_and_func!(
39    StringAgg,
40    string_agg,
41    expr delimiter,
42    "Concatenates the values of string expressions and places separator values between them",
43    string_agg_udaf
44);
45
46#[user_doc(
47    doc_section(label = "General Functions"),
48    description = "Concatenates the values of string expressions and places separator values between them. \
49If ordering is required, strings are concatenated in the specified order. \
50This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
51    syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
52    sql_example = r#"```sql
53> SELECT string_agg(name, ', ') AS names_list
54  FROM employee;
55+--------------------------+
56| names_list               |
57+--------------------------+
58| Alice, Bob, Bob, Charlie |
59+--------------------------+
60> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
61  FROM employee;
62+--------------------------+
63| names_list               |
64+--------------------------+
65| Charlie, Bob, Bob, Alice |
66+--------------------------+
67> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
68  FROM employee;
69+--------------------------+
70| names_list               |
71+--------------------------+
72| Charlie, Bob, Alice |
73+--------------------------+
74```"#,
75    argument(
76        name = "expression",
77        description = "The string expression to concatenate. Can be a column or any valid string expression."
78    ),
79    argument(
80        name = "delimiter",
81        description = "A literal string used as a separator between the concatenated values."
82    )
83)]
84/// STRING_AGG aggregate expression
85#[derive(Debug)]
86pub struct StringAgg {
87    signature: Signature,
88    array_agg: ArrayAgg,
89}
90
91impl StringAgg {
92    /// Create a new StringAgg aggregate function
93    pub fn new() -> Self {
94        Self {
95            signature: Signature::one_of(
96                vec![
97                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
98                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
99                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
100                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
101                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
102                    TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
103                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
104                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
105                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
106                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
107                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
108                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
109                ],
110                Volatility::Immutable,
111            ),
112            array_agg: Default::default(),
113        }
114    }
115}
116
117impl Default for StringAgg {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl AggregateUDFImpl for StringAgg {
124    fn as_any(&self) -> &dyn Any {
125        self
126    }
127
128    fn name(&self) -> &str {
129        "string_agg"
130    }
131
132    fn signature(&self) -> &Signature {
133        &self.signature
134    }
135
136    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
137        Ok(DataType::LargeUtf8)
138    }
139
140    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
141        self.array_agg.state_fields(args)
142    }
143
144    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
145        let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
146            return not_impl_err!(
147                "The second argument of the string_agg function must be a string literal"
148            );
149        };
150
151        let delimiter = if lit.value().is_null() {
152            // If the second argument (the delimiter that joins strings) is NULL, join
153            // on an empty string. (e.g. [a, b, c] => "abc").
154            ""
155        } else if let Some(lit_string) = lit.value().try_as_str() {
156            lit_string.unwrap_or("")
157        } else {
158            return not_impl_err!(
159                "StringAgg not supported for delimiter \"{}\"",
160                lit.value()
161            );
162        };
163
164        let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
165            return_field: Field::new(
166                "f",
167                DataType::new_list(acc_args.return_field.data_type().clone(), true),
168                true,
169            )
170            .into(),
171            exprs: &filter_index(acc_args.exprs, 1),
172            ..acc_args
173        })?;
174
175        Ok(Box::new(StringAggAccumulator::new(
176            array_agg_acc,
177            delimiter,
178        )))
179    }
180
181    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
182        datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
183    }
184
185    fn documentation(&self) -> Option<&Documentation> {
186        self.doc()
187    }
188
189    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
190        let Some(other) = other.as_any().downcast_ref::<Self>() else {
191            return false;
192        };
193        let Self {
194            signature,
195            array_agg,
196        } = self;
197        signature == &other.signature && array_agg.equals(&other.array_agg)
198    }
199
200    fn hash_value(&self) -> u64 {
201        let Self {
202            signature,
203            array_agg,
204        } = self;
205        let mut hasher = DefaultHasher::new();
206        std::any::type_name::<Self>().hash(&mut hasher);
207        signature.hash(&mut hasher);
208        hasher.write_u64(array_agg.hash_value());
209        hasher.finish()
210    }
211}
212
213#[derive(Debug)]
214pub(crate) struct StringAggAccumulator {
215    array_agg_acc: Box<dyn Accumulator>,
216    delimiter: String,
217}
218
219impl StringAggAccumulator {
220    pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
221        Self {
222            array_agg_acc,
223            delimiter: delimiter.to_string(),
224        }
225    }
226}
227
228impl Accumulator for StringAggAccumulator {
229    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
230        self.array_agg_acc.update_batch(&filter_index(values, 1))
231    }
232
233    fn evaluate(&mut self) -> Result<ScalarValue> {
234        let scalar = self.array_agg_acc.evaluate()?;
235
236        let ScalarValue::List(list) = scalar else {
237            return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
238        };
239
240        let string_arr: Vec<_> = match list.value_type() {
241            DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
242                .iter()
243                .flatten()
244                .collect(),
245            DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
246                .iter()
247                .flatten()
248                .collect(),
249            DataType::Utf8View => as_string_view_array(list.values())?
250                .iter()
251                .flatten()
252                .collect(),
253            _ => {
254                return internal_err!(
255                    "Expected elements to of type Utf8 or LargeUtf8, but got {}",
256                    list.value_type()
257                )
258            }
259        };
260
261        if string_arr.is_empty() {
262            return Ok(ScalarValue::LargeUtf8(None));
263        }
264
265        Ok(ScalarValue::LargeUtf8(Some(
266            string_arr.join(&self.delimiter),
267        )))
268    }
269
270    fn size(&self) -> usize {
271        size_of_val(self) - size_of_val(&self.array_agg_acc)
272            + self.array_agg_acc.size()
273            + self.delimiter.capacity()
274    }
275
276    fn state(&mut self) -> Result<Vec<ScalarValue>> {
277        self.array_agg_acc.state()
278    }
279
280    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
281        self.array_agg_acc.merge_batch(values)
282    }
283}
284
285fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
286    values
287        .iter()
288        .enumerate()
289        .filter(|(i, _)| *i != index)
290        .map(|(_, v)| v)
291        .cloned()
292        .collect::<Vec<_>>()
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use arrow::array::LargeStringArray;
299    use arrow::compute::SortOptions;
300    use arrow::datatypes::{Fields, Schema};
301    use datafusion_common::internal_err;
302    use datafusion_physical_expr::expressions::Column;
303    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
304    use std::sync::Arc;
305
306    #[test]
307    fn no_duplicates_no_distinct() -> Result<()> {
308        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
309
310        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
311        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
312        acc1 = merge(acc1, acc2)?;
313
314        let result = some_str(acc1.evaluate()?);
315
316        assert_eq!(result, "a,b,c,d,e,f");
317
318        Ok(())
319    }
320
321    #[test]
322    fn no_duplicates_distinct() -> Result<()> {
323        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
324            .distinct()
325            .build_two()?;
326
327        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
328        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
329        acc1 = merge(acc1, acc2)?;
330
331        let result = some_str_sorted(acc1.evaluate()?, ",");
332
333        assert_eq!(result, "a,b,c,d,e,f");
334
335        Ok(())
336    }
337
338    #[test]
339    fn duplicates_no_distinct() -> Result<()> {
340        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
341
342        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
343        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
344        acc1 = merge(acc1, acc2)?;
345
346        let result = some_str(acc1.evaluate()?);
347
348        assert_eq!(result, "a,b,c,a,b,c");
349
350        Ok(())
351    }
352
353    #[test]
354    fn duplicates_distinct() -> Result<()> {
355        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
356            .distinct()
357            .build_two()?;
358
359        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
360        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
361        acc1 = merge(acc1, acc2)?;
362
363        let result = some_str_sorted(acc1.evaluate()?, ",");
364
365        assert_eq!(result, "a,b,c");
366
367        Ok(())
368    }
369
370    #[test]
371    fn no_duplicates_distinct_sort_asc() -> Result<()> {
372        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
373            .distinct()
374            .order_by_col("col", SortOptions::new(false, false))
375            .build_two()?;
376
377        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
378        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
379        acc1 = merge(acc1, acc2)?;
380
381        let result = some_str(acc1.evaluate()?);
382
383        assert_eq!(result, "a,b,c,d,e,f");
384
385        Ok(())
386    }
387
388    #[test]
389    fn no_duplicates_distinct_sort_desc() -> Result<()> {
390        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
391            .distinct()
392            .order_by_col("col", SortOptions::new(true, false))
393            .build_two()?;
394
395        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
396        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
397        acc1 = merge(acc1, acc2)?;
398
399        let result = some_str(acc1.evaluate()?);
400
401        assert_eq!(result, "f,e,d,c,b,a");
402
403        Ok(())
404    }
405
406    #[test]
407    fn duplicates_distinct_sort_asc() -> Result<()> {
408        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
409            .distinct()
410            .order_by_col("col", SortOptions::new(false, false))
411            .build_two()?;
412
413        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
414        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
415        acc1 = merge(acc1, acc2)?;
416
417        let result = some_str(acc1.evaluate()?);
418
419        assert_eq!(result, "a,b,c");
420
421        Ok(())
422    }
423
424    #[test]
425    fn duplicates_distinct_sort_desc() -> Result<()> {
426        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
427            .distinct()
428            .order_by_col("col", SortOptions::new(true, false))
429            .build_two()?;
430
431        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
432        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
433        acc1 = merge(acc1, acc2)?;
434
435        let result = some_str(acc1.evaluate()?);
436
437        assert_eq!(result, "c,b,a");
438
439        Ok(())
440    }
441
442    struct StringAggAccumulatorBuilder {
443        sep: String,
444        distinct: bool,
445        order_bys: Vec<PhysicalSortExpr>,
446        schema: Schema,
447    }
448
449    impl StringAggAccumulatorBuilder {
450        fn new(sep: &str) -> Self {
451            Self {
452                sep: sep.to_string(),
453                distinct: Default::default(),
454                order_bys: vec![],
455                schema: Schema {
456                    fields: Fields::from(vec![Field::new(
457                        "col",
458                        DataType::LargeUtf8,
459                        true,
460                    )]),
461                    metadata: Default::default(),
462                },
463            }
464        }
465        fn distinct(mut self) -> Self {
466            self.distinct = true;
467            self
468        }
469
470        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
471            self.order_bys.extend([PhysicalSortExpr::new(
472                Arc::new(
473                    Column::new_with_schema(col, &self.schema)
474                        .expect("column not available in schema"),
475                ),
476                sort_options,
477            )]);
478            self
479        }
480
481        fn build(&self) -> Result<Box<dyn Accumulator>> {
482            StringAgg::new().accumulator(AccumulatorArgs {
483                return_field: Field::new("f", DataType::LargeUtf8, true).into(),
484                schema: &self.schema,
485                ignore_nulls: false,
486                order_bys: &self.order_bys,
487                is_reversed: false,
488                name: "",
489                is_distinct: self.distinct,
490                exprs: &[
491                    Arc::new(Column::new("col", 0)),
492                    Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
493                ],
494            })
495        }
496
497        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
498            Ok((self.build()?, self.build()?))
499        }
500    }
501
502    fn some_str(value: ScalarValue) -> String {
503        str(value)
504            .expect("ScalarValue was not a String")
505            .expect("ScalarValue was None")
506    }
507
508    fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
509        let value = some_str(value);
510        let mut parts: Vec<&str> = value.split(sep).collect();
511        parts.sort();
512        parts.join(sep)
513    }
514
515    fn str(value: ScalarValue) -> Result<Option<String>> {
516        match value {
517            ScalarValue::LargeUtf8(v) => Ok(v),
518            _ => internal_err!(
519                "Expected ScalarValue::LargeUtf8, got {}",
520                value.data_type()
521            ),
522        }
523    }
524
525    fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
526        Arc::new(LargeStringArray::from(list.to_vec()))
527    }
528
529    fn merge(
530        mut acc1: Box<dyn Accumulator>,
531        mut acc2: Box<dyn Accumulator>,
532    ) -> Result<Box<dyn Accumulator>> {
533        let intermediate_state = acc2.state().and_then(|e| {
534            e.iter()
535                .map(|v| v.to_array())
536                .collect::<Result<Vec<ArrayRef>>>()
537        })?;
538        acc1.merge_batch(&intermediate_state)?;
539        Ok(acc1)
540    }
541}