datafusion_functions/string/
repeat.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_str_type};
22use arrow::array::{
23 ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
24 OffsetSizeTrait, StringArrayType, StringViewArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::{DataFusionError, Result, exec_err};
31use datafusion_expr::{ColumnarValue, Documentation, Volatility};
32use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
33use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Returns a string with an input string repeated a specified number.",
39 syntax_example = "repeat(str, n)",
40 sql_example = r#"```sql
41> select repeat('data', 3);
42+-------------------------------+
43| repeat(Utf8("data"),Int64(3)) |
44+-------------------------------+
45| datadatadata |
46+-------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(
50 name = "n",
51 description = "Number of times to repeat the input string."
52 )
53)]
54#[derive(Debug, PartialEq, Eq, Hash)]
55pub struct RepeatFunc {
56 signature: Signature,
57}
58
59impl Default for RepeatFunc {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl RepeatFunc {
66 pub fn new() -> Self {
67 Self {
68 signature: Signature::coercible(
69 vec![
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 Coercion::new_implicit(
73 TypeSignatureClass::Native(logical_int64()),
74 vec![TypeSignatureClass::Integer],
75 NativeType::Int64,
76 ),
77 ],
78 Volatility::Immutable,
79 ),
80 }
81 }
82}
83
84impl ScalarUDFImpl for RepeatFunc {
85 fn as_any(&self) -> &dyn Any {
86 self
87 }
88
89 fn name(&self) -> &str {
90 "repeat"
91 }
92
93 fn signature(&self) -> &Signature {
94 &self.signature
95 }
96
97 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
98 utf8_to_str_type(&arg_types[0], "repeat")
99 }
100
101 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102 make_scalar_function(repeat, vec![])(&args.args)
103 }
104
105 fn documentation(&self) -> Option<&Documentation> {
106 self.doc()
107 }
108}
109
110fn repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
113 let number_array = as_int64_array(&args[1])?;
114 match args[0].data_type() {
115 Utf8View => {
116 let string_view_array = args[0].as_string_view();
117 repeat_impl::<i32, &StringViewArray>(
118 &string_view_array,
119 number_array,
120 i32::MAX as usize,
121 )
122 }
123 Utf8 => {
124 let string_array = args[0].as_string::<i32>();
125 repeat_impl::<i32, &GenericStringArray<i32>>(
126 &string_array,
127 number_array,
128 i32::MAX as usize,
129 )
130 }
131 LargeUtf8 => {
132 let string_array = args[0].as_string::<i64>();
133 repeat_impl::<i64, &GenericStringArray<i64>>(
134 &string_array,
135 number_array,
136 i64::MAX as usize,
137 )
138 }
139 other => exec_err!(
140 "Unsupported data type {other:?} for function repeat. \
141 Expected Utf8, Utf8View or LargeUtf8."
142 ),
143 }
144}
145
146fn repeat_impl<'a, T, S>(
147 string_array: &S,
148 number_array: &Int64Array,
149 max_str_len: usize,
150) -> Result<ArrayRef>
151where
152 T: OffsetSizeTrait,
153 S: StringArrayType<'a>,
154{
155 let mut total_capacity = 0;
156 let mut max_item_capacity = 0;
157 string_array.iter().zip(number_array.iter()).try_for_each(
158 |(string, number)| -> Result<(), DataFusionError> {
159 match (string, number) {
160 (Some(string), Some(number)) if number >= 0 => {
161 let item_capacity = string.len() * number as usize;
162 if item_capacity > max_str_len {
163 return exec_err!(
164 "string size overflow on repeat, max size is {}, but got {}",
165 max_str_len,
166 number as usize * string.len()
167 );
168 }
169 total_capacity += item_capacity;
170 max_item_capacity = max_item_capacity.max(item_capacity);
171 }
172 _ => (),
173 }
174 Ok(())
175 },
176 )?;
177
178 let mut builder =
179 GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
180
181 let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
183
184 string_array
185 .iter()
186 .zip(number_array.iter())
187 .for_each(|(string, number)| {
188 match (string, number) {
189 (Some(string), Some(number)) if number >= 0 => {
190 buffer.clear();
191 let count = number as usize;
192 if count > 0 && !string.is_empty() {
193 let src = string.as_bytes();
194 buffer.extend_from_slice(src);
196 while buffer.len() < src.len() * count {
198 let copy_len =
199 buffer.len().min(src.len() * count - buffer.len());
200 buffer.extend_from_within(..copy_len);
202 }
203 }
204 builder
206 .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
207 }
208 (Some(_), Some(_)) => builder.append_value(""),
209 _ => builder.append_null(),
210 }
211 });
212 let array = builder.finish();
213
214 Ok(Arc::new(array) as ArrayRef)
215}
216
217#[cfg(test)]
218mod tests {
219 use arrow::array::{Array, StringArray};
220 use arrow::datatypes::DataType::Utf8;
221
222 use datafusion_common::ScalarValue;
223 use datafusion_common::{Result, exec_err};
224 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
225
226 use crate::string::repeat::RepeatFunc;
227 use crate::utils::test::test_function;
228
229 #[test]
230 fn test_functions() -> Result<()> {
231 test_function!(
232 RepeatFunc::new(),
233 vec![
234 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
235 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
236 ],
237 Ok(Some("PgPgPgPg")),
238 &str,
239 Utf8,
240 StringArray
241 );
242 test_function!(
243 RepeatFunc::new(),
244 vec![
245 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
246 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
247 ],
248 Ok(None),
249 &str,
250 Utf8,
251 StringArray
252 );
253 test_function!(
254 RepeatFunc::new(),
255 vec![
256 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
257 ColumnarValue::Scalar(ScalarValue::Int64(None)),
258 ],
259 Ok(None),
260 &str,
261 Utf8,
262 StringArray
263 );
264
265 test_function!(
266 RepeatFunc::new(),
267 vec![
268 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
269 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
270 ],
271 Ok(Some("PgPgPgPg")),
272 &str,
273 Utf8,
274 StringArray
275 );
276 test_function!(
277 RepeatFunc::new(),
278 vec![
279 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
280 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
281 ],
282 Ok(None),
283 &str,
284 Utf8,
285 StringArray
286 );
287 test_function!(
288 RepeatFunc::new(),
289 vec![
290 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
291 ColumnarValue::Scalar(ScalarValue::Int64(None)),
292 ],
293 Ok(None),
294 &str,
295 Utf8,
296 StringArray
297 );
298 test_function!(
299 RepeatFunc::new(),
300 vec![
301 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
302 ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
303 ],
304 exec_err!(
305 "string size overflow on repeat, max size is {}, but got {}",
306 i32::MAX,
307 2usize * 1073741824
308 ),
309 &str,
310 Utf8,
311 StringArray
312 );
313
314 Ok(())
315 }
316}