1use arrow::array::{Array, StringArray, as_largestring_array};
19use std::any::Any;
20use std::sync::Arc;
21
22use arrow::datatypes::DataType;
23
24use crate::string::concat;
25use crate::string::concat::simplify_concat;
26use crate::string::concat_ws;
27use crate::strings::{ColumnarValueRef, StringArrayBuilder};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Concatenates multiple strings together with a specified separator.",
39 syntax_example = "concat_ws(separator, str[, ..., str_n])",
40 sql_example = r#"```sql
41> select concat_ws('_', 'data', 'fusion');
42+--------------------------------------------------+
43| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) |
44+--------------------------------------------------+
45| data_fusion |
46+--------------------------------------------------+
47```"#,
48 argument(
49 name = "separator",
50 description = "Separator to insert between concatenated strings."
51 ),
52 argument(
53 name = "str",
54 description = "String expression to operate on. Can be a constant, column, or function, and any combination of operators."
55 ),
56 argument(
57 name = "str_n",
58 description = "Subsequent string expressions to concatenate."
59 ),
60 related_udf(name = "concat")
61)]
62#[derive(Debug, PartialEq, Eq, Hash)]
63pub struct ConcatWsFunc {
64 signature: Signature,
65}
66
67impl Default for ConcatWsFunc {
68 fn default() -> Self {
69 ConcatWsFunc::new()
70 }
71}
72
73impl ConcatWsFunc {
74 pub fn new() -> Self {
75 use DataType::*;
76 Self {
77 signature: Signature::variadic(
78 vec![Utf8View, Utf8, LargeUtf8],
79 Volatility::Immutable,
80 ),
81 }
82 }
83}
84
85impl ScalarUDFImpl for ConcatWsFunc {
86 fn as_any(&self) -> &dyn Any {
87 self
88 }
89
90 fn name(&self) -> &str {
91 "concat_ws"
92 }
93
94 fn signature(&self) -> &Signature {
95 &self.signature
96 }
97
98 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99 use DataType::*;
100 Ok(Utf8)
101 }
102
103 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106 let ScalarFunctionArgs { args, .. } = args;
107
108 if args.len() < 2 {
110 return exec_err!(
111 "concat_ws was called with {} arguments. It requires at least 2.",
112 args.len()
113 );
114 }
115
116 let array_len = args
117 .iter()
118 .filter_map(|x| match x {
119 ColumnarValue::Array(array) => Some(array.len()),
120 _ => None,
121 })
122 .next();
123
124 if array_len.is_none() {
126 let ColumnarValue::Scalar(scalar) = &args[0] else {
127 unreachable!()
129 };
130 let sep = match scalar.try_as_str() {
131 Some(Some(s)) => s,
132 Some(None) => {
133 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
135 }
136 None => return internal_err!("Expected string literal, got {scalar:?}"),
137 };
138
139 let mut result = String::new();
140 let iter = &mut args[1..].iter().map(|arg| {
142 let ColumnarValue::Scalar(scalar) = arg else {
143 unreachable!()
145 };
146 scalar.try_as_str()
147 });
148
149 for scalar in iter.by_ref() {
151 match scalar {
152 Some(Some(s)) => {
153 result.push_str(s);
154 break;
155 }
156 Some(None) => {} None => {
158 return internal_err!("Expected string literal, got {scalar:?}");
159 }
160 }
161 }
162
163 for scalar in iter.by_ref() {
165 match scalar {
166 Some(Some(s)) => {
167 result.push_str(sep);
168 result.push_str(s);
169 }
170 Some(None) => {} None => {
172 return internal_err!("Expected string literal, got {scalar:?}");
173 }
174 }
175 }
176
177 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
178 }
179
180 let len = array_len.unwrap();
182 let mut data_size = 0;
183
184 let sep = match &args[0] {
186 ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
187 data_size += s.len() * len * (args.len() - 2); ColumnarValueRef::Scalar(s.as_bytes())
189 }
190 ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
191 return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len))));
192 }
193 ColumnarValue::Array(array) => {
194 let string_array = as_string_array(array)?;
195 data_size += string_array.values().len() * (args.len() - 2); if array.is_nullable() {
197 ColumnarValueRef::NullableArray(string_array)
198 } else {
199 ColumnarValueRef::NonNullableArray(string_array)
200 }
201 }
202 _ => unreachable!("concat ws"),
203 };
204
205 let mut columns = Vec::with_capacity(args.len() - 1);
206 for arg in &args[1..] {
207 match arg {
208 ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
209 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
210 | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
211 if let Some(s) = maybe_value {
212 data_size += s.len() * len;
213 columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
214 }
215 }
216 ColumnarValue::Array(array) => {
217 match array.data_type() {
218 DataType::Utf8 => {
219 let string_array = as_string_array(array)?;
220
221 data_size += string_array.values().len();
222 let column = if array.is_nullable() {
223 ColumnarValueRef::NullableArray(string_array)
224 } else {
225 ColumnarValueRef::NonNullableArray(string_array)
226 };
227 columns.push(column);
228 }
229 DataType::LargeUtf8 => {
230 let string_array = as_largestring_array(array);
231
232 data_size += string_array.values().len();
233 let column = if array.is_nullable() {
234 ColumnarValueRef::NullableLargeStringArray(string_array)
235 } else {
236 ColumnarValueRef::NonNullableLargeStringArray(
237 string_array,
238 )
239 };
240 columns.push(column);
241 }
242 DataType::Utf8View => {
243 let string_array = as_string_view_array(array)?;
244
245 data_size += string_array
246 .data_buffers()
247 .iter()
248 .map(|buf| buf.len())
249 .sum::<usize>();
250 let column = if array.is_nullable() {
251 ColumnarValueRef::NullableStringViewArray(string_array)
252 } else {
253 ColumnarValueRef::NonNullableStringViewArray(string_array)
254 };
255 columns.push(column);
256 }
257 other => {
258 return plan_err!(
259 "Input was {other} which is not a supported datatype for concat_ws function."
260 );
261 }
262 };
263 }
264 _ => unreachable!(),
265 }
266 }
267
268 let mut builder = StringArrayBuilder::with_capacity(len, data_size);
269 for i in 0..len {
270 if !sep.is_valid(i) {
271 builder.append_offset();
272 continue;
273 }
274
275 let mut iter = columns.iter();
276 for column in iter.by_ref() {
277 if column.is_valid(i) {
278 builder.write::<false>(column, i);
279 break;
280 }
281 }
282
283 for column in iter {
284 if column.is_valid(i) {
285 builder.write::<false>(&sep, i);
286 builder.write::<false>(column, i);
287 }
288 }
289
290 builder.append_offset();
291 }
292
293 Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
294 }
295
296 fn simplify(
302 &self,
303 args: Vec<Expr>,
304 _info: &dyn SimplifyInfo,
305 ) -> Result<ExprSimplifyResult> {
306 match &args[..] {
307 [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals),
308 _ => Ok(ExprSimplifyResult::Original(args)),
309 }
310 }
311
312 fn documentation(&self) -> Option<&Documentation> {
313 self.doc()
314 }
315}
316
317fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<ExprSimplifyResult> {
318 match delimiter {
319 Expr::Literal(
320 ScalarValue::Utf8(delimiter)
321 | ScalarValue::LargeUtf8(delimiter)
322 | ScalarValue::Utf8View(delimiter),
323 _,
324 ) => {
325 match delimiter {
326 Some(delimiter) if delimiter.is_empty() => {
329 match simplify_concat(args.to_vec())? {
330 ExprSimplifyResult::Original(_) => {
331 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
332 ScalarFunction {
333 func: concat(),
334 args: args.to_vec(),
335 },
336 )))
337 }
338 expr => Ok(expr),
339 }
340 }
341 Some(delimiter) => {
342 let mut new_args = Vec::with_capacity(args.len());
343 new_args.push(lit(delimiter));
344 let mut contiguous_scalar = None;
345 for arg in args {
346 match arg {
347 Expr::Literal(
349 ScalarValue::Utf8(None)
350 | ScalarValue::LargeUtf8(None)
351 | ScalarValue::Utf8View(None),
352 _,
353 ) => {}
354 Expr::Literal(
355 ScalarValue::Utf8(Some(v))
356 | ScalarValue::LargeUtf8(Some(v))
357 | ScalarValue::Utf8View(Some(v)),
358 _,
359 ) => match contiguous_scalar {
360 None => contiguous_scalar = Some(v.to_string()),
361 Some(mut pre) => {
362 pre += delimiter;
363 pre += v;
364 contiguous_scalar = Some(pre)
365 }
366 },
367 Expr::Literal(s, _) => {
368 return internal_err!(
369 "The scalar {s} should be casted to string type during the type coercion."
370 );
371 }
372 arg => {
376 if let Some(val) = contiguous_scalar {
377 new_args.push(lit(val));
378 }
379 new_args.push(arg.clone());
380 contiguous_scalar = None;
381 }
382 }
383 }
384 if let Some(val) = contiguous_scalar {
385 new_args.push(lit(val));
386 }
387
388 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
389 ScalarFunction {
390 func: concat_ws(),
391 args: new_args,
392 },
393 )))
394 }
395 None => Ok(ExprSimplifyResult::Simplified(Expr::Literal(
397 ScalarValue::Utf8(None),
398 None,
399 ))),
400 }
401 }
402 Expr::Literal(d, _) => internal_err!(
403 "The scalar {d} should be casted to string type during the type coercion."
404 ),
405 _ => {
406 let mut args = args
407 .iter()
408 .filter(|&x| !is_null(x))
409 .cloned()
410 .collect::<Vec<Expr>>();
411 args.insert(0, delimiter.clone());
412 Ok(ExprSimplifyResult::Original(args))
413 }
414 }
415}
416
417fn is_null(expr: &Expr) -> bool {
418 match expr {
419 Expr::Literal(v, _) => v.is_null(),
420 _ => false,
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use std::sync::Arc;
427
428 use crate::string::concat_ws::ConcatWsFunc;
429 use arrow::array::{Array, ArrayRef, StringArray};
430 use arrow::datatypes::DataType::Utf8;
431 use arrow::datatypes::Field;
432 use datafusion_common::Result;
433 use datafusion_common::ScalarValue;
434 use datafusion_common::config::ConfigOptions;
435 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
436
437 use crate::utils::test::test_function;
438
439 #[test]
440 fn test_functions() -> Result<()> {
441 test_function!(
442 ConcatWsFunc::new(),
443 vec![
444 ColumnarValue::Scalar(ScalarValue::from("|")),
445 ColumnarValue::Scalar(ScalarValue::from("aa")),
446 ColumnarValue::Scalar(ScalarValue::from("bb")),
447 ColumnarValue::Scalar(ScalarValue::from("cc")),
448 ],
449 Ok(Some("aa|bb|cc")),
450 &str,
451 Utf8,
452 StringArray
453 );
454 test_function!(
455 ConcatWsFunc::new(),
456 vec![
457 ColumnarValue::Scalar(ScalarValue::from("|")),
458 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
459 ],
460 Ok(Some("")),
461 &str,
462 Utf8,
463 StringArray
464 );
465 test_function!(
466 ConcatWsFunc::new(),
467 vec![
468 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
469 ColumnarValue::Scalar(ScalarValue::from("aa")),
470 ColumnarValue::Scalar(ScalarValue::from("bb")),
471 ColumnarValue::Scalar(ScalarValue::from("cc")),
472 ],
473 Ok(None),
474 &str,
475 Utf8,
476 StringArray
477 );
478 test_function!(
479 ConcatWsFunc::new(),
480 vec![
481 ColumnarValue::Scalar(ScalarValue::from("|")),
482 ColumnarValue::Scalar(ScalarValue::from("aa")),
483 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
484 ColumnarValue::Scalar(ScalarValue::from("cc")),
485 ],
486 Ok(Some("aa|cc")),
487 &str,
488 Utf8,
489 StringArray
490 );
491
492 Ok(())
493 }
494
495 #[test]
496 fn concat_ws() -> Result<()> {
497 let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
499 let c1 =
500 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
501 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
502 Some("x"),
503 None,
504 Some("z"),
505 ])));
506
507 let arg_fields = vec![
508 Field::new("a", Utf8, true).into(),
509 Field::new("a", Utf8, true).into(),
510 Field::new("a", Utf8, true).into(),
511 ];
512 let args = ScalarFunctionArgs {
513 args: vec![c0, c1, c2],
514 arg_fields,
515 number_rows: 3,
516 return_field: Field::new("f", Utf8, true).into(),
517 config_options: Arc::new(ConfigOptions::default()),
518 };
519
520 let result = ConcatWsFunc::new().invoke_with_args(args)?;
521 let expected =
522 Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
523 match &result {
524 ColumnarValue::Array(array) => {
525 assert_eq!(&expected, array);
526 }
527 _ => panic!(),
528 }
529
530 let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
532 Some(","),
533 None,
534 Some("+"),
535 ])));
536 let c1 =
537 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
538 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
539 Some("x"),
540 Some("y"),
541 Some("z"),
542 ])));
543
544 let arg_fields = vec![
545 Field::new("a", Utf8, true).into(),
546 Field::new("a", Utf8, true).into(),
547 Field::new("a", Utf8, true).into(),
548 ];
549 let args = ScalarFunctionArgs {
550 args: vec![c0, c1, c2],
551 arg_fields,
552 number_rows: 3,
553 return_field: Field::new("f", Utf8, true).into(),
554 config_options: Arc::new(ConfigOptions::default()),
555 };
556
557 let result = ConcatWsFunc::new().invoke_with_args(args)?;
558 let expected =
559 Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")]))
560 as ArrayRef;
561 match &result {
562 ColumnarValue::Array(array) => {
563 assert_eq!(&expected, array);
564 }
565 _ => panic!(),
566 }
567
568 Ok(())
569 }
570}