datafusion_functions_aggregate/
string_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`StringAgg`] accumulator for the `string_agg` function
19
20use crate::array_agg::ArrayAgg;
21use arrow::array::ArrayRef;
22use arrow::datatypes::{DataType, Field, FieldRef};
23use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
24use datafusion_common::Result;
25use datafusion_common::{internal_err, not_impl_err, ScalarValue};
26use datafusion_expr::function::AccumulatorArgs;
27use datafusion_expr::{
28    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
29};
30use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
31use datafusion_macros::user_doc;
32use datafusion_physical_expr::expressions::Literal;
33use std::any::Any;
34use std::mem::size_of_val;
35
36make_udaf_expr_and_func!(
37    StringAgg,
38    string_agg,
39    expr delimiter,
40    "Concatenates the values of string expressions and places separator values between them",
41    string_agg_udaf
42);
43
44#[user_doc(
45    doc_section(label = "General Functions"),
46    description = "Concatenates the values of string expressions and places separator values between them. \
47If ordering is required, strings are concatenated in the specified order. \
48This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
49    syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
50    sql_example = r#"```sql
51> SELECT string_agg(name, ', ') AS names_list
52  FROM employee;
53+--------------------------+
54| names_list               |
55+--------------------------+
56| Alice, Bob, Bob, Charlie |
57+--------------------------+
58> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
59  FROM employee;
60+--------------------------+
61| names_list               |
62+--------------------------+
63| Charlie, Bob, Bob, Alice |
64+--------------------------+
65> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
66  FROM employee;
67+--------------------------+
68| names_list               |
69+--------------------------+
70| Charlie, Bob, Alice |
71+--------------------------+
72```"#,
73    argument(
74        name = "expression",
75        description = "The string expression to concatenate. Can be a column or any valid string expression."
76    ),
77    argument(
78        name = "delimiter",
79        description = "A literal string used as a separator between the concatenated values."
80    )
81)]
82/// STRING_AGG aggregate expression
83#[derive(Debug)]
84pub struct StringAgg {
85    signature: Signature,
86    array_agg: ArrayAgg,
87}
88
89impl StringAgg {
90    /// Create a new StringAgg aggregate function
91    pub fn new() -> Self {
92        Self {
93            signature: Signature::one_of(
94                vec![
95                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
96                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
97                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
98                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
99                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
100                    TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
101                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
102                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
103                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
104                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
105                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
106                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
107                ],
108                Volatility::Immutable,
109            ),
110            array_agg: Default::default(),
111        }
112    }
113}
114
115impl Default for StringAgg {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl AggregateUDFImpl for StringAgg {
122    fn as_any(&self) -> &dyn Any {
123        self
124    }
125
126    fn name(&self) -> &str {
127        "string_agg"
128    }
129
130    fn signature(&self) -> &Signature {
131        &self.signature
132    }
133
134    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
135        Ok(DataType::LargeUtf8)
136    }
137
138    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
139        self.array_agg.state_fields(args)
140    }
141
142    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
143        let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
144            return not_impl_err!(
145                "The second argument of the string_agg function must be a string literal"
146            );
147        };
148
149        let delimiter = if lit.value().is_null() {
150            // If the second argument (the delimiter that joins strings) is NULL, join
151            // on an empty string. (e.g. [a, b, c] => "abc").
152            ""
153        } else if let Some(lit_string) = lit.value().try_as_str() {
154            lit_string.unwrap_or("")
155        } else {
156            return not_impl_err!(
157                "StringAgg not supported for delimiter \"{}\"",
158                lit.value()
159            );
160        };
161
162        let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
163            return_field: Field::new(
164                "f",
165                DataType::new_list(acc_args.return_field.data_type().clone(), true),
166                true,
167            )
168            .into(),
169            exprs: &filter_index(acc_args.exprs, 1),
170            ..acc_args
171        })?;
172
173        Ok(Box::new(StringAggAccumulator::new(
174            array_agg_acc,
175            delimiter,
176        )))
177    }
178
179    fn documentation(&self) -> Option<&Documentation> {
180        self.doc()
181    }
182}
183
184#[derive(Debug)]
185pub(crate) struct StringAggAccumulator {
186    array_agg_acc: Box<dyn Accumulator>,
187    delimiter: String,
188}
189
190impl StringAggAccumulator {
191    pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
192        Self {
193            array_agg_acc,
194            delimiter: delimiter.to_string(),
195        }
196    }
197}
198
199impl Accumulator for StringAggAccumulator {
200    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
201        self.array_agg_acc.update_batch(&filter_index(values, 1))
202    }
203
204    fn evaluate(&mut self) -> Result<ScalarValue> {
205        let scalar = self.array_agg_acc.evaluate()?;
206
207        let ScalarValue::List(list) = scalar else {
208            return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
209        };
210
211        let string_arr: Vec<_> = match list.value_type() {
212            DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
213                .iter()
214                .flatten()
215                .collect(),
216            DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
217                .iter()
218                .flatten()
219                .collect(),
220            DataType::Utf8View => as_string_view_array(list.values())?
221                .iter()
222                .flatten()
223                .collect(),
224            _ => {
225                return internal_err!(
226                    "Expected elements to of type Utf8 or LargeUtf8, but got {}",
227                    list.value_type()
228                )
229            }
230        };
231
232        if string_arr.is_empty() {
233            return Ok(ScalarValue::LargeUtf8(None));
234        }
235
236        Ok(ScalarValue::LargeUtf8(Some(
237            string_arr.join(&self.delimiter),
238        )))
239    }
240
241    fn size(&self) -> usize {
242        size_of_val(self) - size_of_val(&self.array_agg_acc)
243            + self.array_agg_acc.size()
244            + self.delimiter.capacity()
245    }
246
247    fn state(&mut self) -> Result<Vec<ScalarValue>> {
248        self.array_agg_acc.state()
249    }
250
251    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
252        self.array_agg_acc.merge_batch(values)
253    }
254}
255
256fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
257    values
258        .iter()
259        .enumerate()
260        .filter(|(i, _)| *i != index)
261        .map(|(_, v)| v)
262        .cloned()
263        .collect::<Vec<_>>()
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use arrow::array::LargeStringArray;
270    use arrow::compute::SortOptions;
271    use arrow::datatypes::{Fields, Schema};
272    use datafusion_common::internal_err;
273    use datafusion_physical_expr::expressions::Column;
274    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
275    use std::sync::Arc;
276
277    #[test]
278    fn no_duplicates_no_distinct() -> Result<()> {
279        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
280
281        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
282        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
283        acc1 = merge(acc1, acc2)?;
284
285        let result = some_str(acc1.evaluate()?);
286
287        assert_eq!(result, "a,b,c,d,e,f");
288
289        Ok(())
290    }
291
292    #[test]
293    fn no_duplicates_distinct() -> Result<()> {
294        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
295            .distinct()
296            .build_two()?;
297
298        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
299        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
300        acc1 = merge(acc1, acc2)?;
301
302        let result = some_str_sorted(acc1.evaluate()?, ",");
303
304        assert_eq!(result, "a,b,c,d,e,f");
305
306        Ok(())
307    }
308
309    #[test]
310    fn duplicates_no_distinct() -> Result<()> {
311        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
312
313        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
314        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
315        acc1 = merge(acc1, acc2)?;
316
317        let result = some_str(acc1.evaluate()?);
318
319        assert_eq!(result, "a,b,c,a,b,c");
320
321        Ok(())
322    }
323
324    #[test]
325    fn duplicates_distinct() -> Result<()> {
326        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
327            .distinct()
328            .build_two()?;
329
330        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
331        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
332        acc1 = merge(acc1, acc2)?;
333
334        let result = some_str_sorted(acc1.evaluate()?, ",");
335
336        assert_eq!(result, "a,b,c");
337
338        Ok(())
339    }
340
341    #[test]
342    fn no_duplicates_distinct_sort_asc() -> Result<()> {
343        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
344            .distinct()
345            .order_by_col("col", SortOptions::new(false, false))
346            .build_two()?;
347
348        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
349        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
350        acc1 = merge(acc1, acc2)?;
351
352        let result = some_str(acc1.evaluate()?);
353
354        assert_eq!(result, "a,b,c,d,e,f");
355
356        Ok(())
357    }
358
359    #[test]
360    fn no_duplicates_distinct_sort_desc() -> Result<()> {
361        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
362            .distinct()
363            .order_by_col("col", SortOptions::new(true, false))
364            .build_two()?;
365
366        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
367        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
368        acc1 = merge(acc1, acc2)?;
369
370        let result = some_str(acc1.evaluate()?);
371
372        assert_eq!(result, "f,e,d,c,b,a");
373
374        Ok(())
375    }
376
377    #[test]
378    fn duplicates_distinct_sort_asc() -> Result<()> {
379        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
380            .distinct()
381            .order_by_col("col", SortOptions::new(false, false))
382            .build_two()?;
383
384        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
385        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
386        acc1 = merge(acc1, acc2)?;
387
388        let result = some_str(acc1.evaluate()?);
389
390        assert_eq!(result, "a,b,c");
391
392        Ok(())
393    }
394
395    #[test]
396    fn duplicates_distinct_sort_desc() -> Result<()> {
397        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
398            .distinct()
399            .order_by_col("col", SortOptions::new(true, false))
400            .build_two()?;
401
402        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
403        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
404        acc1 = merge(acc1, acc2)?;
405
406        let result = some_str(acc1.evaluate()?);
407
408        assert_eq!(result, "c,b,a");
409
410        Ok(())
411    }
412
413    struct StringAggAccumulatorBuilder {
414        sep: String,
415        distinct: bool,
416        ordering: LexOrdering,
417        schema: Schema,
418    }
419
420    impl StringAggAccumulatorBuilder {
421        fn new(sep: &str) -> Self {
422            Self {
423                sep: sep.to_string(),
424                distinct: Default::default(),
425                ordering: Default::default(),
426                schema: Schema {
427                    fields: Fields::from(vec![Field::new(
428                        "col",
429                        DataType::LargeUtf8,
430                        true,
431                    )]),
432                    metadata: Default::default(),
433                },
434            }
435        }
436        fn distinct(mut self) -> Self {
437            self.distinct = true;
438            self
439        }
440
441        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
442            self.ordering.extend([PhysicalSortExpr::new(
443                Arc::new(
444                    Column::new_with_schema(col, &self.schema)
445                        .expect("column not available in schema"),
446                ),
447                sort_options,
448            )]);
449            self
450        }
451
452        fn build(&self) -> Result<Box<dyn Accumulator>> {
453            StringAgg::new().accumulator(AccumulatorArgs {
454                return_field: Field::new("f", DataType::LargeUtf8, true).into(),
455                schema: &self.schema,
456                ignore_nulls: false,
457                ordering_req: &self.ordering,
458                is_reversed: false,
459                name: "",
460                is_distinct: self.distinct,
461                exprs: &[
462                    Arc::new(Column::new("col", 0)),
463                    Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
464                ],
465            })
466        }
467
468        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
469            Ok((self.build()?, self.build()?))
470        }
471    }
472
473    fn some_str(value: ScalarValue) -> String {
474        str(value)
475            .expect("ScalarValue was not a String")
476            .expect("ScalarValue was None")
477    }
478
479    fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
480        let value = some_str(value);
481        let mut parts: Vec<&str> = value.split(sep).collect();
482        parts.sort();
483        parts.join(sep)
484    }
485
486    fn str(value: ScalarValue) -> Result<Option<String>> {
487        match value {
488            ScalarValue::LargeUtf8(v) => Ok(v),
489            _ => internal_err!(
490                "Expected ScalarValue::LargeUtf8, got {}",
491                value.data_type()
492            ),
493        }
494    }
495
496    fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
497        Arc::new(LargeStringArray::from(list.to_vec()))
498    }
499
500    fn merge(
501        mut acc1: Box<dyn Accumulator>,
502        mut acc2: Box<dyn Accumulator>,
503    ) -> Result<Box<dyn Accumulator>> {
504        let intermediate_state = acc2.state().and_then(|e| {
505            e.iter()
506                .map(|v| v.to_array())
507                .collect::<Result<Vec<ArrayRef>>>()
508        })?;
509        acc1.merge_batch(&intermediate_state)?;
510        Ok(acc1)
511    }
512}