datafusion_functions/string/
starts_with.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Scalar};
22use arrow::compute::kernels::comparison::starts_with as arrow_starts_with;
23use arrow::datatypes::DataType;
24use datafusion_common::utils::take_function_args;
25use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
26use datafusion_expr::type_coercion::binary::{
27    binary_to_string_coercion, string_coercion,
28};
29
30use datafusion_common::types::logical_string;
31use datafusion_common::{Result, ScalarValue, exec_err};
32use datafusion_expr::{
33    Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
34    ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Tests if a string starts with a substring.",
41    syntax_example = "starts_with(str, substr)",
42    sql_example = r#"```sql
43> select starts_with('datafusion','data');
44+----------------------------------------------+
45| starts_with(Utf8("datafusion"),Utf8("data")) |
46+----------------------------------------------+
47| true                                         |
48+----------------------------------------------+
49```"#,
50    standard_argument(name = "str", prefix = "String"),
51    argument(name = "substr", description = "Substring to test for.")
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct StartsWithFunc {
55    signature: Signature,
56}
57
58impl Default for StartsWithFunc {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl StartsWithFunc {
65    pub fn new() -> Self {
66        Self {
67            signature: Signature::coercible(
68                vec![
69                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71                ],
72                Volatility::Immutable,
73            ),
74        }
75    }
76}
77
78impl ScalarUDFImpl for StartsWithFunc {
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn name(&self) -> &str {
84        "starts_with"
85    }
86
87    fn signature(&self) -> &Signature {
88        &self.signature
89    }
90
91    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
92        Ok(DataType::Boolean)
93    }
94
95    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96        let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?;
97
98        // Determine the common type for coercion
99        let coercion_type = string_coercion(
100            &str_arg.data_type(),
101            &prefix_arg.data_type(),
102        )
103        .or_else(|| {
104            binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type())
105        });
106
107        let Some(coercion_type) = coercion_type else {
108            return exec_err!(
109                "Unsupported data types {:?}, {:?} for function `starts_with`.",
110                str_arg.data_type(),
111                prefix_arg.data_type()
112            );
113        };
114
115        // Helper to cast an array if needed
116        let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
117            if arr.data_type() == target {
118                Ok(Arc::clone(arr))
119            } else {
120                Ok(arrow::compute::kernels::cast::cast(arr, target)?)
121            }
122        };
123
124        match (str_arg, prefix_arg) {
125            // Both scalars - just compute directly
126            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => {
127                let str_arr = str_scalar.to_array_of_size(1)?;
128                let prefix_arr = prefix_scalar.to_array_of_size(1)?;
129                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
130                let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
131                let result = arrow_starts_with(&str_arr, &prefix_arr)?;
132                Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
133                    &result, 0,
134                )?))
135            }
136            // String is array, prefix is scalar - use Scalar wrapper for optimization
137            (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => {
138                let str_arr = maybe_cast(str_arr, &coercion_type)?;
139                let prefix_arr = prefix_scalar.to_array_of_size(1)?;
140                let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
141                let prefix_scalar = Scalar::new(prefix_arr);
142                let result = arrow_starts_with(&str_arr, &prefix_scalar)?;
143                Ok(ColumnarValue::Array(Arc::new(result)))
144            }
145            // String is scalar, prefix is array - use Scalar wrapper for string
146            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => {
147                let str_arr = str_scalar.to_array_of_size(1)?;
148                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
149                let str_scalar = Scalar::new(str_arr);
150                let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
151                let result = arrow_starts_with(&str_scalar, &prefix_arr)?;
152                Ok(ColumnarValue::Array(Arc::new(result)))
153            }
154            // Both arrays - pass directly
155            (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => {
156                let str_arr = maybe_cast(str_arr, &coercion_type)?;
157                let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
158                let result = arrow_starts_with(&str_arr, &prefix_arr)?;
159                Ok(ColumnarValue::Array(Arc::new(result)))
160            }
161        }
162    }
163
164    fn simplify(
165        &self,
166        args: Vec<Expr>,
167        info: &dyn SimplifyInfo,
168    ) -> Result<ExprSimplifyResult> {
169        if let Expr::Literal(scalar_value, _) = &args[1] {
170            // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
171            // Escapes pattern characters: starts_with(col, 'j\_a%') -> col LIKE 'j\\\_a\%%'
172            //   1. 'j\_a%'         (input pattern)
173            //   2. 'j\\\_a\%'       (escape special chars '%', '_' and '\')
174            //   3. 'j\\\_a\%%'      (add unescaped % suffix for starts_with)
175            let like_expr = match scalar_value {
176                ScalarValue::Utf8(Some(pattern))
177                | ScalarValue::LargeUtf8(Some(pattern))
178                | ScalarValue::Utf8View(Some(pattern)) => {
179                    let escaped_pattern = pattern
180                        .replace("\\", "\\\\")
181                        .replace("%", "\\%")
182                        .replace("_", "\\_");
183                    let like_pattern = format!("{escaped_pattern}%");
184                    Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None)
185                }
186                _ => return Ok(ExprSimplifyResult::Original(args)),
187            };
188
189            let expr_data_type = info.get_data_type(&args[0])?;
190            let pattern_data_type = info.get_data_type(&like_expr)?;
191
192            if let Some(coercion_data_type) =
193                string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
194                    binary_to_string_coercion(&expr_data_type, &pattern_data_type)
195                })
196            {
197                let expr = if expr_data_type == coercion_data_type {
198                    args[0].clone()
199                } else {
200                    cast(args[0].clone(), coercion_data_type.clone())
201                };
202
203                let pattern = if pattern_data_type == coercion_data_type {
204                    like_expr
205                } else {
206                    cast(like_expr, coercion_data_type)
207                };
208
209                return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
210                    negated: false,
211                    expr: Box::new(expr),
212                    pattern: Box::new(pattern),
213                    escape_char: None,
214                    case_insensitive: false,
215                })));
216            }
217        }
218
219        Ok(ExprSimplifyResult::Original(args))
220    }
221
222    fn documentation(&self) -> Option<&Documentation> {
223        self.doc()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use crate::utils::test::test_function;
230    use arrow::array::{Array, BooleanArray, StringArray};
231    use arrow::datatypes::DataType::Boolean;
232    use arrow::datatypes::{DataType, Field};
233    use datafusion_common::config::ConfigOptions;
234    use datafusion_common::{Result, ScalarValue};
235    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
236    use std::sync::Arc;
237
238    use super::*;
239
240    #[test]
241    fn test_scalar_scalar() -> Result<()> {
242        // Test Scalar + Scalar combinations
243        let test_cases = vec![
244            (Some("alphabet"), Some("alph"), Some(true)),
245            (Some("alphabet"), Some("bet"), Some(false)),
246            (
247                Some("somewhat large string"),
248                Some("somewhat large"),
249                Some(true),
250            ),
251            (Some("somewhat large string"), Some("large"), Some(false)),
252        ]
253        .into_iter()
254        .flat_map(|(a, b, c)| {
255            let utf_8_args = vec![
256                ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
257                ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
258            ];
259
260            let large_utf_8_args = vec![
261                ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
262                ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
263            ];
264
265            let utf_8_view_args = vec![
266                ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
267                ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
268            ];
269
270            vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
271        });
272
273        for (args, expected) in test_cases {
274            test_function!(
275                StartsWithFunc::new(),
276                args,
277                Ok(expected),
278                bool,
279                Boolean,
280                BooleanArray
281            );
282        }
283
284        Ok(())
285    }
286
287    #[test]
288    fn test_array_scalar() -> Result<()> {
289        // Test Array + Scalar (the optimized path)
290        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
291            Some("alphabet"),
292            Some("alphabet"),
293            Some("beta"),
294            None,
295        ])));
296        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string())));
297
298        let args = vec![array, scalar];
299        test_function!(
300            StartsWithFunc::new(),
301            args,
302            Ok(Some(true)), // First element result
303            bool,
304            Boolean,
305            BooleanArray
306        );
307
308        Ok(())
309    }
310
311    #[test]
312    fn test_array_scalar_full_result() {
313        // Test Array + Scalar and verify all results
314        let func = StartsWithFunc::new();
315        let array = Arc::new(StringArray::from(vec![
316            Some("alphabet"),
317            Some("alphabet"),
318            Some("beta"),
319            None,
320        ]));
321        let args = vec![
322            ColumnarValue::Array(array),
323            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))),
324        ];
325
326        let result = func
327            .invoke_with_args(ScalarFunctionArgs {
328                args,
329                arg_fields: vec![
330                    Field::new("a", DataType::Utf8, true).into(),
331                    Field::new("b", DataType::Utf8, true).into(),
332                ],
333                number_rows: 4,
334                return_field: Field::new("f", Boolean, true).into(),
335                config_options: Arc::new(ConfigOptions::default()),
336            })
337            .unwrap();
338
339        let result_array = result.into_array(4).unwrap();
340        let bool_array = result_array
341            .as_any()
342            .downcast_ref::<BooleanArray>()
343            .unwrap();
344
345        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
346        assert!(bool_array.value(1)); // "alphabet" starts with "alph"
347        assert!(!bool_array.value(2)); // "beta" does not start with "alph"
348        assert!(bool_array.is_null(3)); // null input -> null output
349    }
350
351    #[test]
352    fn test_scalar_array() {
353        // Test Scalar + Array
354        let func = StartsWithFunc::new();
355        let prefixes = Arc::new(StringArray::from(vec![
356            Some("alph"),
357            Some("bet"),
358            Some("alpha"),
359            None,
360        ]));
361        let args = vec![
362            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
363            ColumnarValue::Array(prefixes),
364        ];
365
366        let result = func
367            .invoke_with_args(ScalarFunctionArgs {
368                args,
369                arg_fields: vec![
370                    Field::new("a", DataType::Utf8, true).into(),
371                    Field::new("b", DataType::Utf8, true).into(),
372                ],
373                number_rows: 4,
374                return_field: Field::new("f", Boolean, true).into(),
375                config_options: Arc::new(ConfigOptions::default()),
376            })
377            .unwrap();
378
379        let result_array = result.into_array(4).unwrap();
380        let bool_array = result_array
381            .as_any()
382            .downcast_ref::<BooleanArray>()
383            .unwrap();
384
385        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
386        assert!(!bool_array.value(1)); // "alphabet" does not start with "bet"
387        assert!(bool_array.value(2)); // "alphabet" starts with "alpha"
388        assert!(bool_array.is_null(3)); // null prefix -> null output
389    }
390
391    #[test]
392    fn test_array_array() {
393        // Test Array + Array
394        let func = StartsWithFunc::new();
395        let strings = Arc::new(StringArray::from(vec![
396            Some("alphabet"),
397            Some("rust"),
398            Some("datafusion"),
399            None,
400        ]));
401        let prefixes = Arc::new(StringArray::from(vec![
402            Some("alph"),
403            Some("ru"),
404            Some("hello"),
405            Some("test"),
406        ]));
407        let args = vec![
408            ColumnarValue::Array(strings),
409            ColumnarValue::Array(prefixes),
410        ];
411
412        let result = func
413            .invoke_with_args(ScalarFunctionArgs {
414                args,
415                arg_fields: vec![
416                    Field::new("a", DataType::Utf8, true).into(),
417                    Field::new("b", DataType::Utf8, true).into(),
418                ],
419                number_rows: 4,
420                return_field: Field::new("f", Boolean, true).into(),
421                config_options: Arc::new(ConfigOptions::default()),
422            })
423            .unwrap();
424
425        let result_array = result.into_array(4).unwrap();
426        let bool_array = result_array
427            .as_any()
428            .downcast_ref::<BooleanArray>()
429            .unwrap();
430
431        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
432        assert!(bool_array.value(1)); // "rust" starts with "ru"
433        assert!(!bool_array.value(2)); // "datafusion" does not start with "hello"
434        assert!(bool_array.is_null(3)); // null string -> null output
435    }
436}