datafusion_spark/function/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::Array;
19use arrow::buffer::NullBuffer;
20use arrow::datatypes::DataType;
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{
23    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
24    Volatility,
25};
26use datafusion_functions::string::concat::ConcatFunc;
27use std::any::Any;
28use std::sync::Arc;
29
30/// Spark-compatible `concat` expression
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
32///
33/// Concatenates multiple input strings into a single string.
34/// Returns NULL if any input is NULL.
35///
36/// Differences with DataFusion concat:
37/// - Support 0 arguments
38/// - Return NULL if any input is NULL
39#[derive(Debug, PartialEq, Eq, Hash)]
40pub struct SparkConcat {
41    signature: Signature,
42}
43
44impl Default for SparkConcat {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl SparkConcat {
51    pub fn new() -> Self {
52        Self {
53            signature: Signature::one_of(
54                vec![TypeSignature::UserDefined, TypeSignature::Nullary],
55                Volatility::Immutable,
56            ),
57        }
58    }
59}
60
61impl ScalarUDFImpl for SparkConcat {
62    fn as_any(&self) -> &dyn Any {
63        self
64    }
65
66    fn name(&self) -> &str {
67        "concat"
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75        Ok(DataType::Utf8)
76    }
77
78    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
79        spark_concat(args)
80    }
81
82    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
83        // Accept any string types, including zero arguments
84        Ok(arg_types.to_vec())
85    }
86}
87
88/// Represents the null state for Spark concat
89enum NullMaskResolution {
90    /// Return NULL as the result (e.g., scalar inputs with at least one NULL)
91    ReturnNull,
92    /// No null mask needed (e.g., all scalar inputs are non-NULL)
93    NoMask,
94    /// Null mask to apply for arrays
95    Apply(NullBuffer),
96}
97
98/// Concatenates strings, returning NULL if any input is NULL
99/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
100/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
101fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102    let ScalarFunctionArgs {
103        args: arg_values,
104        arg_fields,
105        number_rows,
106        return_field,
107        config_options,
108    } = args;
109
110    // Handle zero-argument case: return empty string
111    if arg_values.is_empty() {
112        return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
113            Some(String::new()),
114        )));
115    }
116
117    // Step 1: Check for NULL mask in incoming args
118    let null_mask = compute_null_mask(&arg_values, number_rows)?;
119
120    // If all scalars and any is NULL, return NULL immediately
121    if matches!(null_mask, NullMaskResolution::ReturnNull) {
122        return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
123    }
124
125    // Step 2: Delegate to DataFusion's concat
126    let concat_func = ConcatFunc::new();
127    let func_args = ScalarFunctionArgs {
128        args: arg_values,
129        arg_fields,
130        number_rows,
131        return_field,
132        config_options,
133    };
134    let result = concat_func.invoke_with_args(func_args)?;
135
136    // Step 3: Apply NULL mask to result
137    apply_null_mask(result, null_mask)
138}
139
140/// Compute NULL mask for the arguments using NullBuffer::union
141fn compute_null_mask(
142    args: &[ColumnarValue],
143    number_rows: usize,
144) -> Result<NullMaskResolution> {
145    // Check if all arguments are scalars
146    let all_scalars = args
147        .iter()
148        .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
149
150    if all_scalars {
151        // For scalars, check if any is NULL
152        for arg in args {
153            if let ColumnarValue::Scalar(scalar) = arg {
154                if scalar.is_null() {
155                    return Ok(NullMaskResolution::ReturnNull);
156                }
157            }
158        }
159        // No NULLs in scalars
160        Ok(NullMaskResolution::NoMask)
161    } else {
162        // For arrays, compute NULL mask for each row using NullBuffer::union
163        let array_len = args
164            .iter()
165            .find_map(|arg| match arg {
166                ColumnarValue::Array(array) => Some(array.len()),
167                _ => None,
168            })
169            .unwrap_or(number_rows);
170
171        // Convert all scalars to arrays for uniform processing
172        let arrays: Result<Vec<_>> = args
173            .iter()
174            .map(|arg| match arg {
175                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
176                ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
177            })
178            .collect();
179        let arrays = arrays?;
180
181        // Use NullBuffer::union to combine all null buffers
182        let combined_nulls = arrays
183            .iter()
184            .map(|arr| arr.nulls())
185            .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
186
187        match combined_nulls {
188            Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
189            None => Ok(NullMaskResolution::NoMask),
190        }
191    }
192}
193
194/// Apply NULL mask to the result using NullBuffer::union
195fn apply_null_mask(
196    result: ColumnarValue,
197    null_mask: NullMaskResolution,
198) -> Result<ColumnarValue> {
199    match (result, null_mask) {
200        // Scalar with ReturnNull mask means return NULL
201        (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
202            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
203        }
204        // Scalar without mask, return as-is
205        (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
206        // Array with NULL mask - use NullBuffer::union to combine nulls
207        (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
208            // Combine the result's existing nulls with our computed null mask
209            let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
210
211            // Create new array with combined nulls
212            let new_array = array
213                .into_data()
214                .into_builder()
215                .nulls(combined_nulls)
216                .build()?;
217
218            Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
219                new_array,
220            ))))
221        }
222        // Array without NULL mask, return as-is
223        (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
224        // Edge cases that shouldn't happen in practice
225        (scalar, _) => Ok(scalar),
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::function::utils::test::test_scalar_function;
233    use arrow::array::StringArray;
234    use arrow::datatypes::DataType;
235    use datafusion_common::Result;
236
237    #[test]
238    fn test_concat_basic() -> Result<()> {
239        test_scalar_function!(
240            SparkConcat::new(),
241            vec![
242                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
243                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
244            ],
245            Ok(Some("SparkSQL")),
246            &str,
247            DataType::Utf8,
248            StringArray
249        );
250        Ok(())
251    }
252
253    #[test]
254    fn test_concat_with_null() -> Result<()> {
255        test_scalar_function!(
256            SparkConcat::new(),
257            vec![
258                ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
259                ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
260                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
261            ],
262            Ok(None),
263            &str,
264            DataType::Utf8,
265            StringArray
266        );
267        Ok(())
268    }
269}