Skip to main content

datafusion_spark/function/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::Array;
19use arrow::buffer::NullBuffer;
20use arrow::datatypes::{DataType, Field};
21use datafusion_common::arrow::datatypes::FieldRef;
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr::ReturnFieldArgs;
24use datafusion_expr::{
25    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
26    Volatility,
27};
28use datafusion_functions::string::concat::ConcatFunc;
29use std::any::Any;
30use std::sync::Arc;
31
32/// Spark-compatible `concat` expression
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
34///
35/// Concatenates multiple input strings into a single string.
36/// Returns NULL if any input is NULL.
37///
38/// Differences with DataFusion concat:
39/// - Support 0 arguments
40/// - Return NULL if any input is NULL
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkConcat {
43    signature: Signature,
44}
45
46impl Default for SparkConcat {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SparkConcat {
53    pub fn new() -> Self {
54        Self {
55            signature: Signature::one_of(
56                vec![TypeSignature::UserDefined, TypeSignature::Nullary],
57                Volatility::Immutable,
58            ),
59        }
60    }
61}
62
63impl ScalarUDFImpl for SparkConcat {
64    fn as_any(&self) -> &dyn Any {
65        self
66    }
67
68    fn name(&self) -> &str {
69        "concat"
70    }
71
72    fn signature(&self) -> &Signature {
73        &self.signature
74    }
75
76    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
77        spark_concat(args)
78    }
79
80    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
81        // Accept any string types, including zero arguments
82        Ok(arg_types.to_vec())
83    }
84    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85        datafusion_common::internal_err!(
86            "return_type should not be called for Spark concat"
87        )
88    }
89    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
90        // Spark semantics: concat returns NULL if ANY input is NULL
91        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
92
93        Ok(Arc::new(Field::new("concat", DataType::Utf8, nullable)))
94    }
95}
96
97/// Represents the null state for Spark concat
98enum NullMaskResolution {
99    /// Return NULL as the result (e.g., scalar inputs with at least one NULL)
100    ReturnNull,
101    /// No null mask needed (e.g., all scalar inputs are non-NULL)
102    NoMask,
103    /// Null mask to apply for arrays
104    Apply(NullBuffer),
105}
106
107/// Concatenates strings, returning NULL if any input is NULL
108/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
109/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
110fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
111    let ScalarFunctionArgs {
112        args: arg_values,
113        arg_fields,
114        number_rows,
115        return_field,
116        config_options,
117    } = args;
118
119    // Handle zero-argument case: return empty string
120    if arg_values.is_empty() {
121        return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
122            Some(String::new()),
123        )));
124    }
125
126    // Step 1: Check for NULL mask in incoming args
127    let null_mask = compute_null_mask(&arg_values, number_rows)?;
128
129    // If all scalars and any is NULL, return NULL immediately
130    if matches!(null_mask, NullMaskResolution::ReturnNull) {
131        return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
132    }
133
134    // Step 2: Delegate to DataFusion's concat
135    let concat_func = ConcatFunc::new();
136    let func_args = ScalarFunctionArgs {
137        args: arg_values,
138        arg_fields,
139        number_rows,
140        return_field,
141        config_options,
142    };
143    let result = concat_func.invoke_with_args(func_args)?;
144
145    // Step 3: Apply NULL mask to result
146    apply_null_mask(result, null_mask)
147}
148
149/// Compute NULL mask for the arguments using NullBuffer::union
150fn compute_null_mask(
151    args: &[ColumnarValue],
152    number_rows: usize,
153) -> Result<NullMaskResolution> {
154    // Check if all arguments are scalars
155    let all_scalars = args
156        .iter()
157        .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
158
159    if all_scalars {
160        // For scalars, check if any is NULL
161        for arg in args {
162            if let ColumnarValue::Scalar(scalar) = arg
163                && scalar.is_null()
164            {
165                return Ok(NullMaskResolution::ReturnNull);
166            }
167        }
168        // No NULLs in scalars
169        Ok(NullMaskResolution::NoMask)
170    } else {
171        // For arrays, compute NULL mask for each row using NullBuffer::union
172        let array_len = args
173            .iter()
174            .find_map(|arg| match arg {
175                ColumnarValue::Array(array) => Some(array.len()),
176                _ => None,
177            })
178            .unwrap_or(number_rows);
179
180        // Convert all scalars to arrays for uniform processing
181        let arrays: Result<Vec<_>> = args
182            .iter()
183            .map(|arg| match arg {
184                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
185                ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
186            })
187            .collect();
188        let arrays = arrays?;
189
190        // Use NullBuffer::union to combine all null buffers
191        let combined_nulls = arrays
192            .iter()
193            .map(|arr| arr.nulls())
194            .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
195
196        match combined_nulls {
197            Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
198            None => Ok(NullMaskResolution::NoMask),
199        }
200    }
201}
202
203/// Apply NULL mask to the result using NullBuffer::union
204fn apply_null_mask(
205    result: ColumnarValue,
206    null_mask: NullMaskResolution,
207) -> Result<ColumnarValue> {
208    match (result, null_mask) {
209        // Scalar with ReturnNull mask means return NULL
210        (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
211            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
212        }
213        // Scalar without mask, return as-is
214        (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
215        // Array with NULL mask - use NullBuffer::union to combine nulls
216        (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
217            // Combine the result's existing nulls with our computed null mask
218            let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
219
220            // Create new array with combined nulls
221            let new_array = array
222                .into_data()
223                .into_builder()
224                .nulls(combined_nulls)
225                .build()?;
226
227            Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
228                new_array,
229            ))))
230        }
231        // Array without NULL mask, return as-is
232        (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
233        // Edge cases that shouldn't happen in practice
234        (scalar, _) => Ok(scalar),
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::function::utils::test::test_scalar_function;
242    use arrow::array::StringArray;
243    use arrow::datatypes::{DataType, Field};
244    use datafusion_common::Result;
245    use datafusion_expr::ReturnFieldArgs;
246    use std::sync::Arc;
247
248    #[test]
249    fn test_concat_basic() -> Result<()> {
250        test_scalar_function!(
251            SparkConcat::new(),
252            vec![
253                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
254                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
255            ],
256            Ok(Some("SparkSQL")),
257            &str,
258            DataType::Utf8,
259            StringArray
260        );
261        Ok(())
262    }
263
264    #[test]
265    fn test_concat_with_null() -> Result<()> {
266        test_scalar_function!(
267            SparkConcat::new(),
268            vec![
269                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
270                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
271                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
272            ],
273            Ok(None),
274            &str,
275            DataType::Utf8,
276            StringArray
277        );
278        Ok(())
279    }
280    #[test]
281    fn test_spark_concat_return_field_non_nullable() -> Result<()> {
282        let func = SparkConcat::new();
283
284        let fields = vec![
285            Arc::new(Field::new("a", DataType::Utf8, false)),
286            Arc::new(Field::new("b", DataType::Utf8, false)),
287        ];
288
289        let args = ReturnFieldArgs {
290            arg_fields: &fields,
291            scalar_arguments: &[],
292        };
293
294        let field = func.return_field_from_args(args)?;
295
296        assert!(
297            !field.is_nullable(),
298            "Expected concat result to be non-nullable when all inputs are non-nullable"
299        );
300
301        Ok(())
302    }
303    #[test]
304    fn test_spark_concat_return_field_nullable() -> Result<()> {
305        let func = SparkConcat::new();
306
307        let fields = vec![
308            Arc::new(Field::new("a", DataType::Utf8, false)),
309            Arc::new(Field::new("b", DataType::Utf8, true)),
310        ];
311
312        let args = ReturnFieldArgs {
313            arg_fields: &fields,
314            scalar_arguments: &[],
315        };
316
317        let field = func.return_field_from_args(args)?;
318
319        assert!(
320            field.is_nullable(),
321            "Expected concat result to be nullable when any input is nullable"
322        );
323
324        Ok(())
325    }
326}