datafusion_functions/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{as_largestring_array, Array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::any::Any;
22use std::sync::Arc;
23
24use crate::string::concat;
25use crate::strings::{
26    ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
27};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Concatenates multiple strings together.",
39    syntax_example = "concat(str[, ..., str_n])",
40    sql_example = r#"```sql
41> select concat('data', 'f', 'us', 'ion');
42+-------------------------------------------------------+
43| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
44+-------------------------------------------------------+
45| datafusion                                            |
46+-------------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    argument(
50        name = "str_n",
51        description = "Subsequent string expressions to concatenate."
52    ),
53    related_udf(name = "concat_ws")
54)]
55#[derive(Debug)]
56pub struct ConcatFunc {
57    signature: Signature,
58}
59
60impl Default for ConcatFunc {
61    fn default() -> Self {
62        ConcatFunc::new()
63    }
64}
65
66impl ConcatFunc {
67    pub fn new() -> Self {
68        use DataType::*;
69        Self {
70            signature: Signature::variadic(
71                vec![Utf8View, Utf8, LargeUtf8],
72                Volatility::Immutable,
73            ),
74        }
75    }
76}
77
78impl ScalarUDFImpl for ConcatFunc {
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn name(&self) -> &str {
84        "concat"
85    }
86
87    fn signature(&self) -> &Signature {
88        &self.signature
89    }
90
91    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92        use DataType::*;
93        let mut dt = &Utf8;
94        arg_types.iter().for_each(|data_type| {
95            if data_type == &Utf8View {
96                dt = data_type;
97            }
98            if data_type == &LargeUtf8 && dt != &Utf8View {
99                dt = data_type;
100            }
101        });
102
103        Ok(dt.to_owned())
104    }
105
106    /// Concatenates the text representations of all the arguments. NULL arguments are ignored.
107    /// concat('abcde', 2, NULL, 22) = 'abcde222'
108    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109        let ScalarFunctionArgs { args, .. } = args;
110
111        let mut return_datatype = DataType::Utf8;
112        args.iter().for_each(|col| {
113            if col.data_type() == DataType::Utf8View {
114                return_datatype = col.data_type();
115            }
116            if col.data_type() == DataType::LargeUtf8
117                && return_datatype != DataType::Utf8View
118            {
119                return_datatype = col.data_type();
120            }
121        });
122
123        let array_len = args
124            .iter()
125            .filter_map(|x| match x {
126                ColumnarValue::Array(array) => Some(array.len()),
127                _ => None,
128            })
129            .next();
130
131        // Scalar
132        if array_len.is_none() {
133            let mut result = String::new();
134            for arg in args {
135                let ColumnarValue::Scalar(scalar) = arg else {
136                    return internal_err!("concat expected scalar value, got {arg:?}");
137                };
138
139                match scalar.try_as_str() {
140                    Some(Some(v)) => result.push_str(v),
141                    Some(None) => {} // null literal
142                    None => plan_err!(
143                        "Concat function does not support scalar type {:?}",
144                        scalar
145                    )?,
146                }
147            }
148
149            return match return_datatype {
150                DataType::Utf8View => {
151                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
152                }
153                DataType::Utf8 => {
154                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
155                }
156                DataType::LargeUtf8 => {
157                    Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
158                }
159                other => {
160                    plan_err!("Concat function does not support datatype of {other}")
161                }
162            };
163        }
164
165        // Array
166        let len = array_len.unwrap();
167        let mut data_size = 0;
168        let mut columns = Vec::with_capacity(args.len());
169
170        for arg in &args {
171            match arg {
172                ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
173                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
174                | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
175                    if let Some(s) = maybe_value {
176                        data_size += s.len() * len;
177                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
178                    }
179                }
180                ColumnarValue::Array(array) => {
181                    match array.data_type() {
182                        DataType::Utf8 => {
183                            let string_array = as_string_array(array)?;
184
185                            data_size += string_array.values().len();
186                            let column = if array.is_nullable() {
187                                ColumnarValueRef::NullableArray(string_array)
188                            } else {
189                                ColumnarValueRef::NonNullableArray(string_array)
190                            };
191                            columns.push(column);
192                        },
193                        DataType::LargeUtf8 => {
194                            let string_array = as_largestring_array(array);
195
196                            data_size += string_array.values().len();
197                            let column = if array.is_nullable() {
198                                ColumnarValueRef::NullableLargeStringArray(string_array)
199                            } else {
200                                ColumnarValueRef::NonNullableLargeStringArray(string_array)
201                            };
202                            columns.push(column);
203                        },
204                        DataType::Utf8View => {
205                            let string_array = as_string_view_array(array)?;
206
207                            data_size += string_array.len();
208                            let column = if array.is_nullable() {
209                                ColumnarValueRef::NullableStringViewArray(string_array)
210                            } else {
211                                ColumnarValueRef::NonNullableStringViewArray(string_array)
212                            };
213                            columns.push(column);
214                        },
215                        other => {
216                            return plan_err!("Input was {other} which is not a supported datatype for concat function")
217                        }
218                    };
219                }
220                _ => unreachable!("concat"),
221            }
222        }
223
224        match return_datatype {
225            DataType::Utf8 => {
226                let mut builder = StringArrayBuilder::with_capacity(len, data_size);
227                for i in 0..len {
228                    columns
229                        .iter()
230                        .for_each(|column| builder.write::<true>(column, i));
231                    builder.append_offset();
232                }
233
234                let string_array = builder.finish(None);
235                Ok(ColumnarValue::Array(Arc::new(string_array)))
236            }
237            DataType::Utf8View => {
238                let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
239                for i in 0..len {
240                    columns
241                        .iter()
242                        .for_each(|column| builder.write::<true>(column, i));
243                    builder.append_offset();
244                }
245
246                let string_array = builder.finish();
247                Ok(ColumnarValue::Array(Arc::new(string_array)))
248            }
249            DataType::LargeUtf8 => {
250                let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
251                for i in 0..len {
252                    columns
253                        .iter()
254                        .for_each(|column| builder.write::<true>(column, i));
255                    builder.append_offset();
256                }
257
258                let string_array = builder.finish(None);
259                Ok(ColumnarValue::Array(Arc::new(string_array)))
260            }
261            _ => unreachable!(),
262        }
263    }
264
265    /// Simplify the `concat` function by
266    /// 1. filtering out all `null` literals
267    /// 2. concatenating contiguous literal arguments
268    ///
269    /// For example:
270    /// `concat(col(a), 'hello ', 'world', col(b), null)`
271    /// will be optimized to
272    /// `concat(col(a), 'hello world', col(b))`
273    fn simplify(
274        &self,
275        args: Vec<Expr>,
276        _info: &dyn SimplifyInfo,
277    ) -> Result<ExprSimplifyResult> {
278        simplify_concat(args)
279    }
280
281    fn documentation(&self) -> Option<&Documentation> {
282        self.doc()
283    }
284
285    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
286        Ok(true)
287    }
288}
289
290pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
291    let mut new_args = Vec::with_capacity(args.len());
292    let mut contiguous_scalar = "".to_string();
293
294    let return_type = {
295        let data_types: Vec<_> = args
296            .iter()
297            .filter_map(|expr| match expr {
298                Expr::Literal(l, _) => Some(l.data_type()),
299                _ => None,
300            })
301            .collect();
302        ConcatFunc::new().return_type(&data_types)
303    }?;
304
305    for arg in args.clone() {
306        match arg {
307            Expr::Literal(ScalarValue::Utf8(None), _) => {}
308            Expr::Literal(ScalarValue::LargeUtf8(None), _) => {
309            }
310            Expr::Literal(ScalarValue::Utf8View(None), _) => { }
311
312            // filter out `null` args
313            // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
314            // Concatenate it with the `contiguous_scalar`.
315            Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
316                contiguous_scalar += &v;
317            }
318            Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
319                contiguous_scalar += &v;
320            }
321            Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
322                contiguous_scalar += &v;
323            }
324
325            Expr::Literal(x, _) => {
326                return internal_err!(
327                    "The scalar {x} should be casted to string type during the type coercion."
328                )
329            }
330            // If the arg is not a literal, we should first push the current `contiguous_scalar`
331            // to the `new_args` (if it is not empty) and reset it to empty string.
332            // Then pushing this arg to the `new_args`.
333            arg => {
334                if !contiguous_scalar.is_empty() {
335                    match return_type {
336                        DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
337                        DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
338                        DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
339                        _ => unreachable!(),
340                    }
341                    contiguous_scalar = "".to_string();
342                }
343                new_args.push(arg);
344            }
345        }
346    }
347
348    if !contiguous_scalar.is_empty() {
349        match return_type {
350            DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
351            DataType::LargeUtf8 => {
352                new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
353            }
354            DataType::Utf8View => {
355                new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
356            }
357            _ => unreachable!(),
358        }
359    }
360
361    if !args.eq(&new_args) {
362        Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
363            ScalarFunction {
364                func: concat(),
365                args: new_args,
366            },
367        )))
368    } else {
369        Ok(ExprSimplifyResult::Original(args))
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::utils::test::test_function;
377    use arrow::array::{Array, LargeStringArray, StringViewArray};
378    use arrow::array::{ArrayRef, StringArray};
379    use arrow::datatypes::Field;
380    use DataType::*;
381
382    #[test]
383    fn test_functions() -> Result<()> {
384        test_function!(
385            ConcatFunc::new(),
386            vec![
387                ColumnarValue::Scalar(ScalarValue::from("aa")),
388                ColumnarValue::Scalar(ScalarValue::from("bb")),
389                ColumnarValue::Scalar(ScalarValue::from("cc")),
390            ],
391            Ok(Some("aabbcc")),
392            &str,
393            Utf8,
394            StringArray
395        );
396        test_function!(
397            ConcatFunc::new(),
398            vec![
399                ColumnarValue::Scalar(ScalarValue::from("aa")),
400                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
401                ColumnarValue::Scalar(ScalarValue::from("cc")),
402            ],
403            Ok(Some("aacc")),
404            &str,
405            Utf8,
406            StringArray
407        );
408        test_function!(
409            ConcatFunc::new(),
410            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
411            Ok(Some("")),
412            &str,
413            Utf8,
414            StringArray
415        );
416        test_function!(
417            ConcatFunc::new(),
418            vec![
419                ColumnarValue::Scalar(ScalarValue::from("aa")),
420                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
421                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
422                ColumnarValue::Scalar(ScalarValue::from("cc")),
423            ],
424            Ok(Some("aacc")),
425            &str,
426            Utf8View,
427            StringViewArray
428        );
429        test_function!(
430            ConcatFunc::new(),
431            vec![
432                ColumnarValue::Scalar(ScalarValue::from("aa")),
433                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
434                ColumnarValue::Scalar(ScalarValue::from("cc")),
435            ],
436            Ok(Some("aacc")),
437            &str,
438            LargeUtf8,
439            LargeStringArray
440        );
441        test_function!(
442            ConcatFunc::new(),
443            vec![
444                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
445                ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
446            ],
447            Ok(Some("aacc")),
448            &str,
449            Utf8View,
450            StringViewArray
451        );
452
453        Ok(())
454    }
455
456    #[test]
457    fn concat() -> Result<()> {
458        let c0 =
459            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
460        let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
461        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
462            Some("x"),
463            None,
464            Some("z"),
465        ])));
466        let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
467        let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
468            Some("a"),
469            None,
470            Some("b"),
471        ])));
472        let arg_fields = vec![
473            Field::new("a", Utf8, true),
474            Field::new("a", Utf8, true),
475            Field::new("a", Utf8, true),
476            Field::new("a", Utf8View, true),
477            Field::new("a", Utf8View, true),
478        ]
479        .into_iter()
480        .map(Arc::new)
481        .collect::<Vec<_>>();
482
483        let args = ScalarFunctionArgs {
484            args: vec![c0, c1, c2, c3, c4],
485            arg_fields,
486            number_rows: 3,
487            return_field: Field::new("f", Utf8, true).into(),
488        };
489
490        let result = ConcatFunc::new().invoke_with_args(args)?;
491        let expected =
492            Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
493                as ArrayRef;
494        match &result {
495            ColumnarValue::Array(array) => {
496                assert_eq!(&expected, array);
497            }
498            _ => panic!(),
499        }
500        Ok(())
501    }
502}