Skip to main content

datafusion_functions/string/
ends_with.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, Scalar};
21use arrow::compute::kernels::comparison::ends_with as arrow_ends_with;
22use arrow::datatypes::DataType;
23
24use datafusion_common::types::logical_string;
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
28use datafusion_expr::{
29    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    TypeSignatureClass, Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35    doc_section(label = "String Functions"),
36    description = "Tests if a string ends with a substring.",
37    syntax_example = "ends_with(str, substr)",
38    sql_example = r#"```sql
39>  select ends_with('datafusion', 'soin');
40+--------------------------------------------+
41| ends_with(Utf8("datafusion"),Utf8("soin")) |
42+--------------------------------------------+
43| false                                      |
44+--------------------------------------------+
45> select ends_with('datafusion', 'sion');
46+--------------------------------------------+
47| ends_with(Utf8("datafusion"),Utf8("sion")) |
48+--------------------------------------------+
49| true                                       |
50+--------------------------------------------+
51```"#,
52    standard_argument(name = "str", prefix = "String"),
53    argument(name = "substr", description = "Substring to test for.")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct EndsWithFunc {
57    signature: Signature,
58}
59
60impl Default for EndsWithFunc {
61    fn default() -> Self {
62        EndsWithFunc::new()
63    }
64}
65
66impl EndsWithFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::coercible(
70                vec![
71                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73                ],
74                Volatility::Immutable,
75            ),
76        }
77    }
78}
79
80impl ScalarUDFImpl for EndsWithFunc {
81    fn name(&self) -> &str {
82        "ends_with"
83    }
84
85    fn signature(&self) -> &Signature {
86        &self.signature
87    }
88
89    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
90        Ok(DataType::Boolean)
91    }
92
93    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94        let [str_arg, suffix_arg] = take_function_args(self.name(), &args.args)?;
95
96        // Determine the common type for coercion
97        let coercion_type = string_coercion(
98            &str_arg.data_type(),
99            &suffix_arg.data_type(),
100        )
101        .or_else(|| {
102            binary_to_string_coercion(&str_arg.data_type(), &suffix_arg.data_type())
103        });
104
105        let Some(coercion_type) = coercion_type else {
106            return exec_err!(
107                "Unsupported data types {:?}, {:?} for function `ends_with`.",
108                str_arg.data_type(),
109                suffix_arg.data_type()
110            );
111        };
112
113        // Helper to cast an array if needed
114        let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
115            if arr.data_type() == target {
116                Ok(Arc::clone(arr))
117            } else {
118                Ok(arrow::compute::kernels::cast::cast(arr, target)?)
119            }
120        };
121
122        match (str_arg, suffix_arg) {
123            // Both scalars - just compute directly
124            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(suffix_scalar)) => {
125                let str_arr = str_scalar.to_array_of_size(1)?;
126                let suffix_arr = suffix_scalar.to_array_of_size(1)?;
127                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
128                let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
129                let result = arrow_ends_with(&str_arr, &suffix_arr)?;
130                Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
131                    &result, 0,
132                )?))
133            }
134            // String is array, suffix is scalar - use Scalar wrapper for optimization
135            (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(suffix_scalar)) => {
136                let str_arr = maybe_cast(str_arr, &coercion_type)?;
137                let suffix_arr = suffix_scalar.to_array_of_size(1)?;
138                let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
139                let suffix_scalar = Scalar::new(suffix_arr);
140                let result = arrow_ends_with(&str_arr, &suffix_scalar)?;
141                Ok(ColumnarValue::Array(Arc::new(result)))
142            }
143            // String is scalar, suffix is array - use Scalar wrapper for string
144            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(suffix_arr)) => {
145                let str_arr = str_scalar.to_array_of_size(1)?;
146                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
147                let str_scalar = Scalar::new(str_arr);
148                let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
149                let result = arrow_ends_with(&str_scalar, &suffix_arr)?;
150                Ok(ColumnarValue::Array(Arc::new(result)))
151            }
152            // Both arrays - pass directly
153            (ColumnarValue::Array(str_arr), ColumnarValue::Array(suffix_arr)) => {
154                let str_arr = maybe_cast(str_arr, &coercion_type)?;
155                let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
156                let result = arrow_ends_with(&str_arr, &suffix_arr)?;
157                Ok(ColumnarValue::Array(Arc::new(result)))
158            }
159        }
160    }
161
162    fn documentation(&self) -> Option<&Documentation> {
163        self.doc()
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use arrow::array::{Array, BooleanArray, StringArray};
170    use arrow::datatypes::DataType::Boolean;
171    use arrow::datatypes::{DataType, Field};
172    use std::sync::Arc;
173
174    use datafusion_common::Result;
175    use datafusion_common::ScalarValue;
176    use datafusion_common::config::ConfigOptions;
177    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
178
179    use crate::string::ends_with::EndsWithFunc;
180    use crate::utils::test::test_function;
181
182    #[test]
183    fn test_scalar_scalar() -> Result<()> {
184        // Test Scalar + Scalar combinations
185        test_function!(
186            EndsWithFunc::new(),
187            vec![
188                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
189                ColumnarValue::Scalar(ScalarValue::from("alph")),
190            ],
191            Ok(Some(false)),
192            bool,
193            Boolean,
194            BooleanArray
195        );
196        test_function!(
197            EndsWithFunc::new(),
198            vec![
199                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
200                ColumnarValue::Scalar(ScalarValue::from("bet")),
201            ],
202            Ok(Some(true)),
203            bool,
204            Boolean,
205            BooleanArray
206        );
207        test_function!(
208            EndsWithFunc::new(),
209            vec![
210                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
211                ColumnarValue::Scalar(ScalarValue::from("alph")),
212            ],
213            Ok(None),
214            bool,
215            Boolean,
216            BooleanArray
217        );
218        test_function!(
219            EndsWithFunc::new(),
220            vec![
221                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
222                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
223            ],
224            Ok(None),
225            bool,
226            Boolean,
227            BooleanArray
228        );
229
230        // Test with LargeUtf8
231        test_function!(
232            EndsWithFunc::new(),
233            vec![
234                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(
235                    "alphabet".to_string()
236                ))),
237                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bet".to_string()))),
238            ],
239            Ok(Some(true)),
240            bool,
241            Boolean,
242            BooleanArray
243        );
244
245        // Test with Utf8View
246        test_function!(
247            EndsWithFunc::new(),
248            vec![
249                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
250                    "alphabet".to_string()
251                ))),
252                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bet".to_string()))),
253            ],
254            Ok(Some(true)),
255            bool,
256            Boolean,
257            BooleanArray
258        );
259
260        Ok(())
261    }
262
263    #[test]
264    fn test_array_scalar() -> Result<()> {
265        // Test Array + Scalar (the optimized path)
266        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
267            Some("alphabet"),
268            Some("alphabet"),
269            Some("beta"),
270            None,
271        ])));
272        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string())));
273
274        let args = vec![array, scalar];
275        test_function!(
276            EndsWithFunc::new(),
277            args,
278            Ok(Some(true)), // First element result: "alphabet" ends with "bet"
279            bool,
280            Boolean,
281            BooleanArray
282        );
283
284        Ok(())
285    }
286
287    #[test]
288    fn test_array_scalar_full_result() {
289        // Test Array + Scalar and verify all results
290        let func = EndsWithFunc::new();
291        let array = Arc::new(StringArray::from(vec![
292            Some("alphabet"),
293            Some("alphabet"),
294            Some("beta"),
295            None,
296        ]));
297        let args = vec![
298            ColumnarValue::Array(array),
299            ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))),
300        ];
301
302        let result = func
303            .invoke_with_args(ScalarFunctionArgs {
304                args,
305                arg_fields: vec![
306                    Field::new("a", DataType::Utf8, true).into(),
307                    Field::new("b", DataType::Utf8, true).into(),
308                ],
309                number_rows: 4,
310                return_field: Field::new("f", Boolean, true).into(),
311                config_options: Arc::new(ConfigOptions::default()),
312            })
313            .unwrap();
314
315        let result_array = result.into_array(4).unwrap();
316        let bool_array = result_array
317            .as_any()
318            .downcast_ref::<BooleanArray>()
319            .unwrap();
320
321        assert!(bool_array.value(0)); // "alphabet" ends with "bet"
322        assert!(bool_array.value(1)); // "alphabet" ends with "bet"
323        assert!(!bool_array.value(2)); // "beta" does not end with "bet"
324        assert!(bool_array.is_null(3)); // null input -> null output
325    }
326
327    #[test]
328    fn test_scalar_array() {
329        // Test Scalar + Array
330        let func = EndsWithFunc::new();
331        let suffixes = Arc::new(StringArray::from(vec![
332            Some("bet"),
333            Some("alph"),
334            Some("phabet"),
335            None,
336        ]));
337        let args = vec![
338            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
339            ColumnarValue::Array(suffixes),
340        ];
341
342        let result = func
343            .invoke_with_args(ScalarFunctionArgs {
344                args,
345                arg_fields: vec![
346                    Field::new("a", DataType::Utf8, true).into(),
347                    Field::new("b", DataType::Utf8, true).into(),
348                ],
349                number_rows: 4,
350                return_field: Field::new("f", Boolean, true).into(),
351                config_options: Arc::new(ConfigOptions::default()),
352            })
353            .unwrap();
354
355        let result_array = result.into_array(4).unwrap();
356        let bool_array = result_array
357            .as_any()
358            .downcast_ref::<BooleanArray>()
359            .unwrap();
360
361        assert!(bool_array.value(0)); // "alphabet" ends with "bet"
362        assert!(!bool_array.value(1)); // "alphabet" does not end with "alph"
363        assert!(bool_array.value(2)); // "alphabet" ends with "phabet"
364        assert!(bool_array.is_null(3)); // null suffix -> null output
365    }
366
367    #[test]
368    fn test_array_array() {
369        // Test Array + Array
370        let func = EndsWithFunc::new();
371        let strings = Arc::new(StringArray::from(vec![
372            Some("alphabet"),
373            Some("rust"),
374            Some("datafusion"),
375            None,
376        ]));
377        let suffixes = Arc::new(StringArray::from(vec![
378            Some("bet"),
379            Some("st"),
380            Some("hello"),
381            Some("test"),
382        ]));
383        let args = vec![
384            ColumnarValue::Array(strings),
385            ColumnarValue::Array(suffixes),
386        ];
387
388        let result = func
389            .invoke_with_args(ScalarFunctionArgs {
390                args,
391                arg_fields: vec![
392                    Field::new("a", DataType::Utf8, true).into(),
393                    Field::new("b", DataType::Utf8, true).into(),
394                ],
395                number_rows: 4,
396                return_field: Field::new("f", Boolean, true).into(),
397                config_options: Arc::new(ConfigOptions::default()),
398            })
399            .unwrap();
400
401        let result_array = result.into_array(4).unwrap();
402        let bool_array = result_array
403            .as_any()
404            .downcast_ref::<BooleanArray>()
405            .unwrap();
406
407        assert!(bool_array.value(0)); // "alphabet" ends with "bet"
408        assert!(bool_array.value(1)); // "rust" ends with "st"
409        assert!(!bool_array.value(2)); // "datafusion" does not end with "hello"
410        assert!(bool_array.is_null(3)); // null string -> null output
411    }
412}