Skip to main content

datafusion_functions/string/
concat_ws.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::Array;
19use std::any::Any;
20use std::sync::Arc;
21
22use arrow::datatypes::DataType;
23
24use crate::string::concat;
25use crate::string::concat::simplify_concat;
26use crate::string::concat_ws;
27use crate::strings::{ColumnarValueRef, StringArrayBuilder};
28use datafusion_common::cast::{
29    as_large_string_array, as_string_array, as_string_view_array,
30};
31use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
32use datafusion_expr::expr::ScalarFunction;
33use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
34use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
35use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Concatenates multiple strings together with a specified separator.",
41    syntax_example = "concat_ws(separator, str[, ..., str_n])",
42    sql_example = r#"```sql
43> select concat_ws('_', 'data', 'fusion');
44+--------------------------------------------------+
45| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) |
46+--------------------------------------------------+
47| data_fusion                                      |
48+--------------------------------------------------+
49```"#,
50    argument(
51        name = "separator",
52        description = "Separator to insert between concatenated strings."
53    ),
54    argument(
55        name = "str",
56        description = "String expression to operate on. Can be a constant, column, or function, and any combination of operators."
57    ),
58    argument(
59        name = "str_n",
60        description = "Subsequent string expressions to concatenate."
61    ),
62    related_udf(name = "concat")
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct ConcatWsFunc {
66    signature: Signature,
67}
68
69impl Default for ConcatWsFunc {
70    fn default() -> Self {
71        ConcatWsFunc::new()
72    }
73}
74
75impl ConcatWsFunc {
76    pub fn new() -> Self {
77        use DataType::*;
78        Self {
79            signature: Signature::variadic(
80                vec![Utf8View, Utf8, LargeUtf8],
81                Volatility::Immutable,
82            ),
83        }
84    }
85}
86
87impl ScalarUDFImpl for ConcatWsFunc {
88    fn as_any(&self) -> &dyn Any {
89        self
90    }
91
92    fn name(&self) -> &str {
93        "concat_ws"
94    }
95
96    fn signature(&self) -> &Signature {
97        &self.signature
98    }
99
100    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
101        use DataType::*;
102        Ok(Utf8)
103    }
104
105    /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored.
106    /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22'
107    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
108        let ScalarFunctionArgs { args, .. } = args;
109
110        if args.len() < 2 {
111            return exec_err!(
112                "concat_ws was called with {} arguments. It requires at least 2.",
113                args.len()
114            );
115        }
116
117        let array_len = args.iter().find_map(|x| match x {
118            ColumnarValue::Array(array) => Some(array.len()),
119            _ => None,
120        });
121
122        // Scalar
123        if array_len.is_none() {
124            let ColumnarValue::Scalar(scalar) = &args[0] else {
125                unreachable!()
126            };
127            let sep = match scalar.try_as_str() {
128                Some(Some(s)) => s,
129                Some(None) => {
130                    // null literal string
131                    return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
132                }
133                None => return internal_err!("Expected string literal, got {scalar:?}"),
134            };
135
136            let mut values = Vec::with_capacity(args.len() - 1);
137            for arg in &args[1..] {
138                let ColumnarValue::Scalar(scalar) = arg else {
139                    unreachable!()
140                };
141
142                match scalar.try_as_str() {
143                    Some(Some(v)) => values.push(v),
144                    Some(None) => {} // null literal string
145                    None => {
146                        return internal_err!("Expected string literal, got {scalar:?}");
147                    }
148                }
149            }
150            let result = values.join(sep);
151
152            return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
153        }
154
155        // Array
156        let len = array_len.unwrap();
157        let mut data_size = 0;
158
159        // parse sep
160        let sep = match &args[0] {
161            ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
162                Some(Some(s)) => {
163                    data_size += s.len() * len * (args.len() - 2); // estimate
164                    ColumnarValueRef::Scalar(s.as_bytes())
165                }
166                Some(None) => {
167                    return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
168                }
169                None => {
170                    return internal_err!("Expected string separator, got {scalar:?}");
171                }
172            },
173            ColumnarValue::Array(array) => match array.data_type() {
174                DataType::Utf8 => {
175                    let string_array = as_string_array(array)?;
176                    data_size += string_array.values().len() * (args.len() - 2);
177                    if array.is_nullable() {
178                        ColumnarValueRef::NullableArray(string_array)
179                    } else {
180                        ColumnarValueRef::NonNullableArray(string_array)
181                    }
182                }
183                DataType::LargeUtf8 => {
184                    let string_array = as_large_string_array(array)?;
185                    data_size += string_array.values().len() * (args.len() - 2);
186                    if array.is_nullable() {
187                        ColumnarValueRef::NullableLargeStringArray(string_array)
188                    } else {
189                        ColumnarValueRef::NonNullableLargeStringArray(string_array)
190                    }
191                }
192                DataType::Utf8View => {
193                    let string_array = as_string_view_array(array)?;
194                    data_size +=
195                        string_array.total_buffer_bytes_used() * (args.len() - 2);
196                    if array.is_nullable() {
197                        ColumnarValueRef::NullableStringViewArray(string_array)
198                    } else {
199                        ColumnarValueRef::NonNullableStringViewArray(string_array)
200                    }
201                }
202                other => {
203                    return plan_err!(
204                        "Input was {other} which is not a supported datatype for concat_ws separator"
205                    );
206                }
207            },
208        };
209
210        let mut columns = Vec::with_capacity(args.len() - 1);
211        for arg in &args[1..] {
212            match arg {
213                ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
214                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
215                | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
216                    if let Some(s) = maybe_value {
217                        data_size += s.len() * len;
218                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
219                    }
220                }
221                ColumnarValue::Array(array) => {
222                    match array.data_type() {
223                        DataType::Utf8 => {
224                            let string_array = as_string_array(array)?;
225
226                            data_size += string_array.values().len();
227                            let column = if array.is_nullable() {
228                                ColumnarValueRef::NullableArray(string_array)
229                            } else {
230                                ColumnarValueRef::NonNullableArray(string_array)
231                            };
232                            columns.push(column);
233                        }
234                        DataType::LargeUtf8 => {
235                            let string_array = as_large_string_array(array)?;
236
237                            data_size += string_array.values().len();
238                            let column = if array.is_nullable() {
239                                ColumnarValueRef::NullableLargeStringArray(string_array)
240                            } else {
241                                ColumnarValueRef::NonNullableLargeStringArray(
242                                    string_array,
243                                )
244                            };
245                            columns.push(column);
246                        }
247                        DataType::Utf8View => {
248                            let string_array = as_string_view_array(array)?;
249
250                            // This is an estimate; in particular, it will
251                            // undercount arrays of short strings (<= 12 bytes).
252                            data_size += string_array.total_buffer_bytes_used();
253                            let column = if array.is_nullable() {
254                                ColumnarValueRef::NullableStringViewArray(string_array)
255                            } else {
256                                ColumnarValueRef::NonNullableStringViewArray(string_array)
257                            };
258                            columns.push(column);
259                        }
260                        other => {
261                            return plan_err!(
262                                "Input was {other} which is not a supported datatype for concat_ws function."
263                            );
264                        }
265                    };
266                }
267                _ => unreachable!(),
268            }
269        }
270
271        let mut builder = StringArrayBuilder::with_capacity(len, data_size);
272        for i in 0..len {
273            if !sep.is_valid(i) {
274                builder.append_offset();
275                continue;
276            }
277
278            let mut first = true;
279            for column in &columns {
280                if column.is_valid(i) {
281                    if !first {
282                        builder.write::<false>(&sep, i);
283                    }
284                    builder.write::<false>(column, i);
285                    first = false;
286                }
287            }
288
289            builder.append_offset();
290        }
291
292        Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
293    }
294
295    /// Simply the `concat_ws` function by
296    /// 1. folding to `null` if the delimiter is null
297    /// 2. filtering out `null` arguments
298    /// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string
299    /// 4. concatenating contiguous literals if the delimiter is a literal.
300    fn simplify(
301        &self,
302        args: Vec<Expr>,
303        _info: &SimplifyContext,
304    ) -> Result<ExprSimplifyResult> {
305        match &args[..] {
306            [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals),
307            _ => Ok(ExprSimplifyResult::Original(args)),
308        }
309    }
310
311    fn documentation(&self) -> Option<&Documentation> {
312        self.doc()
313    }
314}
315
316fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<ExprSimplifyResult> {
317    match delimiter {
318        Expr::Literal(
319            ScalarValue::Utf8(delimiter)
320            | ScalarValue::LargeUtf8(delimiter)
321            | ScalarValue::Utf8View(delimiter),
322            _,
323        ) => {
324            match delimiter {
325                // when the delimiter is an empty string,
326                // we can use `concat` to replace `concat_ws`
327                Some(delimiter) if delimiter.is_empty() => {
328                    match simplify_concat(args.to_vec())? {
329                        ExprSimplifyResult::Original(_) => {
330                            Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
331                                ScalarFunction {
332                                    func: concat(),
333                                    args: args.to_vec(),
334                                },
335                            )))
336                        }
337                        expr => Ok(expr),
338                    }
339                }
340                Some(delimiter) => {
341                    let mut new_args = Vec::with_capacity(args.len());
342                    new_args.push(lit(delimiter));
343                    let mut contiguous_scalar = None;
344                    for arg in args {
345                        match arg {
346                            // filter out null args
347                            Expr::Literal(
348                                ScalarValue::Utf8(None)
349                                | ScalarValue::LargeUtf8(None)
350                                | ScalarValue::Utf8View(None),
351                                _,
352                            ) => {}
353                            Expr::Literal(
354                                ScalarValue::Utf8(Some(v))
355                                | ScalarValue::LargeUtf8(Some(v))
356                                | ScalarValue::Utf8View(Some(v)),
357                                _,
358                            ) => match contiguous_scalar {
359                                None => contiguous_scalar = Some(v.to_string()),
360                                Some(mut pre) => {
361                                    pre += delimiter;
362                                    pre += v;
363                                    contiguous_scalar = Some(pre)
364                                }
365                            },
366                            Expr::Literal(s, _) => {
367                                return internal_err!(
368                                    "The scalar {s} should be casted to string type during the type coercion."
369                                );
370                            }
371                            // If the arg is not a literal, we should first push the current `contiguous_scalar`
372                            // to the `new_args` and reset it to None.
373                            // Then pushing this arg to the `new_args`.
374                            arg => {
375                                if let Some(val) = contiguous_scalar {
376                                    new_args.push(lit(val));
377                                }
378                                new_args.push(arg.clone());
379                                contiguous_scalar = None;
380                            }
381                        }
382                    }
383                    if let Some(val) = contiguous_scalar {
384                        new_args.push(lit(val));
385                    }
386
387                    Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
388                        ScalarFunction {
389                            func: concat_ws(),
390                            args: new_args,
391                        },
392                    )))
393                }
394                // if the delimiter is null, then the value of the whole expression is null.
395                None => Ok(ExprSimplifyResult::Simplified(Expr::Literal(
396                    ScalarValue::Utf8(None),
397                    None,
398                ))),
399            }
400        }
401        Expr::Literal(d, _) => internal_err!(
402            "The scalar {d} should be casted to string type during the type coercion."
403        ),
404        _ => {
405            let mut args = args
406                .iter()
407                .filter(|&x| !is_null(x))
408                .cloned()
409                .collect::<Vec<Expr>>();
410            args.insert(0, delimiter.clone());
411            Ok(ExprSimplifyResult::Original(args))
412        }
413    }
414}
415
416fn is_null(expr: &Expr) -> bool {
417    match expr {
418        Expr::Literal(v, _) => v.is_null(),
419        _ => false,
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use std::sync::Arc;
426
427    use crate::string::concat_ws::ConcatWsFunc;
428    use arrow::array::{Array, ArrayRef, StringArray};
429    use arrow::datatypes::DataType::Utf8;
430    use arrow::datatypes::Field;
431    use datafusion_common::Result;
432    use datafusion_common::ScalarValue;
433    use datafusion_common::config::ConfigOptions;
434    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
435
436    use crate::utils::test::test_function;
437
438    #[test]
439    fn test_functions() -> Result<()> {
440        test_function!(
441            ConcatWsFunc::new(),
442            vec![
443                ColumnarValue::Scalar(ScalarValue::from("|")),
444                ColumnarValue::Scalar(ScalarValue::from("aa")),
445                ColumnarValue::Scalar(ScalarValue::from("bb")),
446                ColumnarValue::Scalar(ScalarValue::from("cc")),
447            ],
448            Ok(Some("aa|bb|cc")),
449            &str,
450            Utf8,
451            StringArray
452        );
453        test_function!(
454            ConcatWsFunc::new(),
455            vec![
456                ColumnarValue::Scalar(ScalarValue::from("|")),
457                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
458            ],
459            Ok(Some("")),
460            &str,
461            Utf8,
462            StringArray
463        );
464        test_function!(
465            ConcatWsFunc::new(),
466            vec![
467                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
468                ColumnarValue::Scalar(ScalarValue::from("aa")),
469                ColumnarValue::Scalar(ScalarValue::from("bb")),
470                ColumnarValue::Scalar(ScalarValue::from("cc")),
471            ],
472            Ok(None),
473            &str,
474            Utf8,
475            StringArray
476        );
477        test_function!(
478            ConcatWsFunc::new(),
479            vec![
480                ColumnarValue::Scalar(ScalarValue::from("|")),
481                ColumnarValue::Scalar(ScalarValue::from("aa")),
482                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
483                ColumnarValue::Scalar(ScalarValue::from("cc")),
484            ],
485            Ok(Some("aa|cc")),
486            &str,
487            Utf8,
488            StringArray
489        );
490
491        Ok(())
492    }
493
494    #[test]
495    fn concat_ws() -> Result<()> {
496        // sep is scalar
497        let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
498        let c1 =
499            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
500        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
501            Some("x"),
502            None,
503            Some("z"),
504        ])));
505
506        let arg_fields = vec![
507            Field::new("a", Utf8, true).into(),
508            Field::new("a", Utf8, true).into(),
509            Field::new("a", Utf8, true).into(),
510        ];
511        let args = ScalarFunctionArgs {
512            args: vec![c0, c1, c2],
513            arg_fields,
514            number_rows: 3,
515            return_field: Field::new("f", Utf8, true).into(),
516            config_options: Arc::new(ConfigOptions::default()),
517        };
518
519        let result = ConcatWsFunc::new().invoke_with_args(args)?;
520        let expected =
521            Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
522        match &result {
523            ColumnarValue::Array(array) => {
524                assert_eq!(&expected, array);
525            }
526            _ => panic!(),
527        }
528
529        // sep is nullable array
530        let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
531            Some(","),
532            None,
533            Some("+"),
534        ])));
535        let c1 =
536            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
537        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
538            Some("x"),
539            Some("y"),
540            Some("z"),
541        ])));
542
543        let arg_fields = vec![
544            Field::new("a", Utf8, true).into(),
545            Field::new("a", Utf8, true).into(),
546            Field::new("a", Utf8, true).into(),
547        ];
548        let args = ScalarFunctionArgs {
549            args: vec![c0, c1, c2],
550            arg_fields,
551            number_rows: 3,
552            return_field: Field::new("f", Utf8, true).into(),
553            config_options: Arc::new(ConfigOptions::default()),
554        };
555
556        let result = ConcatWsFunc::new().invoke_with_args(args)?;
557        let expected =
558            Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")]))
559                as ArrayRef;
560        match &result {
561            ColumnarValue::Array(array) => {
562                assert_eq!(&expected, array);
563            }
564            _ => panic!(),
565        }
566
567        Ok(())
568    }
569
570    #[test]
571    fn concat_ws_utf8view_scalar_separator() -> Result<()> {
572        let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
573        let c1 =
574            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
575        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
576            Some("x"),
577            None,
578            Some("z"),
579        ])));
580
581        let arg_fields = vec![
582            Field::new("a", Utf8, true).into(),
583            Field::new("a", Utf8, true).into(),
584            Field::new("a", Utf8, true).into(),
585        ];
586        let args = ScalarFunctionArgs {
587            args: vec![c0, c1, c2],
588            arg_fields,
589            number_rows: 3,
590            return_field: Field::new("f", Utf8, true).into(),
591            config_options: Arc::new(ConfigOptions::default()),
592        };
593
594        let result = ConcatWsFunc::new().invoke_with_args(args)?;
595        let expected =
596            Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
597        match &result {
598            ColumnarValue::Array(array) => {
599                assert_eq!(&expected, array);
600            }
601            _ => panic!("Expected array result"),
602        }
603
604        Ok(())
605    }
606
607    #[test]
608    fn concat_ws_largeutf8_scalar_separator() -> Result<()> {
609        let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string())));
610        let c1 =
611            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
612        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
613            Some("x"),
614            None,
615            Some("z"),
616        ])));
617
618        let arg_fields = vec![
619            Field::new("a", Utf8, true).into(),
620            Field::new("a", Utf8, true).into(),
621            Field::new("a", Utf8, true).into(),
622        ];
623        let args = ScalarFunctionArgs {
624            args: vec![c0, c1, c2],
625            arg_fields,
626            number_rows: 3,
627            return_field: Field::new("f", Utf8, true).into(),
628            config_options: Arc::new(ConfigOptions::default()),
629        };
630
631        let result = ConcatWsFunc::new().invoke_with_args(args)?;
632        let expected =
633            Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
634        match &result {
635            ColumnarValue::Array(array) => {
636                assert_eq!(&expected, array);
637            }
638            _ => panic!("Expected array result"),
639        }
640
641        Ok(())
642    }
643}