Skip to main content

datafusion_spark/function/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field};
19use datafusion_common::arrow::datatypes::FieldRef;
20use datafusion_common::{Result, ScalarValue};
21use datafusion_expr::ReturnFieldArgs;
22use datafusion_expr::{
23    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use datafusion_functions::string::concat::ConcatFunc;
26use std::any::Any;
27use std::sync::Arc;
28
29use crate::function::null_utils::{
30    NullMaskResolution, apply_null_mask, compute_null_mask,
31};
32
33/// Spark-compatible `concat` expression
34/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
35///
36/// Concatenates multiple input strings into a single string.
37/// Returns NULL if any input is NULL.
38///
39/// Differences with DataFusion concat:
40/// - Support 0 arguments
41/// - Return NULL if any input is NULL
42#[derive(Debug, PartialEq, Eq, Hash)]
43pub struct SparkConcat {
44    signature: Signature,
45}
46
47impl Default for SparkConcat {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SparkConcat {
54    pub fn new() -> Self {
55        Self {
56            signature: Signature::user_defined(Volatility::Immutable),
57        }
58    }
59}
60
61impl ScalarUDFImpl for SparkConcat {
62    fn as_any(&self) -> &dyn Any {
63        self
64    }
65
66    fn name(&self) -> &str {
67        "concat"
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
75        spark_concat(args)
76    }
77
78    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
79        // Accept any string types, including zero arguments
80        Ok(arg_types.to_vec())
81    }
82    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
83        datafusion_common::internal_err!(
84            "return_type should not be called for Spark concat"
85        )
86    }
87    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
88        use DataType::*;
89
90        // Spark semantics: concat returns NULL if ANY input is NULL
91        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
92
93        // Determine return type: Utf8View > LargeUtf8 > Utf8
94        let mut dt = &Utf8;
95        for field in args.arg_fields {
96            let data_type = field.data_type();
97            if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) {
98                dt = data_type;
99            }
100        }
101
102        Ok(Arc::new(Field::new("concat", dt.clone(), nullable)))
103    }
104}
105
106/// Concatenates strings, returning NULL if any input is NULL
107/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
108/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
109fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110    let ScalarFunctionArgs {
111        args: arg_values,
112        arg_fields,
113        number_rows,
114        return_field,
115        config_options,
116    } = args;
117
118    // Handle zero-argument case: return empty string
119    if arg_values.is_empty() {
120        let return_type = return_field.data_type();
121        return match return_type {
122            DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
123                String::new(),
124            )))),
125            DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(
126                Some(String::new()),
127            ))),
128            _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
129                Some(String::new()),
130            ))),
131        };
132    }
133
134    // Step 1: Check for NULL mask in incoming args
135    let null_mask = compute_null_mask(&arg_values, number_rows)?;
136
137    // If all scalars and any is NULL, return NULL immediately
138    if matches!(null_mask, NullMaskResolution::ReturnNull) {
139        let return_type = return_field.data_type();
140        return match return_type {
141            DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))),
142            DataType::LargeUtf8 => {
143                Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)))
144            }
145            _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
146        };
147    }
148
149    // Step 2: Delegate to DataFusion's concat
150    let concat_func = ConcatFunc::new();
151    let return_type = return_field.data_type().clone();
152    let func_args = ScalarFunctionArgs {
153        args: arg_values,
154        arg_fields,
155        number_rows,
156        return_field,
157        config_options,
158    };
159    let result = concat_func.invoke_with_args(func_args)?;
160
161    // Step 3: Apply NULL mask to result
162    apply_null_mask(result, null_mask, &return_type)
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::function::utils::test::test_scalar_function;
169    use arrow::array::{Array, StringArray};
170    use arrow::datatypes::{DataType, Field};
171    use datafusion_common::Result;
172    use datafusion_expr::ReturnFieldArgs;
173    use std::sync::Arc;
174
175    #[test]
176    fn test_concat_basic() -> Result<()> {
177        test_scalar_function!(
178            SparkConcat::new(),
179            vec![
180                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
181                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
182            ],
183            Ok(Some("SparkSQL")),
184            &str,
185            DataType::Utf8,
186            StringArray
187        );
188        Ok(())
189    }
190
191    #[test]
192    fn test_concat_with_null() -> Result<()> {
193        test_scalar_function!(
194            SparkConcat::new(),
195            vec![
196                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
197                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
198                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
199            ],
200            Ok(None),
201            &str,
202            DataType::Utf8,
203            StringArray
204        );
205        Ok(())
206    }
207
208    #[test]
209    fn test_spark_concat_return_field_non_nullable() -> Result<()> {
210        let func = SparkConcat::new();
211
212        let fields = vec![
213            Arc::new(Field::new("a", DataType::Utf8, false)),
214            Arc::new(Field::new("b", DataType::Utf8, false)),
215        ];
216
217        let args = ReturnFieldArgs {
218            arg_fields: &fields,
219            scalar_arguments: &[],
220        };
221
222        let field = func.return_field_from_args(args)?;
223
224        assert!(
225            !field.is_nullable(),
226            "Expected concat result to be non-nullable when all inputs are non-nullable"
227        );
228
229        Ok(())
230    }
231    #[test]
232    fn test_spark_concat_return_field_nullable() -> Result<()> {
233        let func = SparkConcat::new();
234
235        let fields = vec![
236            Arc::new(Field::new("a", DataType::Utf8, false)),
237            Arc::new(Field::new("b", DataType::Utf8, true)),
238        ];
239
240        let args = ReturnFieldArgs {
241            arg_fields: &fields,
242            scalar_arguments: &[],
243        };
244
245        let field = func.return_field_from_args(args)?;
246
247        assert!(
248            field.is_nullable(),
249            "Expected concat result to be nullable when any input is nullable"
250        );
251
252        Ok(())
253    }
254}