1use crate::utils::{make_scalar_function, utf8_to_str_type};
19use DataType::{LargeUtf8, Utf8, Utf8View};
20use arrow::array::{
21 ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
22 OffsetSizeTrait, StringArrayType, StringViewArray,
23};
24use arrow::datatypes::DataType;
25use datafusion_common::DataFusionError;
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::TypeSignature::Exact;
29use datafusion_expr::{
30 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_macros::user_doc;
33use std::any::Any;
34use std::fmt::Write;
35use std::sync::Arc;
36use unicode_segmentation::UnicodeSegmentation;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Pads the right side of a string with another string to a specified string length.",
41 syntax_example = "rpad(str, n[, padding_str])",
42 sql_example = r#"```sql
43> select rpad('datafusion', 20, '_-');
44+-----------------------------------------------+
45| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) |
46+-----------------------------------------------+
47| datafusion_-_-_-_-_- |
48+-----------------------------------------------+
49```"#,
50 standard_argument(name = "str", prefix = "String"),
51 argument(name = "n", description = "String length to pad to."),
52 argument(
53 name = "padding_str",
54 description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
55 ),
56 related_udf(name = "lpad")
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct RPadFunc {
60 signature: Signature,
61}
62
63impl Default for RPadFunc {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl RPadFunc {
70 pub fn new() -> Self {
71 use DataType::*;
72 Self {
73 signature: Signature::one_of(
74 vec![
75 Exact(vec![Utf8View, Int64]),
76 Exact(vec![Utf8View, Int64, Utf8View]),
77 Exact(vec![Utf8View, Int64, Utf8]),
78 Exact(vec![Utf8View, Int64, LargeUtf8]),
79 Exact(vec![Utf8, Int64]),
80 Exact(vec![Utf8, Int64, Utf8View]),
81 Exact(vec![Utf8, Int64, Utf8]),
82 Exact(vec![Utf8, Int64, LargeUtf8]),
83 Exact(vec![LargeUtf8, Int64]),
84 Exact(vec![LargeUtf8, Int64, Utf8View]),
85 Exact(vec![LargeUtf8, Int64, Utf8]),
86 Exact(vec![LargeUtf8, Int64, LargeUtf8]),
87 ],
88 Volatility::Immutable,
89 ),
90 }
91 }
92}
93
94impl ScalarUDFImpl for RPadFunc {
95 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "rpad"
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108 utf8_to_str_type(&arg_types[0], "rpad")
109 }
110
111 fn invoke_with_args(
112 &self,
113 args: datafusion_expr::ScalarFunctionArgs,
114 ) -> Result<ColumnarValue> {
115 let args = &args.args;
116 match (
117 args.len(),
118 args[0].data_type(),
119 args.get(2).map(|arg| arg.data_type()),
120 ) {
121 (2, Utf8 | Utf8View, _) => {
122 make_scalar_function(rpad::<i32, i32>, vec![])(args)
123 }
124 (2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args),
125 (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => {
126 make_scalar_function(rpad::<i32, i32>, vec![])(args)
127 }
128 (3, LargeUtf8, Some(LargeUtf8)) => {
129 make_scalar_function(rpad::<i64, i64>, vec![])(args)
130 }
131 (3, Utf8 | Utf8View, Some(LargeUtf8)) => {
132 make_scalar_function(rpad::<i32, i64>, vec![])(args)
133 }
134 (3, LargeUtf8, Some(Utf8 | Utf8View)) => {
135 make_scalar_function(rpad::<i64, i32>, vec![])(args)
136 }
137 (_, _, _) => {
138 exec_err!("Unsupported combination of data types for function rpad")
139 }
140 }
141 }
142
143 fn documentation(&self) -> Option<&Documentation> {
144 self.doc()
145 }
146}
147
148fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
149 args: &[ArrayRef],
150) -> Result<ArrayRef> {
151 if args.len() < 2 || args.len() > 3 {
152 return exec_err!(
153 "rpad was called with {} arguments. It requires 2 or 3 arguments.",
154 args.len()
155 );
156 }
157
158 let length_array = as_int64_array(&args[1])?;
159 match (
160 args.len(),
161 args[0].data_type(),
162 args.get(2).map(|arg| arg.data_type()),
163 ) {
164 (2, Utf8View, _) => {
165 rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
166 &args[0].as_string_view(),
167 length_array,
168 None,
169 )
170 }
171 (3, Utf8View, Some(Utf8View)) => {
172 rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
173 &args[0].as_string_view(),
174 length_array,
175 Some(args[2].as_string_view()),
176 )
177 }
178 (3, Utf8View, Some(Utf8 | LargeUtf8)) => {
179 rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>(
180 &args[0].as_string_view(),
181 length_array,
182 Some(args[2].as_string::<FillArrayLen>()),
183 )
184 }
185 (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::<
186 &GenericStringArray<StringArrayLen>,
187 &StringViewArray,
188 StringArrayLen,
189 >(
190 &args[0].as_string::<StringArrayLen>(),
191 length_array,
192 Some(args[2].as_string_view()),
193 ),
194 (_, _, _) => rpad_impl::<
195 &GenericStringArray<StringArrayLen>,
196 &GenericStringArray<FillArrayLen>,
197 StringArrayLen,
198 >(
199 &args[0].as_string::<StringArrayLen>(),
200 length_array,
201 args.get(2).map(|arg| arg.as_string::<FillArrayLen>()),
202 ),
203 }
204}
205
206fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
209 string_array: &StringArrType,
210 length_array: &Int64Array,
211 fill_array: Option<FillArrType>,
212) -> Result<ArrayRef>
213where
214 StringArrType: StringArrayType<'a>,
215 FillArrType: StringArrayType<'a>,
216 StringArrayLen: OffsetSizeTrait,
217{
218 let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
219 let mut graphemes_buf = Vec::new();
220 let mut fill_chars_buf = Vec::new();
221
222 match fill_array {
223 None => {
224 string_array.iter().zip(length_array.iter()).try_for_each(
225 |(string, length)| -> Result<(), DataFusionError> {
226 match (string, length) {
227 (Some(string), Some(length)) => {
228 if length > i32::MAX as i64 {
229 return exec_err!(
230 "rpad requested length {} too large",
231 length
232 );
233 }
234 let length = if length < 0 { 0 } else { length as usize };
235 if length == 0 {
236 builder.append_value("");
237 } else {
238 graphemes_buf.clear();
240 graphemes_buf.extend(string.graphemes(true));
241
242 if length < graphemes_buf.len() {
243 builder
244 .append_value(graphemes_buf[..length].concat());
245 } else {
246 builder.write_str(string)?;
247 builder.write_str(
248 &" ".repeat(length - graphemes_buf.len()),
249 )?;
250 builder.append_value("");
251 }
252 }
253 }
254 _ => builder.append_null(),
255 }
256 Ok(())
257 },
258 )?;
259 }
260 Some(fill_array) => {
261 string_array
262 .iter()
263 .zip(length_array.iter())
264 .zip(fill_array.iter())
265 .try_for_each(
266 |((string, length), fill)| -> Result<(), DataFusionError> {
267 match (string, length, fill) {
268 (Some(string), Some(length), Some(fill)) => {
269 if length > i32::MAX as i64 {
270 return exec_err!(
271 "rpad requested length {} too large",
272 length
273 );
274 }
275 let length = if length < 0 { 0 } else { length as usize };
276 graphemes_buf.clear();
278 graphemes_buf.extend(string.graphemes(true));
279
280 if length < graphemes_buf.len() {
281 builder
282 .append_value(graphemes_buf[..length].concat());
283 } else if fill.is_empty() {
284 builder.append_value(string);
285 } else {
286 builder.write_str(string)?;
287 fill_chars_buf.clear();
289 fill_chars_buf.extend(fill.chars());
290 for l in 0..length - graphemes_buf.len() {
291 let c = *fill_chars_buf
292 .get(l % fill_chars_buf.len())
293 .unwrap();
294 builder.write_char(c)?;
295 }
296 builder.append_value("");
297 }
298 }
299 _ => builder.append_null(),
300 }
301 Ok(())
302 },
303 )?;
304 }
305 }
306
307 Ok(Arc::new(builder.finish()) as ArrayRef)
308}
309
310#[cfg(test)]
311mod tests {
312 use arrow::array::{Array, StringArray};
313 use arrow::datatypes::DataType::Utf8;
314
315 use datafusion_common::{Result, ScalarValue};
316 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
317
318 use crate::unicode::rpad::RPadFunc;
319 use crate::utils::test::test_function;
320
321 #[test]
322 fn test_functions() -> Result<()> {
323 test_function!(
324 RPadFunc::new(),
325 vec![
326 ColumnarValue::Scalar(ScalarValue::from("josé")),
327 ColumnarValue::Scalar(ScalarValue::from(5i64)),
328 ],
329 Ok(Some("josé ")),
330 &str,
331 Utf8,
332 StringArray
333 );
334 test_function!(
335 RPadFunc::new(),
336 vec![
337 ColumnarValue::Scalar(ScalarValue::from("hi")),
338 ColumnarValue::Scalar(ScalarValue::from(5i64)),
339 ],
340 Ok(Some("hi ")),
341 &str,
342 Utf8,
343 StringArray
344 );
345 test_function!(
346 RPadFunc::new(),
347 vec![
348 ColumnarValue::Scalar(ScalarValue::from("hi")),
349 ColumnarValue::Scalar(ScalarValue::from(0i64)),
350 ],
351 Ok(Some("")),
352 &str,
353 Utf8,
354 StringArray
355 );
356 test_function!(
357 RPadFunc::new(),
358 vec![
359 ColumnarValue::Scalar(ScalarValue::from("hi")),
360 ColumnarValue::Scalar(ScalarValue::Int64(None)),
361 ],
362 Ok(None),
363 &str,
364 Utf8,
365 StringArray
366 );
367 test_function!(
368 RPadFunc::new(),
369 vec![
370 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
371 ColumnarValue::Scalar(ScalarValue::from(5i64)),
372 ],
373 Ok(None),
374 &str,
375 Utf8,
376 StringArray
377 );
378 test_function!(
379 RPadFunc::new(),
380 vec![
381 ColumnarValue::Scalar(ScalarValue::from("hi")),
382 ColumnarValue::Scalar(ScalarValue::from(5i64)),
383 ColumnarValue::Scalar(ScalarValue::from("xy")),
384 ],
385 Ok(Some("hixyx")),
386 &str,
387 Utf8,
388 StringArray
389 );
390 test_function!(
391 RPadFunc::new(),
392 vec![
393 ColumnarValue::Scalar(ScalarValue::from("hi")),
394 ColumnarValue::Scalar(ScalarValue::from(21i64)),
395 ColumnarValue::Scalar(ScalarValue::from("abcdef")),
396 ],
397 Ok(Some("hiabcdefabcdefabcdefa")),
398 &str,
399 Utf8,
400 StringArray
401 );
402 test_function!(
403 RPadFunc::new(),
404 vec![
405 ColumnarValue::Scalar(ScalarValue::from("hi")),
406 ColumnarValue::Scalar(ScalarValue::from(5i64)),
407 ColumnarValue::Scalar(ScalarValue::from(" ")),
408 ],
409 Ok(Some("hi ")),
410 &str,
411 Utf8,
412 StringArray
413 );
414 test_function!(
415 RPadFunc::new(),
416 vec![
417 ColumnarValue::Scalar(ScalarValue::from("hi")),
418 ColumnarValue::Scalar(ScalarValue::from(5i64)),
419 ColumnarValue::Scalar(ScalarValue::from("")),
420 ],
421 Ok(Some("hi")),
422 &str,
423 Utf8,
424 StringArray
425 );
426 test_function!(
427 RPadFunc::new(),
428 vec![
429 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
430 ColumnarValue::Scalar(ScalarValue::from(5i64)),
431 ColumnarValue::Scalar(ScalarValue::from("xy")),
432 ],
433 Ok(None),
434 &str,
435 Utf8,
436 StringArray
437 );
438 test_function!(
439 RPadFunc::new(),
440 vec![
441 ColumnarValue::Scalar(ScalarValue::from("hi")),
442 ColumnarValue::Scalar(ScalarValue::Int64(None)),
443 ColumnarValue::Scalar(ScalarValue::from("xy")),
444 ],
445 Ok(None),
446 &str,
447 Utf8,
448 StringArray
449 );
450 test_function!(
451 RPadFunc::new(),
452 vec![
453 ColumnarValue::Scalar(ScalarValue::from("hi")),
454 ColumnarValue::Scalar(ScalarValue::from(5i64)),
455 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
456 ],
457 Ok(None),
458 &str,
459 Utf8,
460 StringArray
461 );
462 test_function!(
463 RPadFunc::new(),
464 vec![
465 ColumnarValue::Scalar(ScalarValue::from("josé")),
466 ColumnarValue::Scalar(ScalarValue::from(10i64)),
467 ColumnarValue::Scalar(ScalarValue::from("xy")),
468 ],
469 Ok(Some("joséxyxyxy")),
470 &str,
471 Utf8,
472 StringArray
473 );
474 test_function!(
475 RPadFunc::new(),
476 vec![
477 ColumnarValue::Scalar(ScalarValue::from("josé")),
478 ColumnarValue::Scalar(ScalarValue::from(10i64)),
479 ColumnarValue::Scalar(ScalarValue::from("éñ")),
480 ],
481 Ok(Some("josééñéñéñ")),
482 &str,
483 Utf8,
484 StringArray
485 );
486 #[cfg(not(feature = "unicode_expressions"))]
487 test_function!(
488 RPadFunc::new(),
489 &[
490 ColumnarValue::Scalar(ScalarValue::from("josé")),
491 ColumnarValue::Scalar(ScalarValue::from(5i64)),
492 ],
493 internal_err!(
494 "function rpad requires compilation with feature flag: unicode_expressions."
495 ),
496 &str,
497 Utf8,
498 StringArray
499 );
500
501 Ok(())
502 }
503}