datafusion_spark/function/string/
concat.rs1use arrow::array::Array;
19use arrow::buffer::NullBuffer;
20use arrow::datatypes::DataType;
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{
23 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
24 Volatility,
25};
26use datafusion_functions::string::concat::ConcatFunc;
27use std::any::Any;
28use std::sync::Arc;
29
30#[derive(Debug, PartialEq, Eq, Hash)]
40pub struct SparkConcat {
41 signature: Signature,
42}
43
44impl Default for SparkConcat {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl SparkConcat {
51 pub fn new() -> Self {
52 Self {
53 signature: Signature::one_of(
54 vec![TypeSignature::UserDefined, TypeSignature::Nullary],
55 Volatility::Immutable,
56 ),
57 }
58 }
59}
60
61impl ScalarUDFImpl for SparkConcat {
62 fn as_any(&self) -> &dyn Any {
63 self
64 }
65
66 fn name(&self) -> &str {
67 "concat"
68 }
69
70 fn signature(&self) -> &Signature {
71 &self.signature
72 }
73
74 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75 Ok(DataType::Utf8)
76 }
77
78 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
79 spark_concat(args)
80 }
81
82 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
83 Ok(arg_types.to_vec())
85 }
86}
87
88enum NullMaskResolution {
90 ReturnNull,
92 NoMask,
94 Apply(NullBuffer),
96}
97
98fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102 let ScalarFunctionArgs {
103 args: arg_values,
104 arg_fields,
105 number_rows,
106 return_field,
107 config_options,
108 } = args;
109
110 if arg_values.is_empty() {
112 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
113 Some(String::new()),
114 )));
115 }
116
117 let null_mask = compute_null_mask(&arg_values, number_rows)?;
119
120 if matches!(null_mask, NullMaskResolution::ReturnNull) {
122 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
123 }
124
125 let concat_func = ConcatFunc::new();
127 let func_args = ScalarFunctionArgs {
128 args: arg_values,
129 arg_fields,
130 number_rows,
131 return_field,
132 config_options,
133 };
134 let result = concat_func.invoke_with_args(func_args)?;
135
136 apply_null_mask(result, null_mask)
138}
139
140fn compute_null_mask(
142 args: &[ColumnarValue],
143 number_rows: usize,
144) -> Result<NullMaskResolution> {
145 let all_scalars = args
147 .iter()
148 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
149
150 if all_scalars {
151 for arg in args {
153 if let ColumnarValue::Scalar(scalar) = arg {
154 if scalar.is_null() {
155 return Ok(NullMaskResolution::ReturnNull);
156 }
157 }
158 }
159 Ok(NullMaskResolution::NoMask)
161 } else {
162 let array_len = args
164 .iter()
165 .find_map(|arg| match arg {
166 ColumnarValue::Array(array) => Some(array.len()),
167 _ => None,
168 })
169 .unwrap_or(number_rows);
170
171 let arrays: Result<Vec<_>> = args
173 .iter()
174 .map(|arg| match arg {
175 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
176 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
177 })
178 .collect();
179 let arrays = arrays?;
180
181 let combined_nulls = arrays
183 .iter()
184 .map(|arr| arr.nulls())
185 .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
186
187 match combined_nulls {
188 Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
189 None => Ok(NullMaskResolution::NoMask),
190 }
191 }
192}
193
194fn apply_null_mask(
196 result: ColumnarValue,
197 null_mask: NullMaskResolution,
198) -> Result<ColumnarValue> {
199 match (result, null_mask) {
200 (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
202 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
203 }
204 (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
206 (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
208 let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
210
211 let new_array = array
213 .into_data()
214 .into_builder()
215 .nulls(combined_nulls)
216 .build()?;
217
218 Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
219 new_array,
220 ))))
221 }
222 (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
224 (scalar, _) => Ok(scalar),
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::function::utils::test::test_scalar_function;
233 use arrow::array::StringArray;
234 use arrow::datatypes::DataType;
235 use datafusion_common::Result;
236
237 #[test]
238 fn test_concat_basic() -> Result<()> {
239 test_scalar_function!(
240 SparkConcat::new(),
241 vec![
242 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
243 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
244 ],
245 Ok(Some("SparkSQL")),
246 &str,
247 DataType::Utf8,
248 StringArray
249 );
250 Ok(())
251 }
252
253 #[test]
254 fn test_concat_with_null() -> Result<()> {
255 test_scalar_function!(
256 SparkConcat::new(),
257 vec![
258 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
259 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
260 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
261 ],
262 Ok(None),
263 &str,
264 DataType::Utf8,
265 StringArray
266 );
267 Ok(())
268 }
269}