1use std::any::Any;
21use std::hash::{DefaultHasher, Hash, Hasher};
22use std::mem::size_of_val;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
30use datafusion_expr::function::AccumulatorArgs;
31use datafusion_expr::{
32 Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
33};
34use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
35use datafusion_macros::user_doc;
36use datafusion_physical_expr::expressions::Literal;
37
38make_udaf_expr_and_func!(
39 StringAgg,
40 string_agg,
41 expr delimiter,
42 "Concatenates the values of string expressions and places separator values between them",
43 string_agg_udaf
44);
45
46#[user_doc(
47 doc_section(label = "General Functions"),
48 description = "Concatenates the values of string expressions and places separator values between them. \
49If ordering is required, strings are concatenated in the specified order. \
50This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
51 syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
52 sql_example = r#"```sql
53> SELECT string_agg(name, ', ') AS names_list
54 FROM employee;
55+--------------------------+
56| names_list |
57+--------------------------+
58| Alice, Bob, Bob, Charlie |
59+--------------------------+
60> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
61 FROM employee;
62+--------------------------+
63| names_list |
64+--------------------------+
65| Charlie, Bob, Bob, Alice |
66+--------------------------+
67> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
68 FROM employee;
69+--------------------------+
70| names_list |
71+--------------------------+
72| Charlie, Bob, Alice |
73+--------------------------+
74```"#,
75 argument(
76 name = "expression",
77 description = "The string expression to concatenate. Can be a column or any valid string expression."
78 ),
79 argument(
80 name = "delimiter",
81 description = "A literal string used as a separator between the concatenated values."
82 )
83)]
84#[derive(Debug)]
86pub struct StringAgg {
87 signature: Signature,
88 array_agg: ArrayAgg,
89}
90
91impl StringAgg {
92 pub fn new() -> Self {
94 Self {
95 signature: Signature::one_of(
96 vec![
97 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
98 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
99 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
100 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
101 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
102 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
103 TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
104 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
105 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
106 TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
107 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
108 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
109 ],
110 Volatility::Immutable,
111 ),
112 array_agg: Default::default(),
113 }
114 }
115}
116
117impl Default for StringAgg {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl AggregateUDFImpl for StringAgg {
124 fn as_any(&self) -> &dyn Any {
125 self
126 }
127
128 fn name(&self) -> &str {
129 "string_agg"
130 }
131
132 fn signature(&self) -> &Signature {
133 &self.signature
134 }
135
136 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
137 Ok(DataType::LargeUtf8)
138 }
139
140 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
141 self.array_agg.state_fields(args)
142 }
143
144 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
145 let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
146 return not_impl_err!(
147 "The second argument of the string_agg function must be a string literal"
148 );
149 };
150
151 let delimiter = if lit.value().is_null() {
152 ""
155 } else if let Some(lit_string) = lit.value().try_as_str() {
156 lit_string.unwrap_or("")
157 } else {
158 return not_impl_err!(
159 "StringAgg not supported for delimiter \"{}\"",
160 lit.value()
161 );
162 };
163
164 let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
165 return_field: Field::new(
166 "f",
167 DataType::new_list(acc_args.return_field.data_type().clone(), true),
168 true,
169 )
170 .into(),
171 exprs: &filter_index(acc_args.exprs, 1),
172 ..acc_args
173 })?;
174
175 Ok(Box::new(StringAggAccumulator::new(
176 array_agg_acc,
177 delimiter,
178 )))
179 }
180
181 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
182 datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
183 }
184
185 fn documentation(&self) -> Option<&Documentation> {
186 self.doc()
187 }
188
189 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
190 let Some(other) = other.as_any().downcast_ref::<Self>() else {
191 return false;
192 };
193 let Self {
194 signature,
195 array_agg,
196 } = self;
197 signature == &other.signature && array_agg.equals(&other.array_agg)
198 }
199
200 fn hash_value(&self) -> u64 {
201 let Self {
202 signature,
203 array_agg,
204 } = self;
205 let mut hasher = DefaultHasher::new();
206 std::any::type_name::<Self>().hash(&mut hasher);
207 signature.hash(&mut hasher);
208 hasher.write_u64(array_agg.hash_value());
209 hasher.finish()
210 }
211}
212
213#[derive(Debug)]
214pub(crate) struct StringAggAccumulator {
215 array_agg_acc: Box<dyn Accumulator>,
216 delimiter: String,
217}
218
219impl StringAggAccumulator {
220 pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
221 Self {
222 array_agg_acc,
223 delimiter: delimiter.to_string(),
224 }
225 }
226}
227
228impl Accumulator for StringAggAccumulator {
229 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
230 self.array_agg_acc.update_batch(&filter_index(values, 1))
231 }
232
233 fn evaluate(&mut self) -> Result<ScalarValue> {
234 let scalar = self.array_agg_acc.evaluate()?;
235
236 let ScalarValue::List(list) = scalar else {
237 return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
238 };
239
240 let string_arr: Vec<_> = match list.value_type() {
241 DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
242 .iter()
243 .flatten()
244 .collect(),
245 DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
246 .iter()
247 .flatten()
248 .collect(),
249 DataType::Utf8View => as_string_view_array(list.values())?
250 .iter()
251 .flatten()
252 .collect(),
253 _ => {
254 return internal_err!(
255 "Expected elements to of type Utf8 or LargeUtf8, but got {}",
256 list.value_type()
257 )
258 }
259 };
260
261 if string_arr.is_empty() {
262 return Ok(ScalarValue::LargeUtf8(None));
263 }
264
265 Ok(ScalarValue::LargeUtf8(Some(
266 string_arr.join(&self.delimiter),
267 )))
268 }
269
270 fn size(&self) -> usize {
271 size_of_val(self) - size_of_val(&self.array_agg_acc)
272 + self.array_agg_acc.size()
273 + self.delimiter.capacity()
274 }
275
276 fn state(&mut self) -> Result<Vec<ScalarValue>> {
277 self.array_agg_acc.state()
278 }
279
280 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
281 self.array_agg_acc.merge_batch(values)
282 }
283}
284
285fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
286 values
287 .iter()
288 .enumerate()
289 .filter(|(i, _)| *i != index)
290 .map(|(_, v)| v)
291 .cloned()
292 .collect::<Vec<_>>()
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use arrow::array::LargeStringArray;
299 use arrow::compute::SortOptions;
300 use arrow::datatypes::{Fields, Schema};
301 use datafusion_common::internal_err;
302 use datafusion_physical_expr::expressions::Column;
303 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
304 use std::sync::Arc;
305
306 #[test]
307 fn no_duplicates_no_distinct() -> Result<()> {
308 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
309
310 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
311 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
312 acc1 = merge(acc1, acc2)?;
313
314 let result = some_str(acc1.evaluate()?);
315
316 assert_eq!(result, "a,b,c,d,e,f");
317
318 Ok(())
319 }
320
321 #[test]
322 fn no_duplicates_distinct() -> Result<()> {
323 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
324 .distinct()
325 .build_two()?;
326
327 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
328 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
329 acc1 = merge(acc1, acc2)?;
330
331 let result = some_str_sorted(acc1.evaluate()?, ",");
332
333 assert_eq!(result, "a,b,c,d,e,f");
334
335 Ok(())
336 }
337
338 #[test]
339 fn duplicates_no_distinct() -> Result<()> {
340 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
341
342 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
343 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
344 acc1 = merge(acc1, acc2)?;
345
346 let result = some_str(acc1.evaluate()?);
347
348 assert_eq!(result, "a,b,c,a,b,c");
349
350 Ok(())
351 }
352
353 #[test]
354 fn duplicates_distinct() -> Result<()> {
355 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
356 .distinct()
357 .build_two()?;
358
359 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
360 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
361 acc1 = merge(acc1, acc2)?;
362
363 let result = some_str_sorted(acc1.evaluate()?, ",");
364
365 assert_eq!(result, "a,b,c");
366
367 Ok(())
368 }
369
370 #[test]
371 fn no_duplicates_distinct_sort_asc() -> Result<()> {
372 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
373 .distinct()
374 .order_by_col("col", SortOptions::new(false, false))
375 .build_two()?;
376
377 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
378 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
379 acc1 = merge(acc1, acc2)?;
380
381 let result = some_str(acc1.evaluate()?);
382
383 assert_eq!(result, "a,b,c,d,e,f");
384
385 Ok(())
386 }
387
388 #[test]
389 fn no_duplicates_distinct_sort_desc() -> Result<()> {
390 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
391 .distinct()
392 .order_by_col("col", SortOptions::new(true, false))
393 .build_two()?;
394
395 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
396 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
397 acc1 = merge(acc1, acc2)?;
398
399 let result = some_str(acc1.evaluate()?);
400
401 assert_eq!(result, "f,e,d,c,b,a");
402
403 Ok(())
404 }
405
406 #[test]
407 fn duplicates_distinct_sort_asc() -> Result<()> {
408 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
409 .distinct()
410 .order_by_col("col", SortOptions::new(false, false))
411 .build_two()?;
412
413 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
414 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
415 acc1 = merge(acc1, acc2)?;
416
417 let result = some_str(acc1.evaluate()?);
418
419 assert_eq!(result, "a,b,c");
420
421 Ok(())
422 }
423
424 #[test]
425 fn duplicates_distinct_sort_desc() -> Result<()> {
426 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
427 .distinct()
428 .order_by_col("col", SortOptions::new(true, false))
429 .build_two()?;
430
431 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
432 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
433 acc1 = merge(acc1, acc2)?;
434
435 let result = some_str(acc1.evaluate()?);
436
437 assert_eq!(result, "c,b,a");
438
439 Ok(())
440 }
441
442 struct StringAggAccumulatorBuilder {
443 sep: String,
444 distinct: bool,
445 order_bys: Vec<PhysicalSortExpr>,
446 schema: Schema,
447 }
448
449 impl StringAggAccumulatorBuilder {
450 fn new(sep: &str) -> Self {
451 Self {
452 sep: sep.to_string(),
453 distinct: Default::default(),
454 order_bys: vec![],
455 schema: Schema {
456 fields: Fields::from(vec![Field::new(
457 "col",
458 DataType::LargeUtf8,
459 true,
460 )]),
461 metadata: Default::default(),
462 },
463 }
464 }
465 fn distinct(mut self) -> Self {
466 self.distinct = true;
467 self
468 }
469
470 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
471 self.order_bys.extend([PhysicalSortExpr::new(
472 Arc::new(
473 Column::new_with_schema(col, &self.schema)
474 .expect("column not available in schema"),
475 ),
476 sort_options,
477 )]);
478 self
479 }
480
481 fn build(&self) -> Result<Box<dyn Accumulator>> {
482 StringAgg::new().accumulator(AccumulatorArgs {
483 return_field: Field::new("f", DataType::LargeUtf8, true).into(),
484 schema: &self.schema,
485 ignore_nulls: false,
486 order_bys: &self.order_bys,
487 is_reversed: false,
488 name: "",
489 is_distinct: self.distinct,
490 exprs: &[
491 Arc::new(Column::new("col", 0)),
492 Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
493 ],
494 })
495 }
496
497 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
498 Ok((self.build()?, self.build()?))
499 }
500 }
501
502 fn some_str(value: ScalarValue) -> String {
503 str(value)
504 .expect("ScalarValue was not a String")
505 .expect("ScalarValue was None")
506 }
507
508 fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
509 let value = some_str(value);
510 let mut parts: Vec<&str> = value.split(sep).collect();
511 parts.sort();
512 parts.join(sep)
513 }
514
515 fn str(value: ScalarValue) -> Result<Option<String>> {
516 match value {
517 ScalarValue::LargeUtf8(v) => Ok(v),
518 _ => internal_err!(
519 "Expected ScalarValue::LargeUtf8, got {}",
520 value.data_type()
521 ),
522 }
523 }
524
525 fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
526 Arc::new(LargeStringArray::from(list.to_vec()))
527 }
528
529 fn merge(
530 mut acc1: Box<dyn Accumulator>,
531 mut acc2: Box<dyn Accumulator>,
532 ) -> Result<Box<dyn Accumulator>> {
533 let intermediate_state = acc2.state().and_then(|e| {
534 e.iter()
535 .map(|v| v.to_array())
536 .collect::<Result<Vec<ArrayRef>>>()
537 })?;
538 acc1.merge_batch(&intermediate_state)?;
539 Ok(acc1)
540 }
541}