1use std::any::Any;
21use std::hash::{DefaultHasher, Hash, Hasher};
22use std::mem::size_of_val;
23
24use crate::array_agg::ArrayAgg;
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
30use datafusion_expr::function::AccumulatorArgs;
31use datafusion_expr::{
32 Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
33};
34use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
35use datafusion_macros::user_doc;
36use datafusion_physical_expr::expressions::Literal;
37
38make_udaf_expr_and_func!(
39 StringAgg,
40 string_agg,
41 expr delimiter,
42 "Concatenates the values of string expressions and places separator values between them",
43 string_agg_udaf
44);
45
46#[user_doc(
47 doc_section(label = "General Functions"),
48 description = "Concatenates the values of string expressions and places separator values between them. \
49If ordering is required, strings are concatenated in the specified order. \
50This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.",
51 syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])",
52 sql_example = r#"```sql
53> SELECT string_agg(name, ', ') AS names_list
54 FROM employee;
55+--------------------------+
56| names_list |
57+--------------------------+
58| Alice, Bob, Bob, Charlie |
59+--------------------------+
60> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list
61 FROM employee;
62+--------------------------+
63| names_list |
64+--------------------------+
65| Charlie, Bob, Bob, Alice |
66+--------------------------+
67> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list
68 FROM employee;
69+--------------------------+
70| names_list |
71+--------------------------+
72| Charlie, Bob, Alice |
73+--------------------------+
74```"#,
75 argument(
76 name = "expression",
77 description = "The string expression to concatenate. Can be a column or any valid string expression."
78 ),
79 argument(
80 name = "delimiter",
81 description = "A literal string used as a separator between the concatenated values."
82 )
83)]
84#[derive(Debug)]
86pub struct StringAgg {
87 signature: Signature,
88 array_agg: ArrayAgg,
89}
90
91impl StringAgg {
92 pub fn new() -> Self {
94 Self {
95 signature: Signature::one_of(
96 vec![
97 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
98 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
99 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
100 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
101 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
102 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
103 TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
104 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
105 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
106 TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
107 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]),
108 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
109 ],
110 Volatility::Immutable,
111 ),
112 array_agg: Default::default(),
113 }
114 }
115}
116
117impl Default for StringAgg {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl AggregateUDFImpl for StringAgg {
124 fn as_any(&self) -> &dyn Any {
125 self
126 }
127
128 fn name(&self) -> &str {
129 "string_agg"
130 }
131
132 fn signature(&self) -> &Signature {
133 &self.signature
134 }
135
136 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
137 Ok(DataType::LargeUtf8)
138 }
139
140 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
141 self.array_agg.state_fields(args)
142 }
143
144 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
145 let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() else {
146 return not_impl_err!(
147 "The second argument of the string_agg function must be a string literal"
148 );
149 };
150
151 let delimiter = if lit.value().is_null() {
152 ""
155 } else if let Some(lit_string) = lit.value().try_as_str() {
156 lit_string.unwrap_or("")
157 } else {
158 return not_impl_err!(
159 "StringAgg not supported for delimiter \"{}\"",
160 lit.value()
161 );
162 };
163
164 let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
165 return_field: Field::new(
166 "f",
167 DataType::new_list(acc_args.return_field.data_type().clone(), true),
168 true,
169 )
170 .into(),
171 exprs: &filter_index(acc_args.exprs, 1),
172 ..acc_args
173 })?;
174
175 Ok(Box::new(StringAggAccumulator::new(
176 array_agg_acc,
177 delimiter,
178 )))
179 }
180
181 fn documentation(&self) -> Option<&Documentation> {
182 self.doc()
183 }
184
185 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
186 let Some(other) = other.as_any().downcast_ref::<Self>() else {
187 return false;
188 };
189 let Self {
190 signature,
191 array_agg,
192 } = self;
193 signature == &other.signature && array_agg.equals(&other.array_agg)
194 }
195
196 fn hash_value(&self) -> u64 {
197 let Self {
198 signature,
199 array_agg,
200 } = self;
201 let mut hasher = DefaultHasher::new();
202 std::any::type_name::<Self>().hash(&mut hasher);
203 signature.hash(&mut hasher);
204 hasher.write_u64(array_agg.hash_value());
205 hasher.finish()
206 }
207}
208
209#[derive(Debug)]
210pub(crate) struct StringAggAccumulator {
211 array_agg_acc: Box<dyn Accumulator>,
212 delimiter: String,
213}
214
215impl StringAggAccumulator {
216 pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
217 Self {
218 array_agg_acc,
219 delimiter: delimiter.to_string(),
220 }
221 }
222}
223
224impl Accumulator for StringAggAccumulator {
225 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
226 self.array_agg_acc.update_batch(&filter_index(values, 1))
227 }
228
229 fn evaluate(&mut self) -> Result<ScalarValue> {
230 let scalar = self.array_agg_acc.evaluate()?;
231
232 let ScalarValue::List(list) = scalar else {
233 return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type());
234 };
235
236 let string_arr: Vec<_> = match list.value_type() {
237 DataType::LargeUtf8 => as_generic_string_array::<i64>(list.values())?
238 .iter()
239 .flatten()
240 .collect(),
241 DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
242 .iter()
243 .flatten()
244 .collect(),
245 DataType::Utf8View => as_string_view_array(list.values())?
246 .iter()
247 .flatten()
248 .collect(),
249 _ => {
250 return internal_err!(
251 "Expected elements to of type Utf8 or LargeUtf8, but got {}",
252 list.value_type()
253 )
254 }
255 };
256
257 if string_arr.is_empty() {
258 return Ok(ScalarValue::LargeUtf8(None));
259 }
260
261 Ok(ScalarValue::LargeUtf8(Some(
262 string_arr.join(&self.delimiter),
263 )))
264 }
265
266 fn size(&self) -> usize {
267 size_of_val(self) - size_of_val(&self.array_agg_acc)
268 + self.array_agg_acc.size()
269 + self.delimiter.capacity()
270 }
271
272 fn state(&mut self) -> Result<Vec<ScalarValue>> {
273 self.array_agg_acc.state()
274 }
275
276 fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
277 self.array_agg_acc.merge_batch(values)
278 }
279}
280
281fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
282 values
283 .iter()
284 .enumerate()
285 .filter(|(i, _)| *i != index)
286 .map(|(_, v)| v)
287 .cloned()
288 .collect::<Vec<_>>()
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use arrow::array::LargeStringArray;
295 use arrow::compute::SortOptions;
296 use arrow::datatypes::{Fields, Schema};
297 use datafusion_common::internal_err;
298 use datafusion_physical_expr::expressions::Column;
299 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
300 use std::sync::Arc;
301
302 #[test]
303 fn no_duplicates_no_distinct() -> Result<()> {
304 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
305
306 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
307 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
308 acc1 = merge(acc1, acc2)?;
309
310 let result = some_str(acc1.evaluate()?);
311
312 assert_eq!(result, "a,b,c,d,e,f");
313
314 Ok(())
315 }
316
317 #[test]
318 fn no_duplicates_distinct() -> Result<()> {
319 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
320 .distinct()
321 .build_two()?;
322
323 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
324 acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?;
325 acc1 = merge(acc1, acc2)?;
326
327 let result = some_str_sorted(acc1.evaluate()?, ",");
328
329 assert_eq!(result, "a,b,c,d,e,f");
330
331 Ok(())
332 }
333
334 #[test]
335 fn duplicates_no_distinct() -> Result<()> {
336 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?;
337
338 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
339 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
340 acc1 = merge(acc1, acc2)?;
341
342 let result = some_str(acc1.evaluate()?);
343
344 assert_eq!(result, "a,b,c,a,b,c");
345
346 Ok(())
347 }
348
349 #[test]
350 fn duplicates_distinct() -> Result<()> {
351 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
352 .distinct()
353 .build_two()?;
354
355 acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?;
356 acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?;
357 acc1 = merge(acc1, acc2)?;
358
359 let result = some_str_sorted(acc1.evaluate()?, ",");
360
361 assert_eq!(result, "a,b,c");
362
363 Ok(())
364 }
365
366 #[test]
367 fn no_duplicates_distinct_sort_asc() -> Result<()> {
368 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
369 .distinct()
370 .order_by_col("col", SortOptions::new(false, false))
371 .build_two()?;
372
373 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
374 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
375 acc1 = merge(acc1, acc2)?;
376
377 let result = some_str(acc1.evaluate()?);
378
379 assert_eq!(result, "a,b,c,d,e,f");
380
381 Ok(())
382 }
383
384 #[test]
385 fn no_duplicates_distinct_sort_desc() -> Result<()> {
386 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
387 .distinct()
388 .order_by_col("col", SortOptions::new(true, false))
389 .build_two()?;
390
391 acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?;
392 acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?;
393 acc1 = merge(acc1, acc2)?;
394
395 let result = some_str(acc1.evaluate()?);
396
397 assert_eq!(result, "f,e,d,c,b,a");
398
399 Ok(())
400 }
401
402 #[test]
403 fn duplicates_distinct_sort_asc() -> Result<()> {
404 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
405 .distinct()
406 .order_by_col("col", SortOptions::new(false, false))
407 .build_two()?;
408
409 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
410 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
411 acc1 = merge(acc1, acc2)?;
412
413 let result = some_str(acc1.evaluate()?);
414
415 assert_eq!(result, "a,b,c");
416
417 Ok(())
418 }
419
420 #[test]
421 fn duplicates_distinct_sort_desc() -> Result<()> {
422 let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",")
423 .distinct()
424 .order_by_col("col", SortOptions::new(true, false))
425 .build_two()?;
426
427 acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?;
428 acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?;
429 acc1 = merge(acc1, acc2)?;
430
431 let result = some_str(acc1.evaluate()?);
432
433 assert_eq!(result, "c,b,a");
434
435 Ok(())
436 }
437
438 struct StringAggAccumulatorBuilder {
439 sep: String,
440 distinct: bool,
441 order_bys: Vec<PhysicalSortExpr>,
442 schema: Schema,
443 }
444
445 impl StringAggAccumulatorBuilder {
446 fn new(sep: &str) -> Self {
447 Self {
448 sep: sep.to_string(),
449 distinct: Default::default(),
450 order_bys: vec![],
451 schema: Schema {
452 fields: Fields::from(vec![Field::new(
453 "col",
454 DataType::LargeUtf8,
455 true,
456 )]),
457 metadata: Default::default(),
458 },
459 }
460 }
461 fn distinct(mut self) -> Self {
462 self.distinct = true;
463 self
464 }
465
466 fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self {
467 self.order_bys.extend([PhysicalSortExpr::new(
468 Arc::new(
469 Column::new_with_schema(col, &self.schema)
470 .expect("column not available in schema"),
471 ),
472 sort_options,
473 )]);
474 self
475 }
476
477 fn build(&self) -> Result<Box<dyn Accumulator>> {
478 StringAgg::new().accumulator(AccumulatorArgs {
479 return_field: Field::new("f", DataType::LargeUtf8, true).into(),
480 schema: &self.schema,
481 ignore_nulls: false,
482 order_bys: &self.order_bys,
483 is_reversed: false,
484 name: "",
485 is_distinct: self.distinct,
486 exprs: &[
487 Arc::new(Column::new("col", 0)),
488 Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))),
489 ],
490 })
491 }
492
493 fn build_two(&self) -> Result<(Box<dyn Accumulator>, Box<dyn Accumulator>)> {
494 Ok((self.build()?, self.build()?))
495 }
496 }
497
498 fn some_str(value: ScalarValue) -> String {
499 str(value)
500 .expect("ScalarValue was not a String")
501 .expect("ScalarValue was None")
502 }
503
504 fn some_str_sorted(value: ScalarValue, sep: &str) -> String {
505 let value = some_str(value);
506 let mut parts: Vec<&str> = value.split(sep).collect();
507 parts.sort();
508 parts.join(sep)
509 }
510
511 fn str(value: ScalarValue) -> Result<Option<String>> {
512 match value {
513 ScalarValue::LargeUtf8(v) => Ok(v),
514 _ => internal_err!(
515 "Expected ScalarValue::LargeUtf8, got {}",
516 value.data_type()
517 ),
518 }
519 }
520
521 fn data<const N: usize>(list: [&str; N]) -> ArrayRef {
522 Arc::new(LargeStringArray::from(list.to_vec()))
523 }
524
525 fn merge(
526 mut acc1: Box<dyn Accumulator>,
527 mut acc2: Box<dyn Accumulator>,
528 ) -> Result<Box<dyn Accumulator>> {
529 let intermediate_state = acc2.state().and_then(|e| {
530 e.iter()
531 .map(|v| v.to_array())
532 .collect::<Result<Vec<ArrayRef>>>()
533 })?;
534 acc1.merge_batch(&intermediate_state)?;
535 Ok(acc1)
536 }
537}