1use arrow::array::{as_largestring_array, Array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::any::Any;
22use std::sync::Arc;
23
24use crate::string::concat;
25use crate::strings::{
26 ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
27};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Concatenates multiple strings together.",
39 syntax_example = "concat(str[, ..., str_n])",
40 sql_example = r#"```sql
41> select concat('data', 'f', 'us', 'ion');
42+-------------------------------------------------------+
43| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
44+-------------------------------------------------------+
45| datafusion |
46+-------------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(
50 name = "str_n",
51 description = "Subsequent string expressions to concatenate."
52 ),
53 related_udf(name = "concat_ws")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct ConcatFunc {
57 signature: Signature,
58}
59
60impl Default for ConcatFunc {
61 fn default() -> Self {
62 ConcatFunc::new()
63 }
64}
65
66impl ConcatFunc {
67 pub fn new() -> Self {
68 use DataType::*;
69 Self {
70 signature: Signature::variadic(
71 vec![Utf8View, Utf8, LargeUtf8],
72 Volatility::Immutable,
73 ),
74 }
75 }
76}
77
78impl ScalarUDFImpl for ConcatFunc {
79 fn as_any(&self) -> &dyn Any {
80 self
81 }
82
83 fn name(&self) -> &str {
84 "concat"
85 }
86
87 fn signature(&self) -> &Signature {
88 &self.signature
89 }
90
91 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92 use DataType::*;
93 let mut dt = &Utf8;
94 arg_types.iter().for_each(|data_type| {
95 if data_type == &Utf8View {
96 dt = data_type;
97 }
98 if data_type == &LargeUtf8 && dt != &Utf8View {
99 dt = data_type;
100 }
101 });
102
103 Ok(dt.to_owned())
104 }
105
106 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109 let ScalarFunctionArgs { args, .. } = args;
110
111 let mut return_datatype = DataType::Utf8;
112 args.iter().for_each(|col| {
113 if col.data_type() == DataType::Utf8View {
114 return_datatype = col.data_type();
115 }
116 if col.data_type() == DataType::LargeUtf8
117 && return_datatype != DataType::Utf8View
118 {
119 return_datatype = col.data_type();
120 }
121 });
122
123 let array_len = args
124 .iter()
125 .filter_map(|x| match x {
126 ColumnarValue::Array(array) => Some(array.len()),
127 _ => None,
128 })
129 .next();
130
131 if array_len.is_none() {
133 let mut result = String::new();
134 for arg in args {
135 let ColumnarValue::Scalar(scalar) = arg else {
136 return internal_err!("concat expected scalar value, got {arg:?}");
137 };
138
139 match scalar.try_as_str() {
140 Some(Some(v)) => result.push_str(v),
141 Some(None) => {} None => plan_err!(
143 "Concat function does not support scalar type {:?}",
144 scalar
145 )?,
146 }
147 }
148
149 return match return_datatype {
150 DataType::Utf8View => {
151 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
152 }
153 DataType::Utf8 => {
154 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
155 }
156 DataType::LargeUtf8 => {
157 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
158 }
159 other => {
160 plan_err!("Concat function does not support datatype of {other}")
161 }
162 };
163 }
164
165 let len = array_len.unwrap();
167 let mut data_size = 0;
168 let mut columns = Vec::with_capacity(args.len());
169
170 for arg in &args {
171 match arg {
172 ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
173 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
174 | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
175 if let Some(s) = maybe_value {
176 data_size += s.len() * len;
177 columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
178 }
179 }
180 ColumnarValue::Array(array) => {
181 match array.data_type() {
182 DataType::Utf8 => {
183 let string_array = as_string_array(array)?;
184
185 data_size += string_array.values().len();
186 let column = if array.is_nullable() {
187 ColumnarValueRef::NullableArray(string_array)
188 } else {
189 ColumnarValueRef::NonNullableArray(string_array)
190 };
191 columns.push(column);
192 },
193 DataType::LargeUtf8 => {
194 let string_array = as_largestring_array(array);
195
196 data_size += string_array.values().len();
197 let column = if array.is_nullable() {
198 ColumnarValueRef::NullableLargeStringArray(string_array)
199 } else {
200 ColumnarValueRef::NonNullableLargeStringArray(string_array)
201 };
202 columns.push(column);
203 },
204 DataType::Utf8View => {
205 let string_array = as_string_view_array(array)?;
206
207 data_size += string_array.len();
208 let column = if array.is_nullable() {
209 ColumnarValueRef::NullableStringViewArray(string_array)
210 } else {
211 ColumnarValueRef::NonNullableStringViewArray(string_array)
212 };
213 columns.push(column);
214 },
215 other => {
216 return plan_err!("Input was {other} which is not a supported datatype for concat function")
217 }
218 };
219 }
220 _ => unreachable!("concat"),
221 }
222 }
223
224 match return_datatype {
225 DataType::Utf8 => {
226 let mut builder = StringArrayBuilder::with_capacity(len, data_size);
227 for i in 0..len {
228 columns
229 .iter()
230 .for_each(|column| builder.write::<true>(column, i));
231 builder.append_offset();
232 }
233
234 let string_array = builder.finish(None);
235 Ok(ColumnarValue::Array(Arc::new(string_array)))
236 }
237 DataType::Utf8View => {
238 let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
239 for i in 0..len {
240 columns
241 .iter()
242 .for_each(|column| builder.write::<true>(column, i));
243 builder.append_offset();
244 }
245
246 let string_array = builder.finish();
247 Ok(ColumnarValue::Array(Arc::new(string_array)))
248 }
249 DataType::LargeUtf8 => {
250 let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
251 for i in 0..len {
252 columns
253 .iter()
254 .for_each(|column| builder.write::<true>(column, i));
255 builder.append_offset();
256 }
257
258 let string_array = builder.finish(None);
259 Ok(ColumnarValue::Array(Arc::new(string_array)))
260 }
261 _ => unreachable!(),
262 }
263 }
264
265 fn simplify(
274 &self,
275 args: Vec<Expr>,
276 _info: &dyn SimplifyInfo,
277 ) -> Result<ExprSimplifyResult> {
278 simplify_concat(args)
279 }
280
281 fn documentation(&self) -> Option<&Documentation> {
282 self.doc()
283 }
284
285 fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
286 Ok(true)
287 }
288}
289
290pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
291 let mut new_args = Vec::with_capacity(args.len());
292 let mut contiguous_scalar = "".to_string();
293
294 let return_type = {
295 let data_types: Vec<_> = args
296 .iter()
297 .filter_map(|expr| match expr {
298 Expr::Literal(l, _) => Some(l.data_type()),
299 _ => None,
300 })
301 .collect();
302 ConcatFunc::new().return_type(&data_types)
303 }?;
304
305 for arg in args.clone() {
306 match arg {
307 Expr::Literal(ScalarValue::Utf8(None), _) => {}
308 Expr::Literal(ScalarValue::LargeUtf8(None), _) => {
309 }
310 Expr::Literal(ScalarValue::Utf8View(None), _) => { }
311
312 Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
316 contiguous_scalar += &v;
317 }
318 Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
319 contiguous_scalar += &v;
320 }
321 Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
322 contiguous_scalar += &v;
323 }
324
325 Expr::Literal(x, _) => {
326 return internal_err!(
327 "The scalar {x} should be casted to string type during the type coercion."
328 )
329 }
330 arg => {
334 if !contiguous_scalar.is_empty() {
335 match return_type {
336 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
337 DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
338 DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
339 _ => unreachable!(),
340 }
341 contiguous_scalar = "".to_string();
342 }
343 new_args.push(arg);
344 }
345 }
346 }
347
348 if !contiguous_scalar.is_empty() {
349 match return_type {
350 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
351 DataType::LargeUtf8 => {
352 new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
353 }
354 DataType::Utf8View => {
355 new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
356 }
357 _ => unreachable!(),
358 }
359 }
360
361 if !args.eq(&new_args) {
362 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
363 ScalarFunction {
364 func: concat(),
365 args: new_args,
366 },
367 )))
368 } else {
369 Ok(ExprSimplifyResult::Original(args))
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::utils::test::test_function;
377 use arrow::array::{Array, LargeStringArray, StringViewArray};
378 use arrow::array::{ArrayRef, StringArray};
379 use arrow::datatypes::Field;
380 use datafusion_common::config::ConfigOptions;
381 use DataType::*;
382
383 #[test]
384 fn test_functions() -> Result<()> {
385 test_function!(
386 ConcatFunc::new(),
387 vec![
388 ColumnarValue::Scalar(ScalarValue::from("aa")),
389 ColumnarValue::Scalar(ScalarValue::from("bb")),
390 ColumnarValue::Scalar(ScalarValue::from("cc")),
391 ],
392 Ok(Some("aabbcc")),
393 &str,
394 Utf8,
395 StringArray
396 );
397 test_function!(
398 ConcatFunc::new(),
399 vec![
400 ColumnarValue::Scalar(ScalarValue::from("aa")),
401 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
402 ColumnarValue::Scalar(ScalarValue::from("cc")),
403 ],
404 Ok(Some("aacc")),
405 &str,
406 Utf8,
407 StringArray
408 );
409 test_function!(
410 ConcatFunc::new(),
411 vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
412 Ok(Some("")),
413 &str,
414 Utf8,
415 StringArray
416 );
417 test_function!(
418 ConcatFunc::new(),
419 vec![
420 ColumnarValue::Scalar(ScalarValue::from("aa")),
421 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
422 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
423 ColumnarValue::Scalar(ScalarValue::from("cc")),
424 ],
425 Ok(Some("aacc")),
426 &str,
427 Utf8View,
428 StringViewArray
429 );
430 test_function!(
431 ConcatFunc::new(),
432 vec![
433 ColumnarValue::Scalar(ScalarValue::from("aa")),
434 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
435 ColumnarValue::Scalar(ScalarValue::from("cc")),
436 ],
437 Ok(Some("aacc")),
438 &str,
439 LargeUtf8,
440 LargeStringArray
441 );
442 test_function!(
443 ConcatFunc::new(),
444 vec![
445 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
446 ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
447 ],
448 Ok(Some("aacc")),
449 &str,
450 Utf8View,
451 StringViewArray
452 );
453
454 Ok(())
455 }
456
457 #[test]
458 fn concat() -> Result<()> {
459 let c0 =
460 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
461 let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
462 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
463 Some("x"),
464 None,
465 Some("z"),
466 ])));
467 let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
468 let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
469 Some("a"),
470 None,
471 Some("b"),
472 ])));
473 let arg_fields = vec![
474 Field::new("a", Utf8, true),
475 Field::new("a", Utf8, true),
476 Field::new("a", Utf8, true),
477 Field::new("a", Utf8View, true),
478 Field::new("a", Utf8View, true),
479 ]
480 .into_iter()
481 .map(Arc::new)
482 .collect::<Vec<_>>();
483
484 let args = ScalarFunctionArgs {
485 args: vec![c0, c1, c2, c3, c4],
486 arg_fields,
487 number_rows: 3,
488 return_field: Field::new("f", Utf8, true).into(),
489 config_options: Arc::new(ConfigOptions::default()),
490 };
491
492 let result = ConcatFunc::new().invoke_with_args(args)?;
493 let expected =
494 Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
495 as ArrayRef;
496 match &result {
497 ColumnarValue::Array(array) => {
498 assert_eq!(&expected, array);
499 }
500 _ => panic!(),
501 }
502 Ok(())
503 }
504}