datafusion_functions/regex/
regexpinstr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int64Type};
22use arrow::datatypes::{
23    DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
24};
25use arrow::error::ArrowError;
26use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
27use datafusion_expr::{
28    ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
29    TypeSignature::Uniform, Volatility,
30};
31use datafusion_macros::user_doc;
32use itertools::izip;
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use crate::regex::compile_and_cache_regex;
38
39#[user_doc(
40    doc_section(label = "Regular Expression Functions"),
41    description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.",
42    syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])",
43    sql_example = r#"```sql
44> SELECT regexp_instr('ABCDEF', 'C(.)(..)');
45+---------------------------------------------------------------+
46| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)"))                 |
47+---------------------------------------------------------------+
48| 3                                                             |
49+---------------------------------------------------------------+
50```"#,
51    standard_argument(name = "str", prefix = "String"),
52    standard_argument(name = "regexp", prefix = "Regular"),
53    argument(
54        name = "start",
55        description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1"
56    ),
57    argument(
58        name = "N",
59        description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function."
60    ),
61    argument(
62        name = "flags",
63        description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
64  - **i**: case-insensitive: letters match both upper and lower case
65  - **m**: multi-line mode: ^ and $ match begin/end of line
66  - **s**: allow . to match \n
67  - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
68  - **U**: swap the meaning of x* and x*?"#
69    ),
70    argument(
71        name = "subexpr",
72        description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match."
73    )
74)]
75#[derive(Debug)]
76pub struct RegexpInstrFunc {
77    signature: Signature,
78}
79
80impl Default for RegexpInstrFunc {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl RegexpInstrFunc {
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::one_of(
90                vec![
91                    Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
92                    Exact(vec![Utf8View, Utf8View, Int64]),
93                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
94                    Exact(vec![Utf8, Utf8, Int64]),
95                    Exact(vec![Utf8View, Utf8View, Int64, Int64]),
96                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
97                    Exact(vec![Utf8, Utf8, Int64, Int64]),
98                    Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]),
99                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]),
100                    Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]),
101                    Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]),
102                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]),
103                    Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]),
104                ],
105                Volatility::Immutable,
106            ),
107        }
108    }
109}
110
111impl ScalarUDFImpl for RegexpInstrFunc {
112    fn as_any(&self) -> &dyn std::any::Any {
113        self
114    }
115
116    fn name(&self) -> &str {
117        "regexp_instr"
118    }
119
120    fn signature(&self) -> &Signature {
121        &self.signature
122    }
123
124    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
125        Ok(Int64)
126    }
127
128    fn invoke_with_args(
129        &self,
130        args: datafusion_expr::ScalarFunctionArgs,
131    ) -> Result<ColumnarValue> {
132        let args = &args.args;
133
134        let len = args
135            .iter()
136            .fold(Option::<usize>::None, |acc, arg| match arg {
137                ColumnarValue::Scalar(_) => acc,
138                ColumnarValue::Array(a) => Some(a.len()),
139            });
140
141        let is_scalar = len.is_none();
142        let inferred_length = len.unwrap_or(1);
143        let args = args
144            .iter()
145            .map(|arg| arg.to_array(inferred_length))
146            .collect::<Result<Vec<_>>>()?;
147
148        let result = regexp_instr_func(&args);
149        if is_scalar {
150            // If all inputs are scalar, keeps output as scalar
151            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
152            result.map(ColumnarValue::Scalar)
153        } else {
154            result.map(ColumnarValue::Array)
155        }
156    }
157
158    fn documentation(&self) -> Option<&Documentation> {
159        self.doc()
160    }
161}
162
163pub fn regexp_instr_func(args: &[ArrayRef]) -> Result<ArrayRef> {
164    let args_len = args.len();
165    if !(2..=6).contains(&args_len) {
166        return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6.");
167    }
168
169    let values = &args[0];
170    match values.data_type() {
171        Utf8 | LargeUtf8 | Utf8View => (),
172        other => {
173            return internal_err!(
174                "Unsupported data type {other:?} for function regexp_instr"
175            );
176        }
177    }
178
179    regexp_instr(
180        values,
181        &args[1],
182        if args_len > 2 { Some(&args[2]) } else { None },
183        if args_len > 3 { Some(&args[3]) } else { None },
184        if args_len > 4 { Some(&args[4]) } else { None },
185        if args_len > 5 { Some(&args[5]) } else { None },
186    )
187    .map_err(|e| e.into())
188}
189
190/// `arrow-rs` style implementation of `regexp_instr` function.
191/// This function `regexp_instr` is responsible for returning the index of a regular expression pattern
192/// within a string array. It supports optional start positions and flags for case insensitivity.
193///
194/// The function accepts a variable number of arguments:
195/// - `values`: The array of strings to search within.
196/// - `regex_array`: The array of regular expression patterns to search for.
197/// - `start_array` (optional): The array of start positions for the search.
198/// - `nth_array` (optional): The array of start nth for the search.
199/// - `endoption_array` (optional): The array of endoption positions for the search.
200/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity).
201/// - `subexpr_array` (optional): The array of subexpr positions for the search.
202///
203/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions,
204/// and flags. It uses a cache to store compiled regular expressions for efficiency.
205///
206/// # Errors
207/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile.
208pub fn regexp_instr(
209    values: &dyn Array,
210    regex_array: &dyn Datum,
211    start_array: Option<&dyn Datum>,
212    nth_array: Option<&dyn Datum>,
213    flags_array: Option<&dyn Datum>,
214    subexpr_array: Option<&dyn Datum>,
215) -> Result<ArrayRef, ArrowError> {
216    let (regex_array, _) = regex_array.get();
217    let start_array = start_array.map(|start| {
218        let (start, _) = start.get();
219        start
220    });
221    let nth_array = nth_array.map(|nth| {
222        let (nth, _) = nth.get();
223        nth
224    });
225    let flags_array = flags_array.map(|flags| {
226        let (flags, _) = flags.get();
227        flags
228    });
229    let subexpr_array = subexpr_array.map(|subexpr| {
230        let (subexpr, _) = subexpr.get();
231        subexpr
232    });
233
234    match (values.data_type(), regex_array.data_type(), flags_array) {
235        (Utf8, Utf8, None) => regexp_instr_inner(
236            values.as_string::<i32>(),
237            regex_array.as_string::<i32>(),
238            start_array.map(|start| start.as_primitive::<Int64Type>()),
239            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
240            None,
241            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
242        ),
243        (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner(
244            values.as_string::<i32>(),
245            regex_array.as_string::<i32>(),
246            start_array.map(|start| start.as_primitive::<Int64Type>()),
247            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
248            Some(flags_array.as_string::<i32>()),
249            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
250        ),
251        (LargeUtf8, LargeUtf8, None) => regexp_instr_inner(
252            values.as_string::<i64>(),
253            regex_array.as_string::<i64>(),
254            start_array.map(|start| start.as_primitive::<Int64Type>()),
255            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
256            None,
257            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
258        ),
259        (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner(
260            values.as_string::<i64>(),
261            regex_array.as_string::<i64>(),
262            start_array.map(|start| start.as_primitive::<Int64Type>()),
263            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
264            Some(flags_array.as_string::<i64>()),
265            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
266        ),
267        (Utf8View, Utf8View, None) => regexp_instr_inner(
268            values.as_string_view(),
269            regex_array.as_string_view(),
270            start_array.map(|start| start.as_primitive::<Int64Type>()),
271            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
272            None,
273            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
274        ),
275        (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner(
276            values.as_string_view(),
277            regex_array.as_string_view(),
278            start_array.map(|start| start.as_primitive::<Int64Type>()),
279            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
280            Some(flags_array.as_string_view()),
281            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
282        ),
283        _ => Err(ArrowError::ComputeError(
284            "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
285        )),
286    }
287}
288
289#[allow(clippy::too_many_arguments)]
290pub fn regexp_instr_inner<'a, S>(
291    values: S,
292    regex_array: S,
293    start_array: Option<&Int64Array>,
294    nth_array: Option<&Int64Array>,
295    flags_array: Option<S>,
296    subexp_array: Option<&Int64Array>,
297) -> Result<ArrayRef, ArrowError>
298where
299    S: StringArrayType<'a>,
300{
301    let len = values.len();
302
303    let default_start_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
304    let start_array = start_array.unwrap_or(&default_start_array);
305    let start_input: Vec<i64> = (0..start_array.len())
306        .map(|i| start_array.value(i)) // handle nulls as 0
307        .collect();
308
309    let default_nth_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
310    let nth_array = nth_array.unwrap_or(&default_nth_array);
311    let nth_input: Vec<i64> = (0..nth_array.len())
312        .map(|i| nth_array.value(i)) // handle nulls as 0
313        .collect();
314
315    let flags_input = match flags_array {
316        Some(flags) => flags.iter().collect(),
317        None => vec![None; len],
318    };
319
320    let default_subexp_array = PrimitiveArray::<Int64Type>::from(vec![0; len]);
321    let subexp_array = subexp_array.unwrap_or(&default_subexp_array);
322    let subexp_input: Vec<i64> = (0..subexp_array.len())
323        .map(|i| subexp_array.value(i)) // handle nulls as 0
324        .collect();
325
326    let mut regex_cache = HashMap::new();
327
328    let result: Result<Vec<Option<i64>>, ArrowError> = izip!(
329        values.iter(),
330        regex_array.iter(),
331        start_input.iter(),
332        nth_input.iter(),
333        flags_input.iter(),
334        subexp_input.iter()
335    )
336    .map(|(value, regex, start, nth, flags, subexp)| match regex {
337        None => Ok(None),
338        Some("") => Ok(Some(0)),
339        Some(regex) => get_index(
340            value,
341            regex,
342            *start,
343            *nth,
344            *subexp,
345            *flags,
346            &mut regex_cache,
347        ),
348    })
349    .collect();
350    Ok(Arc::new(Int64Array::from(result?)))
351}
352
353fn handle_subexp(
354    pattern: &Regex,
355    search_slice: &str,
356    subexpr: i64,
357    value: &str,
358    byte_start_offset: usize,
359) -> Result<Option<i64>, ArrowError> {
360    if let Some(captures) = pattern.captures(search_slice) {
361        if let Some(matched) = captures.get(subexpr as usize) {
362            // Convert byte offset relative to search_slice back to 1-based character offset
363            // relative to the original `value` string.
364            let start_char_offset =
365                value[..byte_start_offset + matched.start()].chars().count() as i64 + 1;
366            return Ok(Some(start_char_offset));
367        }
368    }
369    Ok(Some(0)) // Return 0 if the subexpression was not found
370}
371
372fn get_nth_match(
373    pattern: &Regex,
374    search_slice: &str,
375    n: i64,
376    byte_start_offset: usize,
377    value: &str,
378) -> Result<Option<i64>, ArrowError> {
379    if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) {
380        // Convert byte offset relative to search_slice back to 1-based character offset
381        // relative to the original `value` string.
382        let match_start_byte_offset = byte_start_offset + mat.start();
383        let match_start_char_offset =
384            value[..match_start_byte_offset].chars().count() as i64 + 1;
385        Ok(Some(match_start_char_offset))
386    } else {
387        Ok(Some(0)) // Return 0 if the N-th match was not found
388    }
389}
390fn get_index<'strings, 'cache>(
391    value: Option<&str>,
392    pattern: &'strings str,
393    start: i64,
394    n: i64,
395    subexpr: i64,
396    flags: Option<&'strings str>,
397    regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
398) -> Result<Option<i64>, ArrowError>
399where
400    'strings: 'cache,
401{
402    let value = match value {
403        None => return Ok(None),
404        Some("") => return Ok(Some(0)),
405        Some(value) => value,
406    };
407    let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
408    // println!("get_index: value = {}, pattern = {}, start = {}, n = {}, subexpr = {}, flags = {:?}", value, pattern, start, n, subexpr, flags);
409    if start < 1 {
410        return Err(ArrowError::ComputeError(
411            "regexp_instr() requires start to be 1-based".to_string(),
412        ));
413    }
414
415    if n < 1 {
416        return Err(ArrowError::ComputeError(
417            "N must be 1 or greater".to_string(),
418        ));
419    }
420
421    // --- Simplified byte_start_offset calculation ---
422    let total_chars = value.chars().count() as i64;
423    let byte_start_offset: usize = if start > total_chars {
424        // If start is beyond the total characters, it means we start searching
425        // after the string effectively. No matches possible.
426        return Ok(Some(0));
427    } else {
428        // Get the byte offset for the (start - 1)-th character (0-based)
429        value
430            .char_indices()
431            .nth((start - 1) as usize)
432            .map(|(idx, _)| idx)
433            .unwrap_or(0) // Should not happen if start is valid and <= total_chars
434    };
435    // --- End simplified calculation ---
436
437    let search_slice = &value[byte_start_offset..];
438
439    // Handle subexpression capturing first, as it takes precedence
440    if subexpr > 0 {
441        return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset);
442    }
443
444    // Use nth to get the N-th match (n is 1-based, nth is 0-based)
445    get_nth_match(pattern, search_slice, n, byte_start_offset, value)
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use arrow::array::Int64Array;
452    use arrow::array::{GenericStringArray, StringViewArray};
453    use arrow::datatypes::Field;
454    use datafusion_expr::ScalarFunctionArgs;
455    #[test]
456    fn test_regexp_instr() {
457        test_case_sensitive_regexp_instr_nulls();
458        test_case_sensitive_regexp_instr_scalar();
459        test_case_sensitive_regexp_instr_scalar_start();
460        test_case_sensitive_regexp_instr_scalar_nth();
461        test_case_sensitive_regexp_instr_scalar_subexp();
462
463        test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
464        test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
465        test_case_sensitive_regexp_instr_array::<StringViewArray>();
466
467        test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i32>>();
468        test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i64>>();
469        test_case_sensitive_regexp_instr_array_start::<StringViewArray>();
470
471        test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i32>>();
472        test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i64>>();
473        test_case_sensitive_regexp_instr_array_nth::<StringViewArray>();
474    }
475
476    fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
477        let args_values: Vec<ColumnarValue> = args
478            .iter()
479            .map(|sv| ColumnarValue::Scalar(sv.clone()))
480            .collect();
481
482        let arg_fields = args
483            .iter()
484            .enumerate()
485            .map(|(idx, a)| {
486                Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true))
487            })
488            .collect::<Vec<_>>();
489
490        RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs {
491            args: args_values,
492            arg_fields,
493            number_rows: args.len(),
494            return_field: Arc::new(Field::new("f", Int64, true)),
495        })
496    }
497
498    fn test_case_sensitive_regexp_instr_nulls() {
499        let v = "";
500        let r = "";
501        let expected = 0;
502        let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
503        let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
504        // let res_exp = re.unwrap();
505        match re {
506            Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
507                assert_eq!(v, Some(expected), "regexp_instr scalar test failed");
508            }
509            _ => panic!("Unexpected result"),
510        }
511    }
512    fn test_case_sensitive_regexp_instr_scalar() {
513        let values = [
514            "hello world",
515            "abcdefg",
516            "xyz123xyz",
517            "no match here",
518            "abc",
519            "ДатаФусион数据融合📊🔥",
520        ];
521        let regex = ["o", "d", "123", "z", "gg", "📊"];
522
523        let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
524
525        izip!(values.iter(), regex.iter())
526            .enumerate()
527            .for_each(|(pos, (&v, &r))| {
528                // utf8
529                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
530                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
531                let expected = expected.get(pos).cloned();
532                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
533                // let res_exp = re.unwrap();
534                match re {
535                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
536                        assert_eq!(v, expected, "regexp_instr scalar test failed");
537                    }
538                    _ => panic!("Unexpected result"),
539                }
540
541                // largeutf8
542                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
543                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
544                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
545                match re {
546                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
547                        assert_eq!(v, expected, "regexp_instr scalar test failed");
548                    }
549                    _ => panic!("Unexpected result"),
550                }
551
552                // utf8view
553                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
554                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
555                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
556                match re {
557                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
558                        assert_eq!(v, expected, "regexp_instr scalar test failed");
559                    }
560                    _ => panic!("Unexpected result"),
561                }
562            });
563    }
564
565    fn test_case_sensitive_regexp_instr_scalar_start() {
566        let values = ["abcabcabc", "abcabcabc", ""];
567        let regex = ["abc", "abc", "gg"];
568        let start = [4, 5, 5];
569        let expected: Vec<i64> = vec![4, 7, 0];
570
571        izip!(values.iter(), regex.iter(), start.iter())
572            .enumerate()
573            .for_each(|(pos, (&v, &r, &s))| {
574                // utf8
575                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
576                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
577                let start_sv = ScalarValue::Int64(Some(s));
578                let expected = expected.get(pos).cloned();
579                let re =
580                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
581                match re {
582                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
583                        assert_eq!(v, expected, "regexp_instr scalar test failed");
584                    }
585                    _ => panic!("Unexpected result"),
586                }
587
588                // largeutf8
589                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
590                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
591                let start_sv = ScalarValue::Int64(Some(s));
592                let re =
593                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
594                match re {
595                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
596                        assert_eq!(v, expected, "regexp_instr scalar test failed");
597                    }
598                    _ => panic!("Unexpected result"),
599                }
600
601                // utf8view
602                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
603                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
604                let start_sv = ScalarValue::Int64(Some(s));
605                let re =
606                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
607                match re {
608                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
609                        assert_eq!(v, expected, "regexp_instr scalar test failed");
610                    }
611                    _ => panic!("Unexpected result"),
612                }
613            });
614    }
615
616    fn test_case_sensitive_regexp_instr_scalar_nth() {
617        let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"];
618        let regex = ["abc", "abc", "abc", "abc"];
619        let start = [1, 1, 1, 1];
620        let nth = [1, 2, 3, 4];
621        let expected: Vec<i64> = vec![1, 4, 7, 0];
622
623        izip!(values.iter(), regex.iter(), start.iter(), nth.iter())
624            .enumerate()
625            .for_each(|(pos, (&v, &r, &s, &n))| {
626                // utf8
627                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
628                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
629                let start_sv = ScalarValue::Int64(Some(s));
630                let nth_sv = ScalarValue::Int64(Some(n));
631                let expected = expected.get(pos).cloned();
632                let re = regexp_instr_with_scalar_values(&[
633                    v_sv,
634                    regex_sv,
635                    start_sv.clone(),
636                    nth_sv.clone(),
637                ]);
638                match re {
639                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
640                        assert_eq!(v, expected, "regexp_instr scalar test failed");
641                    }
642                    _ => panic!("Unexpected result"),
643                }
644
645                // largeutf8
646                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
647                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
648                let start_sv = ScalarValue::Int64(Some(s));
649                let nth_sv = ScalarValue::Int64(Some(n));
650                let re = regexp_instr_with_scalar_values(&[
651                    v_sv,
652                    regex_sv,
653                    start_sv.clone(),
654                    nth_sv.clone(),
655                ]);
656                match re {
657                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
658                        assert_eq!(v, expected, "regexp_instr scalar test failed");
659                    }
660                    _ => panic!("Unexpected result"),
661                }
662
663                // utf8view
664                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
665                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
666                let start_sv = ScalarValue::Int64(Some(s));
667                let nth_sv = ScalarValue::Int64(Some(n));
668                let re = regexp_instr_with_scalar_values(&[
669                    v_sv,
670                    regex_sv,
671                    start_sv.clone(),
672                    nth_sv.clone(),
673                ]);
674                match re {
675                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
676                        assert_eq!(v, expected, "regexp_instr scalar test failed");
677                    }
678                    _ => panic!("Unexpected result"),
679                }
680            });
681    }
682
683    fn test_case_sensitive_regexp_instr_scalar_subexp() {
684        let values = ["12 abc def ghi 34"];
685        let regex = ["(abc) (def) (ghi)"];
686        let start = [1];
687        let nth = [1];
688        let flags = ["i"];
689        let subexps = [2];
690        let expected: Vec<i64> = vec![8];
691
692        izip!(
693            values.iter(),
694            regex.iter(),
695            start.iter(),
696            nth.iter(),
697            flags.iter(),
698            subexps.iter()
699        )
700        .enumerate()
701        .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| {
702            // utf8
703            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
704            let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
705            let start_sv = ScalarValue::Int64(Some(s));
706            let nth_sv = ScalarValue::Int64(Some(n));
707            let flags_sv = ScalarValue::Utf8(Some(flag.to_string()));
708            let subexp_sv = ScalarValue::Int64(Some(subexp));
709            let expected = expected.get(pos).cloned();
710            let re = regexp_instr_with_scalar_values(&[
711                v_sv,
712                regex_sv,
713                start_sv.clone(),
714                nth_sv.clone(),
715                flags_sv,
716                subexp_sv.clone(),
717            ]);
718            match re {
719                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
720                    assert_eq!(v, expected, "regexp_instr scalar test failed");
721                }
722                _ => panic!("Unexpected result"),
723            }
724
725            // largeutf8
726            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
727            let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
728            let start_sv = ScalarValue::Int64(Some(s));
729            let nth_sv = ScalarValue::Int64(Some(n));
730            let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string()));
731            let subexp_sv = ScalarValue::Int64(Some(subexp));
732            let re = regexp_instr_with_scalar_values(&[
733                v_sv,
734                regex_sv,
735                start_sv.clone(),
736                nth_sv.clone(),
737                flags_sv,
738                subexp_sv.clone(),
739            ]);
740            match re {
741                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
742                    assert_eq!(v, expected, "regexp_instr scalar test failed");
743                }
744                _ => panic!("Unexpected result"),
745            }
746
747            // utf8view
748            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
749            let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
750            let start_sv = ScalarValue::Int64(Some(s));
751            let nth_sv = ScalarValue::Int64(Some(n));
752            let flags_sv = ScalarValue::Utf8View(Some(flag.to_string()));
753            let subexp_sv = ScalarValue::Int64(Some(subexp));
754            let re = regexp_instr_with_scalar_values(&[
755                v_sv,
756                regex_sv,
757                start_sv.clone(),
758                nth_sv.clone(),
759                flags_sv,
760                subexp_sv.clone(),
761            ]);
762            match re {
763                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
764                    assert_eq!(v, expected, "regexp_instr scalar test failed");
765                }
766                _ => panic!("Unexpected result"),
767            }
768        });
769    }
770
771    fn test_case_sensitive_regexp_instr_array<A>()
772    where
773        A: From<Vec<&'static str>> + Array + 'static,
774    {
775        let values = A::from(vec![
776            "hello world",
777            "abcdefg",
778            "xyz123xyz",
779            "no match here",
780            "",
781        ]);
782        let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
783
784        let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
785        let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
786        assert_eq!(re.as_ref(), &expected);
787    }
788
789    fn test_case_sensitive_regexp_instr_array_start<A>()
790    where
791        A: From<Vec<&'static str>> + Array + 'static,
792    {
793        let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
794        let regex = A::from(vec!["abc", "abc", "gg"]);
795        let start = Int64Array::from(vec![4, 5, 5]);
796        let expected = Int64Array::from(vec![4, 7, 0]);
797
798        let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
799            .unwrap();
800        assert_eq!(re.as_ref(), &expected);
801    }
802
803    fn test_case_sensitive_regexp_instr_array_nth<A>()
804    where
805        A: From<Vec<&'static str>> + Array + 'static,
806    {
807        let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
808        let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
809        let start = Int64Array::from(vec![1, 1, 1, 1]);
810        let nth = Int64Array::from(vec![1, 2, 3, 4]);
811        let expected = Int64Array::from(vec![1, 4, 7, 0]);
812
813        let re = regexp_instr_func(&[
814            Arc::new(values),
815            Arc::new(regex),
816            Arc::new(start),
817            Arc::new(nth),
818        ])
819        .unwrap();
820        assert_eq!(re.as_ref(), &expected);
821    }
822}