1use std::any::Any;
21use std::hash::Hash;
22use std::mem::size_of_val;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{
29 as_generic_string_array, as_string_array, as_string_view_array,
30};
31use datafusion_common::{
32 internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue,
33};
34use datafusion_expr::function::AccumulatorArgs;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
38};
39use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
40use datafusion_macros::user_doc;
41use datafusion_physical_expr::expressions::Literal;
42
43make_udaf_expr_and_func!(
44 StringAgg,
45 string_agg,
46 expr delimiter,
47 "Concatenates the values of string expressions and places separator values between them",
48 string_agg_udaf
49);
50
51#[user_doc(
52 doc_section(label = "General Functions"),
53 description = "Concatenates the values of string expressions and places separator values between them. \
54If ordering is required, strings are concatenated in the specified order. \
55This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
56 syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
57 sql_example = r#"```sql
58> SELECT string_agg(name, ', ') AS names_list
59 FROM employee;
60+--------------------------+
61| names_list |
62+--------------------------+
63| Alice, Bob, Bob, Charlie |
64+--------------------------+
65> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
66 FROM employee;
67+--------------------------+
68| names_list |
69+--------------------------+
70| Charlie, Bob, Bob, Alice |
71+--------------------------+
72> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
73 FROM employee;
74+--------------------------+
75| names_list |
76+--------------------------+
77| Charlie, Bob, Alice |
78+--------------------------+
79```"#,
80 argument(
81 name = "expression",
82 description = "The string expression to concatenate. Can be a column or any valid string expression."
83 ),
84 argument(
85 name = "delimiter",
86 description = "A literal string used as a separator between the concatenated values."
87 )
88)]
89#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct StringAgg {
92 signature: Signature,
93 array_agg: ArrayAgg,
94}
95
96impl StringAgg {
97 pub fn new() -> Self {
99 Self {
100 signature: Signature::one_of(
101 vec![
102 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
103 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
104 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
105 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
106 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
107 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
108 TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
109 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
110 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
111 TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
112 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
113 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
114 ],
115 Volatility::Immutable,
116 ),
117 array_agg: Default::default(),
118 }
119 }
120}
121
122impl Default for StringAgg {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl AggregateUDFImpl for StringAgg {
131 fn as_any(&self) -> &dyn Any {
132 self
133 }
134
135 fn name(&self) -> &str {
136 "string_agg"
137 }
138
139 fn signature(&self) -> &Signature {
140 &self.signature
141 }
142
143 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
144 Ok(DataType::LargeUtf8)
145 }
146
147 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
148 let no_order_no_distinct =
150 (args.ordering_fields.is_empty()) && (!args.is_distinct);
151 if no_order_no_distinct {
152 Ok(vec![Field::new(
154 format_state_name(args.name, "string_agg"),
155 DataType::LargeUtf8,
156 true,
157 )
158 .into()])
159 } else {
160 self.array_agg.state_fields(args)
162 }
163 }
164
165 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
166 let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
167 return not_impl_err!(
168 "The second argument of the string_agg function must be a string literal"
169 );
170 };
171
172 let delimiter = if lit.value().is_null() {
173 ""
176 } else if let Some(lit_string) = lit.value().try_as_str() {
177 lit_string.unwrap_or("")
178 } else {
179 return not_impl_err!(
180 "StringAgg not supported for delimiter \"{}\"",
181 lit.value()
182 );
183 };
184
185 let no_order_no_distinct =
187 acc_args.order_bys.is_empty() && (!acc_args.is_distinct);
188
189 if no_order_no_distinct {
190 Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
192 } else {
193 let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
195 return_field: Field::new(
196 "f",
197 DataType::new_list(acc_args.return_field.data_type().clone(), true),
198 true,
199 )
200 .into(),
201 exprs: &filter_index(acc_args.exprs, 1),
202 expr_fields: &filter_index(acc_args.expr_fields, 1),
203 schema: acc_args.schema,
207 ignore_nulls: acc_args.ignore_nulls,
208 order_bys: acc_args.order_bys,
209 is_reversed: acc_args.is_reversed,
210 name: acc_args.name,
211 is_distinct: acc_args.is_distinct,
212 })?;
213
214 Ok(Box::new(StringAggAccumulator::new(
215 array_agg_acc,
216 delimiter,
217 )))
218 }
219 }
220
221 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
222 datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
223 }
224
225 fn documentation(&self) -> Option<&Documentation> {
226 self.doc()
227 }
228}
229
230#[derive(Debug)]
232pub(crate) struct StringAggAccumulator {
233 array_agg_acc: Box<dyn Accumulator>,
234 delimiter: String,
235}
236
237impl StringAggAccumulator {
238 pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
239 Self {
240 array_agg_acc,
241 delimiter: delimiter.to_string(),
242 }
243 }
244}
245
246impl Accumulator for StringAggAccumulator {
247 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
248 self.array_agg_acc.update_batch(&filter_index(values, 1))
249 }
250
251 fn evaluate(&mut self) -> Result<ScalarValue> {
252 let scalar = self.array_agg_acc.evaluate()?;
253
254 let ScalarValue::List(list) = scalar else {
255 return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
256 };
257
258 let string_arr: Vec<_> = match list.value_type() {
259 DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
260 .iter()
261 .flatten()
262 .collect(),
263 DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
264 .iter()
265 .flatten()
266 .collect(),
267 DataType::Utf8View => as_string_view_array(list.values())?
268 .iter()
269 .flatten()
270 .collect(),
271 _ => {
272 return internal_err!(
273 "Expected elements to of type Utf8 or LargeUtf8, but got {}",
274 list.value_type()
275 )
276 }
277 };
278
279 if string_arr.is_empty() {
280 return Ok(ScalarValue::LargeUtf8(None));
281 }
282
283 Ok(ScalarValue::LargeUtf8(Some(
284 string_arr.join(&self.delimiter),
285 )))
286 }
287
288 fn size(&self) -> usize {
289 size_of_val(self) - size_of_val(&self.array_agg_acc)
290 + self.array_agg_acc.size()
291 + self.delimiter.capacity()
292 }
293
294 fn state(&mut self) -> Result<Vec<ScalarValue>> {
295 self.array_agg_acc.state()
296 }
297
298 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
299 self.array_agg_acc.merge_batch(values)
300 }
301}
302
303fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
304 values
305 .iter()
306 .enumerate()
307 .filter(|(i, _)| *i != index)
308 .map(|(_, v)| v)
309 .cloned()
310 .collect::<Vec<_>>()
311}
312
313#[derive(Debug)]
318pub(crate) struct SimpleStringAggAccumulator {
319 delimiter: String,
320 accumulated_string: String,
322 has_value: bool,
323}
324
325impl SimpleStringAggAccumulator {
326 pub fn new(delimiter: &str) -> Self {
327 Self {
328 delimiter: delimiter.to_string(),
329 accumulated_string: "".to_string(),
330 has_value: false,
331 }
332 }
333
334 #[inline]
335 fn append_strings<'a, I>(&mut self, iter: I)
336 where
337 I: Iterator<Item = Option<&'a str>>,
338 {
339 for value in iter.flatten() {
340 if self.has_value {
341 self.accumulated_string.push_str(&self.delimiter);
342 }
343
344 self.accumulated_string.push_str(value);
345 self.has_value = true;
346 }
347 }
348}
349
350impl Accumulator for SimpleStringAggAccumulator {
351 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
352 let string_arr = values.first().ok_or_else(|| {
353 internal_datafusion_err!(
354 "Planner should ensure its first arg is Utf8/Utf8View"
355 )
356 })?;
357
358 match string_arr.data_type() {
359 DataType::Utf8 => {
360 let array = as_string_array(string_arr)?;
361 self.append_strings(array.iter());
362 }
363 DataType::LargeUtf8 => {
364 let array = as_generic_string_array::<i64>(string_arr)?;
365 self.append_strings(array.iter());
366 }
367 DataType::Utf8View => {
368 let array = as_string_view_array(string_arr)?;
369 self.append_strings(array.iter());
370 }
371 other => {
372 return internal_err!(
373 "Planner should ensure string_agg first argument is Utf8-like, found {other}"
374 );
375 }
376 }
377
378 Ok(())
379 }
380
381 fn evaluate(&mut self) -> Result<ScalarValue> {
382 let result = if self.has_value {
383 ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
384 } else {
385 ScalarValue::LargeUtf8(None)
386 };
387
388 self.has_value = false;
389 Ok(result)
390 }
391
392 fn size(&self) -> usize {
393 size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
394 }
395
396 fn state(&mut self) -> Result<Vec<ScalarValue>> {
397 let result = if self.has_value {
398 ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
399 } else {
400 ScalarValue::LargeUtf8(None)
401 };
402 self.has_value = false;
403
404 Ok(vec![result])
405 }
406
407 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
408 self.update_batch(values)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use arrow::array::LargeStringArray;
416 use arrow::compute::SortOptions;
417 use arrow::datatypes::{Fields, Schema};
418 use datafusion_common::internal_err;
419 use datafusion_physical_expr::expressions::Column;
420 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
421 use std::sync::Arc;
422
423 #[test]
424 fn no_duplicates_no_distinct() -> Result<()> {
425 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
426
427 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
428 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
429 acc1 = merge(acc1, acc2)?;
430
431 let result = some_str(acc1.evaluate()?);
432
433 assert_eq!(result, "a,b,c,d,e,f");
434
435 Ok(())
436 }
437
438 #[test]
439 fn no_duplicates_distinct() -> Result<()> {
440 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
441 .distinct()
442 .build_two()?;
443
444 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
445 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
446 acc1 = merge(acc1, acc2)?;
447
448 let result = some_str_sorted(acc1.evaluate()?, ",");
449
450 assert_eq!(result, "a,b,c,d,e,f");
451
452 Ok(())
453 }
454
455 #[test]
456 fn duplicates_no_distinct() -> Result<()> {
457 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
458
459 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
460 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
461 acc1 = merge(acc1, acc2)?;
462
463 let result = some_str(acc1.evaluate()?);
464
465 assert_eq!(result, "a,b,c,a,b,c");
466
467 Ok(())
468 }
469
470 #[test]
471 fn duplicates_distinct() -> Result<()> {
472 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
473 .distinct()
474 .build_two()?;
475
476 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
477 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
478 acc1 = merge(acc1, acc2)?;
479
480 let result = some_str_sorted(acc1.evaluate()?, ",");
481
482 assert_eq!(result, "a,b,c");
483
484 Ok(())
485 }
486
487 #[test]
488 fn no_duplicates_distinct_sort_asc() -> Result<()> {
489 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
490 .distinct()
491 .order_by_col("col", SortOptions::new(false, false))
492 .build_two()?;
493
494 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
495 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
496 acc1 = merge(acc1, acc2)?;
497
498 let result = some_str(acc1.evaluate()?);
499
500 assert_eq!(result, "a,b,c,d,e,f");
501
502 Ok(())
503 }
504
505 #[test]
506 fn no_duplicates_distinct_sort_desc() -> Result<()> {
507 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
508 .distinct()
509 .order_by_col("col", SortOptions::new(true, false))
510 .build_two()?;
511
512 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
513 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
514 acc1 = merge(acc1, acc2)?;
515
516 let result = some_str(acc1.evaluate()?);
517
518 assert_eq!(result, "f,e,d,c,b,a");
519
520 Ok(())
521 }
522
523 #[test]
524 fn duplicates_distinct_sort_asc() -> Result<()> {
525 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
526 .distinct()
527 .order_by_col("col", SortOptions::new(false, false))
528 .build_two()?;
529
530 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
531 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
532 acc1 = merge(acc1, acc2)?;
533
534 let result = some_str(acc1.evaluate()?);
535
536 assert_eq!(result, "a,b,c");
537
538 Ok(())
539 }
540
541 #[test]
542 fn duplicates_distinct_sort_desc() -> Result<()> {
543 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
544 .distinct()
545 .order_by_col("col", SortOptions::new(true, false))
546 .build_two()?;
547
548 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
549 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
550 acc1 = merge(acc1, acc2)?;
551
552 let result = some_str(acc1.evaluate()?);
553
554 assert_eq!(result, "c,b,a");
555
556 Ok(())
557 }
558
559 struct StringAggAccumulatorBuilder {
560 sep: String,
561 distinct: bool,
562 order_bys: Vec<PhysicalSortExpr>,
563 schema: Schema,
564 }
565
566 impl StringAggAccumulatorBuilder {
567 fn new(sep: &str) -> Self {
568 Self {
569 sep: sep.to_string(),
570 distinct: Default::default(),
571 order_bys: vec![],
572 schema: Schema {
573 fields: Fields::from(vec![Field::new(
574 "col",
575 DataType::LargeUtf8,
576 true,
577 )]),
578 metadata: Default::default(),
579 },
580 }
581 }
582 fn distinct(mut self) -> Self {
583 self.distinct = true;
584 self
585 }
586
587 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
588 self.order_bys.extend([PhysicalSortExpr::new(
589 Arc::new(
590 Column::new_with_schema(col, &self.schema)
591 .expect("column not available in schema"),
592 ),
593 sort_options,
594 )]);
595 self
596 }
597
598 fn build(&self) -> Result<Box<dyn Accumulator>> {
599 StringAgg::new().accumulator(AccumulatorArgs {
600 return_field: Field::new("f", DataType::LargeUtf8, true).into(),
601 schema: &self.schema,
602 expr_fields: &[
603 Field::new("col", DataType::LargeUtf8, true).into(),
604 Field::new("lit", DataType::Utf8, false).into(),
605 ],
606 ignore_nulls: false,
607 order_bys: &self.order_bys,
608 is_reversed: false,
609 name: "",
610 is_distinct: self.distinct,
611 exprs: &[
612 Arc::new(Column::new("col", 0)),
613 Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
614 ],
615 })
616 }
617
618 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
619 Ok((self.build()?, self.build()?))
620 }
621 }
622
623 fn some_str(value: ScalarValue) -> String {
624 str(value)
625 .expect("ScalarValue was not a String")
626 .expect("ScalarValue was None")
627 }
628
629 fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
630 let value = some_str(value);
631 let mut parts: Vec<&str> = value.split(sep).collect();
632 parts.sort();
633 parts.join(sep)
634 }
635
636 fn str(value: ScalarValue) -> Result<Option<String>> {
637 match value {
638 ScalarValue::LargeUtf8(v) => Ok(v),
639 _ => internal_err!(
640 "Expected ScalarValue::LargeUtf8, got {}",
641 value.data_type()
642 ),
643 }
644 }
645
646 fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
647 Arc::new(LargeStringArray::from(list.to_vec()))
648 }
649
650 fn merge(
651 mut acc1: Box<dyn Accumulator>,
652 mut acc2: Box<dyn Accumulator>,
653 ) -> Result<Box<dyn Accumulator>> {
654 let intermediate_state = acc2.state().and_then(|e| {
655 e.iter()
656 .map(|v| v.to_array())
657 .collect::<Result<Vec<ArrayRef>>>()
658 })?;
659 acc1.merge_batch(&intermediate_state)?;
660 Ok(acc1)
661 }
662}