Skip to main content

datafusion_functions/unicode/
rpad.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::{make_scalar_function, utf8_to_str_type};
19use DataType::{LargeUtf8, Utf8, Utf8View};
20use arrow::array::{
21    ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
22    OffsetSizeTrait, StringArrayType, StringViewArray,
23};
24use arrow::datatypes::DataType;
25use datafusion_common::DataFusionError;
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::TypeSignature::Exact;
29use datafusion_expr::{
30    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_macros::user_doc;
33use std::any::Any;
34use std::fmt::Write;
35use std::sync::Arc;
36use unicode_segmentation::UnicodeSegmentation;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Pads the right side of a string with another string to a specified string length.",
41    syntax_example = "rpad(str, n[, padding_str])",
42    sql_example = r#"```sql
43>  select rpad('datafusion', 20, '_-');
44+-----------------------------------------------+
45| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) |
46+-----------------------------------------------+
47| datafusion_-_-_-_-_-                          |
48+-----------------------------------------------+
49```"#,
50    standard_argument(name = "str", prefix = "String"),
51    argument(
52        name = "n",
53        description = "String length to pad to. If the input string is longer than this length, it is truncated."
54    ),
55    argument(
56        name = "padding_str",
57        description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
58    ),
59    related_udf(name = "lpad")
60)]
61#[derive(Debug, PartialEq, Eq, Hash)]
62pub struct RPadFunc {
63    signature: Signature,
64}
65
66impl Default for RPadFunc {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl RPadFunc {
73    pub fn new() -> Self {
74        use DataType::*;
75        Self {
76            signature: Signature::one_of(
77                vec![
78                    Exact(vec![Utf8View, Int64]),
79                    Exact(vec![Utf8View, Int64, Utf8View]),
80                    Exact(vec![Utf8View, Int64, Utf8]),
81                    Exact(vec![Utf8View, Int64, LargeUtf8]),
82                    Exact(vec![Utf8, Int64]),
83                    Exact(vec![Utf8, Int64, Utf8View]),
84                    Exact(vec![Utf8, Int64, Utf8]),
85                    Exact(vec![Utf8, Int64, LargeUtf8]),
86                    Exact(vec![LargeUtf8, Int64]),
87                    Exact(vec![LargeUtf8, Int64, Utf8View]),
88                    Exact(vec![LargeUtf8, Int64, Utf8]),
89                    Exact(vec![LargeUtf8, Int64, LargeUtf8]),
90                ],
91                Volatility::Immutable,
92            ),
93        }
94    }
95}
96
97impl ScalarUDFImpl for RPadFunc {
98    fn as_any(&self) -> &dyn Any {
99        self
100    }
101
102    fn name(&self) -> &str {
103        "rpad"
104    }
105
106    fn signature(&self) -> &Signature {
107        &self.signature
108    }
109
110    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
111        utf8_to_str_type(&arg_types[0], "rpad")
112    }
113
114    fn invoke_with_args(
115        &self,
116        args: datafusion_expr::ScalarFunctionArgs,
117    ) -> Result<ColumnarValue> {
118        let args = &args.args;
119        match (
120            args.len(),
121            args[0].data_type(),
122            args.get(2).map(|arg| arg.data_type()),
123        ) {
124            (2, Utf8 | Utf8View, _) => {
125                make_scalar_function(rpad::<i32, i32>, vec![])(args)
126            }
127            (2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args),
128            (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => {
129                make_scalar_function(rpad::<i32, i32>, vec![])(args)
130            }
131            (3, LargeUtf8, Some(LargeUtf8)) => {
132                make_scalar_function(rpad::<i64, i64>, vec![])(args)
133            }
134            (3, Utf8 | Utf8View, Some(LargeUtf8)) => {
135                make_scalar_function(rpad::<i32, i64>, vec![])(args)
136            }
137            (3, LargeUtf8, Some(Utf8 | Utf8View)) => {
138                make_scalar_function(rpad::<i64, i32>, vec![])(args)
139            }
140            (_, _, _) => {
141                exec_err!("Unsupported combination of data types for function rpad")
142            }
143        }
144    }
145
146    fn documentation(&self) -> Option<&Documentation> {
147        self.doc()
148    }
149}
150
151fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
152    args: &[ArrayRef],
153) -> Result<ArrayRef> {
154    if args.len() < 2 || args.len() > 3 {
155        return exec_err!(
156            "rpad was called with {} arguments. It requires 2 or 3 arguments.",
157            args.len()
158        );
159    }
160
161    let length_array = as_int64_array(&args[1])?;
162    match (
163        args.len(),
164        args[0].data_type(),
165        args.get(2).map(|arg| arg.data_type()),
166    ) {
167        (2, Utf8View, _) => {
168            rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
169                &args[0].as_string_view(),
170                length_array,
171                None,
172            )
173        }
174        (3, Utf8View, Some(Utf8View)) => {
175            rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
176                &args[0].as_string_view(),
177                length_array,
178                Some(args[2].as_string_view()),
179            )
180        }
181        (3, Utf8View, Some(Utf8 | LargeUtf8)) => {
182            rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>(
183                &args[0].as_string_view(),
184                length_array,
185                Some(args[2].as_string::<FillArrayLen>()),
186            )
187        }
188        (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::<
189            &GenericStringArray<StringArrayLen>,
190            &StringViewArray,
191            StringArrayLen,
192        >(
193            &args[0].as_string::<StringArrayLen>(),
194            length_array,
195            Some(args[2].as_string_view()),
196        ),
197        (_, _, _) => rpad_impl::<
198            &GenericStringArray<StringArrayLen>,
199            &GenericStringArray<FillArrayLen>,
200            StringArrayLen,
201        >(
202            &args[0].as_string::<StringArrayLen>(),
203            length_array,
204            args.get(2).map(|arg| arg.as_string::<FillArrayLen>()),
205        ),
206    }
207}
208
209/// Extends the string to length 'length' by appending the characters fill (a space by default).
210/// If the string is already longer than length then it is truncated (on the right).
211/// rpad('hi', 5, 'xy') = 'hixyx'
212fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
213    string_array: &StringArrType,
214    length_array: &Int64Array,
215    fill_array: Option<FillArrType>,
216) -> Result<ArrayRef>
217where
218    StringArrType: StringArrayType<'a>,
219    FillArrType: StringArrayType<'a>,
220    StringArrayLen: OffsetSizeTrait,
221{
222    let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
223    let mut graphemes_buf = Vec::new();
224    let mut fill_chars_buf = Vec::new();
225
226    match fill_array {
227        None => {
228            string_array.iter().zip(length_array.iter()).try_for_each(
229                |(string, length)| -> Result<(), DataFusionError> {
230                    match (string, length) {
231                        (Some(string), Some(length)) => {
232                            if length > i32::MAX as i64 {
233                                return exec_err!(
234                                    "rpad requested length {} too large",
235                                    length
236                                );
237                            }
238                            let length = if length < 0 { 0 } else { length as usize };
239                            if length == 0 {
240                                builder.append_value("");
241                            } else if string.is_ascii() {
242                                // ASCII fast path: byte length == character length
243                                let str_len = string.len();
244                                if length < str_len {
245                                    builder.append_value(&string[..length]);
246                                } else {
247                                    builder.write_str(string)?;
248                                    for _ in 0..(length - str_len) {
249                                        builder.write_str(" ")?;
250                                    }
251                                    builder.append_value("");
252                                }
253                            } else {
254                                // Reuse buffer by clearing and refilling
255                                graphemes_buf.clear();
256                                graphemes_buf.extend(string.graphemes(true));
257
258                                if length < graphemes_buf.len() {
259                                    builder
260                                        .append_value(graphemes_buf[..length].concat());
261                                } else {
262                                    builder.write_str(string)?;
263                                    for _ in 0..(length - graphemes_buf.len()) {
264                                        builder.write_str(" ")?;
265                                    }
266                                    builder.append_value("");
267                                }
268                            }
269                        }
270                        _ => builder.append_null(),
271                    }
272                    Ok(())
273                },
274            )?;
275        }
276        Some(fill_array) => {
277            string_array
278                .iter()
279                .zip(length_array.iter())
280                .zip(fill_array.iter())
281                .try_for_each(
282                    |((string, length), fill)| -> Result<(), DataFusionError> {
283                        match (string, length, fill) {
284                            (Some(string), Some(length), Some(fill)) => {
285                                if length > i32::MAX as i64 {
286                                    return exec_err!(
287                                        "rpad requested length {} too large",
288                                        length
289                                    );
290                                }
291                                let length = if length < 0 { 0 } else { length as usize };
292                                if string.is_ascii() && fill.is_ascii() {
293                                    // ASCII fast path: byte length == character length,
294                                    // so we skip expensive grapheme segmentation.
295                                    let str_len = string.len();
296                                    if length < str_len {
297                                        builder.append_value(&string[..length]);
298                                    } else if fill.is_empty() {
299                                        builder.append_value(string);
300                                    } else {
301                                        let pad_len = length - str_len;
302                                        let fill_len = fill.len();
303                                        let full_reps = pad_len / fill_len;
304                                        let remainder = pad_len % fill_len;
305                                        builder.write_str(string)?;
306                                        for _ in 0..full_reps {
307                                            builder.write_str(fill)?;
308                                        }
309                                        if remainder > 0 {
310                                            builder.write_str(&fill[..remainder])?;
311                                        }
312                                        builder.append_value("");
313                                    }
314                                } else {
315                                    // Reuse buffer by clearing and refilling
316                                    graphemes_buf.clear();
317                                    graphemes_buf.extend(string.graphemes(true));
318
319                                    if length < graphemes_buf.len() {
320                                        builder.append_value(
321                                            graphemes_buf[..length].concat(),
322                                        );
323                                    } else if fill.is_empty() {
324                                        builder.append_value(string);
325                                    } else {
326                                        builder.write_str(string)?;
327                                        // Reuse fill_chars_buf by clearing and refilling
328                                        fill_chars_buf.clear();
329                                        fill_chars_buf.extend(fill.chars());
330                                        for l in 0..length - graphemes_buf.len() {
331                                            let c = *fill_chars_buf
332                                                .get(l % fill_chars_buf.len())
333                                                .unwrap();
334                                            builder.write_char(c)?;
335                                        }
336                                        builder.append_value("");
337                                    }
338                                }
339                            }
340                            _ => builder.append_null(),
341                        }
342                        Ok(())
343                    },
344                )?;
345        }
346    }
347
348    Ok(Arc::new(builder.finish()) as ArrayRef)
349}
350
351#[cfg(test)]
352mod tests {
353    use arrow::array::{Array, StringArray};
354    use arrow::datatypes::DataType::Utf8;
355
356    use datafusion_common::{Result, ScalarValue};
357    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
358
359    use crate::unicode::rpad::RPadFunc;
360    use crate::utils::test::test_function;
361
362    #[test]
363    fn test_functions() -> Result<()> {
364        test_function!(
365            RPadFunc::new(),
366            vec![
367                ColumnarValue::Scalar(ScalarValue::from("josé")),
368                ColumnarValue::Scalar(ScalarValue::from(5i64)),
369            ],
370            Ok(Some("josé ")),
371            &str,
372            Utf8,
373            StringArray
374        );
375        test_function!(
376            RPadFunc::new(),
377            vec![
378                ColumnarValue::Scalar(ScalarValue::from("hi")),
379                ColumnarValue::Scalar(ScalarValue::from(5i64)),
380            ],
381            Ok(Some("hi   ")),
382            &str,
383            Utf8,
384            StringArray
385        );
386        test_function!(
387            RPadFunc::new(),
388            vec![
389                ColumnarValue::Scalar(ScalarValue::from("hi")),
390                ColumnarValue::Scalar(ScalarValue::from(0i64)),
391            ],
392            Ok(Some("")),
393            &str,
394            Utf8,
395            StringArray
396        );
397        test_function!(
398            RPadFunc::new(),
399            vec![
400                ColumnarValue::Scalar(ScalarValue::from("hi")),
401                ColumnarValue::Scalar(ScalarValue::Int64(None)),
402            ],
403            Ok(None),
404            &str,
405            Utf8,
406            StringArray
407        );
408        test_function!(
409            RPadFunc::new(),
410            vec![
411                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
412                ColumnarValue::Scalar(ScalarValue::from(5i64)),
413            ],
414            Ok(None),
415            &str,
416            Utf8,
417            StringArray
418        );
419        test_function!(
420            RPadFunc::new(),
421            vec![
422                ColumnarValue::Scalar(ScalarValue::from("hi")),
423                ColumnarValue::Scalar(ScalarValue::from(5i64)),
424                ColumnarValue::Scalar(ScalarValue::from("xy")),
425            ],
426            Ok(Some("hixyx")),
427            &str,
428            Utf8,
429            StringArray
430        );
431        test_function!(
432            RPadFunc::new(),
433            vec![
434                ColumnarValue::Scalar(ScalarValue::from("hi")),
435                ColumnarValue::Scalar(ScalarValue::from(21i64)),
436                ColumnarValue::Scalar(ScalarValue::from("abcdef")),
437            ],
438            Ok(Some("hiabcdefabcdefabcdefa")),
439            &str,
440            Utf8,
441            StringArray
442        );
443        test_function!(
444            RPadFunc::new(),
445            vec![
446                ColumnarValue::Scalar(ScalarValue::from("hi")),
447                ColumnarValue::Scalar(ScalarValue::from(5i64)),
448                ColumnarValue::Scalar(ScalarValue::from(" ")),
449            ],
450            Ok(Some("hi   ")),
451            &str,
452            Utf8,
453            StringArray
454        );
455        test_function!(
456            RPadFunc::new(),
457            vec![
458                ColumnarValue::Scalar(ScalarValue::from("hi")),
459                ColumnarValue::Scalar(ScalarValue::from(5i64)),
460                ColumnarValue::Scalar(ScalarValue::from("")),
461            ],
462            Ok(Some("hi")),
463            &str,
464            Utf8,
465            StringArray
466        );
467        test_function!(
468            RPadFunc::new(),
469            vec![
470                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
471                ColumnarValue::Scalar(ScalarValue::from(5i64)),
472                ColumnarValue::Scalar(ScalarValue::from("xy")),
473            ],
474            Ok(None),
475            &str,
476            Utf8,
477            StringArray
478        );
479        test_function!(
480            RPadFunc::new(),
481            vec![
482                ColumnarValue::Scalar(ScalarValue::from("hi")),
483                ColumnarValue::Scalar(ScalarValue::Int64(None)),
484                ColumnarValue::Scalar(ScalarValue::from("xy")),
485            ],
486            Ok(None),
487            &str,
488            Utf8,
489            StringArray
490        );
491        test_function!(
492            RPadFunc::new(),
493            vec![
494                ColumnarValue::Scalar(ScalarValue::from("hi")),
495                ColumnarValue::Scalar(ScalarValue::from(5i64)),
496                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
497            ],
498            Ok(None),
499            &str,
500            Utf8,
501            StringArray
502        );
503        test_function!(
504            RPadFunc::new(),
505            vec![
506                ColumnarValue::Scalar(ScalarValue::from("hello")),
507                ColumnarValue::Scalar(ScalarValue::from(2i64)),
508            ],
509            Ok(Some("he")),
510            &str,
511            Utf8,
512            StringArray
513        );
514        test_function!(
515            RPadFunc::new(),
516            vec![
517                ColumnarValue::Scalar(ScalarValue::from("hi")),
518                ColumnarValue::Scalar(ScalarValue::from(6i64)),
519                ColumnarValue::Scalar(ScalarValue::from("xy")),
520            ],
521            Ok(Some("hixyxy")),
522            &str,
523            Utf8,
524            StringArray
525        );
526        test_function!(
527            RPadFunc::new(),
528            vec![
529                ColumnarValue::Scalar(ScalarValue::from("josé")),
530                ColumnarValue::Scalar(ScalarValue::from(10i64)),
531                ColumnarValue::Scalar(ScalarValue::from("xy")),
532            ],
533            Ok(Some("joséxyxyxy")),
534            &str,
535            Utf8,
536            StringArray
537        );
538        test_function!(
539            RPadFunc::new(),
540            vec![
541                ColumnarValue::Scalar(ScalarValue::from("josé")),
542                ColumnarValue::Scalar(ScalarValue::from(10i64)),
543                ColumnarValue::Scalar(ScalarValue::from("éñ")),
544            ],
545            Ok(Some("josééñéñéñ")),
546            &str,
547            Utf8,
548            StringArray
549        );
550        #[cfg(not(feature = "unicode_expressions"))]
551        test_function!(
552            RPadFunc::new(),
553            &[
554                ColumnarValue::Scalar(ScalarValue::from("josé")),
555                ColumnarValue::Scalar(ScalarValue::from(5i64)),
556            ],
557            internal_err!(
558                "function rpad requires compilation with feature flag: unicode_expressions."
559            ),
560            &str,
561            Utf8,
562            StringArray
563        );
564
565        Ok(())
566    }
567}