1use arrow::array::Array;
19use std::any::Any;
20use std::sync::Arc;
21
22use arrow::datatypes::DataType;
23
24use crate::string::concat;
25use crate::string::concat::simplify_concat;
26use crate::string::concat_ws;
27use crate::strings::{ColumnarValueRef, StringArrayBuilder};
28use datafusion_common::cast::{
29 as_large_string_array, as_string_array, as_string_view_array,
30};
31use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
32use datafusion_expr::expr::ScalarFunction;
33use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
34use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
35use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Concatenates multiple strings together with a specified separator.",
41 syntax_example = "concat_ws(separator, str[, ..., str_n])",
42 sql_example = r#"```sql
43> select concat_ws('_', 'data', 'fusion');
44+--------------------------------------------------+
45| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) |
46+--------------------------------------------------+
47| data_fusion |
48+--------------------------------------------------+
49```"#,
50 argument(
51 name = "separator",
52 description = "Separator to insert between concatenated strings."
53 ),
54 argument(
55 name = "str",
56 description = "String expression to operate on. Can be a constant, column, or function, and any combination of operators."
57 ),
58 argument(
59 name = "str_n",
60 description = "Subsequent string expressions to concatenate."
61 ),
62 related_udf(name = "concat")
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct ConcatWsFunc {
66 signature: Signature,
67}
68
69impl Default for ConcatWsFunc {
70 fn default() -> Self {
71 ConcatWsFunc::new()
72 }
73}
74
75impl ConcatWsFunc {
76 pub fn new() -> Self {
77 use DataType::*;
78 Self {
79 signature: Signature::variadic(
80 vec![Utf8View, Utf8, LargeUtf8],
81 Volatility::Immutable,
82 ),
83 }
84 }
85}
86
87impl ScalarUDFImpl for ConcatWsFunc {
88 fn as_any(&self) -> &dyn Any {
89 self
90 }
91
92 fn name(&self) -> &str {
93 "concat_ws"
94 }
95
96 fn signature(&self) -> &Signature {
97 &self.signature
98 }
99
100 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
101 use DataType::*;
102 Ok(Utf8)
103 }
104
105 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
108 let ScalarFunctionArgs { args, .. } = args;
109
110 if args.len() < 2 {
111 return exec_err!(
112 "concat_ws was called with {} arguments. It requires at least 2.",
113 args.len()
114 );
115 }
116
117 let array_len = args.iter().find_map(|x| match x {
118 ColumnarValue::Array(array) => Some(array.len()),
119 _ => None,
120 });
121
122 if array_len.is_none() {
124 let ColumnarValue::Scalar(scalar) = &args[0] else {
125 unreachable!()
126 };
127 let sep = match scalar.try_as_str() {
128 Some(Some(s)) => s,
129 Some(None) => {
130 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
132 }
133 None => return internal_err!("Expected string literal, got {scalar:?}"),
134 };
135
136 let mut values = Vec::with_capacity(args.len() - 1);
137 for arg in &args[1..] {
138 let ColumnarValue::Scalar(scalar) = arg else {
139 unreachable!()
140 };
141
142 match scalar.try_as_str() {
143 Some(Some(v)) => values.push(v),
144 Some(None) => {} None => {
146 return internal_err!("Expected string literal, got {scalar:?}");
147 }
148 }
149 }
150 let result = values.join(sep);
151
152 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
153 }
154
155 let len = array_len.unwrap();
157 let mut data_size = 0;
158
159 let sep = match &args[0] {
161 ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
162 Some(Some(s)) => {
163 data_size += s.len() * len * (args.len() - 2); ColumnarValueRef::Scalar(s.as_bytes())
165 }
166 Some(None) => {
167 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
168 }
169 None => {
170 return internal_err!("Expected string separator, got {scalar:?}");
171 }
172 },
173 ColumnarValue::Array(array) => match array.data_type() {
174 DataType::Utf8 => {
175 let string_array = as_string_array(array)?;
176 data_size += string_array.values().len() * (args.len() - 2);
177 if array.is_nullable() {
178 ColumnarValueRef::NullableArray(string_array)
179 } else {
180 ColumnarValueRef::NonNullableArray(string_array)
181 }
182 }
183 DataType::LargeUtf8 => {
184 let string_array = as_large_string_array(array)?;
185 data_size += string_array.values().len() * (args.len() - 2);
186 if array.is_nullable() {
187 ColumnarValueRef::NullableLargeStringArray(string_array)
188 } else {
189 ColumnarValueRef::NonNullableLargeStringArray(string_array)
190 }
191 }
192 DataType::Utf8View => {
193 let string_array = as_string_view_array(array)?;
194 data_size +=
195 string_array.total_buffer_bytes_used() * (args.len() - 2);
196 if array.is_nullable() {
197 ColumnarValueRef::NullableStringViewArray(string_array)
198 } else {
199 ColumnarValueRef::NonNullableStringViewArray(string_array)
200 }
201 }
202 other => {
203 return plan_err!(
204 "Input was {other} which is not a supported datatype for concat_ws separator"
205 );
206 }
207 },
208 };
209
210 let mut columns = Vec::with_capacity(args.len() - 1);
211 for arg in &args[1..] {
212 match arg {
213 ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
214 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
215 | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
216 if let Some(s) = maybe_value {
217 data_size += s.len() * len;
218 columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
219 }
220 }
221 ColumnarValue::Array(array) => {
222 match array.data_type() {
223 DataType::Utf8 => {
224 let string_array = as_string_array(array)?;
225
226 data_size += string_array.values().len();
227 let column = if array.is_nullable() {
228 ColumnarValueRef::NullableArray(string_array)
229 } else {
230 ColumnarValueRef::NonNullableArray(string_array)
231 };
232 columns.push(column);
233 }
234 DataType::LargeUtf8 => {
235 let string_array = as_large_string_array(array)?;
236
237 data_size += string_array.values().len();
238 let column = if array.is_nullable() {
239 ColumnarValueRef::NullableLargeStringArray(string_array)
240 } else {
241 ColumnarValueRef::NonNullableLargeStringArray(
242 string_array,
243 )
244 };
245 columns.push(column);
246 }
247 DataType::Utf8View => {
248 let string_array = as_string_view_array(array)?;
249
250 data_size += string_array.total_buffer_bytes_used();
253 let column = if array.is_nullable() {
254 ColumnarValueRef::NullableStringViewArray(string_array)
255 } else {
256 ColumnarValueRef::NonNullableStringViewArray(string_array)
257 };
258 columns.push(column);
259 }
260 other => {
261 return plan_err!(
262 "Input was {other} which is not a supported datatype for concat_ws function."
263 );
264 }
265 };
266 }
267 _ => unreachable!(),
268 }
269 }
270
271 let mut builder = StringArrayBuilder::with_capacity(len, data_size);
272 for i in 0..len {
273 if !sep.is_valid(i) {
274 builder.append_offset();
275 continue;
276 }
277
278 let mut first = true;
279 for column in &columns {
280 if column.is_valid(i) {
281 if !first {
282 builder.write::<false>(&sep, i);
283 }
284 builder.write::<false>(column, i);
285 first = false;
286 }
287 }
288
289 builder.append_offset();
290 }
291
292 Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
293 }
294
295 fn simplify(
301 &self,
302 args: Vec<Expr>,
303 _info: &SimplifyContext,
304 ) -> Result<ExprSimplifyResult> {
305 match &args[..] {
306 [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals),
307 _ => Ok(ExprSimplifyResult::Original(args)),
308 }
309 }
310
311 fn documentation(&self) -> Option<&Documentation> {
312 self.doc()
313 }
314}
315
316fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<ExprSimplifyResult> {
317 match delimiter {
318 Expr::Literal(
319 ScalarValue::Utf8(delimiter)
320 | ScalarValue::LargeUtf8(delimiter)
321 | ScalarValue::Utf8View(delimiter),
322 _,
323 ) => {
324 match delimiter {
325 Some(delimiter) if delimiter.is_empty() => {
328 match simplify_concat(args.to_vec())? {
329 ExprSimplifyResult::Original(_) => {
330 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
331 ScalarFunction {
332 func: concat(),
333 args: args.to_vec(),
334 },
335 )))
336 }
337 expr => Ok(expr),
338 }
339 }
340 Some(delimiter) => {
341 let mut new_args = Vec::with_capacity(args.len());
342 new_args.push(lit(delimiter));
343 let mut contiguous_scalar = None;
344 for arg in args {
345 match arg {
346 Expr::Literal(
348 ScalarValue::Utf8(None)
349 | ScalarValue::LargeUtf8(None)
350 | ScalarValue::Utf8View(None),
351 _,
352 ) => {}
353 Expr::Literal(
354 ScalarValue::Utf8(Some(v))
355 | ScalarValue::LargeUtf8(Some(v))
356 | ScalarValue::Utf8View(Some(v)),
357 _,
358 ) => match contiguous_scalar {
359 None => contiguous_scalar = Some(v.to_string()),
360 Some(mut pre) => {
361 pre += delimiter;
362 pre += v;
363 contiguous_scalar = Some(pre)
364 }
365 },
366 Expr::Literal(s, _) => {
367 return internal_err!(
368 "The scalar {s} should be casted to string type during the type coercion."
369 );
370 }
371 arg => {
375 if let Some(val) = contiguous_scalar {
376 new_args.push(lit(val));
377 }
378 new_args.push(arg.clone());
379 contiguous_scalar = None;
380 }
381 }
382 }
383 if let Some(val) = contiguous_scalar {
384 new_args.push(lit(val));
385 }
386
387 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
388 ScalarFunction {
389 func: concat_ws(),
390 args: new_args,
391 },
392 )))
393 }
394 None => Ok(ExprSimplifyResult::Simplified(Expr::Literal(
396 ScalarValue::Utf8(None),
397 None,
398 ))),
399 }
400 }
401 Expr::Literal(d, _) => internal_err!(
402 "The scalar {d} should be casted to string type during the type coercion."
403 ),
404 _ => {
405 let mut args = args
406 .iter()
407 .filter(|&x| !is_null(x))
408 .cloned()
409 .collect::<Vec<Expr>>();
410 args.insert(0, delimiter.clone());
411 Ok(ExprSimplifyResult::Original(args))
412 }
413 }
414}
415
416fn is_null(expr: &Expr) -> bool {
417 match expr {
418 Expr::Literal(v, _) => v.is_null(),
419 _ => false,
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use std::sync::Arc;
426
427 use crate::string::concat_ws::ConcatWsFunc;
428 use arrow::array::{Array, ArrayRef, StringArray};
429 use arrow::datatypes::DataType::Utf8;
430 use arrow::datatypes::Field;
431 use datafusion_common::Result;
432 use datafusion_common::ScalarValue;
433 use datafusion_common::config::ConfigOptions;
434 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
435
436 use crate::utils::test::test_function;
437
438 #[test]
439 fn test_functions() -> Result<()> {
440 test_function!(
441 ConcatWsFunc::new(),
442 vec![
443 ColumnarValue::Scalar(ScalarValue::from("|")),
444 ColumnarValue::Scalar(ScalarValue::from("aa")),
445 ColumnarValue::Scalar(ScalarValue::from("bb")),
446 ColumnarValue::Scalar(ScalarValue::from("cc")),
447 ],
448 Ok(Some("aa|bb|cc")),
449 &str,
450 Utf8,
451 StringArray
452 );
453 test_function!(
454 ConcatWsFunc::new(),
455 vec![
456 ColumnarValue::Scalar(ScalarValue::from("|")),
457 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
458 ],
459 Ok(Some("")),
460 &str,
461 Utf8,
462 StringArray
463 );
464 test_function!(
465 ConcatWsFunc::new(),
466 vec![
467 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
468 ColumnarValue::Scalar(ScalarValue::from("aa")),
469 ColumnarValue::Scalar(ScalarValue::from("bb")),
470 ColumnarValue::Scalar(ScalarValue::from("cc")),
471 ],
472 Ok(None),
473 &str,
474 Utf8,
475 StringArray
476 );
477 test_function!(
478 ConcatWsFunc::new(),
479 vec![
480 ColumnarValue::Scalar(ScalarValue::from("|")),
481 ColumnarValue::Scalar(ScalarValue::from("aa")),
482 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
483 ColumnarValue::Scalar(ScalarValue::from("cc")),
484 ],
485 Ok(Some("aa|cc")),
486 &str,
487 Utf8,
488 StringArray
489 );
490
491 Ok(())
492 }
493
494 #[test]
495 fn concat_ws() -> Result<()> {
496 let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
498 let c1 =
499 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
500 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
501 Some("x"),
502 None,
503 Some("z"),
504 ])));
505
506 let arg_fields = vec![
507 Field::new("a", Utf8, true).into(),
508 Field::new("a", Utf8, true).into(),
509 Field::new("a", Utf8, true).into(),
510 ];
511 let args = ScalarFunctionArgs {
512 args: vec![c0, c1, c2],
513 arg_fields,
514 number_rows: 3,
515 return_field: Field::new("f", Utf8, true).into(),
516 config_options: Arc::new(ConfigOptions::default()),
517 };
518
519 let result = ConcatWsFunc::new().invoke_with_args(args)?;
520 let expected =
521 Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
522 match &result {
523 ColumnarValue::Array(array) => {
524 assert_eq!(&expected, array);
525 }
526 _ => panic!(),
527 }
528
529 let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
531 Some(","),
532 None,
533 Some("+"),
534 ])));
535 let c1 =
536 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
537 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
538 Some("x"),
539 Some("y"),
540 Some("z"),
541 ])));
542
543 let arg_fields = vec![
544 Field::new("a", Utf8, true).into(),
545 Field::new("a", Utf8, true).into(),
546 Field::new("a", Utf8, true).into(),
547 ];
548 let args = ScalarFunctionArgs {
549 args: vec![c0, c1, c2],
550 arg_fields,
551 number_rows: 3,
552 return_field: Field::new("f", Utf8, true).into(),
553 config_options: Arc::new(ConfigOptions::default()),
554 };
555
556 let result = ConcatWsFunc::new().invoke_with_args(args)?;
557 let expected =
558 Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")]))
559 as ArrayRef;
560 match &result {
561 ColumnarValue::Array(array) => {
562 assert_eq!(&expected, array);
563 }
564 _ => panic!(),
565 }
566
567 Ok(())
568 }
569
570 #[test]
571 fn concat_ws_utf8view_scalar_separator() -> Result<()> {
572 let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
573 let c1 =
574 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
575 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
576 Some("x"),
577 None,
578 Some("z"),
579 ])));
580
581 let arg_fields = vec![
582 Field::new("a", Utf8, true).into(),
583 Field::new("a", Utf8, true).into(),
584 Field::new("a", Utf8, true).into(),
585 ];
586 let args = ScalarFunctionArgs {
587 args: vec![c0, c1, c2],
588 arg_fields,
589 number_rows: 3,
590 return_field: Field::new("f", Utf8, true).into(),
591 config_options: Arc::new(ConfigOptions::default()),
592 };
593
594 let result = ConcatWsFunc::new().invoke_with_args(args)?;
595 let expected =
596 Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
597 match &result {
598 ColumnarValue::Array(array) => {
599 assert_eq!(&expected, array);
600 }
601 _ => panic!("Expected array result"),
602 }
603
604 Ok(())
605 }
606
607 #[test]
608 fn concat_ws_largeutf8_scalar_separator() -> Result<()> {
609 let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string())));
610 let c1 =
611 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
612 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
613 Some("x"),
614 None,
615 Some("z"),
616 ])));
617
618 let arg_fields = vec![
619 Field::new("a", Utf8, true).into(),
620 Field::new("a", Utf8, true).into(),
621 Field::new("a", Utf8, true).into(),
622 ];
623 let args = ScalarFunctionArgs {
624 args: vec![c0, c1, c2],
625 arg_fields,
626 number_rows: 3,
627 return_field: Field::new("f", Utf8, true).into(),
628 config_options: Arc::new(ConfigOptions::default()),
629 };
630
631 let result = ConcatWsFunc::new().invoke_with_args(args)?;
632 let expected =
633 Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
634 match &result {
635 ColumnarValue::Array(array) => {
636 assert_eq!(&expected, array);
637 }
638 _ => panic!("Expected array result"),
639 }
640
641 Ok(())
642 }
643}