1use arrow::array::{Array, as_largestring_array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::any::Any;
22use std::sync::Arc;
23
24use crate::string::concat;
25use crate::strings::{
26 ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
27};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Concatenates multiple strings together.",
39 syntax_example = "concat(str[, ..., str_n])",
40 sql_example = r#"```sql
41> select concat('data', 'f', 'us', 'ion');
42+-------------------------------------------------------+
43| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
44+-------------------------------------------------------+
45| datafusion |
46+-------------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(
50 name = "str_n",
51 description = "Subsequent string expressions to concatenate."
52 ),
53 related_udf(name = "concat_ws")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct ConcatFunc {
57 signature: Signature,
58}
59
60impl Default for ConcatFunc {
61 fn default() -> Self {
62 ConcatFunc::new()
63 }
64}
65
66impl ConcatFunc {
67 pub fn new() -> Self {
68 use DataType::*;
69 Self {
70 signature: Signature::variadic(
71 vec![Utf8View, Utf8, LargeUtf8],
72 Volatility::Immutable,
73 ),
74 }
75 }
76}
77
78impl ScalarUDFImpl for ConcatFunc {
79 fn as_any(&self) -> &dyn Any {
80 self
81 }
82
83 fn name(&self) -> &str {
84 "concat"
85 }
86
87 fn signature(&self) -> &Signature {
88 &self.signature
89 }
90
91 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92 use DataType::*;
93 let mut dt = &Utf8;
94 arg_types.iter().for_each(|data_type| {
95 if data_type == &Utf8View {
96 dt = data_type;
97 }
98 if data_type == &LargeUtf8 && dt != &Utf8View {
99 dt = data_type;
100 }
101 });
102
103 Ok(dt.to_owned())
104 }
105
106 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109 let ScalarFunctionArgs { args, .. } = args;
110
111 let mut return_datatype = DataType::Utf8;
112 args.iter().for_each(|col| {
113 if col.data_type() == DataType::Utf8View {
114 return_datatype = col.data_type();
115 }
116 if col.data_type() == DataType::LargeUtf8
117 && return_datatype != DataType::Utf8View
118 {
119 return_datatype = col.data_type();
120 }
121 });
122
123 let array_len = args
124 .iter()
125 .filter_map(|x| match x {
126 ColumnarValue::Array(array) => Some(array.len()),
127 _ => None,
128 })
129 .next();
130
131 if array_len.is_none() {
133 let mut result = String::new();
134 for arg in args {
135 let ColumnarValue::Scalar(scalar) = arg else {
136 return internal_err!("concat expected scalar value, got {arg:?}");
137 };
138
139 match scalar.try_as_str() {
140 Some(Some(v)) => result.push_str(v),
141 Some(None) => {} None => plan_err!(
143 "Concat function does not support scalar type {}",
144 scalar
145 )?,
146 }
147 }
148
149 return match return_datatype {
150 DataType::Utf8View => {
151 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
152 }
153 DataType::Utf8 => {
154 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
155 }
156 DataType::LargeUtf8 => {
157 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
158 }
159 other => {
160 plan_err!("Concat function does not support datatype of {other}")
161 }
162 };
163 }
164
165 let len = array_len.unwrap();
167 let mut data_size = 0;
168 let mut columns = Vec::with_capacity(args.len());
169
170 for arg in &args {
171 match arg {
172 ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
173 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
174 | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
175 if let Some(s) = maybe_value {
176 data_size += s.len() * len;
177 columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
178 }
179 }
180 ColumnarValue::Array(array) => {
181 match array.data_type() {
182 DataType::Utf8 => {
183 let string_array = as_string_array(array)?;
184
185 data_size += string_array.values().len();
186 let column = if array.is_nullable() {
187 ColumnarValueRef::NullableArray(string_array)
188 } else {
189 ColumnarValueRef::NonNullableArray(string_array)
190 };
191 columns.push(column);
192 }
193 DataType::LargeUtf8 => {
194 let string_array = as_largestring_array(array);
195
196 data_size += string_array.values().len();
197 let column = if array.is_nullable() {
198 ColumnarValueRef::NullableLargeStringArray(string_array)
199 } else {
200 ColumnarValueRef::NonNullableLargeStringArray(
201 string_array,
202 )
203 };
204 columns.push(column);
205 }
206 DataType::Utf8View => {
207 let string_array = as_string_view_array(array)?;
208
209 data_size += string_array.len();
210 let column = if array.is_nullable() {
211 ColumnarValueRef::NullableStringViewArray(string_array)
212 } else {
213 ColumnarValueRef::NonNullableStringViewArray(string_array)
214 };
215 columns.push(column);
216 }
217 other => {
218 return plan_err!(
219 "Input was {other} which is not a supported datatype for concat function"
220 );
221 }
222 };
223 }
224 _ => unreachable!("concat"),
225 }
226 }
227
228 match return_datatype {
229 DataType::Utf8 => {
230 let mut builder = StringArrayBuilder::with_capacity(len, data_size);
231 for i in 0..len {
232 columns
233 .iter()
234 .for_each(|column| builder.write::<true>(column, i));
235 builder.append_offset();
236 }
237
238 let string_array = builder.finish(None);
239 Ok(ColumnarValue::Array(Arc::new(string_array)))
240 }
241 DataType::Utf8View => {
242 let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
243 for i in 0..len {
244 columns
245 .iter()
246 .for_each(|column| builder.write::<true>(column, i));
247 builder.append_offset();
248 }
249
250 let string_array = builder.finish();
251 Ok(ColumnarValue::Array(Arc::new(string_array)))
252 }
253 DataType::LargeUtf8 => {
254 let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
255 for i in 0..len {
256 columns
257 .iter()
258 .for_each(|column| builder.write::<true>(column, i));
259 builder.append_offset();
260 }
261
262 let string_array = builder.finish(None);
263 Ok(ColumnarValue::Array(Arc::new(string_array)))
264 }
265 _ => unreachable!(),
266 }
267 }
268
269 fn simplify(
278 &self,
279 args: Vec<Expr>,
280 _info: &dyn SimplifyInfo,
281 ) -> Result<ExprSimplifyResult> {
282 simplify_concat(args)
283 }
284
285 fn documentation(&self) -> Option<&Documentation> {
286 self.doc()
287 }
288
289 fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
290 Ok(true)
291 }
292}
293
294pub(crate) fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
295 let mut new_args = Vec::with_capacity(args.len());
296 let mut contiguous_scalar = "".to_string();
297
298 let return_type = {
299 let data_types: Vec<_> = args
300 .iter()
301 .filter_map(|expr| match expr {
302 Expr::Literal(l, _) => Some(l.data_type()),
303 _ => None,
304 })
305 .collect();
306 ConcatFunc::new().return_type(&data_types)
307 }?;
308
309 for arg in args.clone() {
310 match arg {
311 Expr::Literal(ScalarValue::Utf8(None), _) => {}
312 Expr::Literal(ScalarValue::LargeUtf8(None), _) => {}
313 Expr::Literal(ScalarValue::Utf8View(None), _) => {}
314
315 Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
319 contiguous_scalar += &v;
320 }
321 Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
322 contiguous_scalar += &v;
323 }
324 Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
325 contiguous_scalar += &v;
326 }
327
328 Expr::Literal(x, _) => {
329 return internal_err!(
330 "The scalar {x} should be casted to string type during the type coercion."
331 );
332 }
333 arg => {
337 if !contiguous_scalar.is_empty() {
338 match return_type {
339 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
340 DataType::LargeUtf8 => new_args
341 .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
342 DataType::Utf8View => new_args
343 .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
344 _ => unreachable!(),
345 }
346 contiguous_scalar = "".to_string();
347 }
348 new_args.push(arg);
349 }
350 }
351 }
352
353 if !contiguous_scalar.is_empty() {
354 match return_type {
355 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
356 DataType::LargeUtf8 => {
357 new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
358 }
359 DataType::Utf8View => {
360 new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
361 }
362 _ => unreachable!(),
363 }
364 }
365
366 if !args.eq(&new_args) {
367 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
368 ScalarFunction {
369 func: concat(),
370 args: new_args,
371 },
372 )))
373 } else {
374 Ok(ExprSimplifyResult::Original(args))
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::utils::test::test_function;
382 use DataType::*;
383 use arrow::array::{Array, LargeStringArray, StringViewArray};
384 use arrow::array::{ArrayRef, StringArray};
385 use arrow::datatypes::Field;
386 use datafusion_common::config::ConfigOptions;
387
388 #[test]
389 fn test_functions() -> Result<()> {
390 test_function!(
391 ConcatFunc::new(),
392 vec![
393 ColumnarValue::Scalar(ScalarValue::from("aa")),
394 ColumnarValue::Scalar(ScalarValue::from("bb")),
395 ColumnarValue::Scalar(ScalarValue::from("cc")),
396 ],
397 Ok(Some("aabbcc")),
398 &str,
399 Utf8,
400 StringArray
401 );
402 test_function!(
403 ConcatFunc::new(),
404 vec![
405 ColumnarValue::Scalar(ScalarValue::from("aa")),
406 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
407 ColumnarValue::Scalar(ScalarValue::from("cc")),
408 ],
409 Ok(Some("aacc")),
410 &str,
411 Utf8,
412 StringArray
413 );
414 test_function!(
415 ConcatFunc::new(),
416 vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
417 Ok(Some("")),
418 &str,
419 Utf8,
420 StringArray
421 );
422 test_function!(
423 ConcatFunc::new(),
424 vec![
425 ColumnarValue::Scalar(ScalarValue::from("aa")),
426 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
427 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
428 ColumnarValue::Scalar(ScalarValue::from("cc")),
429 ],
430 Ok(Some("aacc")),
431 &str,
432 Utf8View,
433 StringViewArray
434 );
435 test_function!(
436 ConcatFunc::new(),
437 vec![
438 ColumnarValue::Scalar(ScalarValue::from("aa")),
439 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
440 ColumnarValue::Scalar(ScalarValue::from("cc")),
441 ],
442 Ok(Some("aacc")),
443 &str,
444 LargeUtf8,
445 LargeStringArray
446 );
447 test_function!(
448 ConcatFunc::new(),
449 vec![
450 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
451 ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
452 ],
453 Ok(Some("aacc")),
454 &str,
455 Utf8View,
456 StringViewArray
457 );
458
459 Ok(())
460 }
461
462 #[test]
463 fn concat() -> Result<()> {
464 let c0 =
465 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
466 let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
467 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
468 Some("x"),
469 None,
470 Some("z"),
471 ])));
472 let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
473 let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
474 Some("a"),
475 None,
476 Some("b"),
477 ])));
478 let arg_fields = vec![
479 Field::new("a", Utf8, true),
480 Field::new("a", Utf8, true),
481 Field::new("a", Utf8, true),
482 Field::new("a", Utf8View, true),
483 Field::new("a", Utf8View, true),
484 ]
485 .into_iter()
486 .map(Arc::new)
487 .collect::<Vec<_>>();
488
489 let args = ScalarFunctionArgs {
490 args: vec![c0, c1, c2, c3, c4],
491 arg_fields,
492 number_rows: 3,
493 return_field: Field::new("f", Utf8, true).into(),
494 config_options: Arc::new(ConfigOptions::default()),
495 };
496
497 let result = ConcatFunc::new().invoke_with_args(args)?;
498 let expected =
499 Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
500 as ArrayRef;
501 match &result {
502 ColumnarValue::Array(array) => {
503 assert_eq!(&expected, array);
504 }
505 _ => panic!(),
506 }
507 Ok(())
508 }
509}