Skip to main content

datafusion_functions/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, as_largestring_array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::sync::Arc;
22
23use crate::string::concat;
24use crate::strings::{
25    ColumnarValueRef, ConcatLargeStringBuilder, ConcatStringBuilder,
26    ConcatStringViewBuilder,
27};
28use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array};
29use datafusion_common::{
30    Result, ScalarValue, exec_datafusion_err, internal_err, plan_err,
31};
32use datafusion_expr::expr::ScalarFunction;
33use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
34use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
35use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Concatenates multiple strings together.",
41    syntax_example = "concat(str[, ..., str_n])",
42    sql_example = r#"```sql
43> select concat('data', 'f', 'us', 'ion');
44+-------------------------------------------------------+
45| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
46+-------------------------------------------------------+
47| datafusion                                            |
48+-------------------------------------------------------+
49```"#,
50    standard_argument(name = "str", prefix = "String"),
51    argument(
52        name = "str_n",
53        description = "Subsequent string expressions to concatenate."
54    ),
55    related_udf(name = "concat_ws")
56)]
57#[derive(Debug, PartialEq, Eq, Hash)]
58pub struct ConcatFunc {
59    signature: Signature,
60}
61
62impl Default for ConcatFunc {
63    fn default() -> Self {
64        ConcatFunc::new()
65    }
66}
67
68impl ConcatFunc {
69    pub fn new() -> Self {
70        use DataType::*;
71        Self {
72            signature: Signature::variadic(
73                vec![Utf8View, Utf8, LargeUtf8, Binary],
74                Volatility::Immutable,
75            ),
76        }
77    }
78}
79
80fn deduce_return_type(arg_types: &[DataType]) -> DataType {
81    use DataType::*;
82    if arg_types.contains(&Utf8View) {
83        Utf8View
84    } else if arg_types.contains(&LargeUtf8) {
85        LargeUtf8
86    } else {
87        Utf8
88    }
89}
90
91impl ScalarUDFImpl for ConcatFunc {
92    fn name(&self) -> &str {
93        "concat"
94    }
95
96    fn signature(&self) -> &Signature {
97        &self.signature
98    }
99
100    /// Match the return type to the input types to avoid unnecessary casts. On
101    /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid
102    /// potential overflow on LargeUtf8 input.
103    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
104        Ok(deduce_return_type(arg_types))
105    }
106
107    /// Concatenates the text representations of all the arguments. NULL arguments are ignored.
108    /// concat('abcde', 2, NULL, 22) = 'abcde222'
109    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110        let ScalarFunctionArgs { args, .. } = args;
111
112        let arg_types: Vec<DataType> = args.iter().map(|c| c.data_type()).collect();
113        let return_datatype = deduce_return_type(&arg_types);
114
115        let array_len = args.iter().find_map(|x| match x {
116            ColumnarValue::Array(array) => Some(array.len()),
117            _ => None,
118        });
119
120        // Scalar
121        if array_len.is_none() {
122            let mut values: Vec<&[u8]> = Vec::with_capacity(args.len());
123            for arg in &args {
124                let ColumnarValue::Scalar(scalar) = arg else {
125                    return internal_err!("concat expected scalar value, got {arg:?}");
126                };
127                if let ScalarValue::Binary(Some(value)) = scalar {
128                    values.push(value);
129                } else {
130                    match scalar.try_as_str() {
131                        Some(Some(v)) => values.push(v.as_bytes()),
132                        Some(None) => {} // null literal
133                        None => plan_err!(
134                            "Concat function does not support scalar type {}",
135                            scalar
136                        )?,
137                    }
138                }
139            }
140            let concat_bytes = values.concat();
141            let result = std::str::from_utf8(&concat_bytes)
142                .map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))?
143                .to_string();
144
145            return match return_datatype {
146                DataType::Utf8View => {
147                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
148                }
149                DataType::Utf8 => {
150                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
151                }
152                DataType::LargeUtf8 => {
153                    Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
154                }
155                other => {
156                    plan_err!("Concat function does not support datatype of {other}")
157                }
158            };
159        }
160
161        // Array
162        let len = array_len.unwrap();
163        let mut data_size = 0;
164        let mut columns = Vec::with_capacity(args.len());
165
166        for arg in &args {
167            match arg {
168                ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
169                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
170                | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
171                    if let Some(s) = maybe_value {
172                        data_size += s.len() * len;
173                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
174                    }
175                }
176                ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) => {
177                    if let Some(b) = maybe_value {
178                        // data_size is a capacity hint, so doesn't matter if it is chars or bytes
179                        data_size += b.len() * len;
180                        columns.push(ColumnarValueRef::Scalar(b.as_slice()));
181                    }
182                }
183                ColumnarValue::Array(array) => {
184                    match array.data_type() {
185                        DataType::Utf8 => {
186                            let string_array = as_string_array(array)?;
187
188                            data_size += string_array.values().len();
189                            let column = if array.is_nullable() {
190                                ColumnarValueRef::NullableArray(string_array)
191                            } else {
192                                ColumnarValueRef::NonNullableArray(string_array)
193                            };
194                            columns.push(column);
195                        }
196                        DataType::LargeUtf8 => {
197                            let string_array = as_largestring_array(array);
198
199                            data_size += string_array.values().len();
200                            let column = if array.is_nullable() {
201                                ColumnarValueRef::NullableLargeStringArray(string_array)
202                            } else {
203                                ColumnarValueRef::NonNullableLargeStringArray(
204                                    string_array,
205                                )
206                            };
207                            columns.push(column);
208                        }
209                        DataType::Utf8View => {
210                            let string_array = as_string_view_array(array)?;
211
212                            // This is an estimate; in particular, it will
213                            // undercount arrays of short strings (<= 12 bytes).
214                            data_size += string_array.total_buffer_bytes_used();
215                            let column = if array.is_nullable() {
216                                ColumnarValueRef::NullableStringViewArray(string_array)
217                            } else {
218                                ColumnarValueRef::NonNullableStringViewArray(string_array)
219                            };
220                            columns.push(column);
221                        }
222                        DataType::Binary => {
223                            let string_array = as_binary_array(array)?;
224
225                            data_size += string_array.values().len();
226                            let column = if array.is_nullable() {
227                                ColumnarValueRef::NullableBinaryArray(string_array)
228                            } else {
229                                ColumnarValueRef::NonNullableBinaryArray(string_array)
230                            };
231                            columns.push(column);
232                        }
233                        other => {
234                            return plan_err!(
235                                "Input was {other} which is not a supported datatype for concat function"
236                            );
237                        }
238                    };
239                }
240                _ => unreachable!("concat"),
241            }
242        }
243
244        match return_datatype {
245            DataType::Utf8 => {
246                let mut builder = ConcatStringBuilder::with_capacity(len, data_size);
247                for i in 0..len {
248                    columns
249                        .iter()
250                        .for_each(|column| builder.write::<true>(column, i));
251                    builder.append_offset()?;
252                }
253
254                let string_array = builder.finish(None)?;
255                Ok(ColumnarValue::Array(Arc::new(string_array)))
256            }
257            DataType::Utf8View => {
258                let mut builder = ConcatStringViewBuilder::with_capacity(len, data_size);
259                for i in 0..len {
260                    columns
261                        .iter()
262                        .for_each(|column| builder.write::<true>(column, i));
263                    builder.append_offset()?;
264                }
265
266                let string_array = builder.finish(None)?;
267                Ok(ColumnarValue::Array(Arc::new(string_array)))
268            }
269            DataType::LargeUtf8 => {
270                let mut builder = ConcatLargeStringBuilder::with_capacity(len, data_size);
271                for i in 0..len {
272                    columns
273                        .iter()
274                        .for_each(|column| builder.write::<true>(column, i));
275                    builder.append_offset()?;
276                }
277
278                let string_array = builder.finish(None)?;
279                Ok(ColumnarValue::Array(Arc::new(string_array)))
280            }
281            _ => unreachable!(),
282        }
283    }
284
285    /// Simplify the `concat` function by
286    /// 1. filtering out all `null` literals
287    /// 2. concatenating contiguous literal arguments
288    ///
289    /// For example:
290    /// `concat(col(a), 'hello ', 'world', col(b), null)`
291    /// will be optimized to
292    /// `concat(col(a), 'hello world', col(b))`
293    fn simplify(
294        &self,
295        args: Vec<Expr>,
296        _info: &SimplifyContext,
297    ) -> Result<ExprSimplifyResult> {
298        simplify_concat(args)
299    }
300
301    fn documentation(&self) -> Option<&Documentation> {
302        self.doc()
303    }
304
305    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
306        Ok(true)
307    }
308}
309
310pub(crate) fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
311    let mut new_args = Vec::with_capacity(args.len());
312    let mut contiguous_scalar = "".to_string();
313
314    let return_type = {
315        let data_types: Vec<_> = args
316            .iter()
317            .filter_map(|expr| match expr {
318                Expr::Literal(l, _) => Some(l.data_type()),
319                _ => None,
320            })
321            .collect();
322        ConcatFunc::new().return_type(&data_types)
323    }?;
324
325    for arg in args.clone() {
326        match arg {
327            Expr::Literal(ScalarValue::Utf8(None), _) => {}
328            Expr::Literal(ScalarValue::LargeUtf8(None), _) => {}
329            Expr::Literal(ScalarValue::Utf8View(None), _) => {}
330
331            // filter out `null` args
332            // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
333            // Concatenate it with the `contiguous_scalar`.
334            Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
335                contiguous_scalar += &v;
336            }
337            Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
338                contiguous_scalar += &v;
339            }
340            Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
341                contiguous_scalar += &v;
342            }
343
344            Expr::Literal(x, _) => {
345                return internal_err!(
346                    "The scalar {x} should be casted to string type during the type coercion."
347                );
348            }
349            // If the arg is not a literal, we should first push the current `contiguous_scalar`
350            // to the `new_args` (if it is not empty) and reset it to empty string.
351            // Then pushing this arg to the `new_args`.
352            arg => {
353                if !contiguous_scalar.is_empty() {
354                    match return_type {
355                        DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
356                        DataType::LargeUtf8 => new_args
357                            .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
358                        DataType::Utf8View => new_args
359                            .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
360                        _ => unreachable!(),
361                    }
362                    contiguous_scalar = "".to_string();
363                }
364                new_args.push(arg);
365            }
366        }
367    }
368
369    if !contiguous_scalar.is_empty() {
370        match return_type {
371            DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
372            DataType::LargeUtf8 => {
373                new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
374            }
375            DataType::Utf8View => {
376                new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
377            }
378            _ => unreachable!(),
379        }
380    }
381
382    if !args.eq(&new_args) {
383        Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
384            ScalarFunction {
385                func: concat(),
386                args: new_args,
387            },
388        )))
389    } else {
390        Ok(ExprSimplifyResult::Original(args))
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::utils::test::test_function;
398    use DataType::*;
399    use arrow::array::{ArrayRef, StringArray};
400    use arrow::array::{LargeStringArray, StringViewArray};
401    use arrow::datatypes::Field;
402    use datafusion_common::config::ConfigOptions;
403
404    #[test]
405    fn test_functions() -> Result<()> {
406        test_function!(
407            ConcatFunc::new(),
408            vec![
409                ColumnarValue::Scalar(ScalarValue::from("aa")),
410                ColumnarValue::Scalar(ScalarValue::from("bb")),
411                ColumnarValue::Scalar(ScalarValue::from("cc")),
412            ],
413            Ok(Some("aabbcc")),
414            &str,
415            Utf8,
416            StringArray
417        );
418        test_function!(
419            ConcatFunc::new(),
420            vec![
421                ColumnarValue::Scalar(ScalarValue::from("aa")),
422                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
423                ColumnarValue::Scalar(ScalarValue::from("cc")),
424            ],
425            Ok(Some("aacc")),
426            &str,
427            Utf8,
428            StringArray
429        );
430        test_function!(
431            ConcatFunc::new(),
432            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
433            Ok(Some("")),
434            &str,
435            Utf8,
436            StringArray
437        );
438        test_function!(
439            ConcatFunc::new(),
440            vec![
441                ColumnarValue::Scalar(ScalarValue::from("aa")),
442                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
443                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
444                ColumnarValue::Scalar(ScalarValue::from("cc")),
445            ],
446            Ok(Some("aacc")),
447            &str,
448            Utf8View,
449            StringViewArray
450        );
451        test_function!(
452            ConcatFunc::new(),
453            vec![
454                ColumnarValue::Scalar(ScalarValue::from("aa")),
455                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
456                ColumnarValue::Scalar(ScalarValue::from("cc")),
457            ],
458            Ok(Some("aacc")),
459            &str,
460            LargeUtf8,
461            LargeStringArray
462        );
463        test_function!(
464            ConcatFunc::new(),
465            vec![
466                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
467                ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
468            ],
469            Ok(Some("aacc")),
470            &str,
471            Utf8View,
472            StringViewArray
473        );
474        test_function!(
475            ConcatFunc::new(),
476            vec![
477                ColumnarValue::Scalar(ScalarValue::Binary(Some(
478                    "Café".as_bytes().into()
479                ))),
480                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
481                ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
482            ],
483            Ok(Some("Cafécc")),
484            &str,
485            Utf8,
486            StringArray
487        );
488        test_function!(
489            ConcatFunc::new(),
490            vec![
491                ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::from(
492                    "Café".as_bytes()
493                )))),
494                ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))),
495            ],
496            Ok(Some("Cafécc")),
497            &str,
498            Utf8,
499            StringArray
500        );
501        Ok(())
502    }
503
504    #[test]
505    fn concat() -> Result<()> {
506        let c0 =
507            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
508        let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
509        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
510            Some("x"),
511            None,
512            Some("z"),
513        ])));
514        let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
515        let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
516            Some("a"),
517            None,
518            Some("b"),
519        ])));
520        let arg_fields = vec![
521            Field::new("a", Utf8, true),
522            Field::new("a", Utf8, true),
523            Field::new("a", Utf8, true),
524            Field::new("a", Utf8View, true),
525            Field::new("a", Utf8View, true),
526        ]
527        .into_iter()
528        .map(Arc::new)
529        .collect::<Vec<_>>();
530
531        let args = ScalarFunctionArgs {
532            args: vec![c0, c1, c2, c3, c4],
533            arg_fields,
534            number_rows: 3,
535            return_field: Field::new("f", Utf8, true).into(),
536            config_options: Arc::new(ConfigOptions::default()),
537        };
538
539        let result = ConcatFunc::new().invoke_with_args(args)?;
540        let expected =
541            Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
542                as ArrayRef;
543        match &result {
544            ColumnarValue::Array(array) => {
545                assert_eq!(&expected, array);
546            }
547            _ => panic!(),
548        }
549        Ok(())
550    }
551}