datafusion_functions/string/
ends_with.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Scalar};
22use arrow::compute::kernels::comparison::ends_with as arrow_ends_with;
23use arrow::datatypes::DataType;
24
25use datafusion_common::types::logical_string;
26use datafusion_common::utils::take_function_args;
27use datafusion_common::{Result, ScalarValue, exec_err};
28use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
29use datafusion_expr::{
30    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31    TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36    doc_section(label = "String Functions"),
37    description = "Tests if a string ends with a substring.",
38    syntax_example = "ends_with(str, substr)",
39    sql_example = r#"```sql
40>  select ends_with('datafusion', 'soin');
41+--------------------------------------------+
42| ends_with(Utf8("datafusion"),Utf8("soin")) |
43+--------------------------------------------+
44| false                                      |
45+--------------------------------------------+
46> select ends_with('datafusion', 'sion');
47+--------------------------------------------+
48| ends_with(Utf8("datafusion"),Utf8("sion")) |
49+--------------------------------------------+
50| true                                       |
51+--------------------------------------------+
52```"#,
53    standard_argument(name = "str", prefix = "String"),
54    argument(name = "substr", description = "Substring to test for.")
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct EndsWithFunc {
58    signature: Signature,
59}
60
61impl Default for EndsWithFunc {
62    fn default() -> Self {
63        EndsWithFunc::new()
64    }
65}
66
67impl EndsWithFunc {
68    pub fn new() -> Self {
69        Self {
70            signature: Signature::coercible(
71                vec![
72                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74                ],
75                Volatility::Immutable,
76            ),
77        }
78    }
79}
80
81impl ScalarUDFImpl for EndsWithFunc {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn name(&self) -> &str {
87        "ends_with"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
95        Ok(DataType::Boolean)
96    }
97
98    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99        let [str_arg, suffix_arg] = take_function_args(self.name(), &args.args)?;
100
101        // Determine the common type for coercion
102        let coercion_type = string_coercion(
103            &str_arg.data_type(),
104            &suffix_arg.data_type(),
105        )
106        .or_else(|| {
107            binary_to_string_coercion(&str_arg.data_type(), &suffix_arg.data_type())
108        });
109
110        let Some(coercion_type) = coercion_type else {
111            return exec_err!(
112                "Unsupported data types {:?}, {:?} for function `ends_with`.",
113                str_arg.data_type(),
114                suffix_arg.data_type()
115            );
116        };
117
118        // Helper to cast an array if needed
119        let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
120            if arr.data_type() == target {
121                Ok(Arc::clone(arr))
122            } else {
123                Ok(arrow::compute::kernels::cast::cast(arr, target)?)
124            }
125        };
126
127        match (str_arg, suffix_arg) {
128            // Both scalars - just compute directly
129            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(suffix_scalar)) => {
130                let str_arr = str_scalar.to_array_of_size(1)?;
131                let suffix_arr = suffix_scalar.to_array_of_size(1)?;
132                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
133                let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
134                let result = arrow_ends_with(&str_arr, &suffix_arr)?;
135                Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
136                    &result, 0,
137                )?))
138            }
139            // String is array, suffix is scalar - use Scalar wrapper for optimization
140            (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(suffix_scalar)) => {
141                let str_arr = maybe_cast(str_arr, &coercion_type)?;
142                let suffix_arr = suffix_scalar.to_array_of_size(1)?;
143                let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
144                let suffix_scalar = Scalar::new(suffix_arr);
145                let result = arrow_ends_with(&str_arr, &suffix_scalar)?;
146                Ok(ColumnarValue::Array(Arc::new(result)))
147            }
148            // String is scalar, suffix is array - use Scalar wrapper for string
149            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(suffix_arr)) => {
150                let str_arr = str_scalar.to_array_of_size(1)?;
151                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
152                let str_scalar = Scalar::new(str_arr);
153                let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
154                let result = arrow_ends_with(&str_scalar, &suffix_arr)?;
155                Ok(ColumnarValue::Array(Arc::new(result)))
156            }
157            // Both arrays - pass directly
158            (ColumnarValue::Array(str_arr), ColumnarValue::Array(suffix_arr)) => {
159                let str_arr = maybe_cast(str_arr, &coercion_type)?;
160                let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
161                let result = arrow_ends_with(&str_arr, &suffix_arr)?;
162                Ok(ColumnarValue::Array(Arc::new(result)))
163            }
164        }
165    }
166
167    fn documentation(&self) -> Option<&Documentation> {
168        self.doc()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use arrow::array::{Array, BooleanArray, StringArray};
175    use arrow::datatypes::DataType::Boolean;
176    use arrow::datatypes::{DataType, Field};
177    use std::sync::Arc;
178
179    use datafusion_common::Result;
180    use datafusion_common::ScalarValue;
181    use datafusion_common::config::ConfigOptions;
182    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
183
184    use crate::string::ends_with::EndsWithFunc;
185    use crate::utils::test::test_function;
186
187    #[test]
188    fn test_scalar_scalar() -> Result<()> {
189        // Test Scalar + Scalar combinations
190        test_function!(
191            EndsWithFunc::new(),
192            vec![
193                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
194                ColumnarValue::Scalar(ScalarValue::from("alph")),
195            ],
196            Ok(Some(false)),
197            bool,
198            Boolean,
199            BooleanArray
200        );
201        test_function!(
202            EndsWithFunc::new(),
203            vec![
204                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
205                ColumnarValue::Scalar(ScalarValue::from("bet")),
206            ],
207            Ok(Some(true)),
208            bool,
209            Boolean,
210            BooleanArray
211        );
212        test_function!(
213            EndsWithFunc::new(),
214            vec![
215                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
216                ColumnarValue::Scalar(ScalarValue::from("alph")),
217            ],
218            Ok(None),
219            bool,
220            Boolean,
221            BooleanArray
222        );
223        test_function!(
224            EndsWithFunc::new(),
225            vec![
226                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
227                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
228            ],
229            Ok(None),
230            bool,
231            Boolean,
232            BooleanArray
233        );
234
235        // Test with LargeUtf8
236        test_function!(
237            EndsWithFunc::new(),
238            vec![
239                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(
240                    "alphabet".to_string()
241                ))),
242                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bet".to_string()))),
243            ],
244            Ok(Some(true)),
245            bool,
246            Boolean,
247            BooleanArray
248        );
249
250        // Test with Utf8View
251        test_function!(
252            EndsWithFunc::new(),
253            vec![
254                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
255                    "alphabet".to_string()
256                ))),
257                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bet".to_string()))),
258            ],
259            Ok(Some(true)),
260            bool,
261            Boolean,
262            BooleanArray
263        );
264
265        Ok(())
266    }
267
268    #[test]
269    fn test_array_scalar() -> Result<()> {
270        // Test Array + Scalar (the optimized path)
271        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
272            Some("alphabet"),
273            Some("alphabet"),
274            Some("beta"),
275            None,
276        ])));
277        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string())));
278
279        let args = vec![array, scalar];
280        test_function!(
281            EndsWithFunc::new(),
282            args,
283            Ok(Some(true)), // First element result: "alphabet" ends with "bet"
284            bool,
285            Boolean,
286            BooleanArray
287        );
288
289        Ok(())
290    }
291
292    #[test]
293    fn test_array_scalar_full_result() {
294        // Test Array + Scalar and verify all results
295        let func = EndsWithFunc::new();
296        let array = Arc::new(StringArray::from(vec![
297            Some("alphabet"),
298            Some("alphabet"),
299            Some("beta"),
300            None,
301        ]));
302        let args = vec![
303            ColumnarValue::Array(array),
304            ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))),
305        ];
306
307        let result = func
308            .invoke_with_args(ScalarFunctionArgs {
309                args,
310                arg_fields: vec![
311                    Field::new("a", DataType::Utf8, true).into(),
312                    Field::new("b", DataType::Utf8, true).into(),
313                ],
314                number_rows: 4,
315                return_field: Field::new("f", Boolean, true).into(),
316                config_options: Arc::new(ConfigOptions::default()),
317            })
318            .unwrap();
319
320        let result_array = result.into_array(4).unwrap();
321        let bool_array = result_array
322            .as_any()
323            .downcast_ref::<BooleanArray>()
324            .unwrap();
325
326        assert!(bool_array.value(0)); // "alphabet" ends with "bet"
327        assert!(bool_array.value(1)); // "alphabet" ends with "bet"
328        assert!(!bool_array.value(2)); // "beta" does not end with "bet"
329        assert!(bool_array.is_null(3)); // null input -> null output
330    }
331
332    #[test]
333    fn test_scalar_array() {
334        // Test Scalar + Array
335        let func = EndsWithFunc::new();
336        let suffixes = Arc::new(StringArray::from(vec![
337            Some("bet"),
338            Some("alph"),
339            Some("phabet"),
340            None,
341        ]));
342        let args = vec![
343            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
344            ColumnarValue::Array(suffixes),
345        ];
346
347        let result = func
348            .invoke_with_args(ScalarFunctionArgs {
349                args,
350                arg_fields: vec![
351                    Field::new("a", DataType::Utf8, true).into(),
352                    Field::new("b", DataType::Utf8, true).into(),
353                ],
354                number_rows: 4,
355                return_field: Field::new("f", Boolean, true).into(),
356                config_options: Arc::new(ConfigOptions::default()),
357            })
358            .unwrap();
359
360        let result_array = result.into_array(4).unwrap();
361        let bool_array = result_array
362            .as_any()
363            .downcast_ref::<BooleanArray>()
364            .unwrap();
365
366        assert!(bool_array.value(0)); // "alphabet" ends with "bet"
367        assert!(!bool_array.value(1)); // "alphabet" does not end with "alph"
368        assert!(bool_array.value(2)); // "alphabet" ends with "phabet"
369        assert!(bool_array.is_null(3)); // null suffix -> null output
370    }
371
372    #[test]
373    fn test_array_array() {
374        // Test Array + Array
375        let func = EndsWithFunc::new();
376        let strings = Arc::new(StringArray::from(vec![
377            Some("alphabet"),
378            Some("rust"),
379            Some("datafusion"),
380            None,
381        ]));
382        let suffixes = Arc::new(StringArray::from(vec![
383            Some("bet"),
384            Some("st"),
385            Some("hello"),
386            Some("test"),
387        ]));
388        let args = vec![
389            ColumnarValue::Array(strings),
390            ColumnarValue::Array(suffixes),
391        ];
392
393        let result = func
394            .invoke_with_args(ScalarFunctionArgs {
395                args,
396                arg_fields: vec![
397                    Field::new("a", DataType::Utf8, true).into(),
398                    Field::new("b", DataType::Utf8, true).into(),
399                ],
400                number_rows: 4,
401                return_field: Field::new("f", Boolean, true).into(),
402                config_options: Arc::new(ConfigOptions::default()),
403            })
404            .unwrap();
405
406        let result_array = result.into_array(4).unwrap();
407        let bool_array = result_array
408            .as_any()
409            .downcast_ref::<BooleanArray>()
410            .unwrap();
411
412        assert!(bool_array.value(0)); // "alphabet" ends with "bet"
413        assert!(bool_array.value(1)); // "rust" ends with "st"
414        assert!(!bool_array.value(2)); // "datafusion" does not end with "hello"
415        assert!(bool_array.is_null(3)); // null string -> null output
416    }
417}