1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Scalar};
22use arrow::compute::kernels::comparison::ends_with as arrow_ends_with;
23use arrow::datatypes::DataType;
24
25use datafusion_common::types::logical_string;
26use datafusion_common::utils::take_function_args;
27use datafusion_common::{Result, ScalarValue, exec_err};
28use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
29use datafusion_expr::{
30 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31 TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36 doc_section(label = "String Functions"),
37 description = "Tests if a string ends with a substring.",
38 syntax_example = "ends_with(str, substr)",
39 sql_example = r#"```sql
40> select ends_with('datafusion', 'soin');
41+--------------------------------------------+
42| ends_with(Utf8("datafusion"),Utf8("soin")) |
43+--------------------------------------------+
44| false |
45+--------------------------------------------+
46> select ends_with('datafusion', 'sion');
47+--------------------------------------------+
48| ends_with(Utf8("datafusion"),Utf8("sion")) |
49+--------------------------------------------+
50| true |
51+--------------------------------------------+
52```"#,
53 standard_argument(name = "str", prefix = "String"),
54 argument(name = "substr", description = "Substring to test for.")
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct EndsWithFunc {
58 signature: Signature,
59}
60
61impl Default for EndsWithFunc {
62 fn default() -> Self {
63 EndsWithFunc::new()
64 }
65}
66
67impl EndsWithFunc {
68 pub fn new() -> Self {
69 Self {
70 signature: Signature::coercible(
71 vec![
72 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74 ],
75 Volatility::Immutable,
76 ),
77 }
78 }
79}
80
81impl ScalarUDFImpl for EndsWithFunc {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn name(&self) -> &str {
87 "ends_with"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
95 Ok(DataType::Boolean)
96 }
97
98 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99 let [str_arg, suffix_arg] = take_function_args(self.name(), &args.args)?;
100
101 let coercion_type = string_coercion(
103 &str_arg.data_type(),
104 &suffix_arg.data_type(),
105 )
106 .or_else(|| {
107 binary_to_string_coercion(&str_arg.data_type(), &suffix_arg.data_type())
108 });
109
110 let Some(coercion_type) = coercion_type else {
111 return exec_err!(
112 "Unsupported data types {:?}, {:?} for function `ends_with`.",
113 str_arg.data_type(),
114 suffix_arg.data_type()
115 );
116 };
117
118 let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
120 if arr.data_type() == target {
121 Ok(Arc::clone(arr))
122 } else {
123 Ok(arrow::compute::kernels::cast::cast(arr, target)?)
124 }
125 };
126
127 match (str_arg, suffix_arg) {
128 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(suffix_scalar)) => {
130 let str_arr = str_scalar.to_array_of_size(1)?;
131 let suffix_arr = suffix_scalar.to_array_of_size(1)?;
132 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
133 let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
134 let result = arrow_ends_with(&str_arr, &suffix_arr)?;
135 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
136 &result, 0,
137 )?))
138 }
139 (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(suffix_scalar)) => {
141 let str_arr = maybe_cast(str_arr, &coercion_type)?;
142 let suffix_arr = suffix_scalar.to_array_of_size(1)?;
143 let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
144 let suffix_scalar = Scalar::new(suffix_arr);
145 let result = arrow_ends_with(&str_arr, &suffix_scalar)?;
146 Ok(ColumnarValue::Array(Arc::new(result)))
147 }
148 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(suffix_arr)) => {
150 let str_arr = str_scalar.to_array_of_size(1)?;
151 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
152 let str_scalar = Scalar::new(str_arr);
153 let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
154 let result = arrow_ends_with(&str_scalar, &suffix_arr)?;
155 Ok(ColumnarValue::Array(Arc::new(result)))
156 }
157 (ColumnarValue::Array(str_arr), ColumnarValue::Array(suffix_arr)) => {
159 let str_arr = maybe_cast(str_arr, &coercion_type)?;
160 let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
161 let result = arrow_ends_with(&str_arr, &suffix_arr)?;
162 Ok(ColumnarValue::Array(Arc::new(result)))
163 }
164 }
165 }
166
167 fn documentation(&self) -> Option<&Documentation> {
168 self.doc()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use arrow::array::{Array, BooleanArray, StringArray};
175 use arrow::datatypes::DataType::Boolean;
176 use arrow::datatypes::{DataType, Field};
177 use std::sync::Arc;
178
179 use datafusion_common::Result;
180 use datafusion_common::ScalarValue;
181 use datafusion_common::config::ConfigOptions;
182 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
183
184 use crate::string::ends_with::EndsWithFunc;
185 use crate::utils::test::test_function;
186
187 #[test]
188 fn test_scalar_scalar() -> Result<()> {
189 test_function!(
191 EndsWithFunc::new(),
192 vec![
193 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
194 ColumnarValue::Scalar(ScalarValue::from("alph")),
195 ],
196 Ok(Some(false)),
197 bool,
198 Boolean,
199 BooleanArray
200 );
201 test_function!(
202 EndsWithFunc::new(),
203 vec![
204 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
205 ColumnarValue::Scalar(ScalarValue::from("bet")),
206 ],
207 Ok(Some(true)),
208 bool,
209 Boolean,
210 BooleanArray
211 );
212 test_function!(
213 EndsWithFunc::new(),
214 vec![
215 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
216 ColumnarValue::Scalar(ScalarValue::from("alph")),
217 ],
218 Ok(None),
219 bool,
220 Boolean,
221 BooleanArray
222 );
223 test_function!(
224 EndsWithFunc::new(),
225 vec![
226 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
227 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
228 ],
229 Ok(None),
230 bool,
231 Boolean,
232 BooleanArray
233 );
234
235 test_function!(
237 EndsWithFunc::new(),
238 vec![
239 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(
240 "alphabet".to_string()
241 ))),
242 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bet".to_string()))),
243 ],
244 Ok(Some(true)),
245 bool,
246 Boolean,
247 BooleanArray
248 );
249
250 test_function!(
252 EndsWithFunc::new(),
253 vec![
254 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
255 "alphabet".to_string()
256 ))),
257 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bet".to_string()))),
258 ],
259 Ok(Some(true)),
260 bool,
261 Boolean,
262 BooleanArray
263 );
264
265 Ok(())
266 }
267
268 #[test]
269 fn test_array_scalar() -> Result<()> {
270 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
272 Some("alphabet"),
273 Some("alphabet"),
274 Some("beta"),
275 None,
276 ])));
277 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string())));
278
279 let args = vec![array, scalar];
280 test_function!(
281 EndsWithFunc::new(),
282 args,
283 Ok(Some(true)), bool,
285 Boolean,
286 BooleanArray
287 );
288
289 Ok(())
290 }
291
292 #[test]
293 fn test_array_scalar_full_result() {
294 let func = EndsWithFunc::new();
296 let array = Arc::new(StringArray::from(vec![
297 Some("alphabet"),
298 Some("alphabet"),
299 Some("beta"),
300 None,
301 ]));
302 let args = vec![
303 ColumnarValue::Array(array),
304 ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))),
305 ];
306
307 let result = func
308 .invoke_with_args(ScalarFunctionArgs {
309 args,
310 arg_fields: vec![
311 Field::new("a", DataType::Utf8, true).into(),
312 Field::new("b", DataType::Utf8, true).into(),
313 ],
314 number_rows: 4,
315 return_field: Field::new("f", Boolean, true).into(),
316 config_options: Arc::new(ConfigOptions::default()),
317 })
318 .unwrap();
319
320 let result_array = result.into_array(4).unwrap();
321 let bool_array = result_array
322 .as_any()
323 .downcast_ref::<BooleanArray>()
324 .unwrap();
325
326 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
331
332 #[test]
333 fn test_scalar_array() {
334 let func = EndsWithFunc::new();
336 let suffixes = Arc::new(StringArray::from(vec![
337 Some("bet"),
338 Some("alph"),
339 Some("phabet"),
340 None,
341 ]));
342 let args = vec![
343 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
344 ColumnarValue::Array(suffixes),
345 ];
346
347 let result = func
348 .invoke_with_args(ScalarFunctionArgs {
349 args,
350 arg_fields: vec![
351 Field::new("a", DataType::Utf8, true).into(),
352 Field::new("b", DataType::Utf8, true).into(),
353 ],
354 number_rows: 4,
355 return_field: Field::new("f", Boolean, true).into(),
356 config_options: Arc::new(ConfigOptions::default()),
357 })
358 .unwrap();
359
360 let result_array = result.into_array(4).unwrap();
361 let bool_array = result_array
362 .as_any()
363 .downcast_ref::<BooleanArray>()
364 .unwrap();
365
366 assert!(bool_array.value(0)); assert!(!bool_array.value(1)); assert!(bool_array.value(2)); assert!(bool_array.is_null(3)); }
371
372 #[test]
373 fn test_array_array() {
374 let func = EndsWithFunc::new();
376 let strings = Arc::new(StringArray::from(vec![
377 Some("alphabet"),
378 Some("rust"),
379 Some("datafusion"),
380 None,
381 ]));
382 let suffixes = Arc::new(StringArray::from(vec![
383 Some("bet"),
384 Some("st"),
385 Some("hello"),
386 Some("test"),
387 ]));
388 let args = vec![
389 ColumnarValue::Array(strings),
390 ColumnarValue::Array(suffixes),
391 ];
392
393 let result = func
394 .invoke_with_args(ScalarFunctionArgs {
395 args,
396 arg_fields: vec![
397 Field::new("a", DataType::Utf8, true).into(),
398 Field::new("b", DataType::Utf8, true).into(),
399 ],
400 number_rows: 4,
401 return_field: Field::new("f", Boolean, true).into(),
402 config_options: Arc::new(ConfigOptions::default()),
403 })
404 .unwrap();
405
406 let result_array = result.into_array(4).unwrap();
407 let bool_array = result_array
408 .as_any()
409 .downcast_ref::<BooleanArray>()
410 .unwrap();
411
412 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
417}