1use std::any::Any;
21use std::hash::Hash;
22use std::mem::size_of_val;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{
29 as_generic_string_array, as_string_array, as_string_view_array,
30};
31use datafusion_common::{
32 Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err,
33};
34use datafusion_expr::function::AccumulatorArgs;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
38};
39use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
40use datafusion_macros::user_doc;
41use datafusion_physical_expr::expressions::Literal;
42
43make_udaf_expr_and_func!(
44 StringAgg,
45 string_agg,
46 expr delimiter,
47 "Concatenates the values of string expressions and places separator values between them",
48 string_agg_udaf
49);
50
51#[user_doc(
52 doc_section(label = "General Functions"),
53 description = "Concatenates the values of string expressions and places separator values between them. \
54If ordering is required, strings are concatenated in the specified order. \
55This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
56 syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
57 sql_example = r#"```sql
58> SELECT string_agg(name, ', ') AS names_list
59 FROM employee;
60+--------------------------+
61| names_list |
62+--------------------------+
63| Alice, Bob, Bob, Charlie |
64+--------------------------+
65> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
66 FROM employee;
67+--------------------------+
68| names_list |
69+--------------------------+
70| Charlie, Bob, Bob, Alice |
71+--------------------------+
72> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
73 FROM employee;
74+--------------------------+
75| names_list |
76+--------------------------+
77| Charlie, Bob, Alice |
78+--------------------------+
79```"#,
80 argument(
81 name = "expression",
82 description = "The string expression to concatenate. Can be a column or any valid string expression."
83 ),
84 argument(
85 name = "delimiter",
86 description = "A literal string used as a separator between the concatenated values."
87 )
88)]
89#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct StringAgg {
92 signature: Signature,
93 array_agg: ArrayAgg,
94}
95
96impl StringAgg {
97 pub fn new() -> Self {
99 Self {
100 signature: Signature::one_of(
101 vec![
102 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
103 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
104 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
105 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
106 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
107 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
108 TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
109 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
110 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
111 TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
112 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
113 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
114 ],
115 Volatility::Immutable,
116 ),
117 array_agg: Default::default(),
118 }
119 }
120}
121
122impl Default for StringAgg {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl AggregateUDFImpl for StringAgg {
131 fn as_any(&self) -> &dyn Any {
132 self
133 }
134
135 fn name(&self) -> &str {
136 "string_agg"
137 }
138
139 fn signature(&self) -> &Signature {
140 &self.signature
141 }
142
143 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
144 Ok(DataType::LargeUtf8)
145 }
146
147 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
148 let no_order_no_distinct =
150 (args.ordering_fields.is_empty()) && (!args.is_distinct);
151 if no_order_no_distinct {
152 Ok(vec![
154 Field::new(
155 format_state_name(args.name, "string_agg"),
156 DataType::LargeUtf8,
157 true,
158 )
159 .into(),
160 ])
161 } else {
162 self.array_agg.state_fields(args)
164 }
165 }
166
167 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
168 let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
169 return not_impl_err!(
170 "The second argument of the string_agg function must be a string literal"
171 );
172 };
173
174 let delimiter = if lit.value().is_null() {
175 ""
178 } else if let Some(lit_string) = lit.value().try_as_str() {
179 lit_string.unwrap_or("")
180 } else {
181 return not_impl_err!(
182 "StringAgg not supported for delimiter \"{}\"",
183 lit.value()
184 );
185 };
186
187 let no_order_no_distinct =
189 acc_args.order_bys.is_empty() && (!acc_args.is_distinct);
190
191 if no_order_no_distinct {
192 Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
194 } else {
195 let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
197 return_field: Field::new(
198 "f",
199 DataType::new_list(acc_args.return_field.data_type().clone(), true),
200 true,
201 )
202 .into(),
203 exprs: &filter_index(acc_args.exprs, 1),
204 expr_fields: &filter_index(acc_args.expr_fields, 1),
205 schema: acc_args.schema,
209 ignore_nulls: acc_args.ignore_nulls,
210 order_bys: acc_args.order_bys,
211 is_reversed: acc_args.is_reversed,
212 name: acc_args.name,
213 is_distinct: acc_args.is_distinct,
214 })?;
215
216 Ok(Box::new(StringAggAccumulator::new(
217 array_agg_acc,
218 delimiter,
219 )))
220 }
221 }
222
223 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
224 datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
225 }
226
227 fn documentation(&self) -> Option<&Documentation> {
228 self.doc()
229 }
230}
231
232#[derive(Debug)]
234pub(crate) struct StringAggAccumulator {
235 array_agg_acc: Box<dyn Accumulator>,
236 delimiter: String,
237}
238
239impl StringAggAccumulator {
240 pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
241 Self {
242 array_agg_acc,
243 delimiter: delimiter.to_string(),
244 }
245 }
246}
247
248impl Accumulator for StringAggAccumulator {
249 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
250 self.array_agg_acc.update_batch(&filter_index(values, 1))
251 }
252
253 fn evaluate(&mut self) -> Result<ScalarValue> {
254 let scalar = self.array_agg_acc.evaluate()?;
255
256 let ScalarValue::List(list) = scalar else {
257 return internal_err!(
258 "Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}",
259 scalar.data_type()
260 );
261 };
262
263 let string_arr: Vec<_> = match list.value_type() {
264 DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
265 .iter()
266 .flatten()
267 .collect(),
268 DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
269 .iter()
270 .flatten()
271 .collect(),
272 DataType::Utf8View => as_string_view_array(list.values())?
273 .iter()
274 .flatten()
275 .collect(),
276 _ => {
277 return internal_err!(
278 "Expected elements to of type Utf8 or LargeUtf8, but got {}",
279 list.value_type()
280 );
281 }
282 };
283
284 if string_arr.is_empty() {
285 return Ok(ScalarValue::LargeUtf8(None));
286 }
287
288 Ok(ScalarValue::LargeUtf8(Some(
289 string_arr.join(&self.delimiter),
290 )))
291 }
292
293 fn size(&self) -> usize {
294 size_of_val(self) - size_of_val(&self.array_agg_acc)
295 + self.array_agg_acc.size()
296 + self.delimiter.capacity()
297 }
298
299 fn state(&mut self) -> Result<Vec<ScalarValue>> {
300 self.array_agg_acc.state()
301 }
302
303 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
304 self.array_agg_acc.merge_batch(values)
305 }
306}
307
308fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
309 values
310 .iter()
311 .enumerate()
312 .filter(|(i, _)| *i != index)
313 .map(|(_, v)| v)
314 .cloned()
315 .collect::<Vec<_>>()
316}
317
318#[derive(Debug)]
323pub(crate) struct SimpleStringAggAccumulator {
324 delimiter: String,
325 accumulated_string: String,
327 has_value: bool,
328}
329
330impl SimpleStringAggAccumulator {
331 pub fn new(delimiter: &str) -> Self {
332 Self {
333 delimiter: delimiter.to_string(),
334 accumulated_string: "".to_string(),
335 has_value: false,
336 }
337 }
338
339 #[inline]
340 fn append_strings<'a, I>(&mut self, iter: I)
341 where
342 I: Iterator<Item = Option<&'a str>>,
343 {
344 for value in iter.flatten() {
345 if self.has_value {
346 self.accumulated_string.push_str(&self.delimiter);
347 }
348
349 self.accumulated_string.push_str(value);
350 self.has_value = true;
351 }
352 }
353}
354
355impl Accumulator for SimpleStringAggAccumulator {
356 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
357 let string_arr = values.first().ok_or_else(|| {
358 internal_datafusion_err!(
359 "Planner should ensure its first arg is Utf8/Utf8View"
360 )
361 })?;
362
363 match string_arr.data_type() {
364 DataType::Utf8 => {
365 let array = as_string_array(string_arr)?;
366 self.append_strings(array.iter());
367 }
368 DataType::LargeUtf8 => {
369 let array = as_generic_string_array::<i64>(string_arr)?;
370 self.append_strings(array.iter());
371 }
372 DataType::Utf8View => {
373 let array = as_string_view_array(string_arr)?;
374 self.append_strings(array.iter());
375 }
376 other => {
377 return internal_err!(
378 "Planner should ensure string_agg first argument is Utf8-like, found {other}"
379 );
380 }
381 }
382
383 Ok(())
384 }
385
386 fn evaluate(&mut self) -> Result<ScalarValue> {
387 let result = if self.has_value {
388 ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
389 } else {
390 ScalarValue::LargeUtf8(None)
391 };
392
393 self.has_value = false;
394 Ok(result)
395 }
396
397 fn size(&self) -> usize {
398 size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
399 }
400
401 fn state(&mut self) -> Result<Vec<ScalarValue>> {
402 let result = if self.has_value {
403 ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
404 } else {
405 ScalarValue::LargeUtf8(None)
406 };
407 self.has_value = false;
408
409 Ok(vec![result])
410 }
411
412 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
413 self.update_batch(values)
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use arrow::array::LargeStringArray;
421 use arrow::compute::SortOptions;
422 use arrow::datatypes::{Fields, Schema};
423 use datafusion_common::internal_err;
424 use datafusion_physical_expr::expressions::Column;
425 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
426 use std::sync::Arc;
427
428 #[test]
429 fn no_duplicates_no_distinct() -> Result<()> {
430 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
431
432 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
433 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
434 acc1 = merge(acc1, acc2)?;
435
436 let result = some_str(acc1.evaluate()?);
437
438 assert_eq!(result, "a,b,c,d,e,f");
439
440 Ok(())
441 }
442
443 #[test]
444 fn no_duplicates_distinct() -> Result<()> {
445 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
446 .distinct()
447 .build_two()?;
448
449 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
450 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
451 acc1 = merge(acc1, acc2)?;
452
453 let result = some_str_sorted(acc1.evaluate()?, ",");
454
455 assert_eq!(result, "a,b,c,d,e,f");
456
457 Ok(())
458 }
459
460 #[test]
461 fn duplicates_no_distinct() -> Result<()> {
462 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
463
464 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
465 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
466 acc1 = merge(acc1, acc2)?;
467
468 let result = some_str(acc1.evaluate()?);
469
470 assert_eq!(result, "a,b,c,a,b,c");
471
472 Ok(())
473 }
474
475 #[test]
476 fn duplicates_distinct() -> Result<()> {
477 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
478 .distinct()
479 .build_two()?;
480
481 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
482 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
483 acc1 = merge(acc1, acc2)?;
484
485 let result = some_str_sorted(acc1.evaluate()?, ",");
486
487 assert_eq!(result, "a,b,c");
488
489 Ok(())
490 }
491
492 #[test]
493 fn no_duplicates_distinct_sort_asc() -> Result<()> {
494 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
495 .distinct()
496 .order_by_col("col", SortOptions::new(false, false))
497 .build_two()?;
498
499 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
500 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
501 acc1 = merge(acc1, acc2)?;
502
503 let result = some_str(acc1.evaluate()?);
504
505 assert_eq!(result, "a,b,c,d,e,f");
506
507 Ok(())
508 }
509
510 #[test]
511 fn no_duplicates_distinct_sort_desc() -> Result<()> {
512 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
513 .distinct()
514 .order_by_col("col", SortOptions::new(true, false))
515 .build_two()?;
516
517 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
518 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
519 acc1 = merge(acc1, acc2)?;
520
521 let result = some_str(acc1.evaluate()?);
522
523 assert_eq!(result, "f,e,d,c,b,a");
524
525 Ok(())
526 }
527
528 #[test]
529 fn duplicates_distinct_sort_asc() -> Result<()> {
530 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
531 .distinct()
532 .order_by_col("col", SortOptions::new(false, false))
533 .build_two()?;
534
535 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
536 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
537 acc1 = merge(acc1, acc2)?;
538
539 let result = some_str(acc1.evaluate()?);
540
541 assert_eq!(result, "a,b,c");
542
543 Ok(())
544 }
545
546 #[test]
547 fn duplicates_distinct_sort_desc() -> Result<()> {
548 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
549 .distinct()
550 .order_by_col("col", SortOptions::new(true, false))
551 .build_two()?;
552
553 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
554 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
555 acc1 = merge(acc1, acc2)?;
556
557 let result = some_str(acc1.evaluate()?);
558
559 assert_eq!(result, "c,b,a");
560
561 Ok(())
562 }
563
564 struct StringAggAccumulatorBuilder {
565 sep: String,
566 distinct: bool,
567 order_bys: Vec<PhysicalSortExpr>,
568 schema: Schema,
569 }
570
571 impl StringAggAccumulatorBuilder {
572 fn new(sep: &str) -> Self {
573 Self {
574 sep: sep.to_string(),
575 distinct: Default::default(),
576 order_bys: vec![],
577 schema: Schema {
578 fields: Fields::from(vec![Field::new(
579 "col",
580 DataType::LargeUtf8,
581 true,
582 )]),
583 metadata: Default::default(),
584 },
585 }
586 }
587 fn distinct(mut self) -> Self {
588 self.distinct = true;
589 self
590 }
591
592 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
593 self.order_bys.extend([PhysicalSortExpr::new(
594 Arc::new(
595 Column::new_with_schema(col, &self.schema)
596 .expect("column not available in schema"),
597 ),
598 sort_options,
599 )]);
600 self
601 }
602
603 fn build(&self) -> Result<Box<dyn Accumulator>> {
604 StringAgg::new().accumulator(AccumulatorArgs {
605 return_field: Field::new("f", DataType::LargeUtf8, true).into(),
606 schema: &self.schema,
607 expr_fields: &[
608 Field::new("col", DataType::LargeUtf8, true).into(),
609 Field::new("lit", DataType::Utf8, false).into(),
610 ],
611 ignore_nulls: false,
612 order_bys: &self.order_bys,
613 is_reversed: false,
614 name: "",
615 is_distinct: self.distinct,
616 exprs: &[
617 Arc::new(Column::new("col", 0)),
618 Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
619 ],
620 })
621 }
622
623 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
624 Ok((self.build()?, self.build()?))
625 }
626 }
627
628 fn some_str(value: ScalarValue) -> String {
629 str(value)
630 .expect("ScalarValue was not a String")
631 .expect("ScalarValue was None")
632 }
633
634 fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
635 let value = some_str(value);
636 let mut parts: Vec<&str> = value.split(sep).collect();
637 parts.sort();
638 parts.join(sep)
639 }
640
641 fn str(value: ScalarValue) -> Result<Option<String>> {
642 match value {
643 ScalarValue::LargeUtf8(v) => Ok(v),
644 _ => internal_err!(
645 "Expected ScalarValue::LargeUtf8, got {}",
646 value.data_type()
647 ),
648 }
649 }
650
651 fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
652 Arc::new(LargeStringArray::from(list.to_vec()))
653 }
654
655 fn merge(
656 mut acc1: Box<dyn Accumulator>,
657 mut acc2: Box<dyn Accumulator>,
658 ) -> Result<Box<dyn Accumulator>> {
659 let intermediate_state = acc2.state().and_then(|e| {
660 e.iter()
661 .map(|v| v.to_array())
662 .collect::<Result<Vec<ArrayRef>>>()
663 })?;
664 acc1.merge_batch(&intermediate_state)?;
665 Ok(acc1)
666 }
667}