1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, Scalar};
21use arrow::compute::kernels::comparison::starts_with as arrow_starts_with;
22use arrow::datatypes::DataType;
23use datafusion_common::types::logical_string;
24use datafusion_common::utils::take_function_args;
25use datafusion_common::{Result, ScalarValue, exec_err};
26use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
27use datafusion_expr::type_coercion::binary::{
28 binary_to_string_coercion, string_coercion,
29};
30use datafusion_expr::{
31 Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
32 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Tests if a string starts with a substring.",
39 syntax_example = "starts_with(str, substr)",
40 sql_example = r#"```sql
41> select starts_with('datafusion','data');
42+----------------------------------------------+
43| starts_with(Utf8("datafusion"),Utf8("data")) |
44+----------------------------------------------+
45| true |
46+----------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(name = "substr", description = "Substring to test for.")
50)]
51#[derive(Debug, PartialEq, Eq, Hash)]
52pub struct StartsWithFunc {
53 signature: Signature,
54}
55
56impl Default for StartsWithFunc {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl StartsWithFunc {
63 pub fn new() -> Self {
64 Self {
65 signature: Signature::coercible(
66 vec![
67 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
68 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
69 ],
70 Volatility::Immutable,
71 ),
72 }
73 }
74}
75
76impl ScalarUDFImpl for StartsWithFunc {
77 fn name(&self) -> &str {
78 "starts_with"
79 }
80
81 fn signature(&self) -> &Signature {
82 &self.signature
83 }
84
85 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
86 Ok(DataType::Boolean)
87 }
88
89 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90 let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?;
91
92 let coercion_type = string_coercion(
94 &str_arg.data_type(),
95 &prefix_arg.data_type(),
96 )
97 .or_else(|| {
98 binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type())
99 });
100
101 let Some(coercion_type) = coercion_type else {
102 return exec_err!(
103 "Unsupported data types {:?}, {:?} for function `starts_with`.",
104 str_arg.data_type(),
105 prefix_arg.data_type()
106 );
107 };
108
109 let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
111 if arr.data_type() == target {
112 Ok(Arc::clone(arr))
113 } else {
114 Ok(arrow::compute::kernels::cast::cast(arr, target)?)
115 }
116 };
117
118 match (str_arg, prefix_arg) {
119 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => {
121 let str_arr = str_scalar.to_array_of_size(1)?;
122 let prefix_arr = prefix_scalar.to_array_of_size(1)?;
123 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
124 let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
125 let result = arrow_starts_with(&str_arr, &prefix_arr)?;
126 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
127 &result, 0,
128 )?))
129 }
130 (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => {
132 let str_arr = maybe_cast(str_arr, &coercion_type)?;
133 let prefix_arr = prefix_scalar.to_array_of_size(1)?;
134 let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
135 let prefix_scalar = Scalar::new(prefix_arr);
136 let result = arrow_starts_with(&str_arr, &prefix_scalar)?;
137 Ok(ColumnarValue::Array(Arc::new(result)))
138 }
139 (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => {
141 let str_arr = str_scalar.to_array_of_size(1)?;
142 let str_arr = maybe_cast(&str_arr, &coercion_type)?;
143 let str_scalar = Scalar::new(str_arr);
144 let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
145 let result = arrow_starts_with(&str_scalar, &prefix_arr)?;
146 Ok(ColumnarValue::Array(Arc::new(result)))
147 }
148 (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => {
150 let str_arr = maybe_cast(str_arr, &coercion_type)?;
151 let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
152 let result = arrow_starts_with(&str_arr, &prefix_arr)?;
153 Ok(ColumnarValue::Array(Arc::new(result)))
154 }
155 }
156 }
157
158 fn simplify(
159 &self,
160 args: Vec<Expr>,
161 info: &SimplifyContext,
162 ) -> Result<ExprSimplifyResult> {
163 if let Expr::Literal(scalar_value, _) = &args[1] {
164 let like_expr = match scalar_value {
170 ScalarValue::Utf8(Some(pattern))
171 | ScalarValue::LargeUtf8(Some(pattern))
172 | ScalarValue::Utf8View(Some(pattern)) => {
173 let escaped_pattern = pattern
174 .replace("\\", "\\\\")
175 .replace("%", "\\%")
176 .replace("_", "\\_");
177 let like_pattern = format!("{escaped_pattern}%");
178 Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None)
179 }
180 _ => return Ok(ExprSimplifyResult::Original(args)),
181 };
182
183 let expr_data_type = info.get_data_type(&args[0])?;
184 let pattern_data_type = info.get_data_type(&like_expr)?;
185
186 if let Some(coercion_data_type) =
187 string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
188 binary_to_string_coercion(&expr_data_type, &pattern_data_type)
189 })
190 {
191 let expr = if expr_data_type == coercion_data_type {
192 args[0].clone()
193 } else {
194 cast(args[0].clone(), coercion_data_type.clone())
195 };
196
197 let pattern = if pattern_data_type == coercion_data_type {
198 like_expr
199 } else {
200 cast(like_expr, coercion_data_type)
201 };
202
203 return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
204 negated: false,
205 expr: Box::new(expr),
206 pattern: Box::new(pattern),
207 escape_char: None,
208 case_insensitive: false,
209 })));
210 }
211 }
212
213 Ok(ExprSimplifyResult::Original(args))
214 }
215
216 fn documentation(&self) -> Option<&Documentation> {
217 self.doc()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use crate::utils::test::test_function;
224 use arrow::array::{Array, BooleanArray, StringArray};
225 use arrow::datatypes::DataType::Boolean;
226 use arrow::datatypes::Field;
227 use datafusion_common::config::ConfigOptions;
228
229 use super::*;
230
231 #[test]
232 fn test_scalar_scalar() -> Result<()> {
233 let test_cases = vec![
235 (Some("alphabet"), Some("alph"), Some(true)),
236 (Some("alphabet"), Some("bet"), Some(false)),
237 (
238 Some("somewhat large string"),
239 Some("somewhat large"),
240 Some(true),
241 ),
242 (Some("somewhat large string"), Some("large"), Some(false)),
243 ]
244 .into_iter()
245 .flat_map(|(a, b, c)| {
246 let utf_8_args = vec![
247 ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
248 ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
249 ];
250
251 let large_utf_8_args = vec![
252 ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
253 ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
254 ];
255
256 let utf_8_view_args = vec![
257 ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
258 ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
259 ];
260
261 vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
262 });
263
264 for (args, expected) in test_cases {
265 test_function!(
266 StartsWithFunc::new(),
267 args,
268 Ok(expected),
269 bool,
270 Boolean,
271 BooleanArray
272 );
273 }
274
275 Ok(())
276 }
277
278 #[test]
279 fn test_array_scalar() -> Result<()> {
280 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
282 Some("alphabet"),
283 Some("alphabet"),
284 Some("beta"),
285 None,
286 ])));
287 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string())));
288
289 let args = vec![array, scalar];
290 test_function!(
291 StartsWithFunc::new(),
292 args,
293 Ok(Some(true)), bool,
295 Boolean,
296 BooleanArray
297 );
298
299 Ok(())
300 }
301
302 #[test]
303 fn test_array_scalar_full_result() {
304 let func = StartsWithFunc::new();
306 let array = Arc::new(StringArray::from(vec![
307 Some("alphabet"),
308 Some("alphabet"),
309 Some("beta"),
310 None,
311 ]));
312 let args = vec![
313 ColumnarValue::Array(array),
314 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))),
315 ];
316
317 let result = func
318 .invoke_with_args(ScalarFunctionArgs {
319 args,
320 arg_fields: vec![
321 Field::new("a", DataType::Utf8, true).into(),
322 Field::new("b", DataType::Utf8, true).into(),
323 ],
324 number_rows: 4,
325 return_field: Field::new("f", Boolean, true).into(),
326 config_options: Arc::new(ConfigOptions::default()),
327 })
328 .unwrap();
329
330 let result_array = result.into_array(4).unwrap();
331 let bool_array = result_array
332 .as_any()
333 .downcast_ref::<BooleanArray>()
334 .unwrap();
335
336 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
341
342 #[test]
343 fn test_scalar_array() {
344 let func = StartsWithFunc::new();
346 let prefixes = Arc::new(StringArray::from(vec![
347 Some("alph"),
348 Some("bet"),
349 Some("alpha"),
350 None,
351 ]));
352 let args = vec![
353 ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
354 ColumnarValue::Array(prefixes),
355 ];
356
357 let result = func
358 .invoke_with_args(ScalarFunctionArgs {
359 args,
360 arg_fields: vec![
361 Field::new("a", DataType::Utf8, true).into(),
362 Field::new("b", DataType::Utf8, true).into(),
363 ],
364 number_rows: 4,
365 return_field: Field::new("f", Boolean, true).into(),
366 config_options: Arc::new(ConfigOptions::default()),
367 })
368 .unwrap();
369
370 let result_array = result.into_array(4).unwrap();
371 let bool_array = result_array
372 .as_any()
373 .downcast_ref::<BooleanArray>()
374 .unwrap();
375
376 assert!(bool_array.value(0)); assert!(!bool_array.value(1)); assert!(bool_array.value(2)); assert!(bool_array.is_null(3)); }
381
382 #[test]
383 fn test_array_array() {
384 let func = StartsWithFunc::new();
386 let strings = Arc::new(StringArray::from(vec![
387 Some("alphabet"),
388 Some("rust"),
389 Some("datafusion"),
390 None,
391 ]));
392 let prefixes = Arc::new(StringArray::from(vec![
393 Some("alph"),
394 Some("ru"),
395 Some("hello"),
396 Some("test"),
397 ]));
398 let args = vec![
399 ColumnarValue::Array(strings),
400 ColumnarValue::Array(prefixes),
401 ];
402
403 let result = func
404 .invoke_with_args(ScalarFunctionArgs {
405 args,
406 arg_fields: vec![
407 Field::new("a", DataType::Utf8, true).into(),
408 Field::new("b", DataType::Utf8, true).into(),
409 ],
410 number_rows: 4,
411 return_field: Field::new("f", Boolean, true).into(),
412 config_options: Arc::new(ConfigOptions::default()),
413 })
414 .unwrap();
415
416 let result_array = result.into_array(4).unwrap();
417 let bool_array = result_array
418 .as_any()
419 .downcast_ref::<BooleanArray>()
420 .unwrap();
421
422 assert!(bool_array.value(0)); assert!(bool_array.value(1)); assert!(!bool_array.value(2)); assert!(bool_array.is_null(3)); }
427}