Skip to main content

datafusion_functions/string/
starts_with.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Scalar};
22use arrow::compute::kernels::comparison::starts_with as arrow_starts_with;
23use arrow::datatypes::DataType;
24use datafusion_common::types::logical_string;
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
28use datafusion_expr::type_coercion::binary::{
29    binary_to_string_coercion, string_coercion,
30};
31use datafusion_expr::{
32    Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
33    ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast,
34};
35use datafusion_macros::user_doc;
36
37#[user_doc(
38    doc_section(label = "String Functions"),
39    description = "Tests if a string starts with a substring.",
40    syntax_example = "starts_with(str, substr)",
41    sql_example = r#"```sql
42> select starts_with('datafusion','data');
43+----------------------------------------------+
44| starts_with(Utf8("datafusion"),Utf8("data")) |
45+----------------------------------------------+
46| true                                         |
47+----------------------------------------------+
48```"#,
49    standard_argument(name = "str", prefix = "String"),
50    argument(name = "substr", description = "Substring to test for.")
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct StartsWithFunc {
54    signature: Signature,
55}
56
57impl Default for StartsWithFunc {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl StartsWithFunc {
64    pub fn new() -> Self {
65        Self {
66            signature: Signature::coercible(
67                vec![
68                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
69                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70                ],
71                Volatility::Immutable,
72            ),
73        }
74    }
75}
76
77impl ScalarUDFImpl for StartsWithFunc {
78    fn as_any(&self) -> &dyn Any {
79        self
80    }
81
82    fn name(&self) -> &str {
83        "starts_with"
84    }
85
86    fn signature(&self) -> &Signature {
87        &self.signature
88    }
89
90    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
91        Ok(DataType::Boolean)
92    }
93
94    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
95        let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?;
96
97        // Determine the common type for coercion
98        let coercion_type = string_coercion(
99            &str_arg.data_type(),
100            &prefix_arg.data_type(),
101        )
102        .or_else(|| {
103            binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type())
104        });
105
106        let Some(coercion_type) = coercion_type else {
107            return exec_err!(
108                "Unsupported data types {:?}, {:?} for function `starts_with`.",
109                str_arg.data_type(),
110                prefix_arg.data_type()
111            );
112        };
113
114        // Helper to cast an array if needed
115        let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
116            if arr.data_type() == target {
117                Ok(Arc::clone(arr))
118            } else {
119                Ok(arrow::compute::kernels::cast::cast(arr, target)?)
120            }
121        };
122
123        match (str_arg, prefix_arg) {
124            // Both scalars - just compute directly
125            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => {
126                let str_arr = str_scalar.to_array_of_size(1)?;
127                let prefix_arr = prefix_scalar.to_array_of_size(1)?;
128                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
129                let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
130                let result = arrow_starts_with(&str_arr, &prefix_arr)?;
131                Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
132                    &result, 0,
133                )?))
134            }
135            // String is array, prefix is scalar - use Scalar wrapper for optimization
136            (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => {
137                let str_arr = maybe_cast(str_arr, &coercion_type)?;
138                let prefix_arr = prefix_scalar.to_array_of_size(1)?;
139                let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
140                let prefix_scalar = Scalar::new(prefix_arr);
141                let result = arrow_starts_with(&str_arr, &prefix_scalar)?;
142                Ok(ColumnarValue::Array(Arc::new(result)))
143            }
144            // String is scalar, prefix is array - use Scalar wrapper for string
145            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => {
146                let str_arr = str_scalar.to_array_of_size(1)?;
147                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
148                let str_scalar = Scalar::new(str_arr);
149                let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
150                let result = arrow_starts_with(&str_scalar, &prefix_arr)?;
151                Ok(ColumnarValue::Array(Arc::new(result)))
152            }
153            // Both arrays - pass directly
154            (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => {
155                let str_arr = maybe_cast(str_arr, &coercion_type)?;
156                let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
157                let result = arrow_starts_with(&str_arr, &prefix_arr)?;
158                Ok(ColumnarValue::Array(Arc::new(result)))
159            }
160        }
161    }
162
163    fn simplify(
164        &self,
165        args: Vec<Expr>,
166        info: &SimplifyContext,
167    ) -> Result<ExprSimplifyResult> {
168        if let Expr::Literal(scalar_value, _) = &args[1] {
169            // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
170            // Escapes pattern characters: starts_with(col, 'j\_a%') -> col LIKE 'j\\\_a\%%'
171            //   1. 'j\_a%'         (input pattern)
172            //   2. 'j\\\_a\%'       (escape special chars '%', '_' and '\')
173            //   3. 'j\\\_a\%%'      (add unescaped % suffix for starts_with)
174            let like_expr = match scalar_value {
175                ScalarValue::Utf8(Some(pattern))
176                | ScalarValue::LargeUtf8(Some(pattern))
177                | ScalarValue::Utf8View(Some(pattern)) => {
178                    let escaped_pattern = pattern
179                        .replace("\\", "\\\\")
180                        .replace("%", "\\%")
181                        .replace("_", "\\_");
182                    let like_pattern = format!("{escaped_pattern}%");
183                    Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None)
184                }
185                _ => return Ok(ExprSimplifyResult::Original(args)),
186            };
187
188            let expr_data_type = info.get_data_type(&args[0])?;
189            let pattern_data_type = info.get_data_type(&like_expr)?;
190
191            if let Some(coercion_data_type) =
192                string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
193                    binary_to_string_coercion(&expr_data_type, &pattern_data_type)
194                })
195            {
196                let expr = if expr_data_type == coercion_data_type {
197                    args[0].clone()
198                } else {
199                    cast(args[0].clone(), coercion_data_type.clone())
200                };
201
202                let pattern = if pattern_data_type == coercion_data_type {
203                    like_expr
204                } else {
205                    cast(like_expr, coercion_data_type)
206                };
207
208                return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
209                    negated: false,
210                    expr: Box::new(expr),
211                    pattern: Box::new(pattern),
212                    escape_char: None,
213                    case_insensitive: false,
214                })));
215            }
216        }
217
218        Ok(ExprSimplifyResult::Original(args))
219    }
220
221    fn documentation(&self) -> Option<&Documentation> {
222        self.doc()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use crate::utils::test::test_function;
229    use arrow::array::{Array, BooleanArray, StringArray};
230    use arrow::datatypes::DataType::Boolean;
231    use arrow::datatypes::{DataType, Field};
232    use datafusion_common::config::ConfigOptions;
233    use datafusion_common::{Result, ScalarValue};
234    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
235    use std::sync::Arc;
236
237    use super::*;
238
239    #[test]
240    fn test_scalar_scalar() -> Result<()> {
241        // Test Scalar + Scalar combinations
242        let test_cases = vec![
243            (Some("alphabet"), Some("alph"), Some(true)),
244            (Some("alphabet"), Some("bet"), Some(false)),
245            (
246                Some("somewhat large string"),
247                Some("somewhat large"),
248                Some(true),
249            ),
250            (Some("somewhat large string"), Some("large"), Some(false)),
251        ]
252        .into_iter()
253        .flat_map(|(a, b, c)| {
254            let utf_8_args = vec![
255                ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
256                ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
257            ];
258
259            let large_utf_8_args = vec![
260                ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
261                ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
262            ];
263
264            let utf_8_view_args = vec![
265                ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
266                ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
267            ];
268
269            vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
270        });
271
272        for (args, expected) in test_cases {
273            test_function!(
274                StartsWithFunc::new(),
275                args,
276                Ok(expected),
277                bool,
278                Boolean,
279                BooleanArray
280            );
281        }
282
283        Ok(())
284    }
285
286    #[test]
287    fn test_array_scalar() -> Result<()> {
288        // Test Array + Scalar (the optimized path)
289        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
290            Some("alphabet"),
291            Some("alphabet"),
292            Some("beta"),
293            None,
294        ])));
295        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string())));
296
297        let args = vec![array, scalar];
298        test_function!(
299            StartsWithFunc::new(),
300            args,
301            Ok(Some(true)), // First element result
302            bool,
303            Boolean,
304            BooleanArray
305        );
306
307        Ok(())
308    }
309
310    #[test]
311    fn test_array_scalar_full_result() {
312        // Test Array + Scalar and verify all results
313        let func = StartsWithFunc::new();
314        let array = Arc::new(StringArray::from(vec![
315            Some("alphabet"),
316            Some("alphabet"),
317            Some("beta"),
318            None,
319        ]));
320        let args = vec![
321            ColumnarValue::Array(array),
322            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))),
323        ];
324
325        let result = func
326            .invoke_with_args(ScalarFunctionArgs {
327                args,
328                arg_fields: vec![
329                    Field::new("a", DataType::Utf8, true).into(),
330                    Field::new("b", DataType::Utf8, true).into(),
331                ],
332                number_rows: 4,
333                return_field: Field::new("f", Boolean, true).into(),
334                config_options: Arc::new(ConfigOptions::default()),
335            })
336            .unwrap();
337
338        let result_array = result.into_array(4).unwrap();
339        let bool_array = result_array
340            .as_any()
341            .downcast_ref::<BooleanArray>()
342            .unwrap();
343
344        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
345        assert!(bool_array.value(1)); // "alphabet" starts with "alph"
346        assert!(!bool_array.value(2)); // "beta" does not start with "alph"
347        assert!(bool_array.is_null(3)); // null input -> null output
348    }
349
350    #[test]
351    fn test_scalar_array() {
352        // Test Scalar + Array
353        let func = StartsWithFunc::new();
354        let prefixes = Arc::new(StringArray::from(vec![
355            Some("alph"),
356            Some("bet"),
357            Some("alpha"),
358            None,
359        ]));
360        let args = vec![
361            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
362            ColumnarValue::Array(prefixes),
363        ];
364
365        let result = func
366            .invoke_with_args(ScalarFunctionArgs {
367                args,
368                arg_fields: vec![
369                    Field::new("a", DataType::Utf8, true).into(),
370                    Field::new("b", DataType::Utf8, true).into(),
371                ],
372                number_rows: 4,
373                return_field: Field::new("f", Boolean, true).into(),
374                config_options: Arc::new(ConfigOptions::default()),
375            })
376            .unwrap();
377
378        let result_array = result.into_array(4).unwrap();
379        let bool_array = result_array
380            .as_any()
381            .downcast_ref::<BooleanArray>()
382            .unwrap();
383
384        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
385        assert!(!bool_array.value(1)); // "alphabet" does not start with "bet"
386        assert!(bool_array.value(2)); // "alphabet" starts with "alpha"
387        assert!(bool_array.is_null(3)); // null prefix -> null output
388    }
389
390    #[test]
391    fn test_array_array() {
392        // Test Array + Array
393        let func = StartsWithFunc::new();
394        let strings = Arc::new(StringArray::from(vec![
395            Some("alphabet"),
396            Some("rust"),
397            Some("datafusion"),
398            None,
399        ]));
400        let prefixes = Arc::new(StringArray::from(vec![
401            Some("alph"),
402            Some("ru"),
403            Some("hello"),
404            Some("test"),
405        ]));
406        let args = vec![
407            ColumnarValue::Array(strings),
408            ColumnarValue::Array(prefixes),
409        ];
410
411        let result = func
412            .invoke_with_args(ScalarFunctionArgs {
413                args,
414                arg_fields: vec![
415                    Field::new("a", DataType::Utf8, true).into(),
416                    Field::new("b", DataType::Utf8, true).into(),
417                ],
418                number_rows: 4,
419                return_field: Field::new("f", Boolean, true).into(),
420                config_options: Arc::new(ConfigOptions::default()),
421            })
422            .unwrap();
423
424        let result_array = result.into_array(4).unwrap();
425        let bool_array = result_array
426            .as_any()
427            .downcast_ref::<BooleanArray>()
428            .unwrap();
429
430        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
431        assert!(bool_array.value(1)); // "rust" starts with "ru"
432        assert!(!bool_array.value(2)); // "datafusion" does not start with "hello"
433        assert!(bool_array.is_null(3)); // null string -> null output
434    }
435}