1use crate::strings::{
19 BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder,
20};
21use crate::utils::utf8_to_str_type;
22use arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringArrayType};
23use arrow::buffer::NullBuffer;
24use arrow::datatypes::DataType;
25use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::types::{NativeType, logical_int64, logical_string};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err};
30use datafusion_expr::{ColumnarValue, Documentation, Volatility};
31use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
32use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36 doc_section(label = "String Functions"),
37 description = "Returns a string with an input string repeated a specified number.",
38 syntax_example = "repeat(str, n)",
39 sql_example = r#"```sql
40> select repeat('data', 3);
41+-------------------------------+
42| repeat(Utf8("data"),Int64(3)) |
43+-------------------------------+
44| datadatadata |
45+-------------------------------+
46```"#,
47 standard_argument(name = "str", prefix = "String"),
48 argument(
49 name = "n",
50 description = "Number of times to repeat the input string."
51 )
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct RepeatFunc {
55 signature: Signature,
56}
57
58impl Default for RepeatFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl RepeatFunc {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::coercible(
68 vec![
69 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70 Coercion::new_implicit(
72 TypeSignatureClass::Native(logical_int64()),
73 vec![TypeSignatureClass::Integer],
74 NativeType::Int64,
75 ),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for RepeatFunc {
84 fn name(&self) -> &str {
85 "repeat"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93 if arg_types[0] == Utf8View {
94 return Ok(Utf8View);
95 }
96 utf8_to_str_type(&arg_types[0], "repeat")
97 }
98
99 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100 let return_type = args.return_field.data_type().clone();
101 let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;
102
103 if let ColumnarValue::Scalar(s) = &string_arg
105 && s.is_null()
106 {
107 return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
108 }
109 if let ColumnarValue::Scalar(c) = &count_arg
110 && c.is_null()
111 {
112 return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
113 }
114
115 match (&string_arg, &count_arg) {
116 (
117 ColumnarValue::Scalar(string_scalar),
118 ColumnarValue::Scalar(count_scalar),
119 ) => {
120 let count = match count_scalar {
121 ScalarValue::Int64(Some(n)) => *n,
122 _ => {
123 return internal_err!(
124 "Unexpected data type {:?} for repeat count",
125 count_scalar.data_type()
126 );
127 }
128 };
129
130 let result = match string_scalar {
131 ScalarValue::Utf8View(Some(s)) => ScalarValue::Utf8View(Some(
132 compute_repeat(s, count, i32::MAX as usize)?,
133 )),
134 ScalarValue::Utf8(Some(s)) => ScalarValue::Utf8(Some(
135 compute_repeat(s, count, i32::MAX as usize)?,
136 )),
137 ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
138 compute_repeat(s, count, i64::MAX as usize)?,
139 )),
140 _ => {
141 return internal_err!(
142 "Unexpected data type {:?} for function repeat",
143 string_scalar.data_type()
144 );
145 }
146 };
147
148 Ok(ColumnarValue::Scalar(result))
149 }
150 _ => {
151 let string_array = string_arg.to_array(args.number_rows)?;
152 let count_array = count_arg.to_array(args.number_rows)?;
153 Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
154 }
155 }
156 }
157
158 fn documentation(&self) -> Option<&Documentation> {
159 self.doc()
160 }
161}
162
163#[inline]
165fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result<String> {
166 if count <= 0 {
167 return Ok(String::new());
168 }
169 let result_len = s.len().saturating_mul(count as usize);
170 if result_len > max_size {
171 return exec_err!(
172 "string size overflow on repeat, max size is {}, but got {}",
173 max_size,
174 result_len
175 );
176 }
177 Ok(s.repeat(count as usize))
178}
179
180fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
183 let number_array = as_int64_array(count_array)?;
184 match string_array.data_type() {
185 Utf8View => {
186 let string_view_array = string_array.as_string_view();
187 let (_, max_item_capacity) = calculate_capacities(
188 &string_view_array,
189 number_array,
190 i32::MAX as usize,
191 )?;
192 let builder = StringViewArrayBuilder::with_capacity(string_array.len());
193 repeat_impl(&string_view_array, number_array, max_item_capacity, builder)
194 }
195 Utf8 => {
196 let string_arr = string_array.as_string::<i32>();
197 let (total_capacity, max_item_capacity) =
198 calculate_capacities(&string_arr, number_array, i32::MAX as usize)?;
199 let builder = GenericStringArrayBuilder::<i32>::with_capacity(
200 string_array.len(),
201 total_capacity,
202 );
203 repeat_impl(&string_arr, number_array, max_item_capacity, builder)
204 }
205 LargeUtf8 => {
206 let string_arr = string_array.as_string::<i64>();
207 let (total_capacity, max_item_capacity) =
208 calculate_capacities(&string_arr, number_array, i64::MAX as usize)?;
209 let builder = GenericStringArrayBuilder::<i64>::with_capacity(
210 string_array.len(),
211 total_capacity,
212 );
213 repeat_impl(&string_arr, number_array, max_item_capacity, builder)
214 }
215 other => exec_err!(
216 "Unsupported data type {other:?} for function repeat. \
217 Expected Utf8, Utf8View or LargeUtf8."
218 ),
219 }
220}
221
222fn calculate_capacities<'a, S>(
223 string_array: &S,
224 number_array: &Int64Array,
225 max_str_len: usize,
226) -> Result<(usize, usize)>
227where
228 S: StringArrayType<'a>,
229{
230 let mut total_capacity = 0;
231 let mut max_item_capacity = 0;
232
233 string_array.iter().zip(number_array.iter()).try_for_each(
234 |(string, number)| -> Result<(), DataFusionError> {
235 match (string, number) {
236 (Some(string), Some(number)) if number >= 0 => {
237 let item_capacity = string.len() * number as usize;
238 if item_capacity > max_str_len {
239 return exec_err!(
240 "string size overflow on repeat, max size is {}, but got {}",
241 max_str_len,
242 number as usize * string.len()
243 );
244 }
245 total_capacity += item_capacity;
246 max_item_capacity = max_item_capacity.max(item_capacity);
247 }
248 _ => (),
249 }
250 Ok(())
251 },
252 )?;
253
254 Ok((total_capacity, max_item_capacity))
255}
256
257fn repeat_impl<'a, S, B>(
258 string_array: &S,
259 number_array: &Int64Array,
260 max_item_capacity: usize,
261 mut builder: B,
262) -> Result<ArrayRef>
263where
264 S: StringArrayType<'a> + 'a,
265 B: BulkNullStringArrayBuilder,
266{
267 let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
269
270 #[inline]
273 fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
274 buffer.clear();
275 if !string.is_empty() {
276 let src = string.as_bytes();
277 buffer.extend_from_slice(src);
279 while buffer.len() < src.len() * count {
281 let copy_len = buffer.len().min(src.len() * count - buffer.len());
282 buffer.extend_from_within(..copy_len);
284 }
285 }
286 }
287
288 let nulls = NullBuffer::union(string_array.nulls(), number_array.nulls());
290
291 if let Some(ref n) = nulls {
292 for i in 0..string_array.len() {
293 if n.is_null(i) {
294 builder.append_placeholder();
295 continue;
296 }
297 let string = unsafe { string_array.value_unchecked(i) };
299 let count = unsafe { number_array.value_unchecked(i) };
300 if count > 0 {
301 repeat_to_buffer(&mut buffer, string, count as usize);
302 builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
304 } else {
305 builder.append_value("");
306 }
307 }
308 } else {
309 for i in 0..string_array.len() {
310 let string = unsafe { string_array.value_unchecked(i) };
312 let count = unsafe { number_array.value_unchecked(i) };
313 if count > 0 {
314 repeat_to_buffer(&mut buffer, string, count as usize);
315 builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
317 } else {
318 builder.append_value("");
319 }
320 }
321 }
322
323 builder.finish(nulls)
324}
325
326#[cfg(test)]
327mod tests {
328 use std::sync::Arc;
329
330 use arrow::array::{
331 Array, ArrayRef, Int64Array, LargeStringArray, StringArray, StringViewArray,
332 };
333 use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
334
335 use datafusion_common::ScalarValue;
336 use datafusion_common::{Result, exec_err};
337 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
338
339 use crate::string::repeat::RepeatFunc;
340 use crate::utils::test::test_function;
341
342 #[test]
343 fn test_functions() -> Result<()> {
344 test_function!(
345 RepeatFunc::new(),
346 vec![
347 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
348 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
349 ],
350 Ok(Some("PgPgPgPg")),
351 &str,
352 Utf8,
353 StringArray
354 );
355 test_function!(
356 RepeatFunc::new(),
357 vec![
358 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
359 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
360 ],
361 Ok(None),
362 &str,
363 Utf8,
364 StringArray
365 );
366 test_function!(
367 RepeatFunc::new(),
368 vec![
369 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
370 ColumnarValue::Scalar(ScalarValue::Int64(None)),
371 ],
372 Ok(None),
373 &str,
374 Utf8,
375 StringArray
376 );
377
378 test_function!(
379 RepeatFunc::new(),
380 vec![
381 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
382 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
383 ],
384 Ok(Some("PgPgPgPg")),
385 &str,
386 Utf8View,
387 StringViewArray
388 );
389 test_function!(
390 RepeatFunc::new(),
391 vec![
392 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
393 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
394 ],
395 Ok(None),
396 &str,
397 Utf8View,
398 StringViewArray
399 );
400 test_function!(
401 RepeatFunc::new(),
402 vec![
403 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("Pg")))),
404 ColumnarValue::Scalar(ScalarValue::Int64(None)),
405 ],
406 Ok(None),
407 &str,
408 LargeUtf8,
409 LargeStringArray
410 );
411 test_function!(
412 RepeatFunc::new(),
413 vec![
414 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
415 ColumnarValue::Scalar(ScalarValue::Int64(None)),
416 ],
417 Ok(None),
418 &str,
419 Utf8View,
420 StringViewArray
421 );
422 test_function!(
423 RepeatFunc::new(),
424 vec![
425 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
426 ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
427 ],
428 exec_err!(
429 "string size overflow on repeat, max size is {}, but got {}",
430 i32::MAX,
431 2usize * 1073741824
432 ),
433 &str,
434 Utf8,
435 StringArray
436 );
437
438 Ok(())
439 }
440
441 fn sliced_offset_inputs<F>(make_strings: F) -> (ArrayRef, ArrayRef)
448 where
449 F: FnOnce(Vec<Option<&'static str>>) -> ArrayRef,
450 {
451 let strings = make_strings(vec![
452 None,
453 Some("a"),
454 Some("bb"),
455 Some("c"),
456 None,
457 Some("d"),
458 ]);
459 let counts: ArrayRef = Arc::new(Int64Array::from(vec![
460 Some(2),
461 Some(3),
462 Some(2),
463 None,
464 Some(1),
465 Some(2),
466 ]));
467 (strings.slice(1, 4), counts.slice(1, 4))
468 }
469
470 fn assert_sliced_offset_output<A: Array + 'static>(result: ArrayRef)
471 where
472 for<'a> &'a A: arrow::array::ArrayAccessor<Item = &'a str>,
473 {
474 let result = result.as_any().downcast_ref::<A>().unwrap();
475 assert_eq!(result.len(), 4);
476 assert_eq!(arrow::array::ArrayAccessor::value(&result, 0), "aaa");
477 assert_eq!(arrow::array::ArrayAccessor::value(&result, 1), "bbbb");
478 assert!(result.is_null(2));
479 assert!(result.is_null(3));
480 assert_eq!(result.null_count(), 2);
481 }
482
483 #[test]
484 fn test_repeat_sliced_string_with_null_offset() {
485 let (strings, counts) = sliced_offset_inputs(|v| Arc::new(StringArray::from(v)));
486 let result = super::repeat(&strings, &counts).unwrap();
487 assert_sliced_offset_output::<StringArray>(result);
488 }
489
490 #[test]
491 fn test_repeat_sliced_large_string_with_null_offset() {
492 let (strings, counts) =
493 sliced_offset_inputs(|v| Arc::new(LargeStringArray::from(v)));
494 let result = super::repeat(&strings, &counts).unwrap();
495 assert_sliced_offset_output::<LargeStringArray>(result);
496 }
497
498 #[test]
499 fn test_repeat_sliced_string_view_with_null_offset() {
500 let (strings, counts) =
501 sliced_offset_inputs(|v| Arc::new(StringViewArray::from(v)));
502 let result = super::repeat(&strings, &counts).unwrap();
503 assert_sliced_offset_output::<StringViewArray>(result);
504 }
505}