1use arrow::array::{as_largestring_array, Array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::any::Any;
22use std::sync::Arc;
23
24use crate::string::concat;
25use crate::strings::{
26 ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
27};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Concatenates multiple strings together.",
39 syntax_example = "concat(str[, ..., str_n])",
40 sql_example = r#"```sql
41> select concat('data', 'f', 'us', 'ion');
42+-------------------------------------------------------+
43| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
44+-------------------------------------------------------+
45| datafusion |
46+-------------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(
50 name = "str_n",
51 description = "Subsequent string expressions to concatenate."
52 ),
53 related_udf(name = "concat_ws")
54)]
55#[derive(Debug)]
56pub struct ConcatFunc {
57 signature: Signature,
58}
59
60impl Default for ConcatFunc {
61 fn default() -> Self {
62 ConcatFunc::new()
63 }
64}
65
66impl ConcatFunc {
67 pub fn new() -> Self {
68 use DataType::*;
69 Self {
70 signature: Signature::variadic(
71 vec![Utf8View, Utf8, LargeUtf8],
72 Volatility::Immutable,
73 ),
74 }
75 }
76}
77
78impl ScalarUDFImpl for ConcatFunc {
79 fn as_any(&self) -> &dyn Any {
80 self
81 }
82
83 fn name(&self) -> &str {
84 "concat"
85 }
86
87 fn signature(&self) -> &Signature {
88 &self.signature
89 }
90
91 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92 use DataType::*;
93 let mut dt = &Utf8;
94 arg_types.iter().for_each(|data_type| {
95 if data_type == &Utf8View {
96 dt = data_type;
97 }
98 if data_type == &LargeUtf8 && dt != &Utf8View {
99 dt = data_type;
100 }
101 });
102
103 Ok(dt.to_owned())
104 }
105
106 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109 let ScalarFunctionArgs { args, .. } = args;
110
111 let mut return_datatype = DataType::Utf8;
112 args.iter().for_each(|col| {
113 if col.data_type() == DataType::Utf8View {
114 return_datatype = col.data_type();
115 }
116 if col.data_type() == DataType::LargeUtf8
117 && return_datatype != DataType::Utf8View
118 {
119 return_datatype = col.data_type();
120 }
121 });
122
123 let array_len = args
124 .iter()
125 .filter_map(|x| match x {
126 ColumnarValue::Array(array) => Some(array.len()),
127 _ => None,
128 })
129 .next();
130
131 if array_len.is_none() {
133 let mut result = String::new();
134 for arg in args {
135 let ColumnarValue::Scalar(scalar) = arg else {
136 return internal_err!("concat expected scalar value, got {arg:?}");
137 };
138
139 match scalar.try_as_str() {
140 Some(Some(v)) => result.push_str(v),
141 Some(None) => {} None => plan_err!(
143 "Concat function does not support scalar type {:?}",
144 scalar
145 )?,
146 }
147 }
148
149 return match return_datatype {
150 DataType::Utf8View => {
151 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
152 }
153 DataType::Utf8 => {
154 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
155 }
156 DataType::LargeUtf8 => {
157 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
158 }
159 other => {
160 plan_err!("Concat function does not support datatype of {other}")
161 }
162 };
163 }
164
165 let len = array_len.unwrap();
167 let mut data_size = 0;
168 let mut columns = Vec::with_capacity(args.len());
169
170 for arg in &args {
171 match arg {
172 ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
173 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
174 | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
175 if let Some(s) = maybe_value {
176 data_size += s.len() * len;
177 columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
178 }
179 }
180 ColumnarValue::Array(array) => {
181 match array.data_type() {
182 DataType::Utf8 => {
183 let string_array = as_string_array(array)?;
184
185 data_size += string_array.values().len();
186 let column = if array.is_nullable() {
187 ColumnarValueRef::NullableArray(string_array)
188 } else {
189 ColumnarValueRef::NonNullableArray(string_array)
190 };
191 columns.push(column);
192 },
193 DataType::LargeUtf8 => {
194 let string_array = as_largestring_array(array);
195
196 data_size += string_array.values().len();
197 let column = if array.is_nullable() {
198 ColumnarValueRef::NullableLargeStringArray(string_array)
199 } else {
200 ColumnarValueRef::NonNullableLargeStringArray(string_array)
201 };
202 columns.push(column);
203 },
204 DataType::Utf8View => {
205 let string_array = as_string_view_array(array)?;
206
207 data_size += string_array.len();
208 let column = if array.is_nullable() {
209 ColumnarValueRef::NullableStringViewArray(string_array)
210 } else {
211 ColumnarValueRef::NonNullableStringViewArray(string_array)
212 };
213 columns.push(column);
214 },
215 other => {
216 return plan_err!("Input was {other} which is not a supported datatype for concat function")
217 }
218 };
219 }
220 _ => unreachable!("concat"),
221 }
222 }
223
224 match return_datatype {
225 DataType::Utf8 => {
226 let mut builder = StringArrayBuilder::with_capacity(len, data_size);
227 for i in 0..len {
228 columns
229 .iter()
230 .for_each(|column| builder.write::<true>(column, i));
231 builder.append_offset();
232 }
233
234 let string_array = builder.finish(None);
235 Ok(ColumnarValue::Array(Arc::new(string_array)))
236 }
237 DataType::Utf8View => {
238 let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
239 for i in 0..len {
240 columns
241 .iter()
242 .for_each(|column| builder.write::<true>(column, i));
243 builder.append_offset();
244 }
245
246 let string_array = builder.finish();
247 Ok(ColumnarValue::Array(Arc::new(string_array)))
248 }
249 DataType::LargeUtf8 => {
250 let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
251 for i in 0..len {
252 columns
253 .iter()
254 .for_each(|column| builder.write::<true>(column, i));
255 builder.append_offset();
256 }
257
258 let string_array = builder.finish(None);
259 Ok(ColumnarValue::Array(Arc::new(string_array)))
260 }
261 _ => unreachable!(),
262 }
263 }
264
265 fn simplify(
274 &self,
275 args: Vec<Expr>,
276 _info: &dyn SimplifyInfo,
277 ) -> Result<ExprSimplifyResult> {
278 simplify_concat(args)
279 }
280
281 fn documentation(&self) -> Option<&Documentation> {
282 self.doc()
283 }
284
285 fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
286 Ok(true)
287 }
288}
289
290pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
291 let mut new_args = Vec::with_capacity(args.len());
292 let mut contiguous_scalar = "".to_string();
293
294 let return_type = {
295 let data_types: Vec<_> = args
296 .iter()
297 .filter_map(|expr| match expr {
298 Expr::Literal(l, _) => Some(l.data_type()),
299 _ => None,
300 })
301 .collect();
302 ConcatFunc::new().return_type(&data_types)
303 }?;
304
305 for arg in args.clone() {
306 match arg {
307 Expr::Literal(ScalarValue::Utf8(None), _) => {}
308 Expr::Literal(ScalarValue::LargeUtf8(None), _) => {
309 }
310 Expr::Literal(ScalarValue::Utf8View(None), _) => { }
311
312 Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
316 contiguous_scalar += &v;
317 }
318 Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
319 contiguous_scalar += &v;
320 }
321 Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
322 contiguous_scalar += &v;
323 }
324
325 Expr::Literal(x, _) => {
326 return internal_err!(
327 "The scalar {x} should be casted to string type during the type coercion."
328 )
329 }
330 arg => {
334 if !contiguous_scalar.is_empty() {
335 match return_type {
336 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
337 DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
338 DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
339 _ => unreachable!(),
340 }
341 contiguous_scalar = "".to_string();
342 }
343 new_args.push(arg);
344 }
345 }
346 }
347
348 if !contiguous_scalar.is_empty() {
349 match return_type {
350 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
351 DataType::LargeUtf8 => {
352 new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
353 }
354 DataType::Utf8View => {
355 new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
356 }
357 _ => unreachable!(),
358 }
359 }
360
361 if !args.eq(&new_args) {
362 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
363 ScalarFunction {
364 func: concat(),
365 args: new_args,
366 },
367 )))
368 } else {
369 Ok(ExprSimplifyResult::Original(args))
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::utils::test::test_function;
377 use arrow::array::{Array, LargeStringArray, StringViewArray};
378 use arrow::array::{ArrayRef, StringArray};
379 use arrow::datatypes::Field;
380 use DataType::*;
381
382 #[test]
383 fn test_functions() -> Result<()> {
384 test_function!(
385 ConcatFunc::new(),
386 vec![
387 ColumnarValue::Scalar(ScalarValue::from("aa")),
388 ColumnarValue::Scalar(ScalarValue::from("bb")),
389 ColumnarValue::Scalar(ScalarValue::from("cc")),
390 ],
391 Ok(Some("aabbcc")),
392 &str,
393 Utf8,
394 StringArray
395 );
396 test_function!(
397 ConcatFunc::new(),
398 vec![
399 ColumnarValue::Scalar(ScalarValue::from("aa")),
400 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
401 ColumnarValue::Scalar(ScalarValue::from("cc")),
402 ],
403 Ok(Some("aacc")),
404 &str,
405 Utf8,
406 StringArray
407 );
408 test_function!(
409 ConcatFunc::new(),
410 vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
411 Ok(Some("")),
412 &str,
413 Utf8,
414 StringArray
415 );
416 test_function!(
417 ConcatFunc::new(),
418 vec![
419 ColumnarValue::Scalar(ScalarValue::from("aa")),
420 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
421 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
422 ColumnarValue::Scalar(ScalarValue::from("cc")),
423 ],
424 Ok(Some("aacc")),
425 &str,
426 Utf8View,
427 StringViewArray
428 );
429 test_function!(
430 ConcatFunc::new(),
431 vec![
432 ColumnarValue::Scalar(ScalarValue::from("aa")),
433 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
434 ColumnarValue::Scalar(ScalarValue::from("cc")),
435 ],
436 Ok(Some("aacc")),
437 &str,
438 LargeUtf8,
439 LargeStringArray
440 );
441 test_function!(
442 ConcatFunc::new(),
443 vec![
444 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
445 ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
446 ],
447 Ok(Some("aacc")),
448 &str,
449 Utf8View,
450 StringViewArray
451 );
452
453 Ok(())
454 }
455
456 #[test]
457 fn concat() -> Result<()> {
458 let c0 =
459 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
460 let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
461 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
462 Some("x"),
463 None,
464 Some("z"),
465 ])));
466 let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
467 let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
468 Some("a"),
469 None,
470 Some("b"),
471 ])));
472 let arg_fields = vec![
473 Field::new("a", Utf8, true),
474 Field::new("a", Utf8, true),
475 Field::new("a", Utf8, true),
476 Field::new("a", Utf8View, true),
477 Field::new("a", Utf8View, true),
478 ]
479 .into_iter()
480 .map(Arc::new)
481 .collect::<Vec<_>>();
482
483 let args = ScalarFunctionArgs {
484 args: vec![c0, c1, c2, c3, c4],
485 arg_fields,
486 number_rows: 3,
487 return_field: Field::new("f", Utf8, true).into(),
488 };
489
490 let result = ConcatFunc::new().invoke_with_args(args)?;
491 let expected =
492 Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
493 as ArrayRef;
494 match &result {
495 ColumnarValue::Array(array) => {
496 assert_eq!(&expected, array);
497 }
498 _ => panic!(),
499 }
500 Ok(())
501 }
502}