datafusion_functions/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, as_largestring_array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::any::Any;
22use std::sync::Arc;
23
24use crate::string::concat;
25use crate::strings::{
26    ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
27};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Concatenates multiple strings together.",
39    syntax_example = "concat(str[, ..., str_n])",
40    sql_example = r#"```sql
41> select concat('data', 'f', 'us', 'ion');
42+-------------------------------------------------------+
43| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
44+-------------------------------------------------------+
45| datafusion                                            |
46+-------------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    argument(
50        name = "str_n",
51        description = "Subsequent string expressions to concatenate."
52    ),
53    related_udf(name = "concat_ws")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct ConcatFunc {
57    signature: Signature,
58}
59
60impl Default for ConcatFunc {
61    fn default() -> Self {
62        ConcatFunc::new()
63    }
64}
65
66impl ConcatFunc {
67    pub fn new() -> Self {
68        use DataType::*;
69        Self {
70            signature: Signature::variadic(
71                vec![Utf8View, Utf8, LargeUtf8],
72                Volatility::Immutable,
73            ),
74        }
75    }
76}
77
78impl ScalarUDFImpl for ConcatFunc {
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn name(&self) -> &str {
84        "concat"
85    }
86
87    fn signature(&self) -> &Signature {
88        &self.signature
89    }
90
91    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92        use DataType::*;
93        let mut dt = &Utf8;
94        arg_types.iter().for_each(|data_type| {
95            if data_type == &Utf8View {
96                dt = data_type;
97            }
98            if data_type == &LargeUtf8 && dt != &Utf8View {
99                dt = data_type;
100            }
101        });
102
103        Ok(dt.to_owned())
104    }
105
106    /// Concatenates the text representations of all the arguments. NULL arguments are ignored.
107    /// concat('abcde', 2, NULL, 22) = 'abcde222'
108    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109        let ScalarFunctionArgs { args, .. } = args;
110
111        let mut return_datatype = DataType::Utf8;
112        args.iter().for_each(|col| {
113            if col.data_type() == DataType::Utf8View {
114                return_datatype = col.data_type();
115            }
116            if col.data_type() == DataType::LargeUtf8
117                && return_datatype != DataType::Utf8View
118            {
119                return_datatype = col.data_type();
120            }
121        });
122
123        let array_len = args
124            .iter()
125            .filter_map(|x| match x {
126                ColumnarValue::Array(array) => Some(array.len()),
127                _ => None,
128            })
129            .next();
130
131        // Scalar
132        if array_len.is_none() {
133            let mut result = String::new();
134            for arg in args {
135                let ColumnarValue::Scalar(scalar) = arg else {
136                    return internal_err!("concat expected scalar value, got {arg:?}");
137                };
138
139                match scalar.try_as_str() {
140                    Some(Some(v)) => result.push_str(v),
141                    Some(None) => {} // null literal
142                    None => plan_err!(
143                        "Concat function does not support scalar type {}",
144                        scalar
145                    )?,
146                }
147            }
148
149            return match return_datatype {
150                DataType::Utf8View => {
151                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
152                }
153                DataType::Utf8 => {
154                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
155                }
156                DataType::LargeUtf8 => {
157                    Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
158                }
159                other => {
160                    plan_err!("Concat function does not support datatype of {other}")
161                }
162            };
163        }
164
165        // Array
166        let len = array_len.unwrap();
167        let mut data_size = 0;
168        let mut columns = Vec::with_capacity(args.len());
169
170        for arg in &args {
171            match arg {
172                ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
173                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
174                | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
175                    if let Some(s) = maybe_value {
176                        data_size += s.len() * len;
177                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
178                    }
179                }
180                ColumnarValue::Array(array) => {
181                    match array.data_type() {
182                        DataType::Utf8 => {
183                            let string_array = as_string_array(array)?;
184
185                            data_size += string_array.values().len();
186                            let column = if array.is_nullable() {
187                                ColumnarValueRef::NullableArray(string_array)
188                            } else {
189                                ColumnarValueRef::NonNullableArray(string_array)
190                            };
191                            columns.push(column);
192                        }
193                        DataType::LargeUtf8 => {
194                            let string_array = as_largestring_array(array);
195
196                            data_size += string_array.values().len();
197                            let column = if array.is_nullable() {
198                                ColumnarValueRef::NullableLargeStringArray(string_array)
199                            } else {
200                                ColumnarValueRef::NonNullableLargeStringArray(
201                                    string_array,
202                                )
203                            };
204                            columns.push(column);
205                        }
206                        DataType::Utf8View => {
207                            let string_array = as_string_view_array(array)?;
208
209                            data_size += string_array.len();
210                            let column = if array.is_nullable() {
211                                ColumnarValueRef::NullableStringViewArray(string_array)
212                            } else {
213                                ColumnarValueRef::NonNullableStringViewArray(string_array)
214                            };
215                            columns.push(column);
216                        }
217                        other => {
218                            return plan_err!(
219                                "Input was {other} which is not a supported datatype for concat function"
220                            );
221                        }
222                    };
223                }
224                _ => unreachable!("concat"),
225            }
226        }
227
228        match return_datatype {
229            DataType::Utf8 => {
230                let mut builder = StringArrayBuilder::with_capacity(len, data_size);
231                for i in 0..len {
232                    columns
233                        .iter()
234                        .for_each(|column| builder.write::<true>(column, i));
235                    builder.append_offset();
236                }
237
238                let string_array = builder.finish(None);
239                Ok(ColumnarValue::Array(Arc::new(string_array)))
240            }
241            DataType::Utf8View => {
242                let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
243                for i in 0..len {
244                    columns
245                        .iter()
246                        .for_each(|column| builder.write::<true>(column, i));
247                    builder.append_offset();
248                }
249
250                let string_array = builder.finish();
251                Ok(ColumnarValue::Array(Arc::new(string_array)))
252            }
253            DataType::LargeUtf8 => {
254                let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
255                for i in 0..len {
256                    columns
257                        .iter()
258                        .for_each(|column| builder.write::<true>(column, i));
259                    builder.append_offset();
260                }
261
262                let string_array = builder.finish(None);
263                Ok(ColumnarValue::Array(Arc::new(string_array)))
264            }
265            _ => unreachable!(),
266        }
267    }
268
269    /// Simplify the `concat` function by
270    /// 1. filtering out all `null` literals
271    /// 2. concatenating contiguous literal arguments
272    ///
273    /// For example:
274    /// `concat(col(a), 'hello ', 'world', col(b), null)`
275    /// will be optimized to
276    /// `concat(col(a), 'hello world', col(b))`
277    fn simplify(
278        &self,
279        args: Vec<Expr>,
280        _info: &dyn SimplifyInfo,
281    ) -> Result<ExprSimplifyResult> {
282        simplify_concat(args)
283    }
284
285    fn documentation(&self) -> Option<&Documentation> {
286        self.doc()
287    }
288
289    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
290        Ok(true)
291    }
292}
293
294pub(crate) fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
295    let mut new_args = Vec::with_capacity(args.len());
296    let mut contiguous_scalar = "".to_string();
297
298    let return_type = {
299        let data_types: Vec<_> = args
300            .iter()
301            .filter_map(|expr| match expr {
302                Expr::Literal(l, _) => Some(l.data_type()),
303                _ => None,
304            })
305            .collect();
306        ConcatFunc::new().return_type(&data_types)
307    }?;
308
309    for arg in args.clone() {
310        match arg {
311            Expr::Literal(ScalarValue::Utf8(None), _) => {}
312            Expr::Literal(ScalarValue::LargeUtf8(None), _) => {}
313            Expr::Literal(ScalarValue::Utf8View(None), _) => {}
314
315            // filter out `null` args
316            // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
317            // Concatenate it with the `contiguous_scalar`.
318            Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
319                contiguous_scalar += &v;
320            }
321            Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
322                contiguous_scalar += &v;
323            }
324            Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
325                contiguous_scalar += &v;
326            }
327
328            Expr::Literal(x, _) => {
329                return internal_err!(
330                    "The scalar {x} should be casted to string type during the type coercion."
331                );
332            }
333            // If the arg is not a literal, we should first push the current `contiguous_scalar`
334            // to the `new_args` (if it is not empty) and reset it to empty string.
335            // Then pushing this arg to the `new_args`.
336            arg => {
337                if !contiguous_scalar.is_empty() {
338                    match return_type {
339                        DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
340                        DataType::LargeUtf8 => new_args
341                            .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
342                        DataType::Utf8View => new_args
343                            .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
344                        _ => unreachable!(),
345                    }
346                    contiguous_scalar = "".to_string();
347                }
348                new_args.push(arg);
349            }
350        }
351    }
352
353    if !contiguous_scalar.is_empty() {
354        match return_type {
355            DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
356            DataType::LargeUtf8 => {
357                new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
358            }
359            DataType::Utf8View => {
360                new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
361            }
362            _ => unreachable!(),
363        }
364    }
365
366    if !args.eq(&new_args) {
367        Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
368            ScalarFunction {
369                func: concat(),
370                args: new_args,
371            },
372        )))
373    } else {
374        Ok(ExprSimplifyResult::Original(args))
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::utils::test::test_function;
382    use DataType::*;
383    use arrow::array::{Array, LargeStringArray, StringViewArray};
384    use arrow::array::{ArrayRef, StringArray};
385    use arrow::datatypes::Field;
386    use datafusion_common::config::ConfigOptions;
387
388    #[test]
389    fn test_functions() -> Result<()> {
390        test_function!(
391            ConcatFunc::new(),
392            vec![
393                ColumnarValue::Scalar(ScalarValue::from("aa")),
394                ColumnarValue::Scalar(ScalarValue::from("bb")),
395                ColumnarValue::Scalar(ScalarValue::from("cc")),
396            ],
397            Ok(Some("aabbcc")),
398            &str,
399            Utf8,
400            StringArray
401        );
402        test_function!(
403            ConcatFunc::new(),
404            vec![
405                ColumnarValue::Scalar(ScalarValue::from("aa")),
406                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
407                ColumnarValue::Scalar(ScalarValue::from("cc")),
408            ],
409            Ok(Some("aacc")),
410            &str,
411            Utf8,
412            StringArray
413        );
414        test_function!(
415            ConcatFunc::new(),
416            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
417            Ok(Some("")),
418            &str,
419            Utf8,
420            StringArray
421        );
422        test_function!(
423            ConcatFunc::new(),
424            vec![
425                ColumnarValue::Scalar(ScalarValue::from("aa")),
426                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
427                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
428                ColumnarValue::Scalar(ScalarValue::from("cc")),
429            ],
430            Ok(Some("aacc")),
431            &str,
432            Utf8View,
433            StringViewArray
434        );
435        test_function!(
436            ConcatFunc::new(),
437            vec![
438                ColumnarValue::Scalar(ScalarValue::from("aa")),
439                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
440                ColumnarValue::Scalar(ScalarValue::from("cc")),
441            ],
442            Ok(Some("aacc")),
443            &str,
444            LargeUtf8,
445            LargeStringArray
446        );
447        test_function!(
448            ConcatFunc::new(),
449            vec![
450                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
451                ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
452            ],
453            Ok(Some("aacc")),
454            &str,
455            Utf8View,
456            StringViewArray
457        );
458
459        Ok(())
460    }
461
462    #[test]
463    fn concat() -> Result<()> {
464        let c0 =
465            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
466        let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
467        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
468            Some("x"),
469            None,
470            Some("z"),
471        ])));
472        let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
473        let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
474            Some("a"),
475            None,
476            Some("b"),
477        ])));
478        let arg_fields = vec![
479            Field::new("a", Utf8, true),
480            Field::new("a", Utf8, true),
481            Field::new("a", Utf8, true),
482            Field::new("a", Utf8View, true),
483            Field::new("a", Utf8View, true),
484        ]
485        .into_iter()
486        .map(Arc::new)
487        .collect::<Vec<_>>();
488
489        let args = ScalarFunctionArgs {
490            args: vec![c0, c1, c2, c3, c4],
491            arg_fields,
492            number_rows: 3,
493            return_field: Field::new("f", Utf8, true).into(),
494            config_options: Arc::new(ConfigOptions::default()),
495        };
496
497        let result = ConcatFunc::new().invoke_with_args(args)?;
498        let expected =
499            Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
500                as ArrayRef;
501        match &result {
502            ColumnarValue::Array(array) => {
503                assert_eq!(&expected, array);
504            }
505            _ => panic!(),
506        }
507        Ok(())
508    }
509}