Skip to main content

datafusion_functions/string/
repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::strings::{
19    BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder,
20};
21use crate::utils::utf8_to_str_type;
22use arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringArrayType};
23use arrow::buffer::NullBuffer;
24use arrow::datatypes::DataType;
25use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::types::{NativeType, logical_int64, logical_string};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err};
30use datafusion_expr::{ColumnarValue, Documentation, Volatility};
31use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
32use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36    doc_section(label = "String Functions"),
37    description = "Returns a string with an input string repeated a specified number.",
38    syntax_example = "repeat(str, n)",
39    sql_example = r#"```sql
40> select repeat('data', 3);
41+-------------------------------+
42| repeat(Utf8("data"),Int64(3)) |
43+-------------------------------+
44| datadatadata                  |
45+-------------------------------+
46```"#,
47    standard_argument(name = "str", prefix = "String"),
48    argument(
49        name = "n",
50        description = "Number of times to repeat the input string."
51    )
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct RepeatFunc {
55    signature: Signature,
56}
57
58impl Default for RepeatFunc {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl RepeatFunc {
65    pub fn new() -> Self {
66        Self {
67            signature: Signature::coercible(
68                vec![
69                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70                    // Accept all integer types but cast them to i64
71                    Coercion::new_implicit(
72                        TypeSignatureClass::Native(logical_int64()),
73                        vec![TypeSignatureClass::Integer],
74                        NativeType::Int64,
75                    ),
76                ],
77                Volatility::Immutable,
78            ),
79        }
80    }
81}
82
83impl ScalarUDFImpl for RepeatFunc {
84    fn name(&self) -> &str {
85        "repeat"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93        if arg_types[0] == Utf8View {
94            return Ok(Utf8View);
95        }
96        utf8_to_str_type(&arg_types[0], "repeat")
97    }
98
99    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100        let return_type = args.return_field.data_type().clone();
101        let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;
102
103        // Early return if either argument is a scalar null
104        if let ColumnarValue::Scalar(s) = &string_arg
105            && s.is_null()
106        {
107            return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
108        }
109        if let ColumnarValue::Scalar(c) = &count_arg
110            && c.is_null()
111        {
112            return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
113        }
114
115        match (&string_arg, &count_arg) {
116            (
117                ColumnarValue::Scalar(string_scalar),
118                ColumnarValue::Scalar(count_scalar),
119            ) => {
120                let count = match count_scalar {
121                    ScalarValue::Int64(Some(n)) => *n,
122                    _ => {
123                        return internal_err!(
124                            "Unexpected data type {:?} for repeat count",
125                            count_scalar.data_type()
126                        );
127                    }
128                };
129
130                let result = match string_scalar {
131                    ScalarValue::Utf8View(Some(s)) => ScalarValue::Utf8View(Some(
132                        compute_repeat(s, count, i32::MAX as usize)?,
133                    )),
134                    ScalarValue::Utf8(Some(s)) => ScalarValue::Utf8(Some(
135                        compute_repeat(s, count, i32::MAX as usize)?,
136                    )),
137                    ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
138                        compute_repeat(s, count, i64::MAX as usize)?,
139                    )),
140                    _ => {
141                        return internal_err!(
142                            "Unexpected data type {:?} for function repeat",
143                            string_scalar.data_type()
144                        );
145                    }
146                };
147
148                Ok(ColumnarValue::Scalar(result))
149            }
150            _ => {
151                let string_array = string_arg.to_array(args.number_rows)?;
152                let count_array = count_arg.to_array(args.number_rows)?;
153                Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
154            }
155        }
156    }
157
158    fn documentation(&self) -> Option<&Documentation> {
159        self.doc()
160    }
161}
162
163/// Computes repeat for a single string value with max size check
164#[inline]
165fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result<String> {
166    if count <= 0 {
167        return Ok(String::new());
168    }
169    let result_len = s.len().saturating_mul(count as usize);
170    if result_len > max_size {
171        return exec_err!(
172            "string size overflow on repeat, max size is {}, but got {}",
173            max_size,
174            result_len
175        );
176    }
177    Ok(s.repeat(count as usize))
178}
179
180/// Repeats string the specified number of times.
181/// repeat('Pg', 4) = 'PgPgPgPg'
182fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
183    let number_array = as_int64_array(count_array)?;
184    match string_array.data_type() {
185        Utf8View => {
186            let string_view_array = string_array.as_string_view();
187            let (_, max_item_capacity) = calculate_capacities(
188                &string_view_array,
189                number_array,
190                i32::MAX as usize,
191            )?;
192            let builder = StringViewArrayBuilder::with_capacity(string_array.len());
193            repeat_impl(&string_view_array, number_array, max_item_capacity, builder)
194        }
195        Utf8 => {
196            let string_arr = string_array.as_string::<i32>();
197            let (total_capacity, max_item_capacity) =
198                calculate_capacities(&string_arr, number_array, i32::MAX as usize)?;
199            let builder = GenericStringArrayBuilder::<i32>::with_capacity(
200                string_array.len(),
201                total_capacity,
202            );
203            repeat_impl(&string_arr, number_array, max_item_capacity, builder)
204        }
205        LargeUtf8 => {
206            let string_arr = string_array.as_string::<i64>();
207            let (total_capacity, max_item_capacity) =
208                calculate_capacities(&string_arr, number_array, i64::MAX as usize)?;
209            let builder = GenericStringArrayBuilder::<i64>::with_capacity(
210                string_array.len(),
211                total_capacity,
212            );
213            repeat_impl(&string_arr, number_array, max_item_capacity, builder)
214        }
215        other => exec_err!(
216            "Unsupported data type {other:?} for function repeat. \
217        Expected Utf8, Utf8View or LargeUtf8."
218        ),
219    }
220}
221
222fn calculate_capacities<'a, S>(
223    string_array: &S,
224    number_array: &Int64Array,
225    max_str_len: usize,
226) -> Result<(usize, usize)>
227where
228    S: StringArrayType<'a>,
229{
230    let mut total_capacity = 0;
231    let mut max_item_capacity = 0;
232
233    string_array.iter().zip(number_array.iter()).try_for_each(
234        |(string, number)| -> Result<(), DataFusionError> {
235            match (string, number) {
236                (Some(string), Some(number)) if number >= 0 => {
237                    let item_capacity = string.len() * number as usize;
238                    if item_capacity > max_str_len {
239                        return exec_err!(
240                            "string size overflow on repeat, max size is {}, but got {}",
241                            max_str_len,
242                            number as usize * string.len()
243                        );
244                    }
245                    total_capacity += item_capacity;
246                    max_item_capacity = max_item_capacity.max(item_capacity);
247                }
248                _ => (),
249            }
250            Ok(())
251        },
252    )?;
253
254    Ok((total_capacity, max_item_capacity))
255}
256
257fn repeat_impl<'a, S, B>(
258    string_array: &S,
259    number_array: &Int64Array,
260    max_item_capacity: usize,
261    mut builder: B,
262) -> Result<ArrayRef>
263where
264    S: StringArrayType<'a> + 'a,
265    B: BulkNullStringArrayBuilder,
266{
267    // Reusable buffer to avoid allocations in string.repeat()
268    let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
269
270    // Helper function to repeat a string into a buffer using doubling strategy
271    // count must be > 0
272    #[inline]
273    fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
274        buffer.clear();
275        if !string.is_empty() {
276            let src = string.as_bytes();
277            // Initial copy
278            buffer.extend_from_slice(src);
279            // Doubling strategy: copy what we have so far until we reach the target
280            while buffer.len() < src.len() * count {
281                let copy_len = buffer.len().min(src.len() * count - buffer.len());
282                // SAFETY: we're copying valid UTF-8 bytes that we already verified
283                buffer.extend_from_within(..copy_len);
284            }
285        }
286    }
287
288    // Output is null IFF either input is null
289    let nulls = NullBuffer::union(string_array.nulls(), number_array.nulls());
290
291    if let Some(ref n) = nulls {
292        for i in 0..string_array.len() {
293            if n.is_null(i) {
294                builder.append_placeholder();
295                continue;
296            }
297            // SAFETY: index `i` in both arrays is valid
298            let string = unsafe { string_array.value_unchecked(i) };
299            let count = unsafe { number_array.value_unchecked(i) };
300            if count > 0 {
301                repeat_to_buffer(&mut buffer, string, count as usize);
302                // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
303                builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
304            } else {
305                builder.append_value("");
306            }
307        }
308    } else {
309        for i in 0..string_array.len() {
310            // SAFETY: no nulls, so every index in both arrays is valid
311            let string = unsafe { string_array.value_unchecked(i) };
312            let count = unsafe { number_array.value_unchecked(i) };
313            if count > 0 {
314                repeat_to_buffer(&mut buffer, string, count as usize);
315                // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
316                builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
317            } else {
318                builder.append_value("");
319            }
320        }
321    }
322
323    builder.finish(nulls)
324}
325
326#[cfg(test)]
327mod tests {
328    use std::sync::Arc;
329
330    use arrow::array::{
331        Array, ArrayRef, Int64Array, LargeStringArray, StringArray, StringViewArray,
332    };
333    use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
334
335    use datafusion_common::ScalarValue;
336    use datafusion_common::{Result, exec_err};
337    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
338
339    use crate::string::repeat::RepeatFunc;
340    use crate::utils::test::test_function;
341
342    #[test]
343    fn test_functions() -> Result<()> {
344        test_function!(
345            RepeatFunc::new(),
346            vec![
347                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
348                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
349            ],
350            Ok(Some("PgPgPgPg")),
351            &str,
352            Utf8,
353            StringArray
354        );
355        test_function!(
356            RepeatFunc::new(),
357            vec![
358                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
359                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
360            ],
361            Ok(None),
362            &str,
363            Utf8,
364            StringArray
365        );
366        test_function!(
367            RepeatFunc::new(),
368            vec![
369                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
370                ColumnarValue::Scalar(ScalarValue::Int64(None)),
371            ],
372            Ok(None),
373            &str,
374            Utf8,
375            StringArray
376        );
377
378        test_function!(
379            RepeatFunc::new(),
380            vec![
381                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
382                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
383            ],
384            Ok(Some("PgPgPgPg")),
385            &str,
386            Utf8View,
387            StringViewArray
388        );
389        test_function!(
390            RepeatFunc::new(),
391            vec![
392                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
393                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
394            ],
395            Ok(None),
396            &str,
397            Utf8View,
398            StringViewArray
399        );
400        test_function!(
401            RepeatFunc::new(),
402            vec![
403                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("Pg")))),
404                ColumnarValue::Scalar(ScalarValue::Int64(None)),
405            ],
406            Ok(None),
407            &str,
408            LargeUtf8,
409            LargeStringArray
410        );
411        test_function!(
412            RepeatFunc::new(),
413            vec![
414                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
415                ColumnarValue::Scalar(ScalarValue::Int64(None)),
416            ],
417            Ok(None),
418            &str,
419            Utf8View,
420            StringViewArray
421        );
422        test_function!(
423            RepeatFunc::new(),
424            vec![
425                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
426                ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
427            ],
428            exec_err!(
429                "string size overflow on repeat, max size is {}, but got {}",
430                i32::MAX,
431                2usize * 1073741824
432            ),
433            &str,
434            Utf8,
435            StringArray
436        );
437
438        Ok(())
439    }
440
441    // Slicing the input arrays produces a NullBuffer with a non-zero offset.
442    // The tests below use 6-row inputs sliced to (1, 4) so that:
443    //   slot 0 (orig 1): "a"  × 3    → "aaa"
444    //   slot 1 (orig 2): "bb" × 2    → "bbbb"
445    //   slot 2 (orig 3): "c"  × NULL → NULL (count-side null)
446    //   slot 3 (orig 4): NULL × 1    → NULL (string-side null)
447    fn sliced_offset_inputs<F>(make_strings: F) -> (ArrayRef, ArrayRef)
448    where
449        F: FnOnce(Vec<Option<&'static str>>) -> ArrayRef,
450    {
451        let strings = make_strings(vec![
452            None,
453            Some("a"),
454            Some("bb"),
455            Some("c"),
456            None,
457            Some("d"),
458        ]);
459        let counts: ArrayRef = Arc::new(Int64Array::from(vec![
460            Some(2),
461            Some(3),
462            Some(2),
463            None,
464            Some(1),
465            Some(2),
466        ]));
467        (strings.slice(1, 4), counts.slice(1, 4))
468    }
469
470    fn assert_sliced_offset_output<A: Array + 'static>(result: ArrayRef)
471    where
472        for<'a> &'a A: arrow::array::ArrayAccessor<Item = &'a str>,
473    {
474        let result = result.as_any().downcast_ref::<A>().unwrap();
475        assert_eq!(result.len(), 4);
476        assert_eq!(arrow::array::ArrayAccessor::value(&result, 0), "aaa");
477        assert_eq!(arrow::array::ArrayAccessor::value(&result, 1), "bbbb");
478        assert!(result.is_null(2));
479        assert!(result.is_null(3));
480        assert_eq!(result.null_count(), 2);
481    }
482
483    #[test]
484    fn test_repeat_sliced_string_with_null_offset() {
485        let (strings, counts) = sliced_offset_inputs(|v| Arc::new(StringArray::from(v)));
486        let result = super::repeat(&strings, &counts).unwrap();
487        assert_sliced_offset_output::<StringArray>(result);
488    }
489
490    #[test]
491    fn test_repeat_sliced_large_string_with_null_offset() {
492        let (strings, counts) =
493            sliced_offset_inputs(|v| Arc::new(LargeStringArray::from(v)));
494        let result = super::repeat(&strings, &counts).unwrap();
495        assert_sliced_offset_output::<LargeStringArray>(result);
496    }
497
498    #[test]
499    fn test_repeat_sliced_string_view_with_null_offset() {
500        let (strings, counts) =
501            sliced_offset_inputs(|v| Arc::new(StringViewArray::from(v)));
502        let result = super::repeat(&strings, &counts).unwrap();
503        assert_sliced_offset_output::<StringViewArray>(result);
504    }
505}