datafusion_spark/function/string/
concat.rs1use arrow::array::Array;
19use arrow::buffer::NullBuffer;
20use arrow::datatypes::{DataType, Field};
21use datafusion_common::arrow::datatypes::FieldRef;
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr::ReturnFieldArgs;
24use datafusion_expr::{
25 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
26 Volatility,
27};
28use datafusion_functions::string::concat::ConcatFunc;
29use std::any::Any;
30use std::sync::Arc;
31
32#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkConcat {
43 signature: Signature,
44}
45
46impl Default for SparkConcat {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl SparkConcat {
53 pub fn new() -> Self {
54 Self {
55 signature: Signature::one_of(
56 vec![TypeSignature::UserDefined, TypeSignature::Nullary],
57 Volatility::Immutable,
58 ),
59 }
60 }
61}
62
63impl ScalarUDFImpl for SparkConcat {
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67
68 fn name(&self) -> &str {
69 "concat"
70 }
71
72 fn signature(&self) -> &Signature {
73 &self.signature
74 }
75
76 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
77 spark_concat(args)
78 }
79
80 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
81 Ok(arg_types.to_vec())
83 }
84 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85 datafusion_common::internal_err!(
86 "return_type should not be called for Spark concat"
87 )
88 }
89 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
90 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
92
93 Ok(Arc::new(Field::new("concat", DataType::Utf8, nullable)))
94 }
95}
96
97enum NullMaskResolution {
99 ReturnNull,
101 NoMask,
103 Apply(NullBuffer),
105}
106
107fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
111 let ScalarFunctionArgs {
112 args: arg_values,
113 arg_fields,
114 number_rows,
115 return_field,
116 config_options,
117 } = args;
118
119 if arg_values.is_empty() {
121 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
122 Some(String::new()),
123 )));
124 }
125
126 let null_mask = compute_null_mask(&arg_values, number_rows)?;
128
129 if matches!(null_mask, NullMaskResolution::ReturnNull) {
131 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
132 }
133
134 let concat_func = ConcatFunc::new();
136 let func_args = ScalarFunctionArgs {
137 args: arg_values,
138 arg_fields,
139 number_rows,
140 return_field,
141 config_options,
142 };
143 let result = concat_func.invoke_with_args(func_args)?;
144
145 apply_null_mask(result, null_mask)
147}
148
149fn compute_null_mask(
151 args: &[ColumnarValue],
152 number_rows: usize,
153) -> Result<NullMaskResolution> {
154 let all_scalars = args
156 .iter()
157 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
158
159 if all_scalars {
160 for arg in args {
162 if let ColumnarValue::Scalar(scalar) = arg
163 && scalar.is_null()
164 {
165 return Ok(NullMaskResolution::ReturnNull);
166 }
167 }
168 Ok(NullMaskResolution::NoMask)
170 } else {
171 let array_len = args
173 .iter()
174 .find_map(|arg| match arg {
175 ColumnarValue::Array(array) => Some(array.len()),
176 _ => None,
177 })
178 .unwrap_or(number_rows);
179
180 let arrays: Result<Vec<_>> = args
182 .iter()
183 .map(|arg| match arg {
184 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
185 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
186 })
187 .collect();
188 let arrays = arrays?;
189
190 let combined_nulls = arrays
192 .iter()
193 .map(|arr| arr.nulls())
194 .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
195
196 match combined_nulls {
197 Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
198 None => Ok(NullMaskResolution::NoMask),
199 }
200 }
201}
202
203fn apply_null_mask(
205 result: ColumnarValue,
206 null_mask: NullMaskResolution,
207) -> Result<ColumnarValue> {
208 match (result, null_mask) {
209 (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
211 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
212 }
213 (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
215 (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
217 let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
219
220 let new_array = array
222 .into_data()
223 .into_builder()
224 .nulls(combined_nulls)
225 .build()?;
226
227 Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
228 new_array,
229 ))))
230 }
231 (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
233 (scalar, _) => Ok(scalar),
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::function::utils::test::test_scalar_function;
242 use arrow::array::StringArray;
243 use arrow::datatypes::{DataType, Field};
244 use datafusion_common::Result;
245 use datafusion_expr::ReturnFieldArgs;
246 use std::sync::Arc;
247
248 #[test]
249 fn test_concat_basic() -> Result<()> {
250 test_scalar_function!(
251 SparkConcat::new(),
252 vec![
253 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
254 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
255 ],
256 Ok(Some("SparkSQL")),
257 &str,
258 DataType::Utf8,
259 StringArray
260 );
261 Ok(())
262 }
263
264 #[test]
265 fn test_concat_with_null() -> Result<()> {
266 test_scalar_function!(
267 SparkConcat::new(),
268 vec![
269 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
270 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
271 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
272 ],
273 Ok(None),
274 &str,
275 DataType::Utf8,
276 StringArray
277 );
278 Ok(())
279 }
280 #[test]
281 fn test_spark_concat_return_field_non_nullable() -> Result<()> {
282 let func = SparkConcat::new();
283
284 let fields = vec![
285 Arc::new(Field::new("a", DataType::Utf8, false)),
286 Arc::new(Field::new("b", DataType::Utf8, false)),
287 ];
288
289 let args = ReturnFieldArgs {
290 arg_fields: &fields,
291 scalar_arguments: &[],
292 };
293
294 let field = func.return_field_from_args(args)?;
295
296 assert!(
297 !field.is_nullable(),
298 "Expected concat result to be non-nullable when all inputs are non-nullable"
299 );
300
301 Ok(())
302 }
303 #[test]
304 fn test_spark_concat_return_field_nullable() -> Result<()> {
305 let func = SparkConcat::new();
306
307 let fields = vec![
308 Arc::new(Field::new("a", DataType::Utf8, false)),
309 Arc::new(Field::new("b", DataType::Utf8, true)),
310 ];
311
312 let args = ReturnFieldArgs {
313 arg_fields: &fields,
314 scalar_arguments: &[],
315 };
316
317 let field = func.return_field_from_args(args)?;
318
319 assert!(
320 field.is_nullable(),
321 "Expected concat result to be nullable when any input is nullable"
322 );
323
324 Ok(())
325 }
326}