datafusion_spark/function/string/
concat.rs1use arrow::datatypes::{DataType, Field};
19use datafusion_common::arrow::datatypes::FieldRef;
20use datafusion_common::{Result, ScalarValue};
21use datafusion_expr::ReturnFieldArgs;
22use datafusion_expr::{
23 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use datafusion_functions::string::concat::ConcatFunc;
26use std::any::Any;
27use std::sync::Arc;
28
29use crate::function::null_utils::{
30 NullMaskResolution, apply_null_mask, compute_null_mask,
31};
32
33#[derive(Debug, PartialEq, Eq, Hash)]
43pub struct SparkConcat {
44 signature: Signature,
45}
46
47impl Default for SparkConcat {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SparkConcat {
54 pub fn new() -> Self {
55 Self {
56 signature: Signature::user_defined(Volatility::Immutable),
57 }
58 }
59}
60
61impl ScalarUDFImpl for SparkConcat {
62 fn as_any(&self) -> &dyn Any {
63 self
64 }
65
66 fn name(&self) -> &str {
67 "concat"
68 }
69
70 fn signature(&self) -> &Signature {
71 &self.signature
72 }
73
74 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
75 spark_concat(args)
76 }
77
78 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
79 Ok(arg_types.to_vec())
81 }
82 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
83 datafusion_common::internal_err!(
84 "return_type should not be called for Spark concat"
85 )
86 }
87 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
88 use DataType::*;
89
90 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
92
93 let mut dt = &Utf8;
95 for field in args.arg_fields {
96 let data_type = field.data_type();
97 if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) {
98 dt = data_type;
99 }
100 }
101
102 Ok(Arc::new(Field::new("concat", dt.clone(), nullable)))
103 }
104}
105
106fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110 let ScalarFunctionArgs {
111 args: arg_values,
112 arg_fields,
113 number_rows,
114 return_field,
115 config_options,
116 } = args;
117
118 if arg_values.is_empty() {
120 let return_type = return_field.data_type();
121 return match return_type {
122 DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
123 String::new(),
124 )))),
125 DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(
126 Some(String::new()),
127 ))),
128 _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
129 Some(String::new()),
130 ))),
131 };
132 }
133
134 let null_mask = compute_null_mask(&arg_values, number_rows)?;
136
137 if matches!(null_mask, NullMaskResolution::ReturnNull) {
139 let return_type = return_field.data_type();
140 return match return_type {
141 DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))),
142 DataType::LargeUtf8 => {
143 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)))
144 }
145 _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
146 };
147 }
148
149 let concat_func = ConcatFunc::new();
151 let return_type = return_field.data_type().clone();
152 let func_args = ScalarFunctionArgs {
153 args: arg_values,
154 arg_fields,
155 number_rows,
156 return_field,
157 config_options,
158 };
159 let result = concat_func.invoke_with_args(func_args)?;
160
161 apply_null_mask(result, null_mask, &return_type)
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::function::utils::test::test_scalar_function;
169 use arrow::array::{Array, StringArray};
170 use arrow::datatypes::{DataType, Field};
171 use datafusion_common::Result;
172 use datafusion_expr::ReturnFieldArgs;
173 use std::sync::Arc;
174
175 #[test]
176 fn test_concat_basic() -> Result<()> {
177 test_scalar_function!(
178 SparkConcat::new(),
179 vec![
180 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
181 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
182 ],
183 Ok(Some("SparkSQL")),
184 &str,
185 DataType::Utf8,
186 StringArray
187 );
188 Ok(())
189 }
190
191 #[test]
192 fn test_concat_with_null() -> Result<()> {
193 test_scalar_function!(
194 SparkConcat::new(),
195 vec![
196 ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
197 ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
198 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
199 ],
200 Ok(None),
201 &str,
202 DataType::Utf8,
203 StringArray
204 );
205 Ok(())
206 }
207
208 #[test]
209 fn test_spark_concat_return_field_non_nullable() -> Result<()> {
210 let func = SparkConcat::new();
211
212 let fields = vec![
213 Arc::new(Field::new("a", DataType::Utf8, false)),
214 Arc::new(Field::new("b", DataType::Utf8, false)),
215 ];
216
217 let args = ReturnFieldArgs {
218 arg_fields: &fields,
219 scalar_arguments: &[],
220 };
221
222 let field = func.return_field_from_args(args)?;
223
224 assert!(
225 !field.is_nullable(),
226 "Expected concat result to be non-nullable when all inputs are non-nullable"
227 );
228
229 Ok(())
230 }
231 #[test]
232 fn test_spark_concat_return_field_nullable() -> Result<()> {
233 let func = SparkConcat::new();
234
235 let fields = vec![
236 Arc::new(Field::new("a", DataType::Utf8, false)),
237 Arc::new(Field::new("b", DataType::Utf8, true)),
238 ];
239
240 let args = ReturnFieldArgs {
241 arg_fields: &fields,
242 scalar_arguments: &[],
243 };
244
245 let field = func.return_field_from_args(args)?;
246
247 assert!(
248 field.is_nullable(),
249 "Expected concat result to be nullable when any input is nullable"
250 );
251
252 Ok(())
253 }
254}