1use crate::array_agg::ArrayAgg;
21use arrow::array::ArrayRef;
22use arrow::datatypes::{DataType, Field};
23use datafusion_common::cast::as_generic_string_array;
24use datafusion_common::Result;
25use datafusion_common::{internal_err, not_impl_err, ScalarValue};
26use datafusion_expr::function::AccumulatorArgs;
27use datafusion_expr::{
28 Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
29};
30use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
31use datafusion_macros::user_doc;
32use datafusion_physical_expr::expressions::Literal;
33use std::any::Any;
34use std::mem::size_of_val;
35
36make_udaf_expr_and_func!(
37 StringAgg,
38 string_agg,
39 expr delimiter,
40 "Concatenates the values of string expressions and places separator values between them",
41 string_agg_udaf
42);
43
44#[user_doc(
45 doc_section(label = "General Functions"),
46 description = "Concatenates the values of string expressions and places separator values between them. \
47If ordering is required, strings are concatenated in the specified order. \
48This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
49 syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
50 sql_example = r#"```sql
51> SELECT string_agg(name, ', ') AS names_list
52 FROM employee;
53+--------------------------+
54| names_list |
55+--------------------------+
56| Alice, Bob, Bob, Charlie |
57+--------------------------+
58> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
59 FROM employee;
60+--------------------------+
61| names_list |
62+--------------------------+
63| Charlie, Bob, Bob, Alice |
64+--------------------------+
65> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
66 FROM employee;
67+--------------------------+
68| names_list |
69+--------------------------+
70| Charlie, Bob, Alice |
71+--------------------------+
72```"#,
73 argument(
74 name = "expression",
75 description = "The string expression to concatenate. Can be a column or any valid string expression."
76 ),
77 argument(
78 name = "delimiter",
79 description = "A literal string used as a separator between the concatenated values."
80 )
81)]
82#[derive(Debug)]
84pub struct StringAgg {
85 signature: Signature,
86 array_agg: ArrayAgg,
87}
88
89impl StringAgg {
90 pub fn new() -> Self {
92 Self {
93 signature: Signature::one_of(
94 vec![
95 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
96 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
97 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
98 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
99 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
100 TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
101 ],
102 Volatility::Immutable,
103 ),
104 array_agg: Default::default(),
105 }
106 }
107}
108
109impl Default for StringAgg {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl AggregateUDFImpl for StringAgg {
116 fn as_any(&self) -> &dyn Any {
117 self
118 }
119
120 fn name(&self) -> &str {
121 "string_agg"
122 }
123
124 fn signature(&self) -> &Signature {
125 &self.signature
126 }
127
128 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
129 Ok(DataType::LargeUtf8)
130 }
131
132 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
133 self.array_agg.state_fields(args)
134 }
135
136 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
137 let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
138 return not_impl_err!(
139 "The second argument of the string_agg function must be a string literal"
140 );
141 };
142
143 let delimiter = if lit.value().is_null() {
144 ""
147 } else if let Some(lit_string) = lit.value().try_as_str() {
148 lit_string.unwrap_or("")
149 } else {
150 return not_impl_err!(
151 "StringAgg not supported for delimiter \"{}\"",
152 lit.value()
153 );
154 };
155
156 let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
157 return_type: &DataType::new_list(acc_args.return_type.clone(), true),
158 exprs: &filter_index(acc_args.exprs, 1),
159 ..acc_args
160 })?;
161
162 Ok(Box::new(StringAggAccumulator::new(
163 array_agg_acc,
164 delimiter,
165 )))
166 }
167
168 fn documentation(&self) -> Option<&Documentation> {
169 self.doc()
170 }
171}
172
173#[derive(Debug)]
174pub(crate) struct StringAggAccumulator {
175 array_agg_acc: Box<dyn Accumulator>,
176 delimiter: String,
177}
178
179impl StringAggAccumulator {
180 pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
181 Self {
182 array_agg_acc,
183 delimiter: delimiter.to_string(),
184 }
185 }
186}
187
188impl Accumulator for StringAggAccumulator {
189 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
190 self.array_agg_acc.update_batch(&filter_index(values, 1))
191 }
192
193 fn evaluate(&mut self) -> Result<ScalarValue> {
194 let scalar = self.array_agg_acc.evaluate()?;
195
196 let ScalarValue::List(list) = scalar else {
197 return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
198 };
199
200 let string_arr: Vec<_> = match list.value_type() {
201 DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
202 .iter()
203 .flatten()
204 .collect(),
205 DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
206 .iter()
207 .flatten()
208 .collect(),
209 _ => {
210 return internal_err!(
211 "Expected elements to of type Utf8 or LargeUtf8, but got {}",
212 list.value_type()
213 )
214 }
215 };
216
217 if string_arr.is_empty() {
218 return Ok(ScalarValue::LargeUtf8(None));
219 }
220
221 Ok(ScalarValue::LargeUtf8(Some(
222 string_arr.join(&self.delimiter),
223 )))
224 }
225
226 fn size(&self) -> usize {
227 size_of_val(self) - size_of_val(&self.array_agg_acc)
228 + self.array_agg_acc.size()
229 + self.delimiter.capacity()
230 }
231
232 fn state(&mut self) -> Result<Vec<ScalarValue>> {
233 self.array_agg_acc.state()
234 }
235
236 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
237 self.array_agg_acc.merge_batch(values)
238 }
239}
240
241fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
242 values
243 .iter()
244 .enumerate()
245 .filter(|(i, _)| *i != index)
246 .map(|(_, v)| v)
247 .cloned()
248 .collect::<Vec<_>>()
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use arrow::array::LargeStringArray;
255 use arrow::compute::SortOptions;
256 use arrow::datatypes::{Fields, Schema};
257 use datafusion_common::internal_err;
258 use datafusion_physical_expr::expressions::Column;
259 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
260 use std::sync::Arc;
261
262 #[test]
263 fn no_duplicates_no_distinct() -> Result<()> {
264 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
265
266 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
267 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
268 acc1 = merge(acc1, acc2)?;
269
270 let result = some_str(acc1.evaluate()?);
271
272 assert_eq!(result, "a,b,c,d,e,f");
273
274 Ok(())
275 }
276
277 #[test]
278 fn no_duplicates_distinct() -> Result<()> {
279 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
280 .distinct()
281 .build_two()?;
282
283 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
284 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
285 acc1 = merge(acc1, acc2)?;
286
287 let result = some_str_sorted(acc1.evaluate()?, ",");
288
289 assert_eq!(result, "a,b,c,d,e,f");
290
291 Ok(())
292 }
293
294 #[test]
295 fn duplicates_no_distinct() -> Result<()> {
296 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
297
298 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
299 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
300 acc1 = merge(acc1, acc2)?;
301
302 let result = some_str(acc1.evaluate()?);
303
304 assert_eq!(result, "a,b,c,a,b,c");
305
306 Ok(())
307 }
308
309 #[test]
310 fn duplicates_distinct() -> Result<()> {
311 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
312 .distinct()
313 .build_two()?;
314
315 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
316 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
317 acc1 = merge(acc1, acc2)?;
318
319 let result = some_str_sorted(acc1.evaluate()?, ",");
320
321 assert_eq!(result, "a,b,c");
322
323 Ok(())
324 }
325
326 #[test]
327 fn no_duplicates_distinct_sort_asc() -> Result<()> {
328 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
329 .distinct()
330 .order_by_col("col", SortOptions::new(false, false))
331 .build_two()?;
332
333 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
334 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
335 acc1 = merge(acc1, acc2)?;
336
337 let result = some_str(acc1.evaluate()?);
338
339 assert_eq!(result, "a,b,c,d,e,f");
340
341 Ok(())
342 }
343
344 #[test]
345 fn no_duplicates_distinct_sort_desc() -> Result<()> {
346 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
347 .distinct()
348 .order_by_col("col", SortOptions::new(true, false))
349 .build_two()?;
350
351 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
352 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
353 acc1 = merge(acc1, acc2)?;
354
355 let result = some_str(acc1.evaluate()?);
356
357 assert_eq!(result, "f,e,d,c,b,a");
358
359 Ok(())
360 }
361
362 #[test]
363 fn duplicates_distinct_sort_asc() -> Result<()> {
364 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
365 .distinct()
366 .order_by_col("col", SortOptions::new(false, false))
367 .build_two()?;
368
369 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
370 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
371 acc1 = merge(acc1, acc2)?;
372
373 let result = some_str(acc1.evaluate()?);
374
375 assert_eq!(result, "a,b,c");
376
377 Ok(())
378 }
379
380 #[test]
381 fn duplicates_distinct_sort_desc() -> Result<()> {
382 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
383 .distinct()
384 .order_by_col("col", SortOptions::new(true, false))
385 .build_two()?;
386
387 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
388 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
389 acc1 = merge(acc1, acc2)?;
390
391 let result = some_str(acc1.evaluate()?);
392
393 assert_eq!(result, "c,b,a");
394
395 Ok(())
396 }
397
398 struct StringAggAccumulatorBuilder {
399 sep: String,
400 distinct: bool,
401 ordering: LexOrdering,
402 schema: Schema,
403 }
404
405 impl StringAggAccumulatorBuilder {
406 fn new(sep: &str) -> Self {
407 Self {
408 sep: sep.to_string(),
409 distinct: Default::default(),
410 ordering: Default::default(),
411 schema: Schema {
412 fields: Fields::from(vec![Field::new(
413 "col",
414 DataType::LargeUtf8,
415 true,
416 )]),
417 metadata: Default::default(),
418 },
419 }
420 }
421 fn distinct(mut self) -> Self {
422 self.distinct = true;
423 self
424 }
425
426 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
427 self.ordering.extend([PhysicalSortExpr::new(
428 Arc::new(
429 Column::new_with_schema(col, &self.schema)
430 .expect("column not available in schema"),
431 ),
432 sort_options,
433 )]);
434 self
435 }
436
437 fn build(&self) -> Result<Box<dyn Accumulator>> {
438 StringAgg::new().accumulator(AccumulatorArgs {
439 return_type: &DataType::LargeUtf8,
440 schema: &self.schema,
441 ignore_nulls: false,
442 ordering_req: &self.ordering,
443 is_reversed: false,
444 name: "",
445 is_distinct: self.distinct,
446 exprs: &[
447 Arc::new(Column::new("col", 0)),
448 Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
449 ],
450 })
451 }
452
453 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
454 Ok((self.build()?, self.build()?))
455 }
456 }
457
458 fn some_str(value: ScalarValue) -> String {
459 str(value)
460 .expect("ScalarValue was not a String")
461 .expect("ScalarValue was None")
462 }
463
464 fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
465 let value = some_str(value);
466 let mut parts: Vec<&str> = value.split(sep).collect();
467 parts.sort();
468 parts.join(sep)
469 }
470
471 fn str(value: ScalarValue) -> Result<Option<String>> {
472 match value {
473 ScalarValue::LargeUtf8(v) => Ok(v),
474 _ => internal_err!(
475 "Expected ScalarValue::LargeUtf8, got {}",
476 value.data_type()
477 ),
478 }
479 }
480
481 fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
482 Arc::new(LargeStringArray::from(list.to_vec()))
483 }
484
485 fn merge(
486 mut acc1: Box<dyn Accumulator>,
487 mut acc2: Box<dyn Accumulator>,
488 ) -> Result<Box<dyn Accumulator>> {
489 let intermediate_state = acc2.state().and_then(|e| {
490 e.iter()
491 .map(|v| v.to_array())
492 .collect::<Result<Vec<ArrayRef>>>()
493 })?;
494 acc1.merge_batch(&intermediate_state)?;
495 Ok(acc1)
496 }
497}