datafusion_functions_aggregate/
string_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`StringAgg`] accumulator for the `string_agg` function
19
20use crate::array_agg::ArrayAgg;
21use arrow::array::ArrayRef;
22use arrow::datatypes::{DataType, Field};
23use datafusion_common::cast::as_generic_string_array;
24use datafusion_common::Result;
25use datafusion_common::{internal_err, not_impl_err, ScalarValue};
26use datafusion_expr::function::AccumulatorArgs;
27use datafusion_expr::{
28    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
29};
30use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
31use datafusion_macros::user_doc;
32use datafusion_physical_expr::expressions::Literal;
33use std::any::Any;
34use std::mem::size_of_val;
35
36make_udaf_expr_and_func!(
37    StringAgg,
38    string_agg,
39    expr delimiter,
40    "Concatenates the values of string expressions and places separator values between them",
41    string_agg_udaf
42);
43
44#[user_doc(
45    doc_section(label = "General Functions"),
46    description = "Concatenates the values of string expressions and places separator values between them. \
47If ordering is required, strings are concatenated in the specified order. \
48This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
49    syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
50    sql_example = r#"```sql
51> SELECT string_agg(name, ', ') AS names_list
52  FROM employee;
53+--------------------------+
54| names_list               |
55+--------------------------+
56| Alice, Bob, Bob, Charlie |
57+--------------------------+
58> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
59  FROM employee;
60+--------------------------+
61| names_list               |
62+--------------------------+
63| Charlie, Bob, Bob, Alice |
64+--------------------------+
65> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
66  FROM employee;
67+--------------------------+
68| names_list               |
69+--------------------------+
70| Charlie, Bob, Alice |
71+--------------------------+
72```"#,
73    argument(
74        name = "expression",
75        description = "The string expression to concatenate. Can be a column or any valid string expression."
76    ),
77    argument(
78        name = "delimiter",
79        description = "A literal string used as a separator between the concatenated values."
80    )
81)]
82/// STRING_AGG aggregate expression
83#[derive(Debug)]
84pub struct StringAgg {
85    signature: Signature,
86    array_agg: ArrayAgg,
87}
88
89impl StringAgg {
90    /// Create a new StringAgg aggregate function
91    pub fn new() -> Self {
92        Self {
93            signature: Signature::one_of(
94                vec![
95                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
96                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
97                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
98                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
99                    TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
100                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
101                ],
102                Volatility::Immutable,
103            ),
104            array_agg: Default::default(),
105        }
106    }
107}
108
109impl Default for StringAgg {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl AggregateUDFImpl for StringAgg {
116    fn as_any(&self) -> &dyn Any {
117        self
118    }
119
120    fn name(&self) -> &str {
121        "string_agg"
122    }
123
124    fn signature(&self) -> &Signature {
125        &self.signature
126    }
127
128    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
129        Ok(DataType::LargeUtf8)
130    }
131
132    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
133        self.array_agg.state_fields(args)
134    }
135
136    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
137        let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
138            return not_impl_err!(
139                "The second argument of the string_agg function must be a string literal"
140            );
141        };
142
143        let delimiter = if lit.value().is_null() {
144            // If the second argument (the delimiter that joins strings) is NULL, join
145            // on an empty string. (e.g. [a, b, c] => "abc").
146            ""
147        } else if let Some(lit_string) = lit.value().try_as_str() {
148            lit_string.unwrap_or("")
149        } else {
150            return not_impl_err!(
151                "StringAgg not supported for delimiter \"{}\"",
152                lit.value()
153            );
154        };
155
156        let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
157            return_type: &DataType::new_list(acc_args.return_type.clone(), true),
158            exprs: &filter_index(acc_args.exprs, 1),
159            ..acc_args
160        })?;
161
162        Ok(Box::new(StringAggAccumulator::new(
163            array_agg_acc,
164            delimiter,
165        )))
166    }
167
168    fn documentation(&self) -> Option<&Documentation> {
169        self.doc()
170    }
171}
172
173#[derive(Debug)]
174pub(crate) struct StringAggAccumulator {
175    array_agg_acc: Box<dyn Accumulator>,
176    delimiter: String,
177}
178
179impl StringAggAccumulator {
180    pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
181        Self {
182            array_agg_acc,
183            delimiter: delimiter.to_string(),
184        }
185    }
186}
187
188impl Accumulator for StringAggAccumulator {
189    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
190        self.array_agg_acc.update_batch(&filter_index(values, 1))
191    }
192
193    fn evaluate(&mut self) -> Result<ScalarValue> {
194        let scalar = self.array_agg_acc.evaluate()?;
195
196        let ScalarValue::List(list) = scalar else {
197            return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
198        };
199
200        let string_arr: Vec<_> = match list.value_type() {
201            DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
202                .iter()
203                .flatten()
204                .collect(),
205            DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
206                .iter()
207                .flatten()
208                .collect(),
209            _ => {
210                return internal_err!(
211                    "Expected elements to of type Utf8 or LargeUtf8, but got {}",
212                    list.value_type()
213                )
214            }
215        };
216
217        if string_arr.is_empty() {
218            return Ok(ScalarValue::LargeUtf8(None));
219        }
220
221        Ok(ScalarValue::LargeUtf8(Some(
222            string_arr.join(&self.delimiter),
223        )))
224    }
225
226    fn size(&self) -> usize {
227        size_of_val(self) - size_of_val(&self.array_agg_acc)
228            + self.array_agg_acc.size()
229            + self.delimiter.capacity()
230    }
231
232    fn state(&mut self) -> Result<Vec<ScalarValue>> {
233        self.array_agg_acc.state()
234    }
235
236    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
237        self.array_agg_acc.merge_batch(values)
238    }
239}
240
241fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
242    values
243        .iter()
244        .enumerate()
245        .filter(|(i, _)| *i != index)
246        .map(|(_, v)| v)
247        .cloned()
248        .collect::<Vec<_>>()
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use arrow::array::LargeStringArray;
255    use arrow::compute::SortOptions;
256    use arrow::datatypes::{Fields, Schema};
257    use datafusion_common::internal_err;
258    use datafusion_physical_expr::expressions::Column;
259    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
260    use std::sync::Arc;
261
262    #[test]
263    fn no_duplicates_no_distinct() -> Result<()> {
264        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
265
266        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
267        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
268        acc1 = merge(acc1, acc2)?;
269
270        let result = some_str(acc1.evaluate()?);
271
272        assert_eq!(result, "a,b,c,d,e,f");
273
274        Ok(())
275    }
276
277    #[test]
278    fn no_duplicates_distinct() -> Result<()> {
279        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
280            .distinct()
281            .build_two()?;
282
283        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
284        acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
285        acc1 = merge(acc1, acc2)?;
286
287        let result = some_str_sorted(acc1.evaluate()?, ",");
288
289        assert_eq!(result, "a,b,c,d,e,f");
290
291        Ok(())
292    }
293
294    #[test]
295    fn duplicates_no_distinct() -> Result<()> {
296        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
297
298        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
299        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
300        acc1 = merge(acc1, acc2)?;
301
302        let result = some_str(acc1.evaluate()?);
303
304        assert_eq!(result, "a,b,c,a,b,c");
305
306        Ok(())
307    }
308
309    #[test]
310    fn duplicates_distinct() -> Result<()> {
311        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
312            .distinct()
313            .build_two()?;
314
315        acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
316        acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
317        acc1 = merge(acc1, acc2)?;
318
319        let result = some_str_sorted(acc1.evaluate()?, ",");
320
321        assert_eq!(result, "a,b,c");
322
323        Ok(())
324    }
325
326    #[test]
327    fn no_duplicates_distinct_sort_asc() -> Result<()> {
328        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
329            .distinct()
330            .order_by_col("col", SortOptions::new(false, false))
331            .build_two()?;
332
333        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
334        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
335        acc1 = merge(acc1, acc2)?;
336
337        let result = some_str(acc1.evaluate()?);
338
339        assert_eq!(result, "a,b,c,d,e,f");
340
341        Ok(())
342    }
343
344    #[test]
345    fn no_duplicates_distinct_sort_desc() -> Result<()> {
346        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
347            .distinct()
348            .order_by_col("col", SortOptions::new(true, false))
349            .build_two()?;
350
351        acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
352        acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
353        acc1 = merge(acc1, acc2)?;
354
355        let result = some_str(acc1.evaluate()?);
356
357        assert_eq!(result, "f,e,d,c,b,a");
358
359        Ok(())
360    }
361
362    #[test]
363    fn duplicates_distinct_sort_asc() -> Result<()> {
364        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
365            .distinct()
366            .order_by_col("col", SortOptions::new(false, false))
367            .build_two()?;
368
369        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
370        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
371        acc1 = merge(acc1, acc2)?;
372
373        let result = some_str(acc1.evaluate()?);
374
375        assert_eq!(result, "a,b,c");
376
377        Ok(())
378    }
379
380    #[test]
381    fn duplicates_distinct_sort_desc() -> Result<()> {
382        let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
383            .distinct()
384            .order_by_col("col", SortOptions::new(true, false))
385            .build_two()?;
386
387        acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
388        acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
389        acc1 = merge(acc1, acc2)?;
390
391        let result = some_str(acc1.evaluate()?);
392
393        assert_eq!(result, "c,b,a");
394
395        Ok(())
396    }
397
398    struct StringAggAccumulatorBuilder {
399        sep: String,
400        distinct: bool,
401        ordering: LexOrdering,
402        schema: Schema,
403    }
404
405    impl StringAggAccumulatorBuilder {
406        fn new(sep: &str) -> Self {
407            Self {
408                sep: sep.to_string(),
409                distinct: Default::default(),
410                ordering: Default::default(),
411                schema: Schema {
412                    fields: Fields::from(vec![Field::new(
413                        "col",
414                        DataType::LargeUtf8,
415                        true,
416                    )]),
417                    metadata: Default::default(),
418                },
419            }
420        }
421        fn distinct(mut self) -> Self {
422            self.distinct = true;
423            self
424        }
425
426        fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
427            self.ordering.extend([PhysicalSortExpr::new(
428                Arc::new(
429                    Column::new_with_schema(col, &self.schema)
430                        .expect("column not available in schema"),
431                ),
432                sort_options,
433            )]);
434            self
435        }
436
437        fn build(&self) -> Result<Box<dyn Accumulator>> {
438            StringAgg::new().accumulator(AccumulatorArgs {
439                return_type: &DataType::LargeUtf8,
440                schema: &self.schema,
441                ignore_nulls: false,
442                ordering_req: &self.ordering,
443                is_reversed: false,
444                name: "",
445                is_distinct: self.distinct,
446                exprs: &[
447                    Arc::new(Column::new("col", 0)),
448                    Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
449                ],
450            })
451        }
452
453        fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
454            Ok((self.build()?, self.build()?))
455        }
456    }
457
458    fn some_str(value: ScalarValue) -> String {
459        str(value)
460            .expect("ScalarValue was not a String")
461            .expect("ScalarValue was None")
462    }
463
464    fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
465        let value = some_str(value);
466        let mut parts: Vec<&str> = value.split(sep).collect();
467        parts.sort();
468        parts.join(sep)
469    }
470
471    fn str(value: ScalarValue) -> Result<Option<String>> {
472        match value {
473            ScalarValue::LargeUtf8(v) => Ok(v),
474            _ => internal_err!(
475                "Expected ScalarValue::LargeUtf8, got {}",
476                value.data_type()
477            ),
478        }
479    }
480
481    fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
482        Arc::new(LargeStringArray::from(list.to_vec()))
483    }
484
485    fn merge(
486        mut acc1: Box<dyn Accumulator>,
487        mut acc2: Box<dyn Accumulator>,
488    ) -> Result<Box<dyn Accumulator>> {
489        let intermediate_state = acc2.state().and_then(|e| {
490            e.iter()
491                .map(|v| v.to_array())
492                .collect::<Result<Vec<ArrayRef>>>()
493        })?;
494        acc1.merge_batch(&intermediate_state)?;
495        Ok(acc1)
496    }
497}