datafusion_functions/unicode/
rpad.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::{make_scalar_function, utf8_to_str_type};
19use DataType::{LargeUtf8, Utf8, Utf8View};
20use arrow::array::{
21    ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
22    OffsetSizeTrait, StringArrayType, StringViewArray,
23};
24use arrow::datatypes::DataType;
25use datafusion_common::DataFusionError;
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::TypeSignature::Exact;
29use datafusion_expr::{
30    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_macros::user_doc;
33use std::any::Any;
34use std::fmt::Write;
35use std::sync::Arc;
36use unicode_segmentation::UnicodeSegmentation;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Pads the right side of a string with another string to a specified string length.",
41    syntax_example = "rpad(str, n[, padding_str])",
42    sql_example = r#"```sql
43>  select rpad('datafusion', 20, '_-');
44+-----------------------------------------------+
45| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) |
46+-----------------------------------------------+
47| datafusion_-_-_-_-_-                          |
48+-----------------------------------------------+
49```"#,
50    standard_argument(name = "str", prefix = "String"),
51    argument(name = "n", description = "String length to pad to."),
52    argument(
53        name = "padding_str",
54        description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
55    ),
56    related_udf(name = "lpad")
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct RPadFunc {
60    signature: Signature,
61}
62
63impl Default for RPadFunc {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl RPadFunc {
70    pub fn new() -> Self {
71        use DataType::*;
72        Self {
73            signature: Signature::one_of(
74                vec![
75                    Exact(vec![Utf8View, Int64]),
76                    Exact(vec![Utf8View, Int64, Utf8View]),
77                    Exact(vec![Utf8View, Int64, Utf8]),
78                    Exact(vec![Utf8View, Int64, LargeUtf8]),
79                    Exact(vec![Utf8, Int64]),
80                    Exact(vec![Utf8, Int64, Utf8View]),
81                    Exact(vec![Utf8, Int64, Utf8]),
82                    Exact(vec![Utf8, Int64, LargeUtf8]),
83                    Exact(vec![LargeUtf8, Int64]),
84                    Exact(vec![LargeUtf8, Int64, Utf8View]),
85                    Exact(vec![LargeUtf8, Int64, Utf8]),
86                    Exact(vec![LargeUtf8, Int64, LargeUtf8]),
87                ],
88                Volatility::Immutable,
89            ),
90        }
91    }
92}
93
94impl ScalarUDFImpl for RPadFunc {
95    fn as_any(&self) -> &dyn Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "rpad"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        utf8_to_str_type(&arg_types[0], "rpad")
109    }
110
111    fn invoke_with_args(
112        &self,
113        args: datafusion_expr::ScalarFunctionArgs,
114    ) -> Result<ColumnarValue> {
115        let args = &args.args;
116        match (
117            args.len(),
118            args[0].data_type(),
119            args.get(2).map(|arg| arg.data_type()),
120        ) {
121            (2, Utf8 | Utf8View, _) => {
122                make_scalar_function(rpad::<i32, i32>, vec![])(args)
123            }
124            (2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args),
125            (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => {
126                make_scalar_function(rpad::<i32, i32>, vec![])(args)
127            }
128            (3, LargeUtf8, Some(LargeUtf8)) => {
129                make_scalar_function(rpad::<i64, i64>, vec![])(args)
130            }
131            (3, Utf8 | Utf8View, Some(LargeUtf8)) => {
132                make_scalar_function(rpad::<i32, i64>, vec![])(args)
133            }
134            (3, LargeUtf8, Some(Utf8 | Utf8View)) => {
135                make_scalar_function(rpad::<i64, i32>, vec![])(args)
136            }
137            (_, _, _) => {
138                exec_err!("Unsupported combination of data types for function rpad")
139            }
140        }
141    }
142
143    fn documentation(&self) -> Option<&Documentation> {
144        self.doc()
145    }
146}
147
148fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
149    args: &[ArrayRef],
150) -> Result<ArrayRef> {
151    if args.len() < 2 || args.len() > 3 {
152        return exec_err!(
153            "rpad was called with {} arguments. It requires 2 or 3 arguments.",
154            args.len()
155        );
156    }
157
158    let length_array = as_int64_array(&args[1])?;
159    match (
160        args.len(),
161        args[0].data_type(),
162        args.get(2).map(|arg| arg.data_type()),
163    ) {
164        (2, Utf8View, _) => {
165            rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
166                &args[0].as_string_view(),
167                length_array,
168                None,
169            )
170        }
171        (3, Utf8View, Some(Utf8View)) => {
172            rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
173                &args[0].as_string_view(),
174                length_array,
175                Some(args[2].as_string_view()),
176            )
177        }
178        (3, Utf8View, Some(Utf8 | LargeUtf8)) => {
179            rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>(
180                &args[0].as_string_view(),
181                length_array,
182                Some(args[2].as_string::<FillArrayLen>()),
183            )
184        }
185        (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::<
186            &GenericStringArray<StringArrayLen>,
187            &StringViewArray,
188            StringArrayLen,
189        >(
190            &args[0].as_string::<StringArrayLen>(),
191            length_array,
192            Some(args[2].as_string_view()),
193        ),
194        (_, _, _) => rpad_impl::<
195            &GenericStringArray<StringArrayLen>,
196            &GenericStringArray<FillArrayLen>,
197            StringArrayLen,
198        >(
199            &args[0].as_string::<StringArrayLen>(),
200            length_array,
201            args.get(2).map(|arg| arg.as_string::<FillArrayLen>()),
202        ),
203    }
204}
205
206/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
207/// rpad('hi', 5, 'xy') = 'hixyx'
208fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
209    string_array: &StringArrType,
210    length_array: &Int64Array,
211    fill_array: Option<FillArrType>,
212) -> Result<ArrayRef>
213where
214    StringArrType: StringArrayType<'a>,
215    FillArrType: StringArrayType<'a>,
216    StringArrayLen: OffsetSizeTrait,
217{
218    let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
219    let mut graphemes_buf = Vec::new();
220    let mut fill_chars_buf = Vec::new();
221
222    match fill_array {
223        None => {
224            string_array.iter().zip(length_array.iter()).try_for_each(
225                |(string, length)| -> Result<(), DataFusionError> {
226                    match (string, length) {
227                        (Some(string), Some(length)) => {
228                            if length > i32::MAX as i64 {
229                                return exec_err!(
230                                    "rpad requested length {} too large",
231                                    length
232                                );
233                            }
234                            let length = if length < 0 { 0 } else { length as usize };
235                            if length == 0 {
236                                builder.append_value("");
237                            } else {
238                                // Reuse buffer by clearing and refilling
239                                graphemes_buf.clear();
240                                graphemes_buf.extend(string.graphemes(true));
241
242                                if length < graphemes_buf.len() {
243                                    builder
244                                        .append_value(graphemes_buf[..length].concat());
245                                } else {
246                                    builder.write_str(string)?;
247                                    builder.write_str(
248                                        &" ".repeat(length - graphemes_buf.len()),
249                                    )?;
250                                    builder.append_value("");
251                                }
252                            }
253                        }
254                        _ => builder.append_null(),
255                    }
256                    Ok(())
257                },
258            )?;
259        }
260        Some(fill_array) => {
261            string_array
262                .iter()
263                .zip(length_array.iter())
264                .zip(fill_array.iter())
265                .try_for_each(
266                    |((string, length), fill)| -> Result<(), DataFusionError> {
267                        match (string, length, fill) {
268                            (Some(string), Some(length), Some(fill)) => {
269                                if length > i32::MAX as i64 {
270                                    return exec_err!(
271                                        "rpad requested length {} too large",
272                                        length
273                                    );
274                                }
275                                let length = if length < 0 { 0 } else { length as usize };
276                                // Reuse buffer by clearing and refilling
277                                graphemes_buf.clear();
278                                graphemes_buf.extend(string.graphemes(true));
279
280                                if length < graphemes_buf.len() {
281                                    builder
282                                        .append_value(graphemes_buf[..length].concat());
283                                } else if fill.is_empty() {
284                                    builder.append_value(string);
285                                } else {
286                                    builder.write_str(string)?;
287                                    // Reuse fill_chars_buf by clearing and refilling
288                                    fill_chars_buf.clear();
289                                    fill_chars_buf.extend(fill.chars());
290                                    for l in 0..length - graphemes_buf.len() {
291                                        let c = *fill_chars_buf
292                                            .get(l % fill_chars_buf.len())
293                                            .unwrap();
294                                        builder.write_char(c)?;
295                                    }
296                                    builder.append_value("");
297                                }
298                            }
299                            _ => builder.append_null(),
300                        }
301                        Ok(())
302                    },
303                )?;
304        }
305    }
306
307    Ok(Arc::new(builder.finish()) as ArrayRef)
308}
309
310#[cfg(test)]
311mod tests {
312    use arrow::array::{Array, StringArray};
313    use arrow::datatypes::DataType::Utf8;
314
315    use datafusion_common::{Result, ScalarValue};
316    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
317
318    use crate::unicode::rpad::RPadFunc;
319    use crate::utils::test::test_function;
320
321    #[test]
322    fn test_functions() -> Result<()> {
323        test_function!(
324            RPadFunc::new(),
325            vec![
326                ColumnarValue::Scalar(ScalarValue::from("josé")),
327                ColumnarValue::Scalar(ScalarValue::from(5i64)),
328            ],
329            Ok(Some("josé ")),
330            &str,
331            Utf8,
332            StringArray
333        );
334        test_function!(
335            RPadFunc::new(),
336            vec![
337                ColumnarValue::Scalar(ScalarValue::from("hi")),
338                ColumnarValue::Scalar(ScalarValue::from(5i64)),
339            ],
340            Ok(Some("hi   ")),
341            &str,
342            Utf8,
343            StringArray
344        );
345        test_function!(
346            RPadFunc::new(),
347            vec![
348                ColumnarValue::Scalar(ScalarValue::from("hi")),
349                ColumnarValue::Scalar(ScalarValue::from(0i64)),
350            ],
351            Ok(Some("")),
352            &str,
353            Utf8,
354            StringArray
355        );
356        test_function!(
357            RPadFunc::new(),
358            vec![
359                ColumnarValue::Scalar(ScalarValue::from("hi")),
360                ColumnarValue::Scalar(ScalarValue::Int64(None)),
361            ],
362            Ok(None),
363            &str,
364            Utf8,
365            StringArray
366        );
367        test_function!(
368            RPadFunc::new(),
369            vec![
370                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
371                ColumnarValue::Scalar(ScalarValue::from(5i64)),
372            ],
373            Ok(None),
374            &str,
375            Utf8,
376            StringArray
377        );
378        test_function!(
379            RPadFunc::new(),
380            vec![
381                ColumnarValue::Scalar(ScalarValue::from("hi")),
382                ColumnarValue::Scalar(ScalarValue::from(5i64)),
383                ColumnarValue::Scalar(ScalarValue::from("xy")),
384            ],
385            Ok(Some("hixyx")),
386            &str,
387            Utf8,
388            StringArray
389        );
390        test_function!(
391            RPadFunc::new(),
392            vec![
393                ColumnarValue::Scalar(ScalarValue::from("hi")),
394                ColumnarValue::Scalar(ScalarValue::from(21i64)),
395                ColumnarValue::Scalar(ScalarValue::from("abcdef")),
396            ],
397            Ok(Some("hiabcdefabcdefabcdefa")),
398            &str,
399            Utf8,
400            StringArray
401        );
402        test_function!(
403            RPadFunc::new(),
404            vec![
405                ColumnarValue::Scalar(ScalarValue::from("hi")),
406                ColumnarValue::Scalar(ScalarValue::from(5i64)),
407                ColumnarValue::Scalar(ScalarValue::from(" ")),
408            ],
409            Ok(Some("hi   ")),
410            &str,
411            Utf8,
412            StringArray
413        );
414        test_function!(
415            RPadFunc::new(),
416            vec![
417                ColumnarValue::Scalar(ScalarValue::from("hi")),
418                ColumnarValue::Scalar(ScalarValue::from(5i64)),
419                ColumnarValue::Scalar(ScalarValue::from("")),
420            ],
421            Ok(Some("hi")),
422            &str,
423            Utf8,
424            StringArray
425        );
426        test_function!(
427            RPadFunc::new(),
428            vec![
429                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
430                ColumnarValue::Scalar(ScalarValue::from(5i64)),
431                ColumnarValue::Scalar(ScalarValue::from("xy")),
432            ],
433            Ok(None),
434            &str,
435            Utf8,
436            StringArray
437        );
438        test_function!(
439            RPadFunc::new(),
440            vec![
441                ColumnarValue::Scalar(ScalarValue::from("hi")),
442                ColumnarValue::Scalar(ScalarValue::Int64(None)),
443                ColumnarValue::Scalar(ScalarValue::from("xy")),
444            ],
445            Ok(None),
446            &str,
447            Utf8,
448            StringArray
449        );
450        test_function!(
451            RPadFunc::new(),
452            vec![
453                ColumnarValue::Scalar(ScalarValue::from("hi")),
454                ColumnarValue::Scalar(ScalarValue::from(5i64)),
455                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
456            ],
457            Ok(None),
458            &str,
459            Utf8,
460            StringArray
461        );
462        test_function!(
463            RPadFunc::new(),
464            vec![
465                ColumnarValue::Scalar(ScalarValue::from("josé")),
466                ColumnarValue::Scalar(ScalarValue::from(10i64)),
467                ColumnarValue::Scalar(ScalarValue::from("xy")),
468            ],
469            Ok(Some("joséxyxyxy")),
470            &str,
471            Utf8,
472            StringArray
473        );
474        test_function!(
475            RPadFunc::new(),
476            vec![
477                ColumnarValue::Scalar(ScalarValue::from("josé")),
478                ColumnarValue::Scalar(ScalarValue::from(10i64)),
479                ColumnarValue::Scalar(ScalarValue::from("éñ")),
480            ],
481            Ok(Some("josééñéñéñ")),
482            &str,
483            Utf8,
484            StringArray
485        );
486        #[cfg(not(feature = "unicode_expressions"))]
487        test_function!(
488            RPadFunc::new(),
489            &[
490                ColumnarValue::Scalar(ScalarValue::from("josé")),
491                ColumnarValue::Scalar(ScalarValue::from(5i64)),
492            ],
493            internal_err!(
494                "function rpad requires compilation with feature flag: unicode_expressions."
495            ),
496            &str,
497            Utf8,
498            StringArray
499        );
500
501        Ok(())
502    }
503}