1use crate::utils::{make_scalar_function, utf8_to_str_type};
19use DataType::{LargeUtf8, Utf8, Utf8View};
20use arrow::array::{
21 ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
22 OffsetSizeTrait, StringArrayType, StringViewArray,
23};
24use arrow::datatypes::DataType;
25use datafusion_common::DataFusionError;
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::TypeSignature::Exact;
29use datafusion_expr::{
30 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_macros::user_doc;
33use std::any::Any;
34use std::fmt::Write;
35use std::sync::Arc;
36use unicode_segmentation::UnicodeSegmentation;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Pads the right side of a string with another string to a specified string length.",
41 syntax_example = "rpad(str, n[, padding_str])",
42 sql_example = r#"```sql
43> select rpad('datafusion', 20, '_-');
44+-----------------------------------------------+
45| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) |
46+-----------------------------------------------+
47| datafusion_-_-_-_-_- |
48+-----------------------------------------------+
49```"#,
50 standard_argument(name = "str", prefix = "String"),
51 argument(
52 name = "n",
53 description = "String length to pad to. If the input string is longer than this length, it is truncated."
54 ),
55 argument(
56 name = "padding_str",
57 description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
58 ),
59 related_udf(name = "lpad")
60)]
61#[derive(Debug, PartialEq, Eq, Hash)]
62pub struct RPadFunc {
63 signature: Signature,
64}
65
66impl Default for RPadFunc {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl RPadFunc {
73 pub fn new() -> Self {
74 use DataType::*;
75 Self {
76 signature: Signature::one_of(
77 vec![
78 Exact(vec![Utf8View, Int64]),
79 Exact(vec![Utf8View, Int64, Utf8View]),
80 Exact(vec![Utf8View, Int64, Utf8]),
81 Exact(vec![Utf8View, Int64, LargeUtf8]),
82 Exact(vec![Utf8, Int64]),
83 Exact(vec![Utf8, Int64, Utf8View]),
84 Exact(vec![Utf8, Int64, Utf8]),
85 Exact(vec![Utf8, Int64, LargeUtf8]),
86 Exact(vec![LargeUtf8, Int64]),
87 Exact(vec![LargeUtf8, Int64, Utf8View]),
88 Exact(vec![LargeUtf8, Int64, Utf8]),
89 Exact(vec![LargeUtf8, Int64, LargeUtf8]),
90 ],
91 Volatility::Immutable,
92 ),
93 }
94 }
95}
96
97impl ScalarUDFImpl for RPadFunc {
98 fn as_any(&self) -> &dyn Any {
99 self
100 }
101
102 fn name(&self) -> &str {
103 "rpad"
104 }
105
106 fn signature(&self) -> &Signature {
107 &self.signature
108 }
109
110 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
111 utf8_to_str_type(&arg_types[0], "rpad")
112 }
113
114 fn invoke_with_args(
115 &self,
116 args: datafusion_expr::ScalarFunctionArgs,
117 ) -> Result<ColumnarValue> {
118 let args = &args.args;
119 match (
120 args.len(),
121 args[0].data_type(),
122 args.get(2).map(|arg| arg.data_type()),
123 ) {
124 (2, Utf8 | Utf8View, _) => {
125 make_scalar_function(rpad::<i32, i32>, vec![])(args)
126 }
127 (2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args),
128 (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => {
129 make_scalar_function(rpad::<i32, i32>, vec![])(args)
130 }
131 (3, LargeUtf8, Some(LargeUtf8)) => {
132 make_scalar_function(rpad::<i64, i64>, vec![])(args)
133 }
134 (3, Utf8 | Utf8View, Some(LargeUtf8)) => {
135 make_scalar_function(rpad::<i32, i64>, vec![])(args)
136 }
137 (3, LargeUtf8, Some(Utf8 | Utf8View)) => {
138 make_scalar_function(rpad::<i64, i32>, vec![])(args)
139 }
140 (_, _, _) => {
141 exec_err!("Unsupported combination of data types for function rpad")
142 }
143 }
144 }
145
146 fn documentation(&self) -> Option<&Documentation> {
147 self.doc()
148 }
149}
150
151fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
152 args: &[ArrayRef],
153) -> Result<ArrayRef> {
154 if args.len() < 2 || args.len() > 3 {
155 return exec_err!(
156 "rpad was called with {} arguments. It requires 2 or 3 arguments.",
157 args.len()
158 );
159 }
160
161 let length_array = as_int64_array(&args[1])?;
162 match (
163 args.len(),
164 args[0].data_type(),
165 args.get(2).map(|arg| arg.data_type()),
166 ) {
167 (2, Utf8View, _) => {
168 rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
169 &args[0].as_string_view(),
170 length_array,
171 None,
172 )
173 }
174 (3, Utf8View, Some(Utf8View)) => {
175 rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
176 &args[0].as_string_view(),
177 length_array,
178 Some(args[2].as_string_view()),
179 )
180 }
181 (3, Utf8View, Some(Utf8 | LargeUtf8)) => {
182 rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>(
183 &args[0].as_string_view(),
184 length_array,
185 Some(args[2].as_string::<FillArrayLen>()),
186 )
187 }
188 (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::<
189 &GenericStringArray<StringArrayLen>,
190 &StringViewArray,
191 StringArrayLen,
192 >(
193 &args[0].as_string::<StringArrayLen>(),
194 length_array,
195 Some(args[2].as_string_view()),
196 ),
197 (_, _, _) => rpad_impl::<
198 &GenericStringArray<StringArrayLen>,
199 &GenericStringArray<FillArrayLen>,
200 StringArrayLen,
201 >(
202 &args[0].as_string::<StringArrayLen>(),
203 length_array,
204 args.get(2).map(|arg| arg.as_string::<FillArrayLen>()),
205 ),
206 }
207}
208
209fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
213 string_array: &StringArrType,
214 length_array: &Int64Array,
215 fill_array: Option<FillArrType>,
216) -> Result<ArrayRef>
217where
218 StringArrType: StringArrayType<'a>,
219 FillArrType: StringArrayType<'a>,
220 StringArrayLen: OffsetSizeTrait,
221{
222 let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
223 let mut graphemes_buf = Vec::new();
224 let mut fill_chars_buf = Vec::new();
225
226 match fill_array {
227 None => {
228 string_array.iter().zip(length_array.iter()).try_for_each(
229 |(string, length)| -> Result<(), DataFusionError> {
230 match (string, length) {
231 (Some(string), Some(length)) => {
232 if length > i32::MAX as i64 {
233 return exec_err!(
234 "rpad requested length {} too large",
235 length
236 );
237 }
238 let length = if length < 0 { 0 } else { length as usize };
239 if length == 0 {
240 builder.append_value("");
241 } else if string.is_ascii() {
242 let str_len = string.len();
244 if length < str_len {
245 builder.append_value(&string[..length]);
246 } else {
247 builder.write_str(string)?;
248 for _ in 0..(length - str_len) {
249 builder.write_str(" ")?;
250 }
251 builder.append_value("");
252 }
253 } else {
254 graphemes_buf.clear();
256 graphemes_buf.extend(string.graphemes(true));
257
258 if length < graphemes_buf.len() {
259 builder
260 .append_value(graphemes_buf[..length].concat());
261 } else {
262 builder.write_str(string)?;
263 for _ in 0..(length - graphemes_buf.len()) {
264 builder.write_str(" ")?;
265 }
266 builder.append_value("");
267 }
268 }
269 }
270 _ => builder.append_null(),
271 }
272 Ok(())
273 },
274 )?;
275 }
276 Some(fill_array) => {
277 string_array
278 .iter()
279 .zip(length_array.iter())
280 .zip(fill_array.iter())
281 .try_for_each(
282 |((string, length), fill)| -> Result<(), DataFusionError> {
283 match (string, length, fill) {
284 (Some(string), Some(length), Some(fill)) => {
285 if length > i32::MAX as i64 {
286 return exec_err!(
287 "rpad requested length {} too large",
288 length
289 );
290 }
291 let length = if length < 0 { 0 } else { length as usize };
292 if string.is_ascii() && fill.is_ascii() {
293 let str_len = string.len();
296 if length < str_len {
297 builder.append_value(&string[..length]);
298 } else if fill.is_empty() {
299 builder.append_value(string);
300 } else {
301 let pad_len = length - str_len;
302 let fill_len = fill.len();
303 let full_reps = pad_len / fill_len;
304 let remainder = pad_len % fill_len;
305 builder.write_str(string)?;
306 for _ in 0..full_reps {
307 builder.write_str(fill)?;
308 }
309 if remainder > 0 {
310 builder.write_str(&fill[..remainder])?;
311 }
312 builder.append_value("");
313 }
314 } else {
315 graphemes_buf.clear();
317 graphemes_buf.extend(string.graphemes(true));
318
319 if length < graphemes_buf.len() {
320 builder.append_value(
321 graphemes_buf[..length].concat(),
322 );
323 } else if fill.is_empty() {
324 builder.append_value(string);
325 } else {
326 builder.write_str(string)?;
327 fill_chars_buf.clear();
329 fill_chars_buf.extend(fill.chars());
330 for l in 0..length - graphemes_buf.len() {
331 let c = *fill_chars_buf
332 .get(l % fill_chars_buf.len())
333 .unwrap();
334 builder.write_char(c)?;
335 }
336 builder.append_value("");
337 }
338 }
339 }
340 _ => builder.append_null(),
341 }
342 Ok(())
343 },
344 )?;
345 }
346 }
347
348 Ok(Arc::new(builder.finish()) as ArrayRef)
349}
350
351#[cfg(test)]
352mod tests {
353 use arrow::array::{Array, StringArray};
354 use arrow::datatypes::DataType::Utf8;
355
356 use datafusion_common::{Result, ScalarValue};
357 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
358
359 use crate::unicode::rpad::RPadFunc;
360 use crate::utils::test::test_function;
361
362 #[test]
363 fn test_functions() -> Result<()> {
364 test_function!(
365 RPadFunc::new(),
366 vec![
367 ColumnarValue::Scalar(ScalarValue::from("josé")),
368 ColumnarValue::Scalar(ScalarValue::from(5i64)),
369 ],
370 Ok(Some("josé ")),
371 &str,
372 Utf8,
373 StringArray
374 );
375 test_function!(
376 RPadFunc::new(),
377 vec![
378 ColumnarValue::Scalar(ScalarValue::from("hi")),
379 ColumnarValue::Scalar(ScalarValue::from(5i64)),
380 ],
381 Ok(Some("hi ")),
382 &str,
383 Utf8,
384 StringArray
385 );
386 test_function!(
387 RPadFunc::new(),
388 vec![
389 ColumnarValue::Scalar(ScalarValue::from("hi")),
390 ColumnarValue::Scalar(ScalarValue::from(0i64)),
391 ],
392 Ok(Some("")),
393 &str,
394 Utf8,
395 StringArray
396 );
397 test_function!(
398 RPadFunc::new(),
399 vec![
400 ColumnarValue::Scalar(ScalarValue::from("hi")),
401 ColumnarValue::Scalar(ScalarValue::Int64(None)),
402 ],
403 Ok(None),
404 &str,
405 Utf8,
406 StringArray
407 );
408 test_function!(
409 RPadFunc::new(),
410 vec![
411 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
412 ColumnarValue::Scalar(ScalarValue::from(5i64)),
413 ],
414 Ok(None),
415 &str,
416 Utf8,
417 StringArray
418 );
419 test_function!(
420 RPadFunc::new(),
421 vec![
422 ColumnarValue::Scalar(ScalarValue::from("hi")),
423 ColumnarValue::Scalar(ScalarValue::from(5i64)),
424 ColumnarValue::Scalar(ScalarValue::from("xy")),
425 ],
426 Ok(Some("hixyx")),
427 &str,
428 Utf8,
429 StringArray
430 );
431 test_function!(
432 RPadFunc::new(),
433 vec![
434 ColumnarValue::Scalar(ScalarValue::from("hi")),
435 ColumnarValue::Scalar(ScalarValue::from(21i64)),
436 ColumnarValue::Scalar(ScalarValue::from("abcdef")),
437 ],
438 Ok(Some("hiabcdefabcdefabcdefa")),
439 &str,
440 Utf8,
441 StringArray
442 );
443 test_function!(
444 RPadFunc::new(),
445 vec![
446 ColumnarValue::Scalar(ScalarValue::from("hi")),
447 ColumnarValue::Scalar(ScalarValue::from(5i64)),
448 ColumnarValue::Scalar(ScalarValue::from(" ")),
449 ],
450 Ok(Some("hi ")),
451 &str,
452 Utf8,
453 StringArray
454 );
455 test_function!(
456 RPadFunc::new(),
457 vec![
458 ColumnarValue::Scalar(ScalarValue::from("hi")),
459 ColumnarValue::Scalar(ScalarValue::from(5i64)),
460 ColumnarValue::Scalar(ScalarValue::from("")),
461 ],
462 Ok(Some("hi")),
463 &str,
464 Utf8,
465 StringArray
466 );
467 test_function!(
468 RPadFunc::new(),
469 vec![
470 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
471 ColumnarValue::Scalar(ScalarValue::from(5i64)),
472 ColumnarValue::Scalar(ScalarValue::from("xy")),
473 ],
474 Ok(None),
475 &str,
476 Utf8,
477 StringArray
478 );
479 test_function!(
480 RPadFunc::new(),
481 vec![
482 ColumnarValue::Scalar(ScalarValue::from("hi")),
483 ColumnarValue::Scalar(ScalarValue::Int64(None)),
484 ColumnarValue::Scalar(ScalarValue::from("xy")),
485 ],
486 Ok(None),
487 &str,
488 Utf8,
489 StringArray
490 );
491 test_function!(
492 RPadFunc::new(),
493 vec![
494 ColumnarValue::Scalar(ScalarValue::from("hi")),
495 ColumnarValue::Scalar(ScalarValue::from(5i64)),
496 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
497 ],
498 Ok(None),
499 &str,
500 Utf8,
501 StringArray
502 );
503 test_function!(
504 RPadFunc::new(),
505 vec![
506 ColumnarValue::Scalar(ScalarValue::from("hello")),
507 ColumnarValue::Scalar(ScalarValue::from(2i64)),
508 ],
509 Ok(Some("he")),
510 &str,
511 Utf8,
512 StringArray
513 );
514 test_function!(
515 RPadFunc::new(),
516 vec![
517 ColumnarValue::Scalar(ScalarValue::from("hi")),
518 ColumnarValue::Scalar(ScalarValue::from(6i64)),
519 ColumnarValue::Scalar(ScalarValue::from("xy")),
520 ],
521 Ok(Some("hixyxy")),
522 &str,
523 Utf8,
524 StringArray
525 );
526 test_function!(
527 RPadFunc::new(),
528 vec![
529 ColumnarValue::Scalar(ScalarValue::from("josé")),
530 ColumnarValue::Scalar(ScalarValue::from(10i64)),
531 ColumnarValue::Scalar(ScalarValue::from("xy")),
532 ],
533 Ok(Some("joséxyxyxy")),
534 &str,
535 Utf8,
536 StringArray
537 );
538 test_function!(
539 RPadFunc::new(),
540 vec![
541 ColumnarValue::Scalar(ScalarValue::from("josé")),
542 ColumnarValue::Scalar(ScalarValue::from(10i64)),
543 ColumnarValue::Scalar(ScalarValue::from("éñ")),
544 ],
545 Ok(Some("josééñéñéñ")),
546 &str,
547 Utf8,
548 StringArray
549 );
550 #[cfg(not(feature = "unicode_expressions"))]
551 test_function!(
552 RPadFunc::new(),
553 &[
554 ColumnarValue::Scalar(ScalarValue::from("josé")),
555 ColumnarValue::Scalar(ScalarValue::from(5i64)),
556 ],
557 internal_err!(
558 "function rpad requires compilation with feature flag: unicode_expressions."
559 ),
560 &str,
561 Utf8,
562 StringArray
563 );
564
565 Ok(())
566 }
567}