Skip to main content

datafusion_functions/regex/
regexpinstr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int64Type};
22use arrow::datatypes::{
23    DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
24};
25use arrow::error::ArrowError;
26use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
27use datafusion_expr::{
28    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29    TypeSignature::Exact, TypeSignature::Uniform, Volatility,
30};
31use datafusion_macros::user_doc;
32use itertools::izip;
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use crate::regex::compile_and_cache_regex;
38
39#[user_doc(
40    doc_section(label = "Regular Expression Functions"),
41    description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.",
42    syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])",
43    sql_example = r#"```sql
44> SELECT regexp_instr('ABCDEF', 'C(.)(..)');
45+---------------------------------------------------------------+
46| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)"))                 |
47+---------------------------------------------------------------+
48| 3                                                             |
49+---------------------------------------------------------------+
50```"#,
51    standard_argument(name = "str", prefix = "String"),
52    standard_argument(name = "regexp", prefix = "Regular"),
53    argument(
54        name = "start",
55        description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1"
56    ),
57    argument(
58        name = "N",
59        description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function."
60    ),
61    argument(
62        name = "flags",
63        description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
64  - **i**: case-insensitive: letters match both upper and lower case
65  - **m**: multi-line mode: ^ and $ match begin/end of line
66  - **s**: allow . to match \n
67  - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
68  - **U**: swap the meaning of x* and x*?"#
69    ),
70    argument(
71        name = "subexpr",
72        description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match."
73    )
74)]
75#[derive(Debug, PartialEq, Eq, Hash)]
76pub struct RegexpInstrFunc {
77    signature: Signature,
78}
79
80impl Default for RegexpInstrFunc {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl RegexpInstrFunc {
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::one_of(
90                vec![
91                    Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
92                    Exact(vec![Utf8View, Utf8View, Int64]),
93                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
94                    Exact(vec![Utf8, Utf8, Int64]),
95                    Exact(vec![Utf8View, Utf8View, Int64, Int64]),
96                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
97                    Exact(vec![Utf8, Utf8, Int64, Int64]),
98                    Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]),
99                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]),
100                    Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]),
101                    Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]),
102                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]),
103                    Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]),
104                ],
105                Volatility::Immutable,
106            ),
107        }
108    }
109}
110
111impl ScalarUDFImpl for RegexpInstrFunc {
112    fn name(&self) -> &str {
113        "regexp_instr"
114    }
115
116    fn signature(&self) -> &Signature {
117        &self.signature
118    }
119
120    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
121        Ok(Int64)
122    }
123
124    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
125        let args = &args.args;
126
127        let len = args
128            .iter()
129            .fold(Option::<usize>::None, |acc, arg| match arg {
130                ColumnarValue::Scalar(_) => acc,
131                ColumnarValue::Array(a) => Some(a.len()),
132            });
133
134        let is_scalar = len.is_none();
135        let inferred_length = len.unwrap_or(1);
136        let args = args
137            .iter()
138            .map(|arg| arg.to_array(inferred_length))
139            .collect::<Result<Vec<_>>>()?;
140
141        let result = regexp_instr_func(&args);
142        if is_scalar {
143            // If all inputs are scalar, keeps output as scalar
144            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
145            result.map(ColumnarValue::Scalar)
146        } else {
147            result.map(ColumnarValue::Array)
148        }
149    }
150
151    fn documentation(&self) -> Option<&Documentation> {
152        self.doc()
153    }
154}
155
156pub fn regexp_instr_func(args: &[ArrayRef]) -> Result<ArrayRef> {
157    let args_len = args.len();
158    if !(2..=6).contains(&args_len) {
159        return exec_err!(
160            "regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."
161        );
162    }
163
164    let values = &args[0];
165    match values.data_type() {
166        Utf8 | LargeUtf8 | Utf8View => (),
167        other => {
168            return internal_err!(
169                "Unsupported data type {other:?} for function regexp_instr"
170            );
171        }
172    }
173
174    regexp_instr(
175        values,
176        &args[1],
177        if args_len > 2 { Some(&args[2]) } else { None },
178        if args_len > 3 { Some(&args[3]) } else { None },
179        if args_len > 4 { Some(&args[4]) } else { None },
180        if args_len > 5 { Some(&args[5]) } else { None },
181    )
182    .map_err(|e| e.into())
183}
184
185/// `arrow-rs` style implementation of `regexp_instr` function.
186/// This function `regexp_instr` is responsible for returning the index of a regular expression pattern
187/// within a string array. It supports optional start positions and flags for case insensitivity.
188///
189/// The function accepts a variable number of arguments:
190/// - `values`: The array of strings to search within.
191/// - `regex_array`: The array of regular expression patterns to search for.
192/// - `start_array` (optional): The array of start positions for the search.
193/// - `nth_array` (optional): The array of start nth for the search.
194/// - `endoption_array` (optional): The array of endoption positions for the search.
195/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity).
196/// - `subexpr_array` (optional): The array of subexpr positions for the search.
197///
198/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions,
199/// and flags. It uses a cache to store compiled regular expressions for efficiency.
200///
201/// # Errors
202/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile.
203fn regexp_instr(
204    values: &dyn Array,
205    regex_array: &dyn Datum,
206    start_array: Option<&dyn Datum>,
207    nth_array: Option<&dyn Datum>,
208    flags_array: Option<&dyn Datum>,
209    subexpr_array: Option<&dyn Datum>,
210) -> Result<ArrayRef, ArrowError> {
211    let (regex_array, _) = regex_array.get();
212    let start_array = start_array.map(|start| {
213        let (start, _) = start.get();
214        start
215    });
216    let nth_array = nth_array.map(|nth| {
217        let (nth, _) = nth.get();
218        nth
219    });
220    let flags_array = flags_array.map(|flags| {
221        let (flags, _) = flags.get();
222        flags
223    });
224    let subexpr_array = subexpr_array.map(|subexpr| {
225        let (subexpr, _) = subexpr.get();
226        subexpr
227    });
228
229    match (values.data_type(), regex_array.data_type(), flags_array) {
230        (Utf8, Utf8, None) => regexp_instr_inner(
231            &values.as_string::<i32>(),
232            &regex_array.as_string::<i32>(),
233            start_array.map(|start| start.as_primitive::<Int64Type>()),
234            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
235            None,
236            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
237        ),
238        (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner(
239            &values.as_string::<i32>(),
240            &regex_array.as_string::<i32>(),
241            start_array.map(|start| start.as_primitive::<Int64Type>()),
242            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
243            Some(flags_array.as_string::<i32>()),
244            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
245        ),
246        (LargeUtf8, LargeUtf8, None) => regexp_instr_inner(
247            &values.as_string::<i64>(),
248            &regex_array.as_string::<i64>(),
249            start_array.map(|start| start.as_primitive::<Int64Type>()),
250            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
251            None,
252            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
253        ),
254        (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner(
255            &values.as_string::<i64>(),
256            &regex_array.as_string::<i64>(),
257            start_array.map(|start| start.as_primitive::<Int64Type>()),
258            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
259            Some(flags_array.as_string::<i64>()),
260            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
261        ),
262        (Utf8View, Utf8View, None) => regexp_instr_inner(
263            &values.as_string_view(),
264            &regex_array.as_string_view(),
265            start_array.map(|start| start.as_primitive::<Int64Type>()),
266            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
267            None,
268            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
269        ),
270        (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner(
271            &values.as_string_view(),
272            &regex_array.as_string_view(),
273            start_array.map(|start| start.as_primitive::<Int64Type>()),
274            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
275            Some(flags_array.as_string_view()),
276            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
277        ),
278        _ => Err(ArrowError::ComputeError(
279            "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
280        )),
281    }
282}
283
284fn regexp_instr_inner<'a, S>(
285    values: &S,
286    regex_array: &S,
287    start_array: Option<&Int64Array>,
288    nth_array: Option<&Int64Array>,
289    flags_array: Option<S>,
290    subexp_array: Option<&Int64Array>,
291) -> Result<ArrayRef, ArrowError>
292where
293    S: StringArrayType<'a>,
294{
295    let len = values.len();
296
297    let default_start_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
298    let start_array = start_array.unwrap_or(&default_start_array);
299    let start_input: Vec<i64> = (0..start_array.len())
300        .map(|i| start_array.value(i)) // handle nulls as 0
301        .collect();
302
303    let default_nth_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
304    let nth_array = nth_array.unwrap_or(&default_nth_array);
305    let nth_input: Vec<i64> = (0..nth_array.len())
306        .map(|i| nth_array.value(i)) // handle nulls as 0
307        .collect();
308
309    let flags_input = match flags_array {
310        Some(flags) => flags.iter().collect(),
311        None => vec![None; len],
312    };
313
314    let default_subexp_array = PrimitiveArray::<Int64Type>::from(vec![0; len]);
315    let subexp_array = subexp_array.unwrap_or(&default_subexp_array);
316    let subexp_input: Vec<i64> = (0..subexp_array.len())
317        .map(|i| subexp_array.value(i)) // handle nulls as 0
318        .collect();
319
320    let mut regex_cache = HashMap::new();
321
322    let result: Result<Vec<Option<i64>>, ArrowError> = izip!(
323        values.iter(),
324        regex_array.iter(),
325        start_input.iter(),
326        nth_input.iter(),
327        flags_input.iter(),
328        subexp_input.iter()
329    )
330    .map(|(value, regex, start, nth, flags, subexp)| match regex {
331        None => Ok(None),
332        Some("") => Ok(Some(0)),
333        Some(regex) => get_index(
334            value,
335            regex,
336            *start,
337            *nth,
338            *subexp,
339            *flags,
340            &mut regex_cache,
341        ),
342    })
343    .collect();
344    Ok(Arc::new(Int64Array::from(result?)))
345}
346
347fn handle_subexp(
348    pattern: &Regex,
349    search_slice: &str,
350    subexpr: i64,
351    value: &str,
352    byte_start_offset: usize,
353) -> Result<Option<i64>, ArrowError> {
354    if let Some(captures) = pattern.captures(search_slice)
355        && let Some(matched) = captures.get(subexpr as usize)
356    {
357        // Convert byte offset relative to search_slice back to 1-based character offset
358        // relative to the original `value` string.
359        let start_char_offset =
360            value[..byte_start_offset + matched.start()].chars().count() as i64 + 1;
361        return Ok(Some(start_char_offset));
362    }
363    Ok(Some(0)) // Return 0 if the subexpression was not found
364}
365
366fn get_nth_match(
367    pattern: &Regex,
368    search_slice: &str,
369    n: i64,
370    byte_start_offset: usize,
371    value: &str,
372) -> Result<Option<i64>, ArrowError> {
373    if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) {
374        // Convert byte offset relative to search_slice back to 1-based character offset
375        // relative to the original `value` string.
376        let match_start_byte_offset = byte_start_offset + mat.start();
377        let match_start_char_offset =
378            value[..match_start_byte_offset].chars().count() as i64 + 1;
379        Ok(Some(match_start_char_offset))
380    } else {
381        Ok(Some(0)) // Return 0 if the N-th match was not found
382    }
383}
384fn get_index<'strings, 'cache>(
385    value: Option<&str>,
386    pattern: &'strings str,
387    start: i64,
388    n: i64,
389    subexpr: i64,
390    flags: Option<&'strings str>,
391    regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
392) -> Result<Option<i64>, ArrowError>
393where
394    'strings: 'cache,
395{
396    let value = match value {
397        None => return Ok(None),
398        Some("") => return Ok(Some(0)),
399        Some(value) => value,
400    };
401    let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
402    // println!("get_index: value = {}, pattern = {}, start = {}, n = {}, subexpr = {}, flags = {:?}", value, pattern, start, n, subexpr, flags);
403    if start < 1 {
404        return Err(ArrowError::ComputeError(
405            "regexp_instr() requires start to be 1-based".to_string(),
406        ));
407    }
408
409    if n < 1 {
410        return Err(ArrowError::ComputeError(
411            "N must be 1 or greater".to_string(),
412        ));
413    }
414
415    // --- Simplified byte_start_offset calculation ---
416    let total_chars = value.chars().count() as i64;
417    let byte_start_offset: usize = if start > total_chars {
418        // If start is beyond the total characters, it means we start searching
419        // after the string effectively. No matches possible.
420        return Ok(Some(0));
421    } else {
422        // Get the byte offset for the (start - 1)-th character (0-based)
423        value
424            .char_indices()
425            .nth((start - 1) as usize)
426            .map(|(idx, _)| idx)
427            .unwrap_or(0) // Should not happen if start is valid and <= total_chars
428    };
429    // --- End simplified calculation ---
430
431    let search_slice = &value[byte_start_offset..];
432
433    // Handle subexpression capturing first, as it takes precedence
434    if subexpr > 0 {
435        return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset);
436    }
437
438    // Use nth to get the N-th match (n is 1-based, nth is 0-based)
439    get_nth_match(pattern, search_slice, n, byte_start_offset, value)
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use arrow::array::{GenericStringArray, StringViewArray};
446    use arrow::datatypes::Field;
447    use datafusion_common::config::ConfigOptions;
448    #[test]
449    fn test_regexp_instr() {
450        test_case_sensitive_regexp_instr_nulls();
451        test_case_sensitive_regexp_instr_scalar();
452        test_case_sensitive_regexp_instr_scalar_start();
453        test_case_sensitive_regexp_instr_scalar_nth();
454        test_case_sensitive_regexp_instr_scalar_subexp();
455
456        test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
457        test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
458        test_case_sensitive_regexp_instr_array::<StringViewArray>();
459
460        test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i32>>();
461        test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i64>>();
462        test_case_sensitive_regexp_instr_array_start::<StringViewArray>();
463
464        test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i32>>();
465        test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i64>>();
466        test_case_sensitive_regexp_instr_array_nth::<StringViewArray>();
467    }
468
469    fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
470        let args_values: Vec<ColumnarValue> = args
471            .iter()
472            .map(|sv| ColumnarValue::Scalar(sv.clone()))
473            .collect();
474
475        let arg_fields = args
476            .iter()
477            .enumerate()
478            .map(|(idx, a)| {
479                Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true))
480            })
481            .collect::<Vec<_>>();
482
483        RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs {
484            args: args_values,
485            arg_fields,
486            number_rows: args.len(),
487            return_field: Arc::new(Field::new("f", Int64, true)),
488            config_options: Arc::new(ConfigOptions::default()),
489        })
490    }
491
492    fn test_case_sensitive_regexp_instr_nulls() {
493        let v = "";
494        let r = "";
495        let expected = 0;
496        let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
497        let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
498        // let res_exp = re.unwrap();
499        match re {
500            Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
501                assert_eq!(v, Some(expected), "regexp_instr scalar test failed");
502            }
503            _ => panic!("Unexpected result"),
504        }
505    }
506    fn test_case_sensitive_regexp_instr_scalar() {
507        let values = [
508            "hello world",
509            "abcdefg",
510            "xyz123xyz",
511            "no match here",
512            "abc",
513            "ДатаФусион数据融合📊🔥",
514        ];
515        let regex = ["o", "d", "123", "z", "gg", "📊"];
516
517        let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
518
519        izip!(values.iter(), regex.iter())
520            .enumerate()
521            .for_each(|(pos, (&v, &r))| {
522                // utf8
523                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
524                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
525                let expected = expected.get(pos).cloned();
526                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
527                // let res_exp = re.unwrap();
528                match re {
529                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
530                        assert_eq!(v, expected, "regexp_instr scalar test failed");
531                    }
532                    _ => panic!("Unexpected result"),
533                }
534
535                // largeutf8
536                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
537                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
538                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
539                match re {
540                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
541                        assert_eq!(v, expected, "regexp_instr scalar test failed");
542                    }
543                    _ => panic!("Unexpected result"),
544                }
545
546                // utf8view
547                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
548                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
549                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
550                match re {
551                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
552                        assert_eq!(v, expected, "regexp_instr scalar test failed");
553                    }
554                    _ => panic!("Unexpected result"),
555                }
556            });
557    }
558
559    fn test_case_sensitive_regexp_instr_scalar_start() {
560        let values = ["abcabcabc", "abcabcabc", ""];
561        let regex = ["abc", "abc", "gg"];
562        let start = [4, 5, 5];
563        let expected: Vec<i64> = vec![4, 7, 0];
564
565        izip!(values.iter(), regex.iter(), start.iter())
566            .enumerate()
567            .for_each(|(pos, (&v, &r, &s))| {
568                // utf8
569                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
570                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
571                let start_sv = ScalarValue::Int64(Some(s));
572                let expected = expected.get(pos).cloned();
573                let re =
574                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
575                match re {
576                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
577                        assert_eq!(v, expected, "regexp_instr scalar test failed");
578                    }
579                    _ => panic!("Unexpected result"),
580                }
581
582                // largeutf8
583                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
584                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
585                let start_sv = ScalarValue::Int64(Some(s));
586                let re =
587                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
588                match re {
589                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
590                        assert_eq!(v, expected, "regexp_instr scalar test failed");
591                    }
592                    _ => panic!("Unexpected result"),
593                }
594
595                // utf8view
596                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
597                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
598                let start_sv = ScalarValue::Int64(Some(s));
599                let re =
600                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
601                match re {
602                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
603                        assert_eq!(v, expected, "regexp_instr scalar test failed");
604                    }
605                    _ => panic!("Unexpected result"),
606                }
607            });
608    }
609
610    fn test_case_sensitive_regexp_instr_scalar_nth() {
611        let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"];
612        let regex = ["abc", "abc", "abc", "abc"];
613        let start = [1, 1, 1, 1];
614        let nth = [1, 2, 3, 4];
615        let expected: Vec<i64> = vec![1, 4, 7, 0];
616
617        izip!(values.iter(), regex.iter(), start.iter(), nth.iter())
618            .enumerate()
619            .for_each(|(pos, (&v, &r, &s, &n))| {
620                // utf8
621                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
622                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
623                let start_sv = ScalarValue::Int64(Some(s));
624                let nth_sv = ScalarValue::Int64(Some(n));
625                let expected = expected.get(pos).cloned();
626                let re = regexp_instr_with_scalar_values(&[
627                    v_sv,
628                    regex_sv,
629                    start_sv.clone(),
630                    nth_sv.clone(),
631                ]);
632                match re {
633                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
634                        assert_eq!(v, expected, "regexp_instr scalar test failed");
635                    }
636                    _ => panic!("Unexpected result"),
637                }
638
639                // largeutf8
640                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
641                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
642                let start_sv = ScalarValue::Int64(Some(s));
643                let nth_sv = ScalarValue::Int64(Some(n));
644                let re = regexp_instr_with_scalar_values(&[
645                    v_sv,
646                    regex_sv,
647                    start_sv.clone(),
648                    nth_sv.clone(),
649                ]);
650                match re {
651                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
652                        assert_eq!(v, expected, "regexp_instr scalar test failed");
653                    }
654                    _ => panic!("Unexpected result"),
655                }
656
657                // utf8view
658                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
659                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
660                let start_sv = ScalarValue::Int64(Some(s));
661                let nth_sv = ScalarValue::Int64(Some(n));
662                let re = regexp_instr_with_scalar_values(&[
663                    v_sv,
664                    regex_sv,
665                    start_sv.clone(),
666                    nth_sv.clone(),
667                ]);
668                match re {
669                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
670                        assert_eq!(v, expected, "regexp_instr scalar test failed");
671                    }
672                    _ => panic!("Unexpected result"),
673                }
674            });
675    }
676
677    fn test_case_sensitive_regexp_instr_scalar_subexp() {
678        let values = ["12 abc def ghi 34"];
679        let regex = ["(abc) (def) (ghi)"];
680        let start = [1];
681        let nth = [1];
682        let flags = ["i"];
683        let subexps = [2];
684        let expected: Vec<i64> = vec![8];
685
686        izip!(
687            values.iter(),
688            regex.iter(),
689            start.iter(),
690            nth.iter(),
691            flags.iter(),
692            subexps.iter()
693        )
694        .enumerate()
695        .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| {
696            // utf8
697            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
698            let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
699            let start_sv = ScalarValue::Int64(Some(s));
700            let nth_sv = ScalarValue::Int64(Some(n));
701            let flags_sv = ScalarValue::Utf8(Some(flag.to_string()));
702            let subexp_sv = ScalarValue::Int64(Some(subexp));
703            let expected = expected.get(pos).cloned();
704            let re = regexp_instr_with_scalar_values(&[
705                v_sv,
706                regex_sv,
707                start_sv.clone(),
708                nth_sv.clone(),
709                flags_sv,
710                subexp_sv.clone(),
711            ]);
712            match re {
713                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
714                    assert_eq!(v, expected, "regexp_instr scalar test failed");
715                }
716                _ => panic!("Unexpected result"),
717            }
718
719            // largeutf8
720            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
721            let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
722            let start_sv = ScalarValue::Int64(Some(s));
723            let nth_sv = ScalarValue::Int64(Some(n));
724            let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string()));
725            let subexp_sv = ScalarValue::Int64(Some(subexp));
726            let re = regexp_instr_with_scalar_values(&[
727                v_sv,
728                regex_sv,
729                start_sv.clone(),
730                nth_sv.clone(),
731                flags_sv,
732                subexp_sv.clone(),
733            ]);
734            match re {
735                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
736                    assert_eq!(v, expected, "regexp_instr scalar test failed");
737                }
738                _ => panic!("Unexpected result"),
739            }
740
741            // utf8view
742            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
743            let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
744            let start_sv = ScalarValue::Int64(Some(s));
745            let nth_sv = ScalarValue::Int64(Some(n));
746            let flags_sv = ScalarValue::Utf8View(Some(flag.to_string()));
747            let subexp_sv = ScalarValue::Int64(Some(subexp));
748            let re = regexp_instr_with_scalar_values(&[
749                v_sv,
750                regex_sv,
751                start_sv.clone(),
752                nth_sv.clone(),
753                flags_sv,
754                subexp_sv.clone(),
755            ]);
756            match re {
757                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
758                    assert_eq!(v, expected, "regexp_instr scalar test failed");
759                }
760                _ => panic!("Unexpected result"),
761            }
762        });
763    }
764
765    fn test_case_sensitive_regexp_instr_array<A>()
766    where
767        A: From<Vec<&'static str>> + Array + 'static,
768    {
769        let values = A::from(vec![
770            "hello world",
771            "abcdefg",
772            "xyz123xyz",
773            "no match here",
774            "",
775        ]);
776        let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
777
778        let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
779        let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
780        assert_eq!(re.as_ref(), &expected);
781    }
782
783    fn test_case_sensitive_regexp_instr_array_start<A>()
784    where
785        A: From<Vec<&'static str>> + Array + 'static,
786    {
787        let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
788        let regex = A::from(vec!["abc", "abc", "gg"]);
789        let start = Int64Array::from(vec![4, 5, 5]);
790        let expected = Int64Array::from(vec![4, 7, 0]);
791
792        let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
793            .unwrap();
794        assert_eq!(re.as_ref(), &expected);
795    }
796
797    fn test_case_sensitive_regexp_instr_array_nth<A>()
798    where
799        A: From<Vec<&'static str>> + Array + 'static,
800    {
801        let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
802        let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
803        let start = Int64Array::from(vec![1, 1, 1, 1]);
804        let nth = Int64Array::from(vec![1, 2, 3, 4]);
805        let expected = Int64Array::from(vec![1, 4, 7, 0]);
806
807        let re = regexp_instr_func(&[
808            Arc::new(values),
809            Arc::new(regex),
810            Arc::new(start),
811            Arc::new(nth),
812        ])
813        .unwrap();
814        assert_eq!(re.as_ref(), &expected);
815    }
816}