Skip to main content

datafusion_functions/string/
starts_with.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, Scalar};
21use arrow::compute::kernels::comparison::starts_with as arrow_starts_with;
22use arrow::datatypes::DataType;
23use datafusion_common::types::logical_string;
24use datafusion_common::utils::take_function_args;
25use datafusion_common::{Result, ScalarValue, exec_err};
26use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
27use datafusion_expr::type_coercion::binary::{
28    binary_to_string_coercion, string_coercion,
29};
30use datafusion_expr::{
31    Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
32    ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Tests if a string starts with a substring.",
39    syntax_example = "starts_with(str, substr)",
40    sql_example = r#"```sql
41> select starts_with('datafusion','data');
42+----------------------------------------------+
43| starts_with(Utf8("datafusion"),Utf8("data")) |
44+----------------------------------------------+
45| true                                         |
46+----------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    argument(name = "substr", description = "Substring to test for.")
50)]
51#[derive(Debug, PartialEq, Eq, Hash)]
52pub struct StartsWithFunc {
53    signature: Signature,
54}
55
56impl Default for StartsWithFunc {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl StartsWithFunc {
63    pub fn new() -> Self {
64        Self {
65            signature: Signature::coercible(
66                vec![
67                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
68                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
69                ],
70                Volatility::Immutable,
71            ),
72        }
73    }
74}
75
76impl ScalarUDFImpl for StartsWithFunc {
77    fn name(&self) -> &str {
78        "starts_with"
79    }
80
81    fn signature(&self) -> &Signature {
82        &self.signature
83    }
84
85    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
86        Ok(DataType::Boolean)
87    }
88
89    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90        let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?;
91
92        // Determine the common type for coercion
93        let coercion_type = string_coercion(
94            &str_arg.data_type(),
95            &prefix_arg.data_type(),
96        )
97        .or_else(|| {
98            binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type())
99        });
100
101        let Some(coercion_type) = coercion_type else {
102            return exec_err!(
103                "Unsupported data types {:?}, {:?} for function `starts_with`.",
104                str_arg.data_type(),
105                prefix_arg.data_type()
106            );
107        };
108
109        // Helper to cast an array if needed
110        let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result<ArrayRef> {
111            if arr.data_type() == target {
112                Ok(Arc::clone(arr))
113            } else {
114                Ok(arrow::compute::kernels::cast::cast(arr, target)?)
115            }
116        };
117
118        match (str_arg, prefix_arg) {
119            // Both scalars - just compute directly
120            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => {
121                let str_arr = str_scalar.to_array_of_size(1)?;
122                let prefix_arr = prefix_scalar.to_array_of_size(1)?;
123                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
124                let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
125                let result = arrow_starts_with(&str_arr, &prefix_arr)?;
126                Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
127                    &result, 0,
128                )?))
129            }
130            // String is array, prefix is scalar - use Scalar wrapper for optimization
131            (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => {
132                let str_arr = maybe_cast(str_arr, &coercion_type)?;
133                let prefix_arr = prefix_scalar.to_array_of_size(1)?;
134                let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
135                let prefix_scalar = Scalar::new(prefix_arr);
136                let result = arrow_starts_with(&str_arr, &prefix_scalar)?;
137                Ok(ColumnarValue::Array(Arc::new(result)))
138            }
139            // String is scalar, prefix is array - use Scalar wrapper for string
140            (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => {
141                let str_arr = str_scalar.to_array_of_size(1)?;
142                let str_arr = maybe_cast(&str_arr, &coercion_type)?;
143                let str_scalar = Scalar::new(str_arr);
144                let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
145                let result = arrow_starts_with(&str_scalar, &prefix_arr)?;
146                Ok(ColumnarValue::Array(Arc::new(result)))
147            }
148            // Both arrays - pass directly
149            (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => {
150                let str_arr = maybe_cast(str_arr, &coercion_type)?;
151                let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
152                let result = arrow_starts_with(&str_arr, &prefix_arr)?;
153                Ok(ColumnarValue::Array(Arc::new(result)))
154            }
155        }
156    }
157
158    fn simplify(
159        &self,
160        args: Vec<Expr>,
161        info: &SimplifyContext,
162    ) -> Result<ExprSimplifyResult> {
163        if let Expr::Literal(scalar_value, _) = &args[1] {
164            // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
165            // Escapes pattern characters: starts_with(col, 'j\_a%') -> col LIKE 'j\\\_a\%%'
166            //   1. 'j\_a%'         (input pattern)
167            //   2. 'j\\\_a\%'       (escape special chars '%', '_' and '\')
168            //   3. 'j\\\_a\%%'      (add unescaped % suffix for starts_with)
169            let like_expr = match scalar_value {
170                ScalarValue::Utf8(Some(pattern))
171                | ScalarValue::LargeUtf8(Some(pattern))
172                | ScalarValue::Utf8View(Some(pattern)) => {
173                    let escaped_pattern = pattern
174                        .replace("\\", "\\\\")
175                        .replace("%", "\\%")
176                        .replace("_", "\\_");
177                    let like_pattern = format!("{escaped_pattern}%");
178                    Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None)
179                }
180                _ => return Ok(ExprSimplifyResult::Original(args)),
181            };
182
183            let expr_data_type = info.get_data_type(&args[0])?;
184            let pattern_data_type = info.get_data_type(&like_expr)?;
185
186            if let Some(coercion_data_type) =
187                string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
188                    binary_to_string_coercion(&expr_data_type, &pattern_data_type)
189                })
190            {
191                let expr = if expr_data_type == coercion_data_type {
192                    args[0].clone()
193                } else {
194                    cast(args[0].clone(), coercion_data_type.clone())
195                };
196
197                let pattern = if pattern_data_type == coercion_data_type {
198                    like_expr
199                } else {
200                    cast(like_expr, coercion_data_type)
201                };
202
203                return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
204                    negated: false,
205                    expr: Box::new(expr),
206                    pattern: Box::new(pattern),
207                    escape_char: None,
208                    case_insensitive: false,
209                })));
210            }
211        }
212
213        Ok(ExprSimplifyResult::Original(args))
214    }
215
216    fn documentation(&self) -> Option<&Documentation> {
217        self.doc()
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use crate::utils::test::test_function;
224    use arrow::array::{Array, BooleanArray, StringArray};
225    use arrow::datatypes::DataType::Boolean;
226    use arrow::datatypes::Field;
227    use datafusion_common::config::ConfigOptions;
228
229    use super::*;
230
231    #[test]
232    fn test_scalar_scalar() -> Result<()> {
233        // Test Scalar + Scalar combinations
234        let test_cases = vec![
235            (Some("alphabet"), Some("alph"), Some(true)),
236            (Some("alphabet"), Some("bet"), Some(false)),
237            (
238                Some("somewhat large string"),
239                Some("somewhat large"),
240                Some(true),
241            ),
242            (Some("somewhat large string"), Some("large"), Some(false)),
243        ]
244        .into_iter()
245        .flat_map(|(a, b, c)| {
246            let utf_8_args = vec![
247                ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
248                ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
249            ];
250
251            let large_utf_8_args = vec![
252                ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
253                ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
254            ];
255
256            let utf_8_view_args = vec![
257                ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
258                ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
259            ];
260
261            vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
262        });
263
264        for (args, expected) in test_cases {
265            test_function!(
266                StartsWithFunc::new(),
267                args,
268                Ok(expected),
269                bool,
270                Boolean,
271                BooleanArray
272            );
273        }
274
275        Ok(())
276    }
277
278    #[test]
279    fn test_array_scalar() -> Result<()> {
280        // Test Array + Scalar (the optimized path)
281        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
282            Some("alphabet"),
283            Some("alphabet"),
284            Some("beta"),
285            None,
286        ])));
287        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string())));
288
289        let args = vec![array, scalar];
290        test_function!(
291            StartsWithFunc::new(),
292            args,
293            Ok(Some(true)), // First element result
294            bool,
295            Boolean,
296            BooleanArray
297        );
298
299        Ok(())
300    }
301
302    #[test]
303    fn test_array_scalar_full_result() {
304        // Test Array + Scalar and verify all results
305        let func = StartsWithFunc::new();
306        let array = Arc::new(StringArray::from(vec![
307            Some("alphabet"),
308            Some("alphabet"),
309            Some("beta"),
310            None,
311        ]));
312        let args = vec![
313            ColumnarValue::Array(array),
314            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))),
315        ];
316
317        let result = func
318            .invoke_with_args(ScalarFunctionArgs {
319                args,
320                arg_fields: vec![
321                    Field::new("a", DataType::Utf8, true).into(),
322                    Field::new("b", DataType::Utf8, true).into(),
323                ],
324                number_rows: 4,
325                return_field: Field::new("f", Boolean, true).into(),
326                config_options: Arc::new(ConfigOptions::default()),
327            })
328            .unwrap();
329
330        let result_array = result.into_array(4).unwrap();
331        let bool_array = result_array
332            .as_any()
333            .downcast_ref::<BooleanArray>()
334            .unwrap();
335
336        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
337        assert!(bool_array.value(1)); // "alphabet" starts with "alph"
338        assert!(!bool_array.value(2)); // "beta" does not start with "alph"
339        assert!(bool_array.is_null(3)); // null input -> null output
340    }
341
342    #[test]
343    fn test_scalar_array() {
344        // Test Scalar + Array
345        let func = StartsWithFunc::new();
346        let prefixes = Arc::new(StringArray::from(vec![
347            Some("alph"),
348            Some("bet"),
349            Some("alpha"),
350            None,
351        ]));
352        let args = vec![
353            ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
354            ColumnarValue::Array(prefixes),
355        ];
356
357        let result = func
358            .invoke_with_args(ScalarFunctionArgs {
359                args,
360                arg_fields: vec![
361                    Field::new("a", DataType::Utf8, true).into(),
362                    Field::new("b", DataType::Utf8, true).into(),
363                ],
364                number_rows: 4,
365                return_field: Field::new("f", Boolean, true).into(),
366                config_options: Arc::new(ConfigOptions::default()),
367            })
368            .unwrap();
369
370        let result_array = result.into_array(4).unwrap();
371        let bool_array = result_array
372            .as_any()
373            .downcast_ref::<BooleanArray>()
374            .unwrap();
375
376        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
377        assert!(!bool_array.value(1)); // "alphabet" does not start with "bet"
378        assert!(bool_array.value(2)); // "alphabet" starts with "alpha"
379        assert!(bool_array.is_null(3)); // null prefix -> null output
380    }
381
382    #[test]
383    fn test_array_array() {
384        // Test Array + Array
385        let func = StartsWithFunc::new();
386        let strings = Arc::new(StringArray::from(vec![
387            Some("alphabet"),
388            Some("rust"),
389            Some("datafusion"),
390            None,
391        ]));
392        let prefixes = Arc::new(StringArray::from(vec![
393            Some("alph"),
394            Some("ru"),
395            Some("hello"),
396            Some("test"),
397        ]));
398        let args = vec![
399            ColumnarValue::Array(strings),
400            ColumnarValue::Array(prefixes),
401        ];
402
403        let result = func
404            .invoke_with_args(ScalarFunctionArgs {
405                args,
406                arg_fields: vec![
407                    Field::new("a", DataType::Utf8, true).into(),
408                    Field::new("b", DataType::Utf8, true).into(),
409                ],
410                number_rows: 4,
411                return_field: Field::new("f", Boolean, true).into(),
412                config_options: Arc::new(ConfigOptions::default()),
413            })
414            .unwrap();
415
416        let result_array = result.into_array(4).unwrap();
417        let bool_array = result_array
418            .as_any()
419            .downcast_ref::<BooleanArray>()
420            .unwrap();
421
422        assert!(bool_array.value(0)); // "alphabet" starts with "alph"
423        assert!(bool_array.value(1)); // "rust" starts with "ru"
424        assert!(!bool_array.value(2)); // "datafusion" does not start with "hello"
425        assert!(bool_array.is_null(3)); // null string -> null output
426    }
427}