1use arrow::array::{Array, as_largestring_array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::sync::Arc;
22
23use crate::string::concat;
24use crate::strings::{
25 ColumnarValueRef, ConcatLargeStringBuilder, ConcatStringBuilder,
26 ConcatStringViewBuilder,
27};
28use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array};
29use datafusion_common::{
30 Result, ScalarValue, exec_datafusion_err, internal_err, plan_err,
31};
32use datafusion_expr::expr::ScalarFunction;
33use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
34use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
35use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Concatenates multiple strings together.",
41 syntax_example = "concat(str[, ..., str_n])",
42 sql_example = r#"```sql
43> select concat('data', 'f', 'us', 'ion');
44+-------------------------------------------------------+
45| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
46+-------------------------------------------------------+
47| datafusion |
48+-------------------------------------------------------+
49```"#,
50 standard_argument(name = "str", prefix = "String"),
51 argument(
52 name = "str_n",
53 description = "Subsequent string expressions to concatenate."
54 ),
55 related_udf(name = "concat_ws")
56)]
57#[derive(Debug, PartialEq, Eq, Hash)]
58pub struct ConcatFunc {
59 signature: Signature,
60}
61
62impl Default for ConcatFunc {
63 fn default() -> Self {
64 ConcatFunc::new()
65 }
66}
67
68impl ConcatFunc {
69 pub fn new() -> Self {
70 use DataType::*;
71 Self {
72 signature: Signature::variadic(
73 vec![Utf8View, Utf8, LargeUtf8, Binary],
74 Volatility::Immutable,
75 ),
76 }
77 }
78}
79
80fn deduce_return_type(arg_types: &[DataType]) -> DataType {
81 use DataType::*;
82 if arg_types.contains(&Utf8View) {
83 Utf8View
84 } else if arg_types.contains(&LargeUtf8) {
85 LargeUtf8
86 } else {
87 Utf8
88 }
89}
90
91impl ScalarUDFImpl for ConcatFunc {
92 fn name(&self) -> &str {
93 "concat"
94 }
95
96 fn signature(&self) -> &Signature {
97 &self.signature
98 }
99
100 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
104 Ok(deduce_return_type(arg_types))
105 }
106
107 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110 let ScalarFunctionArgs { args, .. } = args;
111
112 let arg_types: Vec<DataType> = args.iter().map(|c| c.data_type()).collect();
113 let return_datatype = deduce_return_type(&arg_types);
114
115 let array_len = args.iter().find_map(|x| match x {
116 ColumnarValue::Array(array) => Some(array.len()),
117 _ => None,
118 });
119
120 if array_len.is_none() {
122 let mut values: Vec<&[u8]> = Vec::with_capacity(args.len());
123 for arg in &args {
124 let ColumnarValue::Scalar(scalar) = arg else {
125 return internal_err!("concat expected scalar value, got {arg:?}");
126 };
127 if let ScalarValue::Binary(Some(value)) = scalar {
128 values.push(value);
129 } else {
130 match scalar.try_as_str() {
131 Some(Some(v)) => values.push(v.as_bytes()),
132 Some(None) => {} None => plan_err!(
134 "Concat function does not support scalar type {}",
135 scalar
136 )?,
137 }
138 }
139 }
140 let concat_bytes = values.concat();
141 let result = std::str::from_utf8(&concat_bytes)
142 .map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))?
143 .to_string();
144
145 return match return_datatype {
146 DataType::Utf8View => {
147 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
148 }
149 DataType::Utf8 => {
150 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
151 }
152 DataType::LargeUtf8 => {
153 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
154 }
155 other => {
156 plan_err!("Concat function does not support datatype of {other}")
157 }
158 };
159 }
160
161 let len = array_len.unwrap();
163 let mut data_size = 0;
164 let mut columns = Vec::with_capacity(args.len());
165
166 for arg in &args {
167 match arg {
168 ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
169 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
170 | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
171 if let Some(s) = maybe_value {
172 data_size += s.len() * len;
173 columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
174 }
175 }
176 ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) => {
177 if let Some(b) = maybe_value {
178 data_size += b.len() * len;
180 columns.push(ColumnarValueRef::Scalar(b.as_slice()));
181 }
182 }
183 ColumnarValue::Array(array) => {
184 match array.data_type() {
185 DataType::Utf8 => {
186 let string_array = as_string_array(array)?;
187
188 data_size += string_array.values().len();
189 let column = if array.is_nullable() {
190 ColumnarValueRef::NullableArray(string_array)
191 } else {
192 ColumnarValueRef::NonNullableArray(string_array)
193 };
194 columns.push(column);
195 }
196 DataType::LargeUtf8 => {
197 let string_array = as_largestring_array(array);
198
199 data_size += string_array.values().len();
200 let column = if array.is_nullable() {
201 ColumnarValueRef::NullableLargeStringArray(string_array)
202 } else {
203 ColumnarValueRef::NonNullableLargeStringArray(
204 string_array,
205 )
206 };
207 columns.push(column);
208 }
209 DataType::Utf8View => {
210 let string_array = as_string_view_array(array)?;
211
212 data_size += string_array.total_buffer_bytes_used();
215 let column = if array.is_nullable() {
216 ColumnarValueRef::NullableStringViewArray(string_array)
217 } else {
218 ColumnarValueRef::NonNullableStringViewArray(string_array)
219 };
220 columns.push(column);
221 }
222 DataType::Binary => {
223 let string_array = as_binary_array(array)?;
224
225 data_size += string_array.values().len();
226 let column = if array.is_nullable() {
227 ColumnarValueRef::NullableBinaryArray(string_array)
228 } else {
229 ColumnarValueRef::NonNullableBinaryArray(string_array)
230 };
231 columns.push(column);
232 }
233 other => {
234 return plan_err!(
235 "Input was {other} which is not a supported datatype for concat function"
236 );
237 }
238 };
239 }
240 _ => unreachable!("concat"),
241 }
242 }
243
244 match return_datatype {
245 DataType::Utf8 => {
246 let mut builder = ConcatStringBuilder::with_capacity(len, data_size);
247 for i in 0..len {
248 columns
249 .iter()
250 .for_each(|column| builder.write::<true>(column, i));
251 builder.append_offset()?;
252 }
253
254 let string_array = builder.finish(None)?;
255 Ok(ColumnarValue::Array(Arc::new(string_array)))
256 }
257 DataType::Utf8View => {
258 let mut builder = ConcatStringViewBuilder::with_capacity(len, data_size);
259 for i in 0..len {
260 columns
261 .iter()
262 .for_each(|column| builder.write::<true>(column, i));
263 builder.append_offset()?;
264 }
265
266 let string_array = builder.finish(None)?;
267 Ok(ColumnarValue::Array(Arc::new(string_array)))
268 }
269 DataType::LargeUtf8 => {
270 let mut builder = ConcatLargeStringBuilder::with_capacity(len, data_size);
271 for i in 0..len {
272 columns
273 .iter()
274 .for_each(|column| builder.write::<true>(column, i));
275 builder.append_offset()?;
276 }
277
278 let string_array = builder.finish(None)?;
279 Ok(ColumnarValue::Array(Arc::new(string_array)))
280 }
281 _ => unreachable!(),
282 }
283 }
284
285 fn simplify(
294 &self,
295 args: Vec<Expr>,
296 _info: &SimplifyContext,
297 ) -> Result<ExprSimplifyResult> {
298 simplify_concat(args)
299 }
300
301 fn documentation(&self) -> Option<&Documentation> {
302 self.doc()
303 }
304
305 fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
306 Ok(true)
307 }
308}
309
310pub(crate) fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
311 let mut new_args = Vec::with_capacity(args.len());
312 let mut contiguous_scalar = "".to_string();
313
314 let return_type = {
315 let data_types: Vec<_> = args
316 .iter()
317 .filter_map(|expr| match expr {
318 Expr::Literal(l, _) => Some(l.data_type()),
319 _ => None,
320 })
321 .collect();
322 ConcatFunc::new().return_type(&data_types)
323 }?;
324
325 for arg in args.clone() {
326 match arg {
327 Expr::Literal(ScalarValue::Utf8(None), _) => {}
328 Expr::Literal(ScalarValue::LargeUtf8(None), _) => {}
329 Expr::Literal(ScalarValue::Utf8View(None), _) => {}
330
331 Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
335 contiguous_scalar += &v;
336 }
337 Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
338 contiguous_scalar += &v;
339 }
340 Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
341 contiguous_scalar += &v;
342 }
343
344 Expr::Literal(x, _) => {
345 return internal_err!(
346 "The scalar {x} should be casted to string type during the type coercion."
347 );
348 }
349 arg => {
353 if !contiguous_scalar.is_empty() {
354 match return_type {
355 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
356 DataType::LargeUtf8 => new_args
357 .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
358 DataType::Utf8View => new_args
359 .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
360 _ => unreachable!(),
361 }
362 contiguous_scalar = "".to_string();
363 }
364 new_args.push(arg);
365 }
366 }
367 }
368
369 if !contiguous_scalar.is_empty() {
370 match return_type {
371 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
372 DataType::LargeUtf8 => {
373 new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
374 }
375 DataType::Utf8View => {
376 new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
377 }
378 _ => unreachable!(),
379 }
380 }
381
382 if !args.eq(&new_args) {
383 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
384 ScalarFunction {
385 func: concat(),
386 args: new_args,
387 },
388 )))
389 } else {
390 Ok(ExprSimplifyResult::Original(args))
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::utils::test::test_function;
398 use DataType::*;
399 use arrow::array::{ArrayRef, StringArray};
400 use arrow::array::{LargeStringArray, StringViewArray};
401 use arrow::datatypes::Field;
402 use datafusion_common::config::ConfigOptions;
403
404 #[test]
405 fn test_functions() -> Result<()> {
406 test_function!(
407 ConcatFunc::new(),
408 vec![
409 ColumnarValue::Scalar(ScalarValue::from("aa")),
410 ColumnarValue::Scalar(ScalarValue::from("bb")),
411 ColumnarValue::Scalar(ScalarValue::from("cc")),
412 ],
413 Ok(Some("aabbcc")),
414 &str,
415 Utf8,
416 StringArray
417 );
418 test_function!(
419 ConcatFunc::new(),
420 vec![
421 ColumnarValue::Scalar(ScalarValue::from("aa")),
422 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
423 ColumnarValue::Scalar(ScalarValue::from("cc")),
424 ],
425 Ok(Some("aacc")),
426 &str,
427 Utf8,
428 StringArray
429 );
430 test_function!(
431 ConcatFunc::new(),
432 vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
433 Ok(Some("")),
434 &str,
435 Utf8,
436 StringArray
437 );
438 test_function!(
439 ConcatFunc::new(),
440 vec![
441 ColumnarValue::Scalar(ScalarValue::from("aa")),
442 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
443 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
444 ColumnarValue::Scalar(ScalarValue::from("cc")),
445 ],
446 Ok(Some("aacc")),
447 &str,
448 Utf8View,
449 StringViewArray
450 );
451 test_function!(
452 ConcatFunc::new(),
453 vec![
454 ColumnarValue::Scalar(ScalarValue::from("aa")),
455 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
456 ColumnarValue::Scalar(ScalarValue::from("cc")),
457 ],
458 Ok(Some("aacc")),
459 &str,
460 LargeUtf8,
461 LargeStringArray
462 );
463 test_function!(
464 ConcatFunc::new(),
465 vec![
466 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
467 ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
468 ],
469 Ok(Some("aacc")),
470 &str,
471 Utf8View,
472 StringViewArray
473 );
474 test_function!(
475 ConcatFunc::new(),
476 vec![
477 ColumnarValue::Scalar(ScalarValue::Binary(Some(
478 "Café".as_bytes().into()
479 ))),
480 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
481 ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
482 ],
483 Ok(Some("Cafécc")),
484 &str,
485 Utf8,
486 StringArray
487 );
488 test_function!(
489 ConcatFunc::new(),
490 vec![
491 ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::from(
492 "Café".as_bytes()
493 )))),
494 ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))),
495 ],
496 Ok(Some("Cafécc")),
497 &str,
498 Utf8,
499 StringArray
500 );
501 Ok(())
502 }
503
504 #[test]
505 fn concat() -> Result<()> {
506 let c0 =
507 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
508 let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
509 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
510 Some("x"),
511 None,
512 Some("z"),
513 ])));
514 let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
515 let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
516 Some("a"),
517 None,
518 Some("b"),
519 ])));
520 let arg_fields = vec![
521 Field::new("a", Utf8, true),
522 Field::new("a", Utf8, true),
523 Field::new("a", Utf8, true),
524 Field::new("a", Utf8View, true),
525 Field::new("a", Utf8View, true),
526 ]
527 .into_iter()
528 .map(Arc::new)
529 .collect::<Vec<_>>();
530
531 let args = ScalarFunctionArgs {
532 args: vec![c0, c1, c2, c3, c4],
533 arg_fields,
534 number_rows: 3,
535 return_field: Field::new("f", Utf8, true).into(),
536 config_options: Arc::new(ConfigOptions::default()),
537 };
538
539 let result = ConcatFunc::new().invoke_with_args(args)?;
540 let expected =
541 Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
542 as ArrayRef;
543 match &result {
544 ColumnarValue::Array(array) => {
545 assert_eq!(&expected, array);
546 }
547 _ => panic!(),
548 }
549 Ok(())
550 }
551}