Skip to main content

datafusion_functions/string/
repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::utf8_to_str_type;
22use arrow::array::{
23    Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
24    OffsetSizeTrait, StringArrayType, StringViewArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err};
32use datafusion_expr::{ColumnarValue, Documentation, Volatility};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
35use datafusion_macros::user_doc;
36
37#[user_doc(
38    doc_section(label = "String Functions"),
39    description = "Returns a string with an input string repeated a specified number.",
40    syntax_example = "repeat(str, n)",
41    sql_example = r#"```sql
42> select repeat('data', 3);
43+-------------------------------+
44| repeat(Utf8("data"),Int64(3)) |
45+-------------------------------+
46| datadatadata                  |
47+-------------------------------+
48```"#,
49    standard_argument(name = "str", prefix = "String"),
50    argument(
51        name = "n",
52        description = "Number of times to repeat the input string."
53    )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct RepeatFunc {
57    signature: Signature,
58}
59
60impl Default for RepeatFunc {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl RepeatFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::coercible(
70                vec![
71                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72                    // Accept all integer types but cast them to i64
73                    Coercion::new_implicit(
74                        TypeSignatureClass::Native(logical_int64()),
75                        vec![TypeSignatureClass::Integer],
76                        NativeType::Int64,
77                    ),
78                ],
79                Volatility::Immutable,
80            ),
81        }
82    }
83}
84
85impl ScalarUDFImpl for RepeatFunc {
86    fn as_any(&self) -> &dyn Any {
87        self
88    }
89
90    fn name(&self) -> &str {
91        "repeat"
92    }
93
94    fn signature(&self) -> &Signature {
95        &self.signature
96    }
97
98    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
99        utf8_to_str_type(&arg_types[0], "repeat")
100    }
101
102    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
103        let return_type = args.return_field.data_type().clone();
104        let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;
105
106        // Early return if either argument is a scalar null
107        if let ColumnarValue::Scalar(s) = &string_arg
108            && s.is_null()
109        {
110            return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
111        }
112        if let ColumnarValue::Scalar(c) = &count_arg
113            && c.is_null()
114        {
115            return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
116        }
117
118        match (&string_arg, &count_arg) {
119            (
120                ColumnarValue::Scalar(string_scalar),
121                ColumnarValue::Scalar(count_scalar),
122            ) => {
123                let count = match count_scalar {
124                    ScalarValue::Int64(Some(n)) => *n,
125                    _ => {
126                        return internal_err!(
127                            "Unexpected data type {:?} for repeat count",
128                            count_scalar.data_type()
129                        );
130                    }
131                };
132
133                let result = match string_scalar {
134                    ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
135                        ScalarValue::Utf8(Some(compute_repeat(
136                            s,
137                            count,
138                            i32::MAX as usize,
139                        )?))
140                    }
141                    ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
142                        compute_repeat(s, count, i64::MAX as usize)?,
143                    )),
144                    _ => {
145                        return internal_err!(
146                            "Unexpected data type {:?} for function repeat",
147                            string_scalar.data_type()
148                        );
149                    }
150                };
151
152                Ok(ColumnarValue::Scalar(result))
153            }
154            _ => {
155                let string_array = string_arg.to_array(args.number_rows)?;
156                let count_array = count_arg.to_array(args.number_rows)?;
157                Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
158            }
159        }
160    }
161
162    fn documentation(&self) -> Option<&Documentation> {
163        self.doc()
164    }
165}
166
167/// Computes repeat for a single string value with max size check
168#[inline]
169fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result<String> {
170    if count <= 0 {
171        return Ok(String::new());
172    }
173    let result_len = s.len().saturating_mul(count as usize);
174    if result_len > max_size {
175        return exec_err!(
176            "string size overflow on repeat, max size is {}, but got {}",
177            max_size,
178            result_len
179        );
180    }
181    Ok(s.repeat(count as usize))
182}
183
184/// Repeats string the specified number of times.
185/// repeat('Pg', 4) = 'PgPgPgPg'
186fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
187    let number_array = as_int64_array(count_array)?;
188    match string_array.data_type() {
189        Utf8View => {
190            let string_view_array = string_array.as_string_view();
191            repeat_impl::<i32, &StringViewArray>(
192                &string_view_array,
193                number_array,
194                i32::MAX as usize,
195            )
196        }
197        Utf8 => {
198            let string_arr = string_array.as_string::<i32>();
199            repeat_impl::<i32, &GenericStringArray<i32>>(
200                &string_arr,
201                number_array,
202                i32::MAX as usize,
203            )
204        }
205        LargeUtf8 => {
206            let string_arr = string_array.as_string::<i64>();
207            repeat_impl::<i64, &GenericStringArray<i64>>(
208                &string_arr,
209                number_array,
210                i64::MAX as usize,
211            )
212        }
213        other => exec_err!(
214            "Unsupported data type {other:?} for function repeat. \
215        Expected Utf8, Utf8View or LargeUtf8."
216        ),
217    }
218}
219
220fn repeat_impl<'a, T, S>(
221    string_array: &S,
222    number_array: &Int64Array,
223    max_str_len: usize,
224) -> Result<ArrayRef>
225where
226    T: OffsetSizeTrait,
227    S: StringArrayType<'a> + 'a,
228{
229    let mut total_capacity = 0;
230    let mut max_item_capacity = 0;
231    string_array.iter().zip(number_array.iter()).try_for_each(
232        |(string, number)| -> Result<(), DataFusionError> {
233            match (string, number) {
234                (Some(string), Some(number)) if number >= 0 => {
235                    let item_capacity = string.len() * number as usize;
236                    if item_capacity > max_str_len {
237                        return exec_err!(
238                            "string size overflow on repeat, max size is {}, but got {}",
239                            max_str_len,
240                            number as usize * string.len()
241                        );
242                    }
243                    total_capacity += item_capacity;
244                    max_item_capacity = max_item_capacity.max(item_capacity);
245                }
246                _ => (),
247            }
248            Ok(())
249        },
250    )?;
251
252    let mut builder =
253        GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
254
255    // Reusable buffer to avoid allocations in string.repeat()
256    let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
257
258    // Helper function to repeat a string into a buffer using doubling strategy
259    // count must be > 0
260    #[inline]
261    fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
262        buffer.clear();
263        if !string.is_empty() {
264            let src = string.as_bytes();
265            // Initial copy
266            buffer.extend_from_slice(src);
267            // Doubling strategy: copy what we have so far until we reach the target
268            while buffer.len() < src.len() * count {
269                let copy_len = buffer.len().min(src.len() * count - buffer.len());
270                // SAFETY: we're copying valid UTF-8 bytes that we already verified
271                buffer.extend_from_within(..copy_len);
272            }
273        }
274    }
275
276    // Fast path: no nulls in either array
277    if string_array.null_count() == 0 && number_array.null_count() == 0 {
278        for i in 0..string_array.len() {
279            // SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value
280            let string = unsafe { string_array.value_unchecked(i) };
281            let count = number_array.value(i);
282            if count > 0 {
283                repeat_to_buffer(&mut buffer, string, count as usize);
284                // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
285                builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
286            } else {
287                builder.append_value("");
288            }
289        }
290    } else {
291        // Slow path: handle nulls
292        for (string, number) in string_array.iter().zip(number_array.iter()) {
293            match (string, number) {
294                (Some(string), Some(count)) if count > 0 => {
295                    repeat_to_buffer(&mut buffer, string, count as usize);
296                    // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
297                    builder
298                        .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
299                }
300                (Some(_), Some(_)) => builder.append_value(""),
301                _ => builder.append_null(),
302            }
303        }
304    }
305
306    Ok(Arc::new(builder.finish()) as ArrayRef)
307}
308
309#[cfg(test)]
310mod tests {
311    use arrow::array::{Array, StringArray};
312    use arrow::datatypes::DataType::Utf8;
313
314    use datafusion_common::ScalarValue;
315    use datafusion_common::{Result, exec_err};
316    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
317
318    use crate::string::repeat::RepeatFunc;
319    use crate::utils::test::test_function;
320
321    #[test]
322    fn test_functions() -> Result<()> {
323        test_function!(
324            RepeatFunc::new(),
325            vec![
326                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
327                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
328            ],
329            Ok(Some("PgPgPgPg")),
330            &str,
331            Utf8,
332            StringArray
333        );
334        test_function!(
335            RepeatFunc::new(),
336            vec![
337                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
338                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
339            ],
340            Ok(None),
341            &str,
342            Utf8,
343            StringArray
344        );
345        test_function!(
346            RepeatFunc::new(),
347            vec![
348                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
349                ColumnarValue::Scalar(ScalarValue::Int64(None)),
350            ],
351            Ok(None),
352            &str,
353            Utf8,
354            StringArray
355        );
356
357        test_function!(
358            RepeatFunc::new(),
359            vec![
360                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
361                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
362            ],
363            Ok(Some("PgPgPgPg")),
364            &str,
365            Utf8,
366            StringArray
367        );
368        test_function!(
369            RepeatFunc::new(),
370            vec![
371                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
372                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
373            ],
374            Ok(None),
375            &str,
376            Utf8,
377            StringArray
378        );
379        test_function!(
380            RepeatFunc::new(),
381            vec![
382                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
383                ColumnarValue::Scalar(ScalarValue::Int64(None)),
384            ],
385            Ok(None),
386            &str,
387            Utf8,
388            StringArray
389        );
390        test_function!(
391            RepeatFunc::new(),
392            vec![
393                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
394                ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
395            ],
396            exec_err!(
397                "string size overflow on repeat, max size is {}, but got {}",
398                i32::MAX,
399                2usize * 1073741824
400            ),
401            &str,
402            Utf8,
403            StringArray
404        );
405
406        Ok(())
407    }
408}