1use crate::array_agg::ArrayAgg;
21use arrow::array::ArrayRef;
22use arrow::datatypes::{DataType, Field, FieldRef};
23use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
24use datafusion_common::Result;
25use datafusion_common::{internal_err, not_impl_err, ScalarValue};
26use datafusion_expr::function::AccumulatorArgs;
27use datafusion_expr::{
28 Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
29};
30use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
31use datafusion_macros::user_doc;
32use datafusion_physical_expr::expressions::Literal;
33use std::any::Any;
34use std::mem::size_of_val;
35
36make_udaf_expr_and_func!(
37 StringAgg,
38 string_agg,
39 expr delimiter,
40 "Concatenates the values of string expressions and places separator values between them",
41 string_agg_udaf
42);
43
44#[user_doc(
45 doc_section(label = "General Functions"),
46 description = "Concatenates the values of string expressions and places separator values between them. \
47If ordering is required, strings are concatenated in the specified order. \
48This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
49 syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
50 sql_example = r#"```sql
51> SELECT string_agg(name, ', ') AS names_list
52 FROM employee;
53+--------------------------+
54| names_list |
55+--------------------------+
56| Alice, Bob, Bob, Charlie |
57+--------------------------+
58> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
59 FROM employee;
60+--------------------------+
61| names_list |
62+--------------------------+
63| Charlie, Bob, Bob, Alice |
64+--------------------------+
65> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
66 FROM employee;
67+--------------------------+
68| names_list |
69+--------------------------+
70| Charlie, Bob, Alice |
71+--------------------------+
72```"#,
73 argument(
74 name = "expression",
75 description = "The string expression to concatenate. Can be a column or any valid string expression."
76 ),
77 argument(
78 name = "delimiter",
79 description = "A literal string used as a separator between the concatenated values."
80 )
81)]
82#[derive(Debug)]
84pub struct StringAgg {
85 signature: Signature,
86 array_agg: ArrayAgg,
87}
88
89impl StringAgg {
90 pub fn new() -> Self {
92 Self {
93 signature: Signature::one_of(
94 vec![
95 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
96 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
97 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
98 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
99 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
100 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
101 TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
102 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
103 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
104 TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
105 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
106 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
107 ],
108 Volatility::Immutable,
109 ),
110 array_agg: Default::default(),
111 }
112 }
113}
114
115impl Default for StringAgg {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121impl AggregateUDFImpl for StringAgg {
122 fn as_any(&self) -> &dyn Any {
123 self
124 }
125
126 fn name(&self) -> &str {
127 "string_agg"
128 }
129
130 fn signature(&self) -> &Signature {
131 &self.signature
132 }
133
134 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
135 Ok(DataType::LargeUtf8)
136 }
137
138 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
139 self.array_agg.state_fields(args)
140 }
141
142 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
143 let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
144 return not_impl_err!(
145 "The second argument of the string_agg function must be a string literal"
146 );
147 };
148
149 let delimiter = if lit.value().is_null() {
150 ""
153 } else if let Some(lit_string) = lit.value().try_as_str() {
154 lit_string.unwrap_or("")
155 } else {
156 return not_impl_err!(
157 "StringAgg not supported for delimiter \"{}\"",
158 lit.value()
159 );
160 };
161
162 let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
163 return_field: Field::new(
164 "f",
165 DataType::new_list(acc_args.return_field.data_type().clone(), true),
166 true,
167 )
168 .into(),
169 exprs: &filter_index(acc_args.exprs, 1),
170 ..acc_args
171 })?;
172
173 Ok(Box::new(StringAggAccumulator::new(
174 array_agg_acc,
175 delimiter,
176 )))
177 }
178
179 fn documentation(&self) -> Option<&Documentation> {
180 self.doc()
181 }
182}
183
184#[derive(Debug)]
185pub(crate) struct StringAggAccumulator {
186 array_agg_acc: Box<dyn Accumulator>,
187 delimiter: String,
188}
189
190impl StringAggAccumulator {
191 pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
192 Self {
193 array_agg_acc,
194 delimiter: delimiter.to_string(),
195 }
196 }
197}
198
199impl Accumulator for StringAggAccumulator {
200 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
201 self.array_agg_acc.update_batch(&filter_index(values, 1))
202 }
203
204 fn evaluate(&mut self) -> Result<ScalarValue> {
205 let scalar = self.array_agg_acc.evaluate()?;
206
207 let ScalarValue::List(list) = scalar else {
208 return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
209 };
210
211 let string_arr: Vec<_> = match list.value_type() {
212 DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
213 .iter()
214 .flatten()
215 .collect(),
216 DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
217 .iter()
218 .flatten()
219 .collect(),
220 DataType::Utf8View => as_string_view_array(list.values())?
221 .iter()
222 .flatten()
223 .collect(),
224 _ => {
225 return internal_err!(
226 "Expected elements to of type Utf8 or LargeUtf8, but got {}",
227 list.value_type()
228 )
229 }
230 };
231
232 if string_arr.is_empty() {
233 return Ok(ScalarValue::LargeUtf8(None));
234 }
235
236 Ok(ScalarValue::LargeUtf8(Some(
237 string_arr.join(&self.delimiter),
238 )))
239 }
240
241 fn size(&self) -> usize {
242 size_of_val(self) - size_of_val(&self.array_agg_acc)
243 + self.array_agg_acc.size()
244 + self.delimiter.capacity()
245 }
246
247 fn state(&mut self) -> Result<Vec<ScalarValue>> {
248 self.array_agg_acc.state()
249 }
250
251 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
252 self.array_agg_acc.merge_batch(values)
253 }
254}
255
256fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
257 values
258 .iter()
259 .enumerate()
260 .filter(|(i, _)| *i != index)
261 .map(|(_, v)| v)
262 .cloned()
263 .collect::<Vec<_>>()
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use arrow::array::LargeStringArray;
270 use arrow::compute::SortOptions;
271 use arrow::datatypes::{Fields, Schema};
272 use datafusion_common::internal_err;
273 use datafusion_physical_expr::expressions::Column;
274 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
275 use std::sync::Arc;
276
277 #[test]
278 fn no_duplicates_no_distinct() -> Result<()> {
279 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
280
281 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
282 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
283 acc1 = merge(acc1, acc2)?;
284
285 let result = some_str(acc1.evaluate()?);
286
287 assert_eq!(result, "a,b,c,d,e,f");
288
289 Ok(())
290 }
291
292 #[test]
293 fn no_duplicates_distinct() -> Result<()> {
294 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
295 .distinct()
296 .build_two()?;
297
298 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
299 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
300 acc1 = merge(acc1, acc2)?;
301
302 let result = some_str_sorted(acc1.evaluate()?, ",");
303
304 assert_eq!(result, "a,b,c,d,e,f");
305
306 Ok(())
307 }
308
309 #[test]
310 fn duplicates_no_distinct() -> Result<()> {
311 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
312
313 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
314 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
315 acc1 = merge(acc1, acc2)?;
316
317 let result = some_str(acc1.evaluate()?);
318
319 assert_eq!(result, "a,b,c,a,b,c");
320
321 Ok(())
322 }
323
324 #[test]
325 fn duplicates_distinct() -> Result<()> {
326 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
327 .distinct()
328 .build_two()?;
329
330 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
331 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
332 acc1 = merge(acc1, acc2)?;
333
334 let result = some_str_sorted(acc1.evaluate()?, ",");
335
336 assert_eq!(result, "a,b,c");
337
338 Ok(())
339 }
340
341 #[test]
342 fn no_duplicates_distinct_sort_asc() -> Result<()> {
343 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
344 .distinct()
345 .order_by_col("col", SortOptions::new(false, false))
346 .build_two()?;
347
348 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
349 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
350 acc1 = merge(acc1, acc2)?;
351
352 let result = some_str(acc1.evaluate()?);
353
354 assert_eq!(result, "a,b,c,d,e,f");
355
356 Ok(())
357 }
358
359 #[test]
360 fn no_duplicates_distinct_sort_desc() -> Result<()> {
361 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
362 .distinct()
363 .order_by_col("col", SortOptions::new(true, false))
364 .build_two()?;
365
366 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
367 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
368 acc1 = merge(acc1, acc2)?;
369
370 let result = some_str(acc1.evaluate()?);
371
372 assert_eq!(result, "f,e,d,c,b,a");
373
374 Ok(())
375 }
376
377 #[test]
378 fn duplicates_distinct_sort_asc() -> Result<()> {
379 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
380 .distinct()
381 .order_by_col("col", SortOptions::new(false, false))
382 .build_two()?;
383
384 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
385 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
386 acc1 = merge(acc1, acc2)?;
387
388 let result = some_str(acc1.evaluate()?);
389
390 assert_eq!(result, "a,b,c");
391
392 Ok(())
393 }
394
395 #[test]
396 fn duplicates_distinct_sort_desc() -> Result<()> {
397 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
398 .distinct()
399 .order_by_col("col", SortOptions::new(true, false))
400 .build_two()?;
401
402 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
403 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
404 acc1 = merge(acc1, acc2)?;
405
406 let result = some_str(acc1.evaluate()?);
407
408 assert_eq!(result, "c,b,a");
409
410 Ok(())
411 }
412
413 struct StringAggAccumulatorBuilder {
414 sep: String,
415 distinct: bool,
416 ordering: LexOrdering,
417 schema: Schema,
418 }
419
420 impl StringAggAccumulatorBuilder {
421 fn new(sep: &str) -> Self {
422 Self {
423 sep: sep.to_string(),
424 distinct: Default::default(),
425 ordering: Default::default(),
426 schema: Schema {
427 fields: Fields::from(vec![Field::new(
428 "col",
429 DataType::LargeUtf8,
430 true,
431 )]),
432 metadata: Default::default(),
433 },
434 }
435 }
436 fn distinct(mut self) -> Self {
437 self.distinct = true;
438 self
439 }
440
441 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
442 self.ordering.extend([PhysicalSortExpr::new(
443 Arc::new(
444 Column::new_with_schema(col, &self.schema)
445 .expect("column not available in schema"),
446 ),
447 sort_options,
448 )]);
449 self
450 }
451
452 fn build(&self) -> Result<Box<dyn Accumulator>> {
453 StringAgg::new().accumulator(AccumulatorArgs {
454 return_field: Field::new("f", DataType::LargeUtf8, true).into(),
455 schema: &self.schema,
456 ignore_nulls: false,
457 ordering_req: &self.ordering,
458 is_reversed: false,
459 name: "",
460 is_distinct: self.distinct,
461 exprs: &[
462 Arc::new(Column::new("col", 0)),
463 Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
464 ],
465 })
466 }
467
468 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
469 Ok((self.build()?, self.build()?))
470 }
471 }
472
473 fn some_str(value: ScalarValue) -> String {
474 str(value)
475 .expect("ScalarValue was not a String")
476 .expect("ScalarValue was None")
477 }
478
479 fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
480 let value = some_str(value);
481 let mut parts: Vec<&str> = value.split(sep).collect();
482 parts.sort();
483 parts.join(sep)
484 }
485
486 fn str(value: ScalarValue) -> Result<Option<String>> {
487 match value {
488 ScalarValue::LargeUtf8(v) => Ok(v),
489 _ => internal_err!(
490 "Expected ScalarValue::LargeUtf8, got {}",
491 value.data_type()
492 ),
493 }
494 }
495
496 fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
497 Arc::new(LargeStringArray::from(list.to_vec()))
498 }
499
500 fn merge(
501 mut acc1: Box<dyn Accumulator>,
502 mut acc2: Box<dyn Accumulator>,
503 ) -> Result<Box<dyn Accumulator>> {
504 let intermediate_state = acc2.state().and_then(|e| {
505 e.iter()
506 .map(|v| v.to_array())
507 .collect::<Result<Vec<ArrayRef>>>()
508 })?;
509 acc1.merge_batch(&intermediate_state)?;
510 Ok(acc1)
511 }
512}