Skip to main content

datafusion_spark/function/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field};
19use datafusion_common::arrow::datatypes::FieldRef;
20use datafusion_common::{Result, ScalarValue};
21use datafusion_expr::ReturnFieldArgs;
22use datafusion_expr::{
23    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use datafusion_functions::string::concat::ConcatFunc;
26use std::sync::Arc;
27
28use crate::function::null_utils::{
29    NullMaskResolution, apply_null_mask, compute_null_mask,
30};
31
32/// Spark-compatible `concat` expression
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
34///
35/// Concatenates multiple input strings into a single string.
36/// Returns NULL if any input is NULL.
37///
38/// Differences with DataFusion concat:
39/// - Support 0 arguments
40/// - Return NULL if any input is NULL
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkConcat {
43    signature: Signature,
44}
45
46impl Default for SparkConcat {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SparkConcat {
53    pub fn new() -> Self {
54        Self {
55            signature: Signature::user_defined(Volatility::Immutable),
56        }
57    }
58}
59
60impl ScalarUDFImpl for SparkConcat {
61    fn name(&self) -> &str {
62        "concat"
63    }
64
65    fn signature(&self) -> &Signature {
66        &self.signature
67    }
68
69    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
70        spark_concat(args)
71    }
72
73    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
74        // Accept any string types, including zero arguments
75        Ok(arg_types.to_vec())
76    }
77    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78        datafusion_common::internal_err!(
79            "return_type should not be called for Spark concat"
80        )
81    }
82    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
83        use DataType::*;
84
85        // Spark semantics: concat returns NULL if ANY input is NULL
86        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
87
88        // Determine return type: Utf8View > LargeUtf8 > Utf8
89        let mut dt = &Utf8;
90        for field in args.arg_fields {
91            let data_type = field.data_type();
92            if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) {
93                dt = data_type;
94            }
95        }
96
97        Ok(Arc::new(Field::new("concat", dt.clone(), nullable)))
98    }
99}
100
101/// Concatenates strings, returning NULL if any input is NULL
102/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
103/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
104fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105    let ScalarFunctionArgs {
106        args: arg_values,
107        arg_fields,
108        number_rows,
109        return_field,
110        config_options,
111    } = args;
112
113    // Handle zero-argument case: return empty string
114    if arg_values.is_empty() {
115        let return_type = return_field.data_type();
116        return match return_type {
117            DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
118                String::new(),
119            )))),
120            DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(
121                Some(String::new()),
122            ))),
123            _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
124                Some(String::new()),
125            ))),
126        };
127    }
128
129    // Step 1: Check for NULL mask in incoming args
130    let null_mask = compute_null_mask(&arg_values);
131
132    // If all scalars and any is NULL, return NULL immediately
133    if matches!(null_mask, NullMaskResolution::ReturnNull) {
134        let return_type = return_field.data_type();
135        return match return_type {
136            DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))),
137            DataType::LargeUtf8 => {
138                Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)))
139            }
140            _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
141        };
142    }
143
144    // Step 2: Delegate to DataFusion's concat
145    let concat_func = ConcatFunc::new();
146    let return_type = return_field.data_type().clone();
147    let func_args = ScalarFunctionArgs {
148        args: arg_values,
149        arg_fields,
150        number_rows,
151        return_field,
152        config_options,
153    };
154    let result = concat_func.invoke_with_args(func_args)?;
155
156    // Step 3: Apply NULL mask to result
157    apply_null_mask(result, null_mask, &return_type)
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::function::utils::test::test_scalar_function;
164    use arrow::array::{Array, StringArray};
165
166    #[test]
167    fn test_concat_basic() -> Result<()> {
168        test_scalar_function!(
169            SparkConcat::new(),
170            vec![
171                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
172                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
173            ],
174            Ok(Some("SparkSQL")),
175            &str,
176            DataType::Utf8,
177            StringArray
178        );
179        Ok(())
180    }
181
182    #[test]
183    fn test_concat_with_null() -> Result<()> {
184        test_scalar_function!(
185            SparkConcat::new(),
186            vec![
187                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
188                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
189                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
190            ],
191            Ok(None),
192            &str,
193            DataType::Utf8,
194            StringArray
195        );
196        Ok(())
197    }
198
199    #[test]
200    fn test_spark_concat_return_field_non_nullable() -> Result<()> {
201        let func = SparkConcat::new();
202
203        let fields = vec![
204            Arc::new(Field::new("a", DataType::Utf8, false)),
205            Arc::new(Field::new("b", DataType::Utf8, false)),
206        ];
207
208        let args = ReturnFieldArgs {
209            arg_fields: &fields,
210            scalar_arguments: &[],
211        };
212
213        let field = func.return_field_from_args(args)?;
214
215        assert!(
216            !field.is_nullable(),
217            "Expected concat result to be non-nullable when all inputs are non-nullable"
218        );
219
220        Ok(())
221    }
222    #[test]
223    fn test_spark_concat_return_field_nullable() -> Result<()> {
224        let func = SparkConcat::new();
225
226        let fields = vec![
227            Arc::new(Field::new("a", DataType::Utf8, false)),
228            Arc::new(Field::new("b", DataType::Utf8, true)),
229        ];
230
231        let args = ReturnFieldArgs {
232            arg_fields: &fields,
233            scalar_arguments: &[],
234        };
235
236        let field = func.return_field_from_args(args)?;
237
238        assert!(
239            field.is_nullable(),
240            "Expected concat result to be nullable when any input is nullable"
241        );
242
243        Ok(())
244    }
245}