Skip to main content

datafusion_functions/unicode/
reverse.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::strings::{
19    BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder,
20};
21use crate::utils::make_scalar_function;
22use DataType::{LargeUtf8, Utf8, Utf8View};
23use arrow::array::{Array, ArrayRef, AsArray, StringArrayType};
24use arrow::datatypes::DataType;
25use datafusion_common::{Result, exec_err};
26use datafusion_expr::{
27    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28    Volatility,
29};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33    doc_section(label = "String Functions"),
34    description = "Reverses the character order of a string.",
35    syntax_example = "reverse(str)",
36    sql_example = r#"```sql
37> select reverse('datafusion');
38+-----------------------------+
39| reverse(Utf8("datafusion")) |
40+-----------------------------+
41| noisufatad                  |
42+-----------------------------+
43```"#,
44    standard_argument(name = "str", prefix = "String")
45)]
46#[derive(Debug, PartialEq, Eq, Hash)]
47pub struct ReverseFunc {
48    signature: Signature,
49}
50
51impl Default for ReverseFunc {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl ReverseFunc {
58    pub fn new() -> Self {
59        use DataType::*;
60        Self {
61            signature: Signature::uniform(
62                1,
63                vec![Utf8View, Utf8, LargeUtf8],
64                Volatility::Immutable,
65            ),
66        }
67    }
68}
69
70impl ScalarUDFImpl for ReverseFunc {
71    fn name(&self) -> &str {
72        "reverse"
73    }
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
80        Ok(arg_types[0].clone())
81    }
82
83    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84        let args = &args.args;
85        match args[0].data_type() {
86            Utf8 | Utf8View | LargeUtf8 => make_scalar_function(reverse, vec![])(args),
87            other => {
88                exec_err!("Unsupported data type {other:?} for function reverse")
89            }
90        }
91    }
92
93    fn documentation(&self) -> Option<&Documentation> {
94        self.doc()
95    }
96}
97
98/// Reverses the order of the characters in the string `reverse('abcde') = 'edcba'`.
99/// The implementation uses UTF-8 code points as characters
100fn reverse(args: &[ArrayRef]) -> Result<ArrayRef> {
101    let len = args[0].len();
102
103    match args[0].data_type() {
104        LargeUtf8 => reverse_impl(
105            &args[0].as_string::<i64>(),
106            GenericStringArrayBuilder::<i64>::with_capacity(len, 1024),
107        ),
108        Utf8 => reverse_impl(
109            &args[0].as_string::<i32>(),
110            GenericStringArrayBuilder::<i32>::with_capacity(len, 1024),
111        ),
112        Utf8View => reverse_impl(
113            &args[0].as_string_view(),
114            StringViewArrayBuilder::with_capacity(len),
115        ),
116        _ => unreachable!(
117            "Reverse can only be applied to Utf8View, Utf8 and LargeUtf8 types"
118        ),
119    }
120}
121
122fn reverse_impl<'a, StringArrType, B>(
123    string_array: &StringArrType,
124    mut array_builder: B,
125) -> Result<ArrayRef>
126where
127    StringArrType: StringArrayType<'a>,
128    B: BulkNullStringArrayBuilder,
129{
130    let item_len = string_array.len();
131    // Null-preserving: reuse the input null buffer as the output null buffer.
132    let nulls = string_array.nulls().cloned();
133    let mut string_buf = String::new();
134    let mut byte_buf = Vec::<u8>::new();
135
136    if let Some(ref n) = nulls {
137        for i in 0..item_len {
138            if n.is_null(i) {
139                array_builder.append_placeholder();
140            } else {
141                // SAFETY: `n.is_null(i)` was false in the branch above.
142                let s = unsafe { string_array.value_unchecked(i) };
143                append_reversed(s, &mut array_builder, &mut byte_buf, &mut string_buf);
144            }
145        }
146    } else {
147        for i in 0..item_len {
148            // SAFETY: no null buffer means every index is valid.
149            let s = unsafe { string_array.value_unchecked(i) };
150            append_reversed(s, &mut array_builder, &mut byte_buf, &mut string_buf);
151        }
152    }
153
154    array_builder.finish(nulls)
155}
156
157#[inline]
158fn append_reversed<B: BulkNullStringArrayBuilder>(
159    s: &str,
160    builder: &mut B,
161    byte_buf: &mut Vec<u8>,
162    string_buf: &mut String,
163) {
164    if s.is_ascii() {
165        // reverse bytes directly since ASCII characters are single bytes
166        byte_buf.extend(s.as_bytes());
167        byte_buf.reverse();
168        // SAFETY: input was ASCII, so reversed bytes are still valid UTF-8.
169        let reversed = unsafe { std::str::from_utf8_unchecked(byte_buf) };
170        builder.append_value(reversed);
171        byte_buf.clear();
172    } else {
173        string_buf.extend(s.chars().rev());
174        builder.append_value(string_buf);
175        string_buf.clear();
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray};
182    use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
183
184    use datafusion_common::{Result, ScalarValue};
185    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
186
187    use crate::unicode::reverse::ReverseFunc;
188    use crate::utils::test::test_function;
189
190    macro_rules! test_reverse {
191        ($INPUT:expr, $EXPECTED:expr) => {
192            test_function!(
193                ReverseFunc::new(),
194                vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
195                $EXPECTED,
196                &str,
197                Utf8,
198                StringArray
199            );
200
201            test_function!(
202                ReverseFunc::new(),
203                vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
204                $EXPECTED,
205                &str,
206                LargeUtf8,
207                LargeStringArray
208            );
209
210            test_function!(
211                ReverseFunc::new(),
212                vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
213                $EXPECTED,
214                &str,
215                Utf8View,
216                StringViewArray
217            );
218        };
219    }
220
221    #[test]
222    fn test_functions() -> Result<()> {
223        test_reverse!(Some("abcde".into()), Ok(Some("edcba")));
224        test_reverse!(Some("loẅks".into()), Ok(Some("sk̈wol")));
225        test_reverse!(Some("loẅks".into()), Ok(Some("sk̈wol")));
226        test_reverse!(None, Ok(None));
227        #[cfg(not(feature = "unicode_expressions"))]
228        test_reverse!(
229            Some("abcde".into()),
230            internal_err!(
231                "function reverse requires compilation with feature flag: unicode_expressions."
232            ),
233        );
234
235        Ok(())
236    }
237
238    #[test]
239    fn test_array_with_nulls() {
240        use crate::unicode::reverse::reverse;
241        use arrow::array::ArrayRef;
242        use std::sync::Arc;
243
244        let input_values = vec![Some("abcd"), None, Some("XYZ"), Some("héllo"), None];
245        let expected: Vec<Option<&str>> =
246            vec![Some("dcba"), None, Some("ZYX"), Some("olléh"), None];
247
248        let cases: Vec<(&str, ArrayRef)> = vec![
249            (
250                "StringArray",
251                Arc::new(StringArray::from(input_values.clone())),
252            ),
253            (
254                "LargeStringArray",
255                Arc::new(LargeStringArray::from(input_values.clone())),
256            ),
257            (
258                "StringViewArray",
259                Arc::new(StringViewArray::from(input_values.clone())),
260            ),
261        ];
262
263        for (label, input) in cases {
264            let out = reverse(&[input]).unwrap();
265            assert_eq!(out.len(), expected.len(), "{label}: length mismatch");
266
267            let actual: Vec<Option<&str>> = match out.data_type() {
268                Utf8 => out
269                    .as_any()
270                    .downcast_ref::<StringArray>()
271                    .unwrap()
272                    .iter()
273                    .collect(),
274                LargeUtf8 => out
275                    .as_any()
276                    .downcast_ref::<LargeStringArray>()
277                    .unwrap()
278                    .iter()
279                    .collect(),
280                Utf8View => out
281                    .as_any()
282                    .downcast_ref::<StringViewArray>()
283                    .unwrap()
284                    .iter()
285                    .collect(),
286                other => panic!("{label}: unexpected output type {other:?}"),
287            };
288            assert_eq!(actual, expected, "{label}: value mismatch");
289        }
290    }
291}