1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, Scalar};
21use arrow::compute::kernels::comparison::ends_with as arrow_ends_with;
22use arrow::datatypes::DataType;
23
24use datafusion_common::types::logical_string;
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
28use datafusion_expr::{
29 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30 TypeSignatureClass, Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35 doc_section(label = "String Functions"),
36 description = "Tests if a string ends with a substring.",
37 syntax_example = "ends_with(str, substr)",
38 sql_example = r#"```sql
39> select ends_with('datafusion', 'soin');
40+--------------------------------------------+
41| ends_with(Utf8("datafusion"),Utf8("soin")) |
42+--------------------------------------------+
43| false |
44+--------------------------------------------+
45> select ends_with('datafusion', 'sion');
46+--------------------------------------------+
47| ends_with(Utf8("datafusion"),Utf8("sion")) |
48+--------------------------------------------+
49| true |
50+--------------------------------------------+
51```"#,
52 standard_argument(name = "str", prefix = "String"),
53 argument(name = "substr", description = "Substring to test for.")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct EndsWithFunc {
57 signature: Signature,
58}
59
60impl Default for EndsWithFunc {
61 fn default() -> Self {
62 EndsWithFunc::new()
63 }
64}
65
66impl EndsWithFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::coercible(
70 vec![
71 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73 ],
74 Volatility::Immutable,
75 ),
76 }
77 }
78}
79
80impl ScalarUDFImpl for EndsWithFunc {
81 fn name(&self) -> &str {
82 "ends_with"
83 }
84
85 fn signature(&self) -> &Signature {
86 &self.signature
87 }
88
89 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
90 Ok(DataType::Boolean)
91 }
92
93 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94 let [str_arg, suffix_arg] = take_function_args(self.name(), &args.args)?;
95
96 let coercion_type = string_coercion(
98 &str_arg.data_type(),
99 &suffix_arg.data_type(),
100 )
101 .or_else(|| {
102 binary_to_string_coercion(&str_arg.data_type(), &suffix_arg.data_type())
103 });
104
105 let Some(coercion_type) = coercion_type else {
106 return exec_err!(
107 "Unsupported data types {:?}, {:?} for function `ends_with`.",
108 str_arg.data_type(),
109 suffix_arg.data_type()
110 );
111 };
112
113 let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
115 if arr.data_type() == target {
116 Ok(Arc::clone(arr))
117 } else {
118 Ok(arrow::compute::kernels::cast::cast(arr, target)?)
119 }
120 };
121
122 match (str_arg, suffix_arg) {
123 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(suffix_scalar)) => {
125 let str_arr = str_scalar.to_array_of_size(1)?;
126 let suffix_arr = suffix_scalar.to_array_of_size(1)?;
127 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
128 let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
129 let result = arrow_ends_with(&str_arr, &suffix_arr)?;
130 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
131 &result, 0,
132 )?))
133 }
134 (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(suffix_scalar)) => {
136 let str_arr = maybe_cast(str_arr, &coercion_type)?;
137 let suffix_arr = suffix_scalar.to_array_of_size(1)?;
138 let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
139 let suffix_scalar = Scalar::new(suffix_arr);
140 let result = arrow_ends_with(&str_arr, &suffix_scalar)?;
141 Ok(ColumnarValue::Array(Arc::new(result)))
142 }
143 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(suffix_arr)) => {
145 let str_arr = str_scalar.to_array_of_size(1)?;
146 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
147 let str_scalar = Scalar::new(str_arr);
148 let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
149 let result = arrow_ends_with(&str_scalar, &suffix_arr)?;
150 Ok(ColumnarValue::Array(Arc::new(result)))
151 }
152 (ColumnarValue::Array(str_arr), ColumnarValue::Array(suffix_arr)) => {
154 let str_arr = maybe_cast(str_arr, &coercion_type)?;
155 let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
156 let result = arrow_ends_with(&str_arr, &suffix_arr)?;
157 Ok(ColumnarValue::Array(Arc::new(result)))
158 }
159 }
160 }
161
162 fn documentation(&self) -> Option<&Documentation> {
163 self.doc()
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use arrow::array::{Array, BooleanArray, StringArray};
170 use arrow::datatypes::DataType::Boolean;
171 use arrow::datatypes::{DataType, Field};
172 use std::sync::Arc;
173
174 use datafusion_common::Result;
175 use datafusion_common::ScalarValue;
176 use datafusion_common::config::ConfigOptions;
177 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
178
179 use crate::string::ends_with::EndsWithFunc;
180 use crate::utils::test::test_function;
181
182 #[test]
183 fn test_scalar_scalar() -> Result<()> {
184 test_function!(
186 EndsWithFunc::new(),
187 vec![
188 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
189 ColumnarValue::Scalar(ScalarValue::from("alph")),
190 ],
191 Ok(Some(false)),
192 bool,
193 Boolean,
194 BooleanArray
195 );
196 test_function!(
197 EndsWithFunc::new(),
198 vec![
199 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
200 ColumnarValue::Scalar(ScalarValue::from("bet")),
201 ],
202 Ok(Some(true)),
203 bool,
204 Boolean,
205 BooleanArray
206 );
207 test_function!(
208 EndsWithFunc::new(),
209 vec![
210 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
211 ColumnarValue::Scalar(ScalarValue::from("alph")),
212 ],
213 Ok(None),
214 bool,
215 Boolean,
216 BooleanArray
217 );
218 test_function!(
219 EndsWithFunc::new(),
220 vec![
221 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
222 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
223 ],
224 Ok(None),
225 bool,
226 Boolean,
227 BooleanArray
228 );
229
230 test_function!(
232 EndsWithFunc::new(),
233 vec![
234 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(
235 "alphabet".to_string()
236 ))),
237 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bet".to_string()))),
238 ],
239 Ok(Some(true)),
240 bool,
241 Boolean,
242 BooleanArray
243 );
244
245 test_function!(
247 EndsWithFunc::new(),
248 vec![
249 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
250 "alphabet".to_string()
251 ))),
252 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bet".to_string()))),
253 ],
254 Ok(Some(true)),
255 bool,
256 Boolean,
257 BooleanArray
258 );
259
260 Ok(())
261 }
262
263 #[test]
264 fn test_array_scalar() -> Result<()> {
265 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
267 Some("alphabet"),
268 Some("alphabet"),
269 Some("beta"),
270 None,
271 ])));
272 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string())));
273
274 let args = vec![array, scalar];
275 test_function!(
276 EndsWithFunc::new(),
277 args,
278 Ok(Some(true)), bool,
280 Boolean,
281 BooleanArray
282 );
283
284 Ok(())
285 }
286
287 #[test]
288 fn test_array_scalar_full_result() {
289 let func = EndsWithFunc::new();
291 let array = Arc::new(StringArray::from(vec![
292 Some("alphabet"),
293 Some("alphabet"),
294 Some("beta"),
295 None,
296 ]));
297 let args = vec![
298 ColumnarValue::Array(array),
299 ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))),
300 ];
301
302 let result = func
303 .invoke_with_args(ScalarFunctionArgs {
304 args,
305 arg_fields: vec![
306 Field::new("a", DataType::Utf8, true).into(),
307 Field::new("b", DataType::Utf8, true).into(),
308 ],
309 number_rows: 4,
310 return_field: Field::new("f", Boolean, true).into(),
311 config_options: Arc::new(ConfigOptions::default()),
312 })
313 .unwrap();
314
315 let result_array = result.into_array(4).unwrap();
316 let bool_array = result_array
317 .as_any()
318 .downcast_ref::<BooleanArray>()
319 .unwrap();
320
321 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
326
327 #[test]
328 fn test_scalar_array() {
329 let func = EndsWithFunc::new();
331 let suffixes = Arc::new(StringArray::from(vec![
332 Some("bet"),
333 Some("alph"),
334 Some("phabet"),
335 None,
336 ]));
337 let args = vec![
338 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
339 ColumnarValue::Array(suffixes),
340 ];
341
342 let result = func
343 .invoke_with_args(ScalarFunctionArgs {
344 args,
345 arg_fields: vec![
346 Field::new("a", DataType::Utf8, true).into(),
347 Field::new("b", DataType::Utf8, true).into(),
348 ],
349 number_rows: 4,
350 return_field: Field::new("f", Boolean, true).into(),
351 config_options: Arc::new(ConfigOptions::default()),
352 })
353 .unwrap();
354
355 let result_array = result.into_array(4).unwrap();
356 let bool_array = result_array
357 .as_any()
358 .downcast_ref::<BooleanArray>()
359 .unwrap();
360
361 assert!(bool_array.value(0)); assert!(!bool_array.value(1)); assert!(bool_array.value(2)); assert!(bool_array.is_null(3)); }
366
367 #[test]
368 fn test_array_array() {
369 let func = EndsWithFunc::new();
371 let strings = Arc::new(StringArray::from(vec![
372 Some("alphabet"),
373 Some("rust"),
374 Some("datafusion"),
375 None,
376 ]));
377 let suffixes = Arc::new(StringArray::from(vec![
378 Some("bet"),
379 Some("st"),
380 Some("hello"),
381 Some("test"),
382 ]));
383 let args = vec![
384 ColumnarValue::Array(strings),
385 ColumnarValue::Array(suffixes),
386 ];
387
388 let result = func
389 .invoke_with_args(ScalarFunctionArgs {
390 args,
391 arg_fields: vec![
392 Field::new("a", DataType::Utf8, true).into(),
393 Field::new("b", DataType::Utf8, true).into(),
394 ],
395 number_rows: 4,
396 return_field: Field::new("f", Boolean, true).into(),
397 config_options: Arc::new(ConfigOptions::default()),
398 })
399 .unwrap();
400
401 let result_array = result.into_array(4).unwrap();
402 let bool_array = result_array
403 .as_any()
404 .downcast_ref::<BooleanArray>()
405 .unwrap();
406
407 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
412}