Skip to main content

datafusion_functions/unicode/
lpad.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::Write;
20use std::sync::Arc;
21
22use DataType::{LargeUtf8, Utf8, Utf8View};
23use arrow::array::{
24    Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
25    OffsetSizeTrait, StringArrayType, StringViewArray,
26};
27use arrow::datatypes::DataType;
28use unicode_segmentation::UnicodeSegmentation;
29
30use crate::utils::{make_scalar_function, utf8_to_str_type};
31use datafusion_common::cast::as_int64_array;
32use datafusion_common::{Result, exec_err};
33use datafusion_expr::TypeSignature::Exact;
34use datafusion_expr::{
35    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
36};
37use datafusion_macros::user_doc;
38
39#[user_doc(
40    doc_section(label = "String Functions"),
41    description = "Pads the left side of a string with another string to a specified string length.",
42    syntax_example = "lpad(str, n[, padding_str])",
43    sql_example = r#"```sql
44> select lpad('Dolly', 10, 'hello');
45+---------------------------------------------+
46| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) |
47+---------------------------------------------+
48| helloDolly                                  |
49+---------------------------------------------+
50```"#,
51    standard_argument(name = "str", prefix = "String"),
52    argument(
53        name = "n",
54        description = "String length to pad to. If the input string is longer than this length, it is truncated (on the right)."
55    ),
56    argument(
57        name = "padding_str",
58        description = "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
59    ),
60    related_udf(name = "rpad")
61)]
62#[derive(Debug, PartialEq, Eq, Hash)]
63pub struct LPadFunc {
64    signature: Signature,
65}
66
67impl Default for LPadFunc {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl LPadFunc {
74    pub fn new() -> Self {
75        use DataType::*;
76        Self {
77            signature: Signature::one_of(
78                vec![
79                    Exact(vec![Utf8View, Int64]),
80                    Exact(vec![Utf8View, Int64, Utf8View]),
81                    Exact(vec![Utf8View, Int64, Utf8]),
82                    Exact(vec![Utf8View, Int64, LargeUtf8]),
83                    Exact(vec![Utf8, Int64]),
84                    Exact(vec![Utf8, Int64, Utf8View]),
85                    Exact(vec![Utf8, Int64, Utf8]),
86                    Exact(vec![Utf8, Int64, LargeUtf8]),
87                    Exact(vec![LargeUtf8, Int64]),
88                    Exact(vec![LargeUtf8, Int64, Utf8View]),
89                    Exact(vec![LargeUtf8, Int64, Utf8]),
90                    Exact(vec![LargeUtf8, Int64, LargeUtf8]),
91                ],
92                Volatility::Immutable,
93            ),
94        }
95    }
96}
97
98impl ScalarUDFImpl for LPadFunc {
99    fn as_any(&self) -> &dyn Any {
100        self
101    }
102
103    fn name(&self) -> &str {
104        "lpad"
105    }
106
107    fn signature(&self) -> &Signature {
108        &self.signature
109    }
110
111    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
112        utf8_to_str_type(&arg_types[0], "lpad")
113    }
114
115    fn invoke_with_args(
116        &self,
117        args: datafusion_expr::ScalarFunctionArgs,
118    ) -> Result<ColumnarValue> {
119        let args = &args.args;
120        match args[0].data_type() {
121            Utf8 | Utf8View => make_scalar_function(lpad::<i32>, vec![])(args),
122            LargeUtf8 => make_scalar_function(lpad::<i64>, vec![])(args),
123            other => exec_err!("Unsupported data type {other:?} for function lpad"),
124        }
125    }
126
127    fn documentation(&self) -> Option<&Documentation> {
128        self.doc()
129    }
130}
131
132/// Extends the string to length 'length' by prepending the characters fill (a space by default).
133/// If the string is already longer than length then it is truncated (on the right).
134/// lpad('hi', 5, 'xy') = 'xyxhi'
135fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
136    if args.len() <= 1 || args.len() > 3 {
137        return exec_err!(
138            "lpad was called with {} arguments. It requires at least 2 and at most 3.",
139            args.len()
140        );
141    }
142
143    let length_array = as_int64_array(&args[1])?;
144
145    match (args.len(), args[0].data_type()) {
146        (2, Utf8View) => lpad_impl::<&StringViewArray, &GenericStringArray<i32>, T>(
147            &args[0].as_string_view(),
148            length_array,
149            None,
150        ),
151        (2, Utf8 | LargeUtf8) => lpad_impl::<
152            &GenericStringArray<T>,
153            &GenericStringArray<T>,
154            T,
155        >(&args[0].as_string::<T>(), length_array, None),
156        (3, Utf8View) => lpad_with_replace::<&StringViewArray, T>(
157            &args[0].as_string_view(),
158            length_array,
159            &args[2],
160        ),
161        (3, Utf8 | LargeUtf8) => lpad_with_replace::<&GenericStringArray<T>, T>(
162            &args[0].as_string::<T>(),
163            length_array,
164            &args[2],
165        ),
166        (_, _) => unreachable!("lpad"),
167    }
168}
169
170fn lpad_with_replace<'a, V, T: OffsetSizeTrait>(
171    string_array: &V,
172    length_array: &Int64Array,
173    fill_array: &'a ArrayRef,
174) -> Result<ArrayRef>
175where
176    V: StringArrayType<'a>,
177{
178    match fill_array.data_type() {
179        Utf8View => lpad_impl::<V, &StringViewArray, T>(
180            string_array,
181            length_array,
182            Some(fill_array.as_string_view()),
183        ),
184        LargeUtf8 => lpad_impl::<V, &GenericStringArray<i64>, T>(
185            string_array,
186            length_array,
187            Some(fill_array.as_string::<i64>()),
188        ),
189        Utf8 => lpad_impl::<V, &GenericStringArray<i32>, T>(
190            string_array,
191            length_array,
192            Some(fill_array.as_string::<i32>()),
193        ),
194        other => {
195            exec_err!("Unsupported data type {other:?} for function lpad")
196        }
197    }
198}
199
200fn lpad_impl<'a, V, V2, T>(
201    string_array: &V,
202    length_array: &Int64Array,
203    fill_array: Option<V2>,
204) -> Result<ArrayRef>
205where
206    V: StringArrayType<'a>,
207    V2: StringArrayType<'a>,
208    T: OffsetSizeTrait,
209{
210    let array = if let Some(fill_array) = fill_array {
211        let mut builder: GenericStringBuilder<T> = GenericStringBuilder::new();
212        let mut graphemes_buf = Vec::new();
213        let mut fill_chars_buf = Vec::new();
214
215        for ((string, length), fill) in string_array
216            .iter()
217            .zip(length_array.iter())
218            .zip(fill_array.iter())
219        {
220            if let (Some(string), Some(length), Some(fill)) = (string, length, fill) {
221                if length > i32::MAX as i64 {
222                    return exec_err!("lpad requested length {length} too large");
223                }
224
225                let length = if length < 0 { 0 } else { length as usize };
226                if length == 0 {
227                    builder.append_value("");
228                    continue;
229                }
230
231                if string.is_ascii() && fill.is_ascii() {
232                    // ASCII fast path: byte length == character length,
233                    // so we skip expensive grapheme segmentation.
234                    let str_len = string.len();
235                    if length < str_len {
236                        builder.append_value(&string[..length]);
237                    } else if fill.is_empty() {
238                        builder.append_value(string);
239                    } else {
240                        let pad_len = length - str_len;
241                        let fill_len = fill.len();
242                        let full_reps = pad_len / fill_len;
243                        let remainder = pad_len % fill_len;
244                        for _ in 0..full_reps {
245                            builder.write_str(fill)?;
246                        }
247                        if remainder > 0 {
248                            builder.write_str(&fill[..remainder])?;
249                        }
250                        builder.append_value(string);
251                    }
252                } else {
253                    // Reuse buffers by clearing and refilling
254                    graphemes_buf.clear();
255                    graphemes_buf.extend(string.graphemes(true));
256
257                    fill_chars_buf.clear();
258                    fill_chars_buf.extend(fill.chars());
259
260                    if length < graphemes_buf.len() {
261                        builder.append_value(graphemes_buf[..length].concat());
262                    } else if fill_chars_buf.is_empty() {
263                        builder.append_value(string);
264                    } else {
265                        for l in 0..length - graphemes_buf.len() {
266                            let c =
267                                *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap();
268                            builder.write_char(c)?;
269                        }
270                        builder.append_value(string);
271                    }
272                }
273            } else {
274                builder.append_null();
275            }
276        }
277
278        builder.finish()
279    } else {
280        let mut builder: GenericStringBuilder<T> = GenericStringBuilder::new();
281        let mut graphemes_buf = Vec::new();
282
283        for (string, length) in string_array.iter().zip(length_array.iter()) {
284            if let (Some(string), Some(length)) = (string, length) {
285                if length > i32::MAX as i64 {
286                    return exec_err!("lpad requested length {length} too large");
287                }
288
289                let length = if length < 0 { 0 } else { length as usize };
290                if length == 0 {
291                    builder.append_value("");
292                    continue;
293                }
294
295                if string.is_ascii() {
296                    // ASCII fast path: byte length == character length
297                    let str_len = string.len();
298                    if length < str_len {
299                        builder.append_value(&string[..length]);
300                    } else {
301                        for _ in 0..(length - str_len) {
302                            builder.write_str(" ")?;
303                        }
304                        builder.append_value(string);
305                    }
306                } else {
307                    // Reuse buffer by clearing and refilling
308                    graphemes_buf.clear();
309                    graphemes_buf.extend(string.graphemes(true));
310
311                    if length < graphemes_buf.len() {
312                        builder.append_value(graphemes_buf[..length].concat());
313                    } else {
314                        for _ in 0..(length - graphemes_buf.len()) {
315                            builder.write_str(" ")?;
316                        }
317                        builder.append_value(string);
318                    }
319                }
320            } else {
321                builder.append_null();
322            }
323        }
324
325        builder.finish()
326    };
327
328    Ok(Arc::new(array) as ArrayRef)
329}
330
331#[cfg(test)]
332mod tests {
333    use crate::unicode::lpad::LPadFunc;
334    use crate::utils::test::test_function;
335
336    use arrow::array::{Array, LargeStringArray, StringArray};
337    use arrow::datatypes::DataType::{LargeUtf8, Utf8};
338
339    use datafusion_common::{Result, ScalarValue};
340    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
341
342    macro_rules! test_lpad {
343        ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => {
344            test_function!(
345                LPadFunc::new(),
346                vec![
347                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
348                    ColumnarValue::Scalar($LENGTH)
349                ],
350                $EXPECTED,
351                &str,
352                Utf8,
353                StringArray
354            );
355
356            test_function!(
357                LPadFunc::new(),
358                vec![
359                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
360                    ColumnarValue::Scalar($LENGTH)
361                ],
362                $EXPECTED,
363                &str,
364                LargeUtf8,
365                LargeStringArray
366            );
367
368            test_function!(
369                LPadFunc::new(),
370                vec![
371                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
372                    ColumnarValue::Scalar($LENGTH)
373                ],
374                $EXPECTED,
375                &str,
376                Utf8,
377                StringArray
378            );
379        };
380
381        ($INPUT:expr, $LENGTH:expr, $REPLACE:expr, $EXPECTED:expr) => {
382            // utf8, utf8
383            test_function!(
384                LPadFunc::new(),
385                vec![
386                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
387                    ColumnarValue::Scalar($LENGTH),
388                    ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE))
389                ],
390                $EXPECTED,
391                &str,
392                Utf8,
393                StringArray
394            );
395            // utf8, largeutf8
396            test_function!(
397                LPadFunc::new(),
398                vec![
399                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
400                    ColumnarValue::Scalar($LENGTH),
401                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE))
402                ],
403                $EXPECTED,
404                &str,
405                Utf8,
406                StringArray
407            );
408            // utf8, utf8view
409            test_function!(
410                LPadFunc::new(),
411                vec![
412                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
413                    ColumnarValue::Scalar($LENGTH),
414                    ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE))
415                ],
416                $EXPECTED,
417                &str,
418                Utf8,
419                StringArray
420            );
421
422            // largeutf8, utf8
423            test_function!(
424                LPadFunc::new(),
425                vec![
426                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
427                    ColumnarValue::Scalar($LENGTH),
428                    ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE))
429                ],
430                $EXPECTED,
431                &str,
432                LargeUtf8,
433                LargeStringArray
434            );
435            // largeutf8, largeutf8
436            test_function!(
437                LPadFunc::new(),
438                vec![
439                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
440                    ColumnarValue::Scalar($LENGTH),
441                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE))
442                ],
443                $EXPECTED,
444                &str,
445                LargeUtf8,
446                LargeStringArray
447            );
448            // largeutf8, utf8view
449            test_function!(
450                LPadFunc::new(),
451                vec![
452                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
453                    ColumnarValue::Scalar($LENGTH),
454                    ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE))
455                ],
456                $EXPECTED,
457                &str,
458                LargeUtf8,
459                LargeStringArray
460            );
461
462            // utf8view, utf8
463            test_function!(
464                LPadFunc::new(),
465                vec![
466                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
467                    ColumnarValue::Scalar($LENGTH),
468                    ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE))
469                ],
470                $EXPECTED,
471                &str,
472                Utf8,
473                StringArray
474            );
475            // utf8view, largeutf8
476            test_function!(
477                LPadFunc::new(),
478                vec![
479                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
480                    ColumnarValue::Scalar($LENGTH),
481                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE))
482                ],
483                $EXPECTED,
484                &str,
485                Utf8,
486                StringArray
487            );
488            // utf8view, utf8view
489            test_function!(
490                LPadFunc::new(),
491                vec![
492                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
493                    ColumnarValue::Scalar($LENGTH),
494                    ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE))
495                ],
496                $EXPECTED,
497                &str,
498                Utf8,
499                StringArray
500            );
501        };
502    }
503
504    #[test]
505    fn test_functions() -> Result<()> {
506        test_lpad!(
507            Some("josé".into()),
508            ScalarValue::Int64(Some(5i64)),
509            Ok(Some(" josé"))
510        );
511        test_lpad!(
512            Some("hi".into()),
513            ScalarValue::Int64(Some(5i64)),
514            Ok(Some("   hi"))
515        );
516        test_lpad!(
517            Some("hi".into()),
518            ScalarValue::Int64(Some(0i64)),
519            Ok(Some(""))
520        );
521        test_lpad!(Some("hi".into()), ScalarValue::Int64(None), Ok(None));
522        test_lpad!(None, ScalarValue::Int64(Some(5i64)), Ok(None));
523        test_lpad!(
524            Some("hi".into()),
525            ScalarValue::Int64(Some(5i64)),
526            Some("xy".into()),
527            Ok(Some("xyxhi"))
528        );
529        test_lpad!(
530            Some("hi".into()),
531            ScalarValue::Int64(Some(21i64)),
532            Some("abcdef".into()),
533            Ok(Some("abcdefabcdefabcdefahi"))
534        );
535        test_lpad!(
536            Some("hi".into()),
537            ScalarValue::Int64(Some(5i64)),
538            Some(" ".into()),
539            Ok(Some("   hi"))
540        );
541        test_lpad!(
542            Some("hi".into()),
543            ScalarValue::Int64(Some(5i64)),
544            Some("".into()),
545            Ok(Some("hi"))
546        );
547        test_lpad!(
548            None,
549            ScalarValue::Int64(Some(5i64)),
550            Some("xy".into()),
551            Ok(None)
552        );
553        test_lpad!(
554            Some("hi".into()),
555            ScalarValue::Int64(None),
556            Some("xy".into()),
557            Ok(None)
558        );
559        test_lpad!(
560            Some("hi".into()),
561            ScalarValue::Int64(Some(5i64)),
562            None,
563            Ok(None)
564        );
565        test_lpad!(
566            Some("hello".into()),
567            ScalarValue::Int64(Some(2i64)),
568            Ok(Some("he"))
569        );
570        test_lpad!(
571            Some("hi".into()),
572            ScalarValue::Int64(Some(6i64)),
573            Some("xy".into()),
574            Ok(Some("xyxyhi"))
575        );
576        test_lpad!(
577            Some("josé".into()),
578            ScalarValue::Int64(Some(10i64)),
579            Some("xy".into()),
580            Ok(Some("xyxyxyjosé"))
581        );
582        test_lpad!(
583            Some("josé".into()),
584            ScalarValue::Int64(Some(10i64)),
585            Some("éñ".into()),
586            Ok(Some("éñéñéñjosé"))
587        );
588
589        #[cfg(not(feature = "unicode_expressions"))]
590        test_lpad!(
591            Some("josé".into()),
592            ScalarValue::Int64(Some(5i64)),
593            internal_err!(
594                "function lpad requires compilation with feature flag: unicode_expressions."
595            )
596        );
597
598        Ok(())
599    }
600}