datafusion_functions_aggregate/
string_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`StringAgg`] accumulator for the `string_agg` function
19
20use arrow::array::ArrayRef;
21use arrow::datatypes::DataType;
22use datafusion_common::cast::as_generic_string_array;
23use datafusion_common::Result;
24use datafusion_common::{not_impl_err, ScalarValue};
25use datafusion_expr::function::AccumulatorArgs;
26use datafusion_expr::{
27    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
28};
29use datafusion_macros::user_doc;
30use datafusion_physical_expr::expressions::Literal;
31use std::any::Any;
32use std::mem::size_of_val;
33
34make_udaf_expr_and_func!(
35    StringAgg,
36    string_agg,
37    expr delimiter,
38    "Concatenates the values of string expressions and places separator values between them",
39    string_agg_udaf
40);
41
42#[user_doc(
43    doc_section(label = "General Functions"),
44    description = "Concatenates the values of string expressions and places separator values between them.",
45    syntax_example = "string_agg(expression, delimiter)",
46    sql_example = r#"```sql
47> SELECT string_agg(name, ', ') AS names_list
48  FROM employee;
49+--------------------------+
50| names_list               |
51+--------------------------+
52| Alice, Bob, Charlie      |
53+--------------------------+
54```"#,
55    argument(
56        name = "expression",
57        description = "The string expression to concatenate. Can be a column or any valid string expression."
58    ),
59    argument(
60        name = "delimiter",
61        description = "A literal string used as a separator between the concatenated values."
62    )
63)]
64/// STRING_AGG aggregate expression
65#[derive(Debug)]
66pub struct StringAgg {
67    signature: Signature,
68}
69
70impl StringAgg {
71    /// Create a new StringAgg aggregate function
72    pub fn new() -> Self {
73        Self {
74            signature: Signature::one_of(
75                vec![
76                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
77                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
78                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
79                ],
80                Volatility::Immutable,
81            ),
82        }
83    }
84}
85
86impl Default for StringAgg {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl AggregateUDFImpl for StringAgg {
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96
97    fn name(&self) -> &str {
98        "string_agg"
99    }
100
101    fn signature(&self) -> &Signature {
102        &self.signature
103    }
104
105    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
106        Ok(DataType::LargeUtf8)
107    }
108
109    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
110        if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() {
111            return match lit.value().try_as_str() {
112                Some(Some(delimiter)) => {
113                    Ok(Box::new(StringAggAccumulator::new(delimiter)))
114                }
115                Some(None) => Ok(Box::new(StringAggAccumulator::new(""))),
116                None => {
117                    not_impl_err!("StringAgg not supported for delimiter {}", lit.value())
118                }
119            };
120        }
121
122        not_impl_err!("expect literal")
123    }
124
125    fn documentation(&self) -> Option<&Documentation> {
126        self.doc()
127    }
128}
129
130#[derive(Debug)]
131pub(crate) struct StringAggAccumulator {
132    values: Option<String>,
133    delimiter: String,
134}
135
136impl StringAggAccumulator {
137    pub fn new(delimiter: &str) -> Self {
138        Self {
139            values: None,
140            delimiter: delimiter.to_string(),
141        }
142    }
143}
144
145impl Accumulator for StringAggAccumulator {
146    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
147        let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
148            .iter()
149            .filter_map(|v| v.as_ref().map(ToString::to_string))
150            .collect();
151        if !string_array.is_empty() {
152            let s = string_array.join(self.delimiter.as_str());
153            let v = self.values.get_or_insert("".to_string());
154            if !v.is_empty() {
155                v.push_str(self.delimiter.as_str());
156            }
157            v.push_str(s.as_str());
158        }
159        Ok(())
160    }
161
162    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
163        self.update_batch(values)?;
164        Ok(())
165    }
166
167    fn state(&mut self) -> Result<Vec<ScalarValue>> {
168        Ok(vec![self.evaluate()?])
169    }
170
171    fn evaluate(&mut self) -> Result<ScalarValue> {
172        Ok(ScalarValue::LargeUtf8(self.values.clone()))
173    }
174
175    fn size(&self) -> usize {
176        size_of_val(self)
177            + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
178            + self.delimiter.capacity()
179    }
180}