1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::utf8_to_str_type;
22use arrow::array::{
23 Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
24 OffsetSizeTrait, StringArrayType, StringViewArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err};
32use datafusion_expr::{ColumnarValue, Documentation, Volatility};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
35use datafusion_macros::user_doc;
36
37#[user_doc(
38 doc_section(label = "String Functions"),
39 description = "Returns a string with an input string repeated a specified number.",
40 syntax_example = "repeat(str, n)",
41 sql_example = r#"```sql
42> select repeat('data', 3);
43+-------------------------------+
44| repeat(Utf8("data"),Int64(3)) |
45+-------------------------------+
46| datadatadata |
47+-------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(
51 name = "n",
52 description = "Number of times to repeat the input string."
53 )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct RepeatFunc {
57 signature: Signature,
58}
59
60impl Default for RepeatFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl RepeatFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::coercible(
70 vec![
71 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72 Coercion::new_implicit(
74 TypeSignatureClass::Native(logical_int64()),
75 vec![TypeSignatureClass::Integer],
76 NativeType::Int64,
77 ),
78 ],
79 Volatility::Immutable,
80 ),
81 }
82 }
83}
84
85impl ScalarUDFImpl for RepeatFunc {
86 fn as_any(&self) -> &dyn Any {
87 self
88 }
89
90 fn name(&self) -> &str {
91 "repeat"
92 }
93
94 fn signature(&self) -> &Signature {
95 &self.signature
96 }
97
98 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
99 utf8_to_str_type(&arg_types[0], "repeat")
100 }
101
102 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
103 let return_type = args.return_field.data_type().clone();
104 let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;
105
106 if let ColumnarValue::Scalar(s) = &string_arg
108 && s.is_null()
109 {
110 return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
111 }
112 if let ColumnarValue::Scalar(c) = &count_arg
113 && c.is_null()
114 {
115 return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
116 }
117
118 match (&string_arg, &count_arg) {
119 (
120 ColumnarValue::Scalar(string_scalar),
121 ColumnarValue::Scalar(count_scalar),
122 ) => {
123 let count = match count_scalar {
124 ScalarValue::Int64(Some(n)) => *n,
125 _ => {
126 return internal_err!(
127 "Unexpected data type {:?} for repeat count",
128 count_scalar.data_type()
129 );
130 }
131 };
132
133 let result = match string_scalar {
134 ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
135 ScalarValue::Utf8(Some(compute_repeat(
136 s,
137 count,
138 i32::MAX as usize,
139 )?))
140 }
141 ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
142 compute_repeat(s, count, i64::MAX as usize)?,
143 )),
144 _ => {
145 return internal_err!(
146 "Unexpected data type {:?} for function repeat",
147 string_scalar.data_type()
148 );
149 }
150 };
151
152 Ok(ColumnarValue::Scalar(result))
153 }
154 _ => {
155 let string_array = string_arg.to_array(args.number_rows)?;
156 let count_array = count_arg.to_array(args.number_rows)?;
157 Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
158 }
159 }
160 }
161
162 fn documentation(&self) -> Option<&Documentation> {
163 self.doc()
164 }
165}
166
167#[inline]
169fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result<String> {
170 if count <= 0 {
171 return Ok(String::new());
172 }
173 let result_len = s.len().saturating_mul(count as usize);
174 if result_len > max_size {
175 return exec_err!(
176 "string size overflow on repeat, max size is {}, but got {}",
177 max_size,
178 result_len
179 );
180 }
181 Ok(s.repeat(count as usize))
182}
183
184fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
187 let number_array = as_int64_array(count_array)?;
188 match string_array.data_type() {
189 Utf8View => {
190 let string_view_array = string_array.as_string_view();
191 repeat_impl::<i32, &StringViewArray>(
192 &string_view_array,
193 number_array,
194 i32::MAX as usize,
195 )
196 }
197 Utf8 => {
198 let string_arr = string_array.as_string::<i32>();
199 repeat_impl::<i32, &GenericStringArray<i32>>(
200 &string_arr,
201 number_array,
202 i32::MAX as usize,
203 )
204 }
205 LargeUtf8 => {
206 let string_arr = string_array.as_string::<i64>();
207 repeat_impl::<i64, &GenericStringArray<i64>>(
208 &string_arr,
209 number_array,
210 i64::MAX as usize,
211 )
212 }
213 other => exec_err!(
214 "Unsupported data type {other:?} for function repeat. \
215 Expected Utf8, Utf8View or LargeUtf8."
216 ),
217 }
218}
219
220fn repeat_impl<'a, T, S>(
221 string_array: &S,
222 number_array: &Int64Array,
223 max_str_len: usize,
224) -> Result<ArrayRef>
225where
226 T: OffsetSizeTrait,
227 S: StringArrayType<'a> + 'a,
228{
229 let mut total_capacity = 0;
230 let mut max_item_capacity = 0;
231 string_array.iter().zip(number_array.iter()).try_for_each(
232 |(string, number)| -> Result<(), DataFusionError> {
233 match (string, number) {
234 (Some(string), Some(number)) if number >= 0 => {
235 let item_capacity = string.len() * number as usize;
236 if item_capacity > max_str_len {
237 return exec_err!(
238 "string size overflow on repeat, max size is {}, but got {}",
239 max_str_len,
240 number as usize * string.len()
241 );
242 }
243 total_capacity += item_capacity;
244 max_item_capacity = max_item_capacity.max(item_capacity);
245 }
246 _ => (),
247 }
248 Ok(())
249 },
250 )?;
251
252 let mut builder =
253 GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
254
255 let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
257
258 #[inline]
261 fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
262 buffer.clear();
263 if !string.is_empty() {
264 let src = string.as_bytes();
265 buffer.extend_from_slice(src);
267 while buffer.len() < src.len() * count {
269 let copy_len = buffer.len().min(src.len() * count - buffer.len());
270 buffer.extend_from_within(..copy_len);
272 }
273 }
274 }
275
276 if string_array.null_count() == 0 && number_array.null_count() == 0 {
278 for i in 0..string_array.len() {
279 let string = unsafe { string_array.value_unchecked(i) };
281 let count = number_array.value(i);
282 if count > 0 {
283 repeat_to_buffer(&mut buffer, string, count as usize);
284 builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
286 } else {
287 builder.append_value("");
288 }
289 }
290 } else {
291 for (string, number) in string_array.iter().zip(number_array.iter()) {
293 match (string, number) {
294 (Some(string), Some(count)) if count > 0 => {
295 repeat_to_buffer(&mut buffer, string, count as usize);
296 builder
298 .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
299 }
300 (Some(_), Some(_)) => builder.append_value(""),
301 _ => builder.append_null(),
302 }
303 }
304 }
305
306 Ok(Arc::new(builder.finish()) as ArrayRef)
307}
308
309#[cfg(test)]
310mod tests {
311 use arrow::array::{Array, StringArray};
312 use arrow::datatypes::DataType::Utf8;
313
314 use datafusion_common::ScalarValue;
315 use datafusion_common::{Result, exec_err};
316 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
317
318 use crate::string::repeat::RepeatFunc;
319 use crate::utils::test::test_function;
320
321 #[test]
322 fn test_functions() -> Result<()> {
323 test_function!(
324 RepeatFunc::new(),
325 vec![
326 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
327 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
328 ],
329 Ok(Some("PgPgPgPg")),
330 &str,
331 Utf8,
332 StringArray
333 );
334 test_function!(
335 RepeatFunc::new(),
336 vec![
337 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
338 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
339 ],
340 Ok(None),
341 &str,
342 Utf8,
343 StringArray
344 );
345 test_function!(
346 RepeatFunc::new(),
347 vec![
348 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
349 ColumnarValue::Scalar(ScalarValue::Int64(None)),
350 ],
351 Ok(None),
352 &str,
353 Utf8,
354 StringArray
355 );
356
357 test_function!(
358 RepeatFunc::new(),
359 vec![
360 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
361 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
362 ],
363 Ok(Some("PgPgPgPg")),
364 &str,
365 Utf8,
366 StringArray
367 );
368 test_function!(
369 RepeatFunc::new(),
370 vec![
371 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
372 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
373 ],
374 Ok(None),
375 &str,
376 Utf8,
377 StringArray
378 );
379 test_function!(
380 RepeatFunc::new(),
381 vec![
382 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
383 ColumnarValue::Scalar(ScalarValue::Int64(None)),
384 ],
385 Ok(None),
386 &str,
387 Utf8,
388 StringArray
389 );
390 test_function!(
391 RepeatFunc::new(),
392 vec![
393 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
394 ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
395 ],
396 exec_err!(
397 "string size overflow on repeat, max size is {}, but got {}",
398 i32::MAX,
399 2usize * 1073741824
400 ),
401 &str,
402 Utf8,
403 StringArray
404 );
405
406 Ok(())
407 }
408}