datafusion_functions/string/
concat_ws.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, StringArray, as_largestring_array};
19use std::any::Any;
20use std::sync::Arc;
21
22use arrow::datatypes::DataType;
23
24use crate::string::concat;
25use crate::string::concat::simplify_concat;
26use crate::string::concat_ws;
27use crate::strings::{ColumnarValueRef, StringArrayBuilder};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Concatenates multiple strings together with a specified separator.",
39    syntax_example = "concat_ws(separator, str[, ..., str_n])",
40    sql_example = r#"```sql
41> select concat_ws('_', 'data', 'fusion');
42+--------------------------------------------------+
43| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) |
44+--------------------------------------------------+
45| data_fusion                                      |
46+--------------------------------------------------+
47```"#,
48    argument(
49        name = "separator",
50        description = "Separator to insert between concatenated strings."
51    ),
52    argument(
53        name = "str",
54        description = "String expression to operate on. Can be a constant, column, or function, and any combination of operators."
55    ),
56    argument(
57        name = "str_n",
58        description = "Subsequent string expressions to concatenate."
59    ),
60    related_udf(name = "concat")
61)]
62#[derive(Debug, PartialEq, Eq, Hash)]
63pub struct ConcatWsFunc {
64    signature: Signature,
65}
66
67impl Default for ConcatWsFunc {
68    fn default() -> Self {
69        ConcatWsFunc::new()
70    }
71}
72
73impl ConcatWsFunc {
74    pub fn new() -> Self {
75        use DataType::*;
76        Self {
77            signature: Signature::variadic(
78                vec![Utf8View, Utf8, LargeUtf8],
79                Volatility::Immutable,
80            ),
81        }
82    }
83}
84
85impl ScalarUDFImpl for ConcatWsFunc {
86    fn as_any(&self) -> &dyn Any {
87        self
88    }
89
90    fn name(&self) -> &str {
91        "concat_ws"
92    }
93
94    fn signature(&self) -> &Signature {
95        &self.signature
96    }
97
98    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99        use DataType::*;
100        Ok(Utf8)
101    }
102
103    /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored.
104    /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22'
105    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106        let ScalarFunctionArgs { args, .. } = args;
107
108        // do not accept 0 arguments.
109        if args.len() < 2 {
110            return exec_err!(
111                "concat_ws was called with {} arguments. It requires at least 2.",
112                args.len()
113            );
114        }
115
116        let array_len = args
117            .iter()
118            .filter_map(|x| match x {
119                ColumnarValue::Array(array) => Some(array.len()),
120                _ => None,
121            })
122            .next();
123
124        // Scalar
125        if array_len.is_none() {
126            let ColumnarValue::Scalar(scalar) = &args[0] else {
127                // loop above checks for all args being scalar
128                unreachable!()
129            };
130            let sep = match scalar.try_as_str() {
131                Some(Some(s)) => s,
132                Some(None) => {
133                    // null literal string
134                    return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
135                }
136                None => return internal_err!("Expected string literal, got {scalar:?}"),
137            };
138
139            let mut result = String::new();
140            // iterator over Option<str>
141            let iter = &mut args[1..].iter().map(|arg| {
142                let ColumnarValue::Scalar(scalar) = arg else {
143                    // loop above checks for all args being scalar
144                    unreachable!()
145                };
146                scalar.try_as_str()
147            });
148
149            // append first non null arg
150            for scalar in iter.by_ref() {
151                match scalar {
152                    Some(Some(s)) => {
153                        result.push_str(s);
154                        break;
155                    }
156                    Some(None) => {} // null literal string
157                    None => {
158                        return internal_err!("Expected string literal, got {scalar:?}");
159                    }
160                }
161            }
162
163            // handle subsequent non null args
164            for scalar in iter.by_ref() {
165                match scalar {
166                    Some(Some(s)) => {
167                        result.push_str(sep);
168                        result.push_str(s);
169                    }
170                    Some(None) => {} // null literal string
171                    None => {
172                        return internal_err!("Expected string literal, got {scalar:?}");
173                    }
174                }
175            }
176
177            return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
178        }
179
180        // Array
181        let len = array_len.unwrap();
182        let mut data_size = 0;
183
184        // parse sep
185        let sep = match &args[0] {
186            ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
187                data_size += s.len() * len * (args.len() - 2); // estimate
188                ColumnarValueRef::Scalar(s.as_bytes())
189            }
190            ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
191                return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len))));
192            }
193            ColumnarValue::Array(array) => {
194                let string_array = as_string_array(array)?;
195                data_size += string_array.values().len() * (args.len() - 2); // estimate
196                if array.is_nullable() {
197                    ColumnarValueRef::NullableArray(string_array)
198                } else {
199                    ColumnarValueRef::NonNullableArray(string_array)
200                }
201            }
202            _ => unreachable!("concat ws"),
203        };
204
205        let mut columns = Vec::with_capacity(args.len() - 1);
206        for arg in &args[1..] {
207            match arg {
208                ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
209                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
210                | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
211                    if let Some(s) = maybe_value {
212                        data_size += s.len() * len;
213                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
214                    }
215                }
216                ColumnarValue::Array(array) => {
217                    match array.data_type() {
218                        DataType::Utf8 => {
219                            let string_array = as_string_array(array)?;
220
221                            data_size += string_array.values().len();
222                            let column = if array.is_nullable() {
223                                ColumnarValueRef::NullableArray(string_array)
224                            } else {
225                                ColumnarValueRef::NonNullableArray(string_array)
226                            };
227                            columns.push(column);
228                        }
229                        DataType::LargeUtf8 => {
230                            let string_array = as_largestring_array(array);
231
232                            data_size += string_array.values().len();
233                            let column = if array.is_nullable() {
234                                ColumnarValueRef::NullableLargeStringArray(string_array)
235                            } else {
236                                ColumnarValueRef::NonNullableLargeStringArray(
237                                    string_array,
238                                )
239                            };
240                            columns.push(column);
241                        }
242                        DataType::Utf8View => {
243                            let string_array = as_string_view_array(array)?;
244
245                            data_size += string_array
246                                .data_buffers()
247                                .iter()
248                                .map(|buf| buf.len())
249                                .sum::<usize>();
250                            let column = if array.is_nullable() {
251                                ColumnarValueRef::NullableStringViewArray(string_array)
252                            } else {
253                                ColumnarValueRef::NonNullableStringViewArray(string_array)
254                            };
255                            columns.push(column);
256                        }
257                        other => {
258                            return plan_err!(
259                                "Input was {other} which is not a supported datatype for concat_ws function."
260                            );
261                        }
262                    };
263                }
264                _ => unreachable!(),
265            }
266        }
267
268        let mut builder = StringArrayBuilder::with_capacity(len, data_size);
269        for i in 0..len {
270            if !sep.is_valid(i) {
271                builder.append_offset();
272                continue;
273            }
274
275            let mut iter = columns.iter();
276            for column in iter.by_ref() {
277                if column.is_valid(i) {
278                    builder.write::<false>(column, i);
279                    break;
280                }
281            }
282
283            for column in iter {
284                if column.is_valid(i) {
285                    builder.write::<false>(&sep, i);
286                    builder.write::<false>(column, i);
287                }
288            }
289
290            builder.append_offset();
291        }
292
293        Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
294    }
295
296    /// Simply the `concat_ws` function by
297    /// 1. folding to `null` if the delimiter is null
298    /// 2. filtering out `null` arguments
299    /// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string
300    /// 4. concatenating contiguous literals if the delimiter is a literal.
301    fn simplify(
302        &self,
303        args: Vec<Expr>,
304        _info: &dyn SimplifyInfo,
305    ) -> Result<ExprSimplifyResult> {
306        match &args[..] {
307            [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals),
308            _ => Ok(ExprSimplifyResult::Original(args)),
309        }
310    }
311
312    fn documentation(&self) -> Option<&Documentation> {
313        self.doc()
314    }
315}
316
317fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<ExprSimplifyResult> {
318    match delimiter {
319        Expr::Literal(
320            ScalarValue::Utf8(delimiter)
321            | ScalarValue::LargeUtf8(delimiter)
322            | ScalarValue::Utf8View(delimiter),
323            _,
324        ) => {
325            match delimiter {
326                // when the delimiter is an empty string,
327                // we can use `concat` to replace `concat_ws`
328                Some(delimiter) if delimiter.is_empty() => {
329                    match simplify_concat(args.to_vec())? {
330                        ExprSimplifyResult::Original(_) => {
331                            Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
332                                ScalarFunction {
333                                    func: concat(),
334                                    args: args.to_vec(),
335                                },
336                            )))
337                        }
338                        expr => Ok(expr),
339                    }
340                }
341                Some(delimiter) => {
342                    let mut new_args = Vec::with_capacity(args.len());
343                    new_args.push(lit(delimiter));
344                    let mut contiguous_scalar = None;
345                    for arg in args {
346                        match arg {
347                            // filter out null args
348                            Expr::Literal(
349                                ScalarValue::Utf8(None)
350                                | ScalarValue::LargeUtf8(None)
351                                | ScalarValue::Utf8View(None),
352                                _,
353                            ) => {}
354                            Expr::Literal(
355                                ScalarValue::Utf8(Some(v))
356                                | ScalarValue::LargeUtf8(Some(v))
357                                | ScalarValue::Utf8View(Some(v)),
358                                _,
359                            ) => match contiguous_scalar {
360                                None => contiguous_scalar = Some(v.to_string()),
361                                Some(mut pre) => {
362                                    pre += delimiter;
363                                    pre += v;
364                                    contiguous_scalar = Some(pre)
365                                }
366                            },
367                            Expr::Literal(s, _) => {
368                                return internal_err!(
369                                    "The scalar {s} should be casted to string type during the type coercion."
370                                );
371                            }
372                            // If the arg is not a literal, we should first push the current `contiguous_scalar`
373                            // to the `new_args` and reset it to None.
374                            // Then pushing this arg to the `new_args`.
375                            arg => {
376                                if let Some(val) = contiguous_scalar {
377                                    new_args.push(lit(val));
378                                }
379                                new_args.push(arg.clone());
380                                contiguous_scalar = None;
381                            }
382                        }
383                    }
384                    if let Some(val) = contiguous_scalar {
385                        new_args.push(lit(val));
386                    }
387
388                    Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
389                        ScalarFunction {
390                            func: concat_ws(),
391                            args: new_args,
392                        },
393                    )))
394                }
395                // if the delimiter is null, then the value of the whole expression is null.
396                None => Ok(ExprSimplifyResult::Simplified(Expr::Literal(
397                    ScalarValue::Utf8(None),
398                    None,
399                ))),
400            }
401        }
402        Expr::Literal(d, _) => internal_err!(
403            "The scalar {d} should be casted to string type during the type coercion."
404        ),
405        _ => {
406            let mut args = args
407                .iter()
408                .filter(|&x| !is_null(x))
409                .cloned()
410                .collect::<Vec<Expr>>();
411            args.insert(0, delimiter.clone());
412            Ok(ExprSimplifyResult::Original(args))
413        }
414    }
415}
416
417fn is_null(expr: &Expr) -> bool {
418    match expr {
419        Expr::Literal(v, _) => v.is_null(),
420        _ => false,
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use std::sync::Arc;
427
428    use crate::string::concat_ws::ConcatWsFunc;
429    use arrow::array::{Array, ArrayRef, StringArray};
430    use arrow::datatypes::DataType::Utf8;
431    use arrow::datatypes::Field;
432    use datafusion_common::Result;
433    use datafusion_common::ScalarValue;
434    use datafusion_common::config::ConfigOptions;
435    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
436
437    use crate::utils::test::test_function;
438
439    #[test]
440    fn test_functions() -> Result<()> {
441        test_function!(
442            ConcatWsFunc::new(),
443            vec![
444                ColumnarValue::Scalar(ScalarValue::from("|")),
445                ColumnarValue::Scalar(ScalarValue::from("aa")),
446                ColumnarValue::Scalar(ScalarValue::from("bb")),
447                ColumnarValue::Scalar(ScalarValue::from("cc")),
448            ],
449            Ok(Some("aa|bb|cc")),
450            &str,
451            Utf8,
452            StringArray
453        );
454        test_function!(
455            ConcatWsFunc::new(),
456            vec![
457                ColumnarValue::Scalar(ScalarValue::from("|")),
458                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
459            ],
460            Ok(Some("")),
461            &str,
462            Utf8,
463            StringArray
464        );
465        test_function!(
466            ConcatWsFunc::new(),
467            vec![
468                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
469                ColumnarValue::Scalar(ScalarValue::from("aa")),
470                ColumnarValue::Scalar(ScalarValue::from("bb")),
471                ColumnarValue::Scalar(ScalarValue::from("cc")),
472            ],
473            Ok(None),
474            &str,
475            Utf8,
476            StringArray
477        );
478        test_function!(
479            ConcatWsFunc::new(),
480            vec![
481                ColumnarValue::Scalar(ScalarValue::from("|")),
482                ColumnarValue::Scalar(ScalarValue::from("aa")),
483                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
484                ColumnarValue::Scalar(ScalarValue::from("cc")),
485            ],
486            Ok(Some("aa|cc")),
487            &str,
488            Utf8,
489            StringArray
490        );
491
492        Ok(())
493    }
494
495    #[test]
496    fn concat_ws() -> Result<()> {
497        // sep is scalar
498        let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
499        let c1 =
500            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
501        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
502            Some("x"),
503            None,
504            Some("z"),
505        ])));
506
507        let arg_fields = vec![
508            Field::new("a", Utf8, true).into(),
509            Field::new("a", Utf8, true).into(),
510            Field::new("a", Utf8, true).into(),
511        ];
512        let args = ScalarFunctionArgs {
513            args: vec![c0, c1, c2],
514            arg_fields,
515            number_rows: 3,
516            return_field: Field::new("f", Utf8, true).into(),
517            config_options: Arc::new(ConfigOptions::default()),
518        };
519
520        let result = ConcatWsFunc::new().invoke_with_args(args)?;
521        let expected =
522            Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
523        match &result {
524            ColumnarValue::Array(array) => {
525                assert_eq!(&expected, array);
526            }
527            _ => panic!(),
528        }
529
530        // sep is nullable array
531        let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
532            Some(","),
533            None,
534            Some("+"),
535        ])));
536        let c1 =
537            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
538        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
539            Some("x"),
540            Some("y"),
541            Some("z"),
542        ])));
543
544        let arg_fields = vec![
545            Field::new("a", Utf8, true).into(),
546            Field::new("a", Utf8, true).into(),
547            Field::new("a", Utf8, true).into(),
548        ];
549        let args = ScalarFunctionArgs {
550            args: vec![c0, c1, c2],
551            arg_fields,
552            number_rows: 3,
553            return_field: Field::new("f", Utf8, true).into(),
554            config_options: Arc::new(ConfigOptions::default()),
555        };
556
557        let result = ConcatWsFunc::new().invoke_with_args(args)?;
558        let expected =
559            Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")]))
560                as ArrayRef;
561        match &result {
562            ColumnarValue::Array(array) => {
563                assert_eq!(&expected, array);
564            }
565            _ => panic!(),
566        }
567
568        Ok(())
569    }
570}