datafusion_spark/function/string/
concat.rs1use arrow::datatypes::{DataType, Field};
19use datafusion_common::arrow::datatypes::FieldRef;
20use datafusion_common::{Result, ScalarValue};
21use datafusion_expr::ReturnFieldArgs;
22use datafusion_expr::{
23 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use datafusion_functions::string::concat::ConcatFunc;
26use std::sync::Arc;
27
28use crate::function::null_utils::{
29 NullMaskResolution, apply_null_mask, compute_null_mask,
30};
31
32#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkConcat {
43 signature: Signature,
44}
45
46impl Default for SparkConcat {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl SparkConcat {
53 pub fn new() -> Self {
54 Self {
55 signature: Signature::user_defined(Volatility::Immutable),
56 }
57 }
58}
59
60impl ScalarUDFImpl for SparkConcat {
61 fn name(&self) -> &str {
62 "concat"
63 }
64
65 fn signature(&self) -> &Signature {
66 &self.signature
67 }
68
69 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
70 spark_concat(args)
71 }
72
73 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
74 Ok(arg_types.to_vec())
76 }
77 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78 datafusion_common::internal_err!(
79 "return_type should not be called for Spark concat"
80 )
81 }
82 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
83 use DataType::*;
84
85 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
87
88 let mut dt = &Utf8;
90 for field in args.arg_fields {
91 let data_type = field.data_type();
92 if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) {
93 dt = data_type;
94 }
95 }
96
97 Ok(Arc::new(Field::new("concat", dt.clone(), nullable)))
98 }
99}
100
101fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105 let ScalarFunctionArgs {
106 args: arg_values,
107 arg_fields,
108 number_rows,
109 return_field,
110 config_options,
111 } = args;
112
113 if arg_values.is_empty() {
115 let return_type = return_field.data_type();
116 return match return_type {
117 DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
118 String::new(),
119 )))),
120 DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(
121 Some(String::new()),
122 ))),
123 _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
124 Some(String::new()),
125 ))),
126 };
127 }
128
129 let null_mask = compute_null_mask(&arg_values);
131
132 if matches!(null_mask, NullMaskResolution::ReturnNull) {
134 let return_type = return_field.data_type();
135 return match return_type {
136 DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))),
137 DataType::LargeUtf8 => {
138 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)))
139 }
140 _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
141 };
142 }
143
144 let concat_func = ConcatFunc::new();
146 let return_type = return_field.data_type().clone();
147 let func_args = ScalarFunctionArgs {
148 args: arg_values,
149 arg_fields,
150 number_rows,
151 return_field,
152 config_options,
153 };
154 let result = concat_func.invoke_with_args(func_args)?;
155
156 apply_null_mask(result, null_mask, &return_type)
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::function::utils::test::test_scalar_function;
164 use arrow::array::{Array, StringArray};
165
166 #[test]
167 fn test_concat_basic() -> Result<()> {
168 test_scalar_function!(
169 SparkConcat::new(),
170 vec![
171 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
172 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
173 ],
174 Ok(Some("SparkSQL")),
175 &str,
176 DataType::Utf8,
177 StringArray
178 );
179 Ok(())
180 }
181
182 #[test]
183 fn test_concat_with_null() -> Result<()> {
184 test_scalar_function!(
185 SparkConcat::new(),
186 vec![
187 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
188 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
189 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
190 ],
191 Ok(None),
192 &str,
193 DataType::Utf8,
194 StringArray
195 );
196 Ok(())
197 }
198
199 #[test]
200 fn test_spark_concat_return_field_non_nullable() -> Result<()> {
201 let func = SparkConcat::new();
202
203 let fields = vec![
204 Arc::new(Field::new("a", DataType::Utf8, false)),
205 Arc::new(Field::new("b", DataType::Utf8, false)),
206 ];
207
208 let args = ReturnFieldArgs {
209 arg_fields: &fields,
210 scalar_arguments: &[],
211 };
212
213 let field = func.return_field_from_args(args)?;
214
215 assert!(
216 !field.is_nullable(),
217 "Expected concat result to be non-nullable when all inputs are non-nullable"
218 );
219
220 Ok(())
221 }
222 #[test]
223 fn test_spark_concat_return_field_nullable() -> Result<()> {
224 let func = SparkConcat::new();
225
226 let fields = vec![
227 Arc::new(Field::new("a", DataType::Utf8, false)),
228 Arc::new(Field::new("b", DataType::Utf8, true)),
229 ];
230
231 let args = ReturnFieldArgs {
232 arg_fields: &fields,
233 scalar_arguments: &[],
234 };
235
236 let field = func.return_field_from_args(args)?;
237
238 assert!(
239 field.is_nullable(),
240 "Expected concat result to be nullable when any input is nullable"
241 );
242
243 Ok(())
244 }
245}