datafusion_functions/string/
rtrim.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, OffsetSizeTrait};
19use arrow::datatypes::DataType;
20use std::any::Any;
21use std::sync::Arc;
22
23use crate::string::common::*;
24use crate::utils::make_scalar_function;
25use datafusion_common::types::logical_string;
26use datafusion_common::{Result, exec_err};
27use datafusion_expr::function::Hint;
28use datafusion_expr::{
29    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    TypeSignature, TypeSignatureClass, Volatility,
31};
32use datafusion_macros::user_doc;
33
34/// Returns the longest string  with trailing characters removed. If the characters are not specified, whitespace is removed.
35/// rtrim('testxxzx', 'xyz') = 'test'
36fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
37    let use_string_view = args[0].data_type() == &DataType::Utf8View;
38    let args = if args.len() > 1 {
39        let arg1 = arrow::compute::kernels::cast::cast(&args[1], args[0].data_type())?;
40        vec![Arc::clone(&args[0]), arg1]
41    } else {
42        args.to_owned()
43    };
44    general_trim::<T, TrimRight>(&args, use_string_view)
45}
46
47#[user_doc(
48    doc_section(label = "String Functions"),
49    description = "Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.",
50    syntax_example = "rtrim(str[, trim_str])",
51    alternative_syntax = "trim(TRAILING trim_str FROM str)",
52    sql_example = r#"```sql
53> select rtrim('  datafusion  ');
54+-------------------------------+
55| rtrim(Utf8("  datafusion  ")) |
56+-------------------------------+
57|   datafusion                  |
58+-------------------------------+
59> select rtrim('___datafusion___', '_');
60+-------------------------------------------+
61| rtrim(Utf8("___datafusion___"),Utf8("_")) |
62+-------------------------------------------+
63| ___datafusion                             |
64+-------------------------------------------+
65```"#,
66    standard_argument(name = "str", prefix = "String"),
67    argument(
68        name = "trim_str",
69        description = "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._"
70    ),
71    related_udf(name = "btrim"),
72    related_udf(name = "ltrim")
73)]
74#[derive(Debug, PartialEq, Eq, Hash)]
75pub struct RtrimFunc {
76    signature: Signature,
77}
78
79impl Default for RtrimFunc {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl RtrimFunc {
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::one_of(
89                vec![
90                    TypeSignature::Coercible(vec![
91                        Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
92                        Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
93                    ]),
94                    TypeSignature::Coercible(vec![Coercion::new_exact(
95                        TypeSignatureClass::Native(logical_string()),
96                    )]),
97                ],
98                Volatility::Immutable,
99            ),
100        }
101    }
102}
103
104impl ScalarUDFImpl for RtrimFunc {
105    fn as_any(&self) -> &dyn Any {
106        self
107    }
108
109    fn name(&self) -> &str {
110        "rtrim"
111    }
112
113    fn signature(&self) -> &Signature {
114        &self.signature
115    }
116
117    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
118        Ok(arg_types[0].clone())
119    }
120
121    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
122        match args.args[0].data_type() {
123            DataType::Utf8 | DataType::Utf8View => make_scalar_function(
124                rtrim::<i32>,
125                vec![Hint::Pad, Hint::AcceptsSingular],
126            )(&args.args),
127            DataType::LargeUtf8 => make_scalar_function(
128                rtrim::<i64>,
129                vec![Hint::Pad, Hint::AcceptsSingular],
130            )(&args.args),
131            other => exec_err!(
132                "Unsupported data type {other:?} for function rtrim,\
133                expected Utf8, LargeUtf8 or Utf8View."
134            ),
135        }
136    }
137
138    fn documentation(&self) -> Option<&Documentation> {
139        self.doc()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use arrow::array::{Array, StringArray, StringViewArray};
146    use arrow::datatypes::DataType::{Utf8, Utf8View};
147
148    use datafusion_common::{Result, ScalarValue};
149    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
150
151    use crate::string::rtrim::RtrimFunc;
152    use crate::utils::test::test_function;
153
154    #[test]
155    fn test_functions() {
156        // String view cases for checking normal logic
157        test_function!(
158            RtrimFunc::new(),
159            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
160                String::from("alphabet  ")
161            ))),],
162            Ok(Some("alphabet")),
163            &str,
164            Utf8View,
165            StringViewArray
166        );
167        test_function!(
168            RtrimFunc::new(),
169            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
170                String::from("  alphabet  ")
171            ))),],
172            Ok(Some("  alphabet")),
173            &str,
174            Utf8View,
175            StringViewArray
176        );
177        test_function!(
178            RtrimFunc::new(),
179            vec![
180                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
181                    "alphabet"
182                )))),
183                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t ")))),
184            ],
185            Ok(Some("alphabe")),
186            &str,
187            Utf8View,
188            StringViewArray
189        );
190        test_function!(
191            RtrimFunc::new(),
192            vec![
193                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
194                    "alphabet"
195                )))),
196                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
197                    "alphabe"
198                )))),
199            ],
200            Ok(Some("alphabet")),
201            &str,
202            Utf8View,
203            StringViewArray
204        );
205        test_function!(
206            RtrimFunc::new(),
207            vec![
208                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
209                    "alphabet"
210                )))),
211                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
212            ],
213            Ok(None),
214            &str,
215            Utf8View,
216            StringViewArray
217        );
218        // Special string view case for checking unlined output(len > 12)
219        test_function!(
220            RtrimFunc::new(),
221            vec![
222                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
223                    "alphabetalphabetxxx"
224                )))),
225                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("x")))),
226            ],
227            Ok(Some("alphabetalphabet")),
228            &str,
229            Utf8View,
230            StringViewArray
231        );
232        // String cases
233        test_function!(
234            RtrimFunc::new(),
235            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
236                String::from("alphabet  ")
237            ))),],
238            Ok(Some("alphabet")),
239            &str,
240            Utf8,
241            StringArray
242        );
243        test_function!(
244            RtrimFunc::new(),
245            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
246                String::from("  alphabet  ")
247            ))),],
248            Ok(Some("  alphabet")),
249            &str,
250            Utf8,
251            StringArray
252        );
253        test_function!(
254            RtrimFunc::new(),
255            vec![
256                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))),
257                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t ")))),
258            ],
259            Ok(Some("alphabe")),
260            &str,
261            Utf8,
262            StringArray
263        );
264        test_function!(
265            RtrimFunc::new(),
266            vec![
267                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))),
268                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))),
269            ],
270            Ok(Some("alphabet")),
271            &str,
272            Utf8,
273            StringArray
274        );
275        test_function!(
276            RtrimFunc::new(),
277            vec![
278                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))),
279                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
280            ],
281            Ok(None),
282            &str,
283            Utf8,
284            StringArray
285        );
286    }
287}