datafusion_functions/string/
repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_str_type};
22use arrow::array::{
23    ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
24    OffsetSizeTrait, StringArrayType, StringViewArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::{DataFusionError, Result, exec_err};
31use datafusion_expr::{ColumnarValue, Documentation, Volatility};
32use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
33use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Returns a string with an input string repeated a specified number.",
39    syntax_example = "repeat(str, n)",
40    sql_example = r#"```sql
41> select repeat('data', 3);
42+-------------------------------+
43| repeat(Utf8("data"),Int64(3)) |
44+-------------------------------+
45| datadatadata                  |
46+-------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    argument(
50        name = "n",
51        description = "Number of times to repeat the input string."
52    )
53)]
54#[derive(Debug, PartialEq, Eq, Hash)]
55pub struct RepeatFunc {
56    signature: Signature,
57}
58
59impl Default for RepeatFunc {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl RepeatFunc {
66    pub fn new() -> Self {
67        Self {
68            signature: Signature::coercible(
69                vec![
70                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71                    // Accept all integer types but cast them to i64
72                    Coercion::new_implicit(
73                        TypeSignatureClass::Native(logical_int64()),
74                        vec![TypeSignatureClass::Integer],
75                        NativeType::Int64,
76                    ),
77                ],
78                Volatility::Immutable,
79            ),
80        }
81    }
82}
83
84impl ScalarUDFImpl for RepeatFunc {
85    fn as_any(&self) -> &dyn Any {
86        self
87    }
88
89    fn name(&self) -> &str {
90        "repeat"
91    }
92
93    fn signature(&self) -> &Signature {
94        &self.signature
95    }
96
97    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
98        utf8_to_str_type(&arg_types[0], "repeat")
99    }
100
101    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102        make_scalar_function(repeat, vec![])(&args.args)
103    }
104
105    fn documentation(&self) -> Option<&Documentation> {
106        self.doc()
107    }
108}
109
110/// Repeats string the specified number of times.
111/// repeat('Pg', 4) = 'PgPgPgPg'
112fn repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
113    let number_array = as_int64_array(&args[1])?;
114    match args[0].data_type() {
115        Utf8View => {
116            let string_view_array = args[0].as_string_view();
117            repeat_impl::<i32, &StringViewArray>(
118                &string_view_array,
119                number_array,
120                i32::MAX as usize,
121            )
122        }
123        Utf8 => {
124            let string_array = args[0].as_string::<i32>();
125            repeat_impl::<i32, &GenericStringArray<i32>>(
126                &string_array,
127                number_array,
128                i32::MAX as usize,
129            )
130        }
131        LargeUtf8 => {
132            let string_array = args[0].as_string::<i64>();
133            repeat_impl::<i64, &GenericStringArray<i64>>(
134                &string_array,
135                number_array,
136                i64::MAX as usize,
137            )
138        }
139        other => exec_err!(
140            "Unsupported data type {other:?} for function repeat. \
141        Expected Utf8, Utf8View or LargeUtf8."
142        ),
143    }
144}
145
146fn repeat_impl<'a, T, S>(
147    string_array: &S,
148    number_array: &Int64Array,
149    max_str_len: usize,
150) -> Result<ArrayRef>
151where
152    T: OffsetSizeTrait,
153    S: StringArrayType<'a>,
154{
155    let mut total_capacity = 0;
156    let mut max_item_capacity = 0;
157    string_array.iter().zip(number_array.iter()).try_for_each(
158        |(string, number)| -> Result<(), DataFusionError> {
159            match (string, number) {
160                (Some(string), Some(number)) if number >= 0 => {
161                    let item_capacity = string.len() * number as usize;
162                    if item_capacity > max_str_len {
163                        return exec_err!(
164                            "string size overflow on repeat, max size is {}, but got {}",
165                            max_str_len,
166                            number as usize * string.len()
167                        );
168                    }
169                    total_capacity += item_capacity;
170                    max_item_capacity = max_item_capacity.max(item_capacity);
171                }
172                _ => (),
173            }
174            Ok(())
175        },
176    )?;
177
178    let mut builder =
179        GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
180
181    // Reusable buffer to avoid allocations in string.repeat()
182    let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
183
184    string_array
185        .iter()
186        .zip(number_array.iter())
187        .for_each(|(string, number)| {
188            match (string, number) {
189                (Some(string), Some(number)) if number >= 0 => {
190                    buffer.clear();
191                    let count = number as usize;
192                    if count > 0 && !string.is_empty() {
193                        let src = string.as_bytes();
194                        // Initial copy
195                        buffer.extend_from_slice(src);
196                        // Doubling strategy: copy what we have so far until we reach the target
197                        while buffer.len() < src.len() * count {
198                            let copy_len =
199                                buffer.len().min(src.len() * count - buffer.len());
200                            // SAFETY: we're copying valid UTF-8 bytes that we already verified
201                            buffer.extend_from_within(..copy_len);
202                        }
203                    }
204                    // SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str
205                    builder
206                        .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
207                }
208                (Some(_), Some(_)) => builder.append_value(""),
209                _ => builder.append_null(),
210            }
211        });
212    let array = builder.finish();
213
214    Ok(Arc::new(array) as ArrayRef)
215}
216
217#[cfg(test)]
218mod tests {
219    use arrow::array::{Array, StringArray};
220    use arrow::datatypes::DataType::Utf8;
221
222    use datafusion_common::ScalarValue;
223    use datafusion_common::{Result, exec_err};
224    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
225
226    use crate::string::repeat::RepeatFunc;
227    use crate::utils::test::test_function;
228
229    #[test]
230    fn test_functions() -> Result<()> {
231        test_function!(
232            RepeatFunc::new(),
233            vec![
234                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
235                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
236            ],
237            Ok(Some("PgPgPgPg")),
238            &str,
239            Utf8,
240            StringArray
241        );
242        test_function!(
243            RepeatFunc::new(),
244            vec![
245                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
246                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
247            ],
248            Ok(None),
249            &str,
250            Utf8,
251            StringArray
252        );
253        test_function!(
254            RepeatFunc::new(),
255            vec![
256                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
257                ColumnarValue::Scalar(ScalarValue::Int64(None)),
258            ],
259            Ok(None),
260            &str,
261            Utf8,
262            StringArray
263        );
264
265        test_function!(
266            RepeatFunc::new(),
267            vec![
268                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
269                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
270            ],
271            Ok(Some("PgPgPgPg")),
272            &str,
273            Utf8,
274            StringArray
275        );
276        test_function!(
277            RepeatFunc::new(),
278            vec![
279                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
280                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
281            ],
282            Ok(None),
283            &str,
284            Utf8,
285            StringArray
286        );
287        test_function!(
288            RepeatFunc::new(),
289            vec![
290                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
291                ColumnarValue::Scalar(ScalarValue::Int64(None)),
292            ],
293            Ok(None),
294            &str,
295            Utf8,
296            StringArray
297        );
298        test_function!(
299            RepeatFunc::new(),
300            vec![
301                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
302                ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
303            ],
304            exec_err!(
305                "string size overflow on repeat, max size is {}, but got {}",
306                i32::MAX,
307                2usize * 1073741824
308            ),
309            &str,
310            Utf8,
311            StringArray
312        );
313
314        Ok(())
315    }
316}