1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Scalar};
22use arrow::compute::kernels::comparison::starts_with as arrow_starts_with;
23use arrow::datatypes::DataType;
24use datafusion_common::utils::take_function_args;
25use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
26use datafusion_expr::type_coercion::binary::{
27 binary_to_string_coercion, string_coercion,
28};
29
30use datafusion_common::types::logical_string;
31use datafusion_common::{Result, ScalarValue, exec_err};
32use datafusion_expr::{
33 Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
34 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Tests if a string starts with a substring.",
41 syntax_example = "starts_with(str, substr)",
42 sql_example = r#"```sql
43> select starts_with('datafusion','data');
44+----------------------------------------------+
45| starts_with(Utf8("datafusion"),Utf8("data")) |
46+----------------------------------------------+
47| true |
48+----------------------------------------------+
49```"#,
50 standard_argument(name = "str", prefix = "String"),
51 argument(name = "substr", description = "Substring to test for.")
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct StartsWithFunc {
55 signature: Signature,
56}
57
58impl Default for StartsWithFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl StartsWithFunc {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::coercible(
68 vec![
69 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 ],
72 Volatility::Immutable,
73 ),
74 }
75 }
76}
77
78impl ScalarUDFImpl for StartsWithFunc {
79 fn as_any(&self) -> &dyn Any {
80 self
81 }
82
83 fn name(&self) -> &str {
84 "starts_with"
85 }
86
87 fn signature(&self) -> &Signature {
88 &self.signature
89 }
90
91 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
92 Ok(DataType::Boolean)
93 }
94
95 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96 let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?;
97
98 let coercion_type = string_coercion(
100 &str_arg.data_type(),
101 &prefix_arg.data_type(),
102 )
103 .or_else(|| {
104 binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type())
105 });
106
107 let Some(coercion_type) = coercion_type else {
108 return exec_err!(
109 "Unsupported data types {:?}, {:?} for function `starts_with`.",
110 str_arg.data_type(),
111 prefix_arg.data_type()
112 );
113 };
114
115 let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
117 if arr.data_type() == target {
118 Ok(Arc::clone(arr))
119 } else {
120 Ok(arrow::compute::kernels::cast::cast(arr, target)?)
121 }
122 };
123
124 match (str_arg, prefix_arg) {
125 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => {
127 let str_arr = str_scalar.to_array_of_size(1)?;
128 let prefix_arr = prefix_scalar.to_array_of_size(1)?;
129 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
130 let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
131 let result = arrow_starts_with(&str_arr, &prefix_arr)?;
132 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
133 &result, 0,
134 )?))
135 }
136 (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => {
138 let str_arr = maybe_cast(str_arr, &coercion_type)?;
139 let prefix_arr = prefix_scalar.to_array_of_size(1)?;
140 let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
141 let prefix_scalar = Scalar::new(prefix_arr);
142 let result = arrow_starts_with(&str_arr, &prefix_scalar)?;
143 Ok(ColumnarValue::Array(Arc::new(result)))
144 }
145 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => {
147 let str_arr = str_scalar.to_array_of_size(1)?;
148 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
149 let str_scalar = Scalar::new(str_arr);
150 let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
151 let result = arrow_starts_with(&str_scalar, &prefix_arr)?;
152 Ok(ColumnarValue::Array(Arc::new(result)))
153 }
154 (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => {
156 let str_arr = maybe_cast(str_arr, &coercion_type)?;
157 let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
158 let result = arrow_starts_with(&str_arr, &prefix_arr)?;
159 Ok(ColumnarValue::Array(Arc::new(result)))
160 }
161 }
162 }
163
164 fn simplify(
165 &self,
166 args: Vec<Expr>,
167 info: &dyn SimplifyInfo,
168 ) -> Result<ExprSimplifyResult> {
169 if let Expr::Literal(scalar_value, _) = &args[1] {
170 let like_expr = match scalar_value {
176 ScalarValue::Utf8(Some(pattern))
177 | ScalarValue::LargeUtf8(Some(pattern))
178 | ScalarValue::Utf8View(Some(pattern)) => {
179 let escaped_pattern = pattern
180 .replace("\\", "\\\\")
181 .replace("%", "\\%")
182 .replace("_", "\\_");
183 let like_pattern = format!("{escaped_pattern}%");
184 Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None)
185 }
186 _ => return Ok(ExprSimplifyResult::Original(args)),
187 };
188
189 let expr_data_type = info.get_data_type(&args[0])?;
190 let pattern_data_type = info.get_data_type(&like_expr)?;
191
192 if let Some(coercion_data_type) =
193 string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
194 binary_to_string_coercion(&expr_data_type, &pattern_data_type)
195 })
196 {
197 let expr = if expr_data_type == coercion_data_type {
198 args[0].clone()
199 } else {
200 cast(args[0].clone(), coercion_data_type.clone())
201 };
202
203 let pattern = if pattern_data_type == coercion_data_type {
204 like_expr
205 } else {
206 cast(like_expr, coercion_data_type)
207 };
208
209 return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
210 negated: false,
211 expr: Box::new(expr),
212 pattern: Box::new(pattern),
213 escape_char: None,
214 case_insensitive: false,
215 })));
216 }
217 }
218
219 Ok(ExprSimplifyResult::Original(args))
220 }
221
222 fn documentation(&self) -> Option<&Documentation> {
223 self.doc()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use crate::utils::test::test_function;
230 use arrow::array::{Array, BooleanArray, StringArray};
231 use arrow::datatypes::DataType::Boolean;
232 use arrow::datatypes::{DataType, Field};
233 use datafusion_common::config::ConfigOptions;
234 use datafusion_common::{Result, ScalarValue};
235 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
236 use std::sync::Arc;
237
238 use super::*;
239
240 #[test]
241 fn test_scalar_scalar() -> Result<()> {
242 let test_cases = vec![
244 (Some("alphabet"), Some("alph"), Some(true)),
245 (Some("alphabet"), Some("bet"), Some(false)),
246 (
247 Some("somewhat large string"),
248 Some("somewhat large"),
249 Some(true),
250 ),
251 (Some("somewhat large string"), Some("large"), Some(false)),
252 ]
253 .into_iter()
254 .flat_map(|(a, b, c)| {
255 let utf_8_args = vec![
256 ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
257 ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
258 ];
259
260 let large_utf_8_args = vec![
261 ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
262 ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
263 ];
264
265 let utf_8_view_args = vec![
266 ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
267 ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
268 ];
269
270 vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
271 });
272
273 for (args, expected) in test_cases {
274 test_function!(
275 StartsWithFunc::new(),
276 args,
277 Ok(expected),
278 bool,
279 Boolean,
280 BooleanArray
281 );
282 }
283
284 Ok(())
285 }
286
287 #[test]
288 fn test_array_scalar() -> Result<()> {
289 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
291 Some("alphabet"),
292 Some("alphabet"),
293 Some("beta"),
294 None,
295 ])));
296 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string())));
297
298 let args = vec![array, scalar];
299 test_function!(
300 StartsWithFunc::new(),
301 args,
302 Ok(Some(true)), bool,
304 Boolean,
305 BooleanArray
306 );
307
308 Ok(())
309 }
310
311 #[test]
312 fn test_array_scalar_full_result() {
313 let func = StartsWithFunc::new();
315 let array = Arc::new(StringArray::from(vec![
316 Some("alphabet"),
317 Some("alphabet"),
318 Some("beta"),
319 None,
320 ]));
321 let args = vec![
322 ColumnarValue::Array(array),
323 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))),
324 ];
325
326 let result = func
327 .invoke_with_args(ScalarFunctionArgs {
328 args,
329 arg_fields: vec![
330 Field::new("a", DataType::Utf8, true).into(),
331 Field::new("b", DataType::Utf8, true).into(),
332 ],
333 number_rows: 4,
334 return_field: Field::new("f", Boolean, true).into(),
335 config_options: Arc::new(ConfigOptions::default()),
336 })
337 .unwrap();
338
339 let result_array = result.into_array(4).unwrap();
340 let bool_array = result_array
341 .as_any()
342 .downcast_ref::<BooleanArray>()
343 .unwrap();
344
345 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
350
351 #[test]
352 fn test_scalar_array() {
353 let func = StartsWithFunc::new();
355 let prefixes = Arc::new(StringArray::from(vec![
356 Some("alph"),
357 Some("bet"),
358 Some("alpha"),
359 None,
360 ]));
361 let args = vec![
362 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
363 ColumnarValue::Array(prefixes),
364 ];
365
366 let result = func
367 .invoke_with_args(ScalarFunctionArgs {
368 args,
369 arg_fields: vec![
370 Field::new("a", DataType::Utf8, true).into(),
371 Field::new("b", DataType::Utf8, true).into(),
372 ],
373 number_rows: 4,
374 return_field: Field::new("f", Boolean, true).into(),
375 config_options: Arc::new(ConfigOptions::default()),
376 })
377 .unwrap();
378
379 let result_array = result.into_array(4).unwrap();
380 let bool_array = result_array
381 .as_any()
382 .downcast_ref::<BooleanArray>()
383 .unwrap();
384
385 assert!(bool_array.value(0)); assert!(!bool_array.value(1)); assert!(bool_array.value(2)); assert!(bool_array.is_null(3)); }
390
391 #[test]
392 fn test_array_array() {
393 let func = StartsWithFunc::new();
395 let strings = Arc::new(StringArray::from(vec![
396 Some("alphabet"),
397 Some("rust"),
398 Some("datafusion"),
399 None,
400 ]));
401 let prefixes = Arc::new(StringArray::from(vec![
402 Some("alph"),
403 Some("ru"),
404 Some("hello"),
405 Some("test"),
406 ]));
407 let args = vec![
408 ColumnarValue::Array(strings),
409 ColumnarValue::Array(prefixes),
410 ];
411
412 let result = func
413 .invoke_with_args(ScalarFunctionArgs {
414 args,
415 arg_fields: vec![
416 Field::new("a", DataType::Utf8, true).into(),
417 Field::new("b", DataType::Utf8, true).into(),
418 ],
419 number_rows: 4,
420 return_field: Field::new("f", Boolean, true).into(),
421 config_options: Arc::new(ConfigOptions::default()),
422 })
423 .unwrap();
424
425 let result_array = result.into_array(4).unwrap();
426 let bool_array = result_array
427 .as_any()
428 .downcast_ref::<BooleanArray>()
429 .unwrap();
430
431 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
436}