1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Scalar};
22use arrow::compute::kernels::comparison::starts_with as arrow_starts_with;
23use arrow::datatypes::DataType;
24use datafusion_common::types::logical_string;
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
28use datafusion_expr::type_coercion::binary::{
29 binary_to_string_coercion, string_coercion,
30};
31use datafusion_expr::{
32 Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
33 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast,
34};
35use datafusion_macros::user_doc;
36
37#[user_doc(
38 doc_section(label = "String Functions"),
39 description = "Tests if a string starts with a substring.",
40 syntax_example = "starts_with(str, substr)",
41 sql_example = r#"```sql
42> select starts_with('datafusion','data');
43+----------------------------------------------+
44| starts_with(Utf8("datafusion"),Utf8("data")) |
45+----------------------------------------------+
46| true |
47+----------------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(name = "substr", description = "Substring to test for.")
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct StartsWithFunc {
54 signature: Signature,
55}
56
57impl Default for StartsWithFunc {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl StartsWithFunc {
64 pub fn new() -> Self {
65 Self {
66 signature: Signature::coercible(
67 vec![
68 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
69 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70 ],
71 Volatility::Immutable,
72 ),
73 }
74 }
75}
76
77impl ScalarUDFImpl for StartsWithFunc {
78 fn as_any(&self) -> &dyn Any {
79 self
80 }
81
82 fn name(&self) -> &str {
83 "starts_with"
84 }
85
86 fn signature(&self) -> &Signature {
87 &self.signature
88 }
89
90 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
91 Ok(DataType::Boolean)
92 }
93
94 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
95 let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?;
96
97 let coercion_type = string_coercion(
99 &str_arg.data_type(),
100 &prefix_arg.data_type(),
101 )
102 .or_else(|| {
103 binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type())
104 });
105
106 let Some(coercion_type) = coercion_type else {
107 return exec_err!(
108 "Unsupported data types {:?}, {:?} for function `starts_with`.",
109 str_arg.data_type(),
110 prefix_arg.data_type()
111 );
112 };
113
114 let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
116 if arr.data_type() == target {
117 Ok(Arc::clone(arr))
118 } else {
119 Ok(arrow::compute::kernels::cast::cast(arr, target)?)
120 }
121 };
122
123 match (str_arg, prefix_arg) {
124 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => {
126 let str_arr = str_scalar.to_array_of_size(1)?;
127 let prefix_arr = prefix_scalar.to_array_of_size(1)?;
128 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
129 let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
130 let result = arrow_starts_with(&str_arr, &prefix_arr)?;
131 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
132 &result, 0,
133 )?))
134 }
135 (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => {
137 let str_arr = maybe_cast(str_arr, &coercion_type)?;
138 let prefix_arr = prefix_scalar.to_array_of_size(1)?;
139 let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
140 let prefix_scalar = Scalar::new(prefix_arr);
141 let result = arrow_starts_with(&str_arr, &prefix_scalar)?;
142 Ok(ColumnarValue::Array(Arc::new(result)))
143 }
144 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => {
146 let str_arr = str_scalar.to_array_of_size(1)?;
147 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
148 let str_scalar = Scalar::new(str_arr);
149 let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
150 let result = arrow_starts_with(&str_scalar, &prefix_arr)?;
151 Ok(ColumnarValue::Array(Arc::new(result)))
152 }
153 (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => {
155 let str_arr = maybe_cast(str_arr, &coercion_type)?;
156 let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
157 let result = arrow_starts_with(&str_arr, &prefix_arr)?;
158 Ok(ColumnarValue::Array(Arc::new(result)))
159 }
160 }
161 }
162
163 fn simplify(
164 &self,
165 args: Vec<Expr>,
166 info: &SimplifyContext,
167 ) -> Result<ExprSimplifyResult> {
168 if let Expr::Literal(scalar_value, _) = &args[1] {
169 let like_expr = match scalar_value {
175 ScalarValue::Utf8(Some(pattern))
176 | ScalarValue::LargeUtf8(Some(pattern))
177 | ScalarValue::Utf8View(Some(pattern)) => {
178 let escaped_pattern = pattern
179 .replace("\\", "\\\\")
180 .replace("%", "\\%")
181 .replace("_", "\\_");
182 let like_pattern = format!("{escaped_pattern}%");
183 Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None)
184 }
185 _ => return Ok(ExprSimplifyResult::Original(args)),
186 };
187
188 let expr_data_type = info.get_data_type(&args[0])?;
189 let pattern_data_type = info.get_data_type(&like_expr)?;
190
191 if let Some(coercion_data_type) =
192 string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
193 binary_to_string_coercion(&expr_data_type, &pattern_data_type)
194 })
195 {
196 let expr = if expr_data_type == coercion_data_type {
197 args[0].clone()
198 } else {
199 cast(args[0].clone(), coercion_data_type.clone())
200 };
201
202 let pattern = if pattern_data_type == coercion_data_type {
203 like_expr
204 } else {
205 cast(like_expr, coercion_data_type)
206 };
207
208 return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
209 negated: false,
210 expr: Box::new(expr),
211 pattern: Box::new(pattern),
212 escape_char: None,
213 case_insensitive: false,
214 })));
215 }
216 }
217
218 Ok(ExprSimplifyResult::Original(args))
219 }
220
221 fn documentation(&self) -> Option<&Documentation> {
222 self.doc()
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use crate::utils::test::test_function;
229 use arrow::array::{Array, BooleanArray, StringArray};
230 use arrow::datatypes::DataType::Boolean;
231 use arrow::datatypes::{DataType, Field};
232 use datafusion_common::config::ConfigOptions;
233 use datafusion_common::{Result, ScalarValue};
234 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
235 use std::sync::Arc;
236
237 use super::*;
238
239 #[test]
240 fn test_scalar_scalar() -> Result<()> {
241 let test_cases = vec![
243 (Some("alphabet"), Some("alph"), Some(true)),
244 (Some("alphabet"), Some("bet"), Some(false)),
245 (
246 Some("somewhat large string"),
247 Some("somewhat large"),
248 Some(true),
249 ),
250 (Some("somewhat large string"), Some("large"), Some(false)),
251 ]
252 .into_iter()
253 .flat_map(|(a, b, c)| {
254 let utf_8_args = vec![
255 ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
256 ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
257 ];
258
259 let large_utf_8_args = vec![
260 ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
261 ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
262 ];
263
264 let utf_8_view_args = vec![
265 ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
266 ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
267 ];
268
269 vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
270 });
271
272 for (args, expected) in test_cases {
273 test_function!(
274 StartsWithFunc::new(),
275 args,
276 Ok(expected),
277 bool,
278 Boolean,
279 BooleanArray
280 );
281 }
282
283 Ok(())
284 }
285
286 #[test]
287 fn test_array_scalar() -> Result<()> {
288 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
290 Some("alphabet"),
291 Some("alphabet"),
292 Some("beta"),
293 None,
294 ])));
295 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string())));
296
297 let args = vec![array, scalar];
298 test_function!(
299 StartsWithFunc::new(),
300 args,
301 Ok(Some(true)), bool,
303 Boolean,
304 BooleanArray
305 );
306
307 Ok(())
308 }
309
310 #[test]
311 fn test_array_scalar_full_result() {
312 let func = StartsWithFunc::new();
314 let array = Arc::new(StringArray::from(vec![
315 Some("alphabet"),
316 Some("alphabet"),
317 Some("beta"),
318 None,
319 ]));
320 let args = vec![
321 ColumnarValue::Array(array),
322 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))),
323 ];
324
325 let result = func
326 .invoke_with_args(ScalarFunctionArgs {
327 args,
328 arg_fields: vec![
329 Field::new("a", DataType::Utf8, true).into(),
330 Field::new("b", DataType::Utf8, true).into(),
331 ],
332 number_rows: 4,
333 return_field: Field::new("f", Boolean, true).into(),
334 config_options: Arc::new(ConfigOptions::default()),
335 })
336 .unwrap();
337
338 let result_array = result.into_array(4).unwrap();
339 let bool_array = result_array
340 .as_any()
341 .downcast_ref::<BooleanArray>()
342 .unwrap();
343
344 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
349
350 #[test]
351 fn test_scalar_array() {
352 let func = StartsWithFunc::new();
354 let prefixes = Arc::new(StringArray::from(vec![
355 Some("alph"),
356 Some("bet"),
357 Some("alpha"),
358 None,
359 ]));
360 let args = vec![
361 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
362 ColumnarValue::Array(prefixes),
363 ];
364
365 let result = func
366 .invoke_with_args(ScalarFunctionArgs {
367 args,
368 arg_fields: vec![
369 Field::new("a", DataType::Utf8, true).into(),
370 Field::new("b", DataType::Utf8, true).into(),
371 ],
372 number_rows: 4,
373 return_field: Field::new("f", Boolean, true).into(),
374 config_options: Arc::new(ConfigOptions::default()),
375 })
376 .unwrap();
377
378 let result_array = result.into_array(4).unwrap();
379 let bool_array = result_array
380 .as_any()
381 .downcast_ref::<BooleanArray>()
382 .unwrap();
383
384 assert!(bool_array.value(0)); assert!(!bool_array.value(1)); assert!(bool_array.value(2)); assert!(bool_array.is_null(3)); }
389
390 #[test]
391 fn test_array_array() {
392 let func = StartsWithFunc::new();
394 let strings = Arc::new(StringArray::from(vec![
395 Some("alphabet"),
396 Some("rust"),
397 Some("datafusion"),
398 None,
399 ]));
400 let prefixes = Arc::new(StringArray::from(vec![
401 Some("alph"),
402 Some("ru"),
403 Some("hello"),
404 Some("test"),
405 ]));
406 let args = vec![
407 ColumnarValue::Array(strings),
408 ColumnarValue::Array(prefixes),
409 ];
410
411 let result = func
412 .invoke_with_args(ScalarFunctionArgs {
413 args,
414 arg_fields: vec![
415 Field::new("a", DataType::Utf8, true).into(),
416 Field::new("b", DataType::Utf8, true).into(),
417 ],
418 number_rows: 4,
419 return_field: Field::new("f", Boolean, true).into(),
420 config_options: Arc::new(ConfigOptions::default()),
421 })
422 .unwrap();
423
424 let result_array = result.into_array(4).unwrap();
425 let bool_array = result_array
426 .as_any()
427 .downcast_ref::<BooleanArray>()
428 .unwrap();
429
430 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
435}