datafusion_functions/regex/
regexpinstr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int64Type};
22use arrow::datatypes::{
23    DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
24};
25use arrow::error::ArrowError;
26use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
27use datafusion_expr::{
28    ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
29    TypeSignature::Uniform, Volatility,
30};
31use datafusion_macros::user_doc;
32use itertools::izip;
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use crate::regex::compile_and_cache_regex;
38
39#[user_doc(
40    doc_section(label = "Regular Expression Functions"),
41    description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.",
42    syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])",
43    sql_example = r#"```sql
44> SELECT regexp_instr('ABCDEF', 'C(.)(..)');
45+---------------------------------------------------------------+
46| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)"))                 |
47+---------------------------------------------------------------+
48| 3                                                             |
49+---------------------------------------------------------------+
50```"#,
51    standard_argument(name = "str", prefix = "String"),
52    standard_argument(name = "regexp", prefix = "Regular"),
53    argument(
54        name = "start",
55        description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1"
56    ),
57    argument(
58        name = "N",
59        description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function."
60    ),
61    argument(
62        name = "flags",
63        description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
64  - **i**: case-insensitive: letters match both upper and lower case
65  - **m**: multi-line mode: ^ and $ match begin/end of line
66  - **s**: allow . to match \n
67  - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
68  - **U**: swap the meaning of x* and x*?"#
69    ),
70    argument(
71        name = "subexpr",
72        description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match."
73    )
74)]
75#[derive(Debug, PartialEq, Eq, Hash)]
76pub struct RegexpInstrFunc {
77    signature: Signature,
78}
79
80impl Default for RegexpInstrFunc {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl RegexpInstrFunc {
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::one_of(
90                vec![
91                    Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
92                    Exact(vec![Utf8View, Utf8View, Int64]),
93                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
94                    Exact(vec![Utf8, Utf8, Int64]),
95                    Exact(vec![Utf8View, Utf8View, Int64, Int64]),
96                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
97                    Exact(vec![Utf8, Utf8, Int64, Int64]),
98                    Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]),
99                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]),
100                    Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]),
101                    Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]),
102                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]),
103                    Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]),
104                ],
105                Volatility::Immutable,
106            ),
107        }
108    }
109}
110
111impl ScalarUDFImpl for RegexpInstrFunc {
112    fn as_any(&self) -> &dyn std::any::Any {
113        self
114    }
115
116    fn name(&self) -> &str {
117        "regexp_instr"
118    }
119
120    fn signature(&self) -> &Signature {
121        &self.signature
122    }
123
124    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
125        Ok(Int64)
126    }
127
128    fn invoke_with_args(
129        &self,
130        args: datafusion_expr::ScalarFunctionArgs,
131    ) -> Result<ColumnarValue> {
132        let args = &args.args;
133
134        let len = args
135            .iter()
136            .fold(Option::<usize>::None, |acc, arg| match arg {
137                ColumnarValue::Scalar(_) => acc,
138                ColumnarValue::Array(a) => Some(a.len()),
139            });
140
141        let is_scalar = len.is_none();
142        let inferred_length = len.unwrap_or(1);
143        let args = args
144            .iter()
145            .map(|arg| arg.to_array(inferred_length))
146            .collect::<Result<Vec<_>>>()?;
147
148        let result = regexp_instr_func(&args);
149        if is_scalar {
150            // If all inputs are scalar, keeps output as scalar
151            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
152            result.map(ColumnarValue::Scalar)
153        } else {
154            result.map(ColumnarValue::Array)
155        }
156    }
157
158    fn documentation(&self) -> Option<&Documentation> {
159        self.doc()
160    }
161}
162
163pub fn regexp_instr_func(args: &[ArrayRef]) -> Result<ArrayRef> {
164    let args_len = args.len();
165    if !(2..=6).contains(&args_len) {
166        return exec_err!(
167            "regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."
168        );
169    }
170
171    let values = &args[0];
172    match values.data_type() {
173        Utf8 | LargeUtf8 | Utf8View => (),
174        other => {
175            return internal_err!(
176                "Unsupported data type {other:?} for function regexp_instr"
177            );
178        }
179    }
180
181    regexp_instr(
182        values,
183        &args[1],
184        if args_len > 2 { Some(&args[2]) } else { None },
185        if args_len > 3 { Some(&args[3]) } else { None },
186        if args_len > 4 { Some(&args[4]) } else { None },
187        if args_len > 5 { Some(&args[5]) } else { None },
188    )
189    .map_err(|e| e.into())
190}
191
192/// `arrow-rs` style implementation of `regexp_instr` function.
193/// This function `regexp_instr` is responsible for returning the index of a regular expression pattern
194/// within a string array. It supports optional start positions and flags for case insensitivity.
195///
196/// The function accepts a variable number of arguments:
197/// - `values`: The array of strings to search within.
198/// - `regex_array`: The array of regular expression patterns to search for.
199/// - `start_array` (optional): The array of start positions for the search.
200/// - `nth_array` (optional): The array of start nth for the search.
201/// - `endoption_array` (optional): The array of endoption positions for the search.
202/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity).
203/// - `subexpr_array` (optional): The array of subexpr positions for the search.
204///
205/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions,
206/// and flags. It uses a cache to store compiled regular expressions for efficiency.
207///
208/// # Errors
209/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile.
210fn regexp_instr(
211    values: &dyn Array,
212    regex_array: &dyn Datum,
213    start_array: Option<&dyn Datum>,
214    nth_array: Option<&dyn Datum>,
215    flags_array: Option<&dyn Datum>,
216    subexpr_array: Option<&dyn Datum>,
217) -> Result<ArrayRef, ArrowError> {
218    let (regex_array, _) = regex_array.get();
219    let start_array = start_array.map(|start| {
220        let (start, _) = start.get();
221        start
222    });
223    let nth_array = nth_array.map(|nth| {
224        let (nth, _) = nth.get();
225        nth
226    });
227    let flags_array = flags_array.map(|flags| {
228        let (flags, _) = flags.get();
229        flags
230    });
231    let subexpr_array = subexpr_array.map(|subexpr| {
232        let (subexpr, _) = subexpr.get();
233        subexpr
234    });
235
236    match (values.data_type(), regex_array.data_type(), flags_array) {
237        (Utf8, Utf8, None) => regexp_instr_inner(
238            &values.as_string::<i32>(),
239            &regex_array.as_string::<i32>(),
240            start_array.map(|start| start.as_primitive::<Int64Type>()),
241            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
242            None,
243            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
244        ),
245        (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner(
246            &values.as_string::<i32>(),
247            &regex_array.as_string::<i32>(),
248            start_array.map(|start| start.as_primitive::<Int64Type>()),
249            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
250            Some(flags_array.as_string::<i32>()),
251            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
252        ),
253        (LargeUtf8, LargeUtf8, None) => regexp_instr_inner(
254            &values.as_string::<i64>(),
255            &regex_array.as_string::<i64>(),
256            start_array.map(|start| start.as_primitive::<Int64Type>()),
257            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
258            None,
259            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
260        ),
261        (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner(
262            &values.as_string::<i64>(),
263            &regex_array.as_string::<i64>(),
264            start_array.map(|start| start.as_primitive::<Int64Type>()),
265            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
266            Some(flags_array.as_string::<i64>()),
267            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
268        ),
269        (Utf8View, Utf8View, None) => regexp_instr_inner(
270            &values.as_string_view(),
271            &regex_array.as_string_view(),
272            start_array.map(|start| start.as_primitive::<Int64Type>()),
273            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
274            None,
275            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
276        ),
277        (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner(
278            &values.as_string_view(),
279            &regex_array.as_string_view(),
280            start_array.map(|start| start.as_primitive::<Int64Type>()),
281            nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
282            Some(flags_array.as_string_view()),
283            subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
284        ),
285        _ => Err(ArrowError::ComputeError(
286            "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
287        )),
288    }
289}
290
291fn regexp_instr_inner<'a, S>(
292    values: &S,
293    regex_array: &S,
294    start_array: Option<&Int64Array>,
295    nth_array: Option<&Int64Array>,
296    flags_array: Option<S>,
297    subexp_array: Option<&Int64Array>,
298) -> Result<ArrayRef, ArrowError>
299where
300    S: StringArrayType<'a>,
301{
302    let len = values.len();
303
304    let default_start_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
305    let start_array = start_array.unwrap_or(&default_start_array);
306    let start_input: Vec<i64> = (0..start_array.len())
307        .map(|i| start_array.value(i)) // handle nulls as 0
308        .collect();
309
310    let default_nth_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
311    let nth_array = nth_array.unwrap_or(&default_nth_array);
312    let nth_input: Vec<i64> = (0..nth_array.len())
313        .map(|i| nth_array.value(i)) // handle nulls as 0
314        .collect();
315
316    let flags_input = match flags_array {
317        Some(flags) => flags.iter().collect(),
318        None => vec![None; len],
319    };
320
321    let default_subexp_array = PrimitiveArray::<Int64Type>::from(vec![0; len]);
322    let subexp_array = subexp_array.unwrap_or(&default_subexp_array);
323    let subexp_input: Vec<i64> = (0..subexp_array.len())
324        .map(|i| subexp_array.value(i)) // handle nulls as 0
325        .collect();
326
327    let mut regex_cache = HashMap::new();
328
329    let result: Result<Vec<Option<i64>>, ArrowError> = izip!(
330        values.iter(),
331        regex_array.iter(),
332        start_input.iter(),
333        nth_input.iter(),
334        flags_input.iter(),
335        subexp_input.iter()
336    )
337    .map(|(value, regex, start, nth, flags, subexp)| match regex {
338        None => Ok(None),
339        Some("") => Ok(Some(0)),
340        Some(regex) => get_index(
341            value,
342            regex,
343            *start,
344            *nth,
345            *subexp,
346            *flags,
347            &mut regex_cache,
348        ),
349    })
350    .collect();
351    Ok(Arc::new(Int64Array::from(result?)))
352}
353
354fn handle_subexp(
355    pattern: &Regex,
356    search_slice: &str,
357    subexpr: i64,
358    value: &str,
359    byte_start_offset: usize,
360) -> Result<Option<i64>, ArrowError> {
361    if let Some(captures) = pattern.captures(search_slice)
362        && let Some(matched) = captures.get(subexpr as usize)
363    {
364        // Convert byte offset relative to search_slice back to 1-based character offset
365        // relative to the original `value` string.
366        let start_char_offset =
367            value[..byte_start_offset + matched.start()].chars().count() as i64 + 1;
368        return Ok(Some(start_char_offset));
369    }
370    Ok(Some(0)) // Return 0 if the subexpression was not found
371}
372
373fn get_nth_match(
374    pattern: &Regex,
375    search_slice: &str,
376    n: i64,
377    byte_start_offset: usize,
378    value: &str,
379) -> Result<Option<i64>, ArrowError> {
380    if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) {
381        // Convert byte offset relative to search_slice back to 1-based character offset
382        // relative to the original `value` string.
383        let match_start_byte_offset = byte_start_offset + mat.start();
384        let match_start_char_offset =
385            value[..match_start_byte_offset].chars().count() as i64 + 1;
386        Ok(Some(match_start_char_offset))
387    } else {
388        Ok(Some(0)) // Return 0 if the N-th match was not found
389    }
390}
391fn get_index<'strings, 'cache>(
392    value: Option<&str>,
393    pattern: &'strings str,
394    start: i64,
395    n: i64,
396    subexpr: i64,
397    flags: Option<&'strings str>,
398    regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
399) -> Result<Option<i64>, ArrowError>
400where
401    'strings: 'cache,
402{
403    let value = match value {
404        None => return Ok(None),
405        Some("") => return Ok(Some(0)),
406        Some(value) => value,
407    };
408    let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
409    // println!("get_index: value = {}, pattern = {}, start = {}, n = {}, subexpr = {}, flags = {:?}", value, pattern, start, n, subexpr, flags);
410    if start < 1 {
411        return Err(ArrowError::ComputeError(
412            "regexp_instr() requires start to be 1-based".to_string(),
413        ));
414    }
415
416    if n < 1 {
417        return Err(ArrowError::ComputeError(
418            "N must be 1 or greater".to_string(),
419        ));
420    }
421
422    // --- Simplified byte_start_offset calculation ---
423    let total_chars = value.chars().count() as i64;
424    let byte_start_offset: usize = if start > total_chars {
425        // If start is beyond the total characters, it means we start searching
426        // after the string effectively. No matches possible.
427        return Ok(Some(0));
428    } else {
429        // Get the byte offset for the (start - 1)-th character (0-based)
430        value
431            .char_indices()
432            .nth((start - 1) as usize)
433            .map(|(idx, _)| idx)
434            .unwrap_or(0) // Should not happen if start is valid and <= total_chars
435    };
436    // --- End simplified calculation ---
437
438    let search_slice = &value[byte_start_offset..];
439
440    // Handle subexpression capturing first, as it takes precedence
441    if subexpr > 0 {
442        return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset);
443    }
444
445    // Use nth to get the N-th match (n is 1-based, nth is 0-based)
446    get_nth_match(pattern, search_slice, n, byte_start_offset, value)
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use arrow::array::Int64Array;
453    use arrow::array::{GenericStringArray, StringViewArray};
454    use arrow::datatypes::Field;
455    use datafusion_common::config::ConfigOptions;
456    use datafusion_expr::ScalarFunctionArgs;
457    #[test]
458    fn test_regexp_instr() {
459        test_case_sensitive_regexp_instr_nulls();
460        test_case_sensitive_regexp_instr_scalar();
461        test_case_sensitive_regexp_instr_scalar_start();
462        test_case_sensitive_regexp_instr_scalar_nth();
463        test_case_sensitive_regexp_instr_scalar_subexp();
464
465        test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
466        test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
467        test_case_sensitive_regexp_instr_array::<StringViewArray>();
468
469        test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i32>>();
470        test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i64>>();
471        test_case_sensitive_regexp_instr_array_start::<StringViewArray>();
472
473        test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i32>>();
474        test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i64>>();
475        test_case_sensitive_regexp_instr_array_nth::<StringViewArray>();
476    }
477
478    fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
479        let args_values: Vec<ColumnarValue> = args
480            .iter()
481            .map(|sv| ColumnarValue::Scalar(sv.clone()))
482            .collect();
483
484        let arg_fields = args
485            .iter()
486            .enumerate()
487            .map(|(idx, a)| {
488                Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true))
489            })
490            .collect::<Vec<_>>();
491
492        RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs {
493            args: args_values,
494            arg_fields,
495            number_rows: args.len(),
496            return_field: Arc::new(Field::new("f", Int64, true)),
497            config_options: Arc::new(ConfigOptions::default()),
498        })
499    }
500
501    fn test_case_sensitive_regexp_instr_nulls() {
502        let v = "";
503        let r = "";
504        let expected = 0;
505        let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
506        let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
507        // let res_exp = re.unwrap();
508        match re {
509            Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
510                assert_eq!(v, Some(expected), "regexp_instr scalar test failed");
511            }
512            _ => panic!("Unexpected result"),
513        }
514    }
515    fn test_case_sensitive_regexp_instr_scalar() {
516        let values = [
517            "hello world",
518            "abcdefg",
519            "xyz123xyz",
520            "no match here",
521            "abc",
522            "ДатаФусион数据融合📊🔥",
523        ];
524        let regex = ["o", "d", "123", "z", "gg", "📊"];
525
526        let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
527
528        izip!(values.iter(), regex.iter())
529            .enumerate()
530            .for_each(|(pos, (&v, &r))| {
531                // utf8
532                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
533                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
534                let expected = expected.get(pos).cloned();
535                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
536                // let res_exp = re.unwrap();
537                match re {
538                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
539                        assert_eq!(v, expected, "regexp_instr scalar test failed");
540                    }
541                    _ => panic!("Unexpected result"),
542                }
543
544                // largeutf8
545                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
546                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
547                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
548                match re {
549                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
550                        assert_eq!(v, expected, "regexp_instr scalar test failed");
551                    }
552                    _ => panic!("Unexpected result"),
553                }
554
555                // utf8view
556                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
557                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
558                let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
559                match re {
560                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
561                        assert_eq!(v, expected, "regexp_instr scalar test failed");
562                    }
563                    _ => panic!("Unexpected result"),
564                }
565            });
566    }
567
568    fn test_case_sensitive_regexp_instr_scalar_start() {
569        let values = ["abcabcabc", "abcabcabc", ""];
570        let regex = ["abc", "abc", "gg"];
571        let start = [4, 5, 5];
572        let expected: Vec<i64> = vec![4, 7, 0];
573
574        izip!(values.iter(), regex.iter(), start.iter())
575            .enumerate()
576            .for_each(|(pos, (&v, &r, &s))| {
577                // utf8
578                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
579                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
580                let start_sv = ScalarValue::Int64(Some(s));
581                let expected = expected.get(pos).cloned();
582                let re =
583                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
584                match re {
585                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
586                        assert_eq!(v, expected, "regexp_instr scalar test failed");
587                    }
588                    _ => panic!("Unexpected result"),
589                }
590
591                // largeutf8
592                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
593                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
594                let start_sv = ScalarValue::Int64(Some(s));
595                let re =
596                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
597                match re {
598                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
599                        assert_eq!(v, expected, "regexp_instr scalar test failed");
600                    }
601                    _ => panic!("Unexpected result"),
602                }
603
604                // utf8view
605                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
606                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
607                let start_sv = ScalarValue::Int64(Some(s));
608                let re =
609                    regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
610                match re {
611                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
612                        assert_eq!(v, expected, "regexp_instr scalar test failed");
613                    }
614                    _ => panic!("Unexpected result"),
615                }
616            });
617    }
618
619    fn test_case_sensitive_regexp_instr_scalar_nth() {
620        let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"];
621        let regex = ["abc", "abc", "abc", "abc"];
622        let start = [1, 1, 1, 1];
623        let nth = [1, 2, 3, 4];
624        let expected: Vec<i64> = vec![1, 4, 7, 0];
625
626        izip!(values.iter(), regex.iter(), start.iter(), nth.iter())
627            .enumerate()
628            .for_each(|(pos, (&v, &r, &s, &n))| {
629                // utf8
630                let v_sv = ScalarValue::Utf8(Some(v.to_string()));
631                let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
632                let start_sv = ScalarValue::Int64(Some(s));
633                let nth_sv = ScalarValue::Int64(Some(n));
634                let expected = expected.get(pos).cloned();
635                let re = regexp_instr_with_scalar_values(&[
636                    v_sv,
637                    regex_sv,
638                    start_sv.clone(),
639                    nth_sv.clone(),
640                ]);
641                match re {
642                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
643                        assert_eq!(v, expected, "regexp_instr scalar test failed");
644                    }
645                    _ => panic!("Unexpected result"),
646                }
647
648                // largeutf8
649                let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
650                let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
651                let start_sv = ScalarValue::Int64(Some(s));
652                let nth_sv = ScalarValue::Int64(Some(n));
653                let re = regexp_instr_with_scalar_values(&[
654                    v_sv,
655                    regex_sv,
656                    start_sv.clone(),
657                    nth_sv.clone(),
658                ]);
659                match re {
660                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
661                        assert_eq!(v, expected, "regexp_instr scalar test failed");
662                    }
663                    _ => panic!("Unexpected result"),
664                }
665
666                // utf8view
667                let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
668                let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
669                let start_sv = ScalarValue::Int64(Some(s));
670                let nth_sv = ScalarValue::Int64(Some(n));
671                let re = regexp_instr_with_scalar_values(&[
672                    v_sv,
673                    regex_sv,
674                    start_sv.clone(),
675                    nth_sv.clone(),
676                ]);
677                match re {
678                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
679                        assert_eq!(v, expected, "regexp_instr scalar test failed");
680                    }
681                    _ => panic!("Unexpected result"),
682                }
683            });
684    }
685
686    fn test_case_sensitive_regexp_instr_scalar_subexp() {
687        let values = ["12 abc def ghi 34"];
688        let regex = ["(abc) (def) (ghi)"];
689        let start = [1];
690        let nth = [1];
691        let flags = ["i"];
692        let subexps = [2];
693        let expected: Vec<i64> = vec![8];
694
695        izip!(
696            values.iter(),
697            regex.iter(),
698            start.iter(),
699            nth.iter(),
700            flags.iter(),
701            subexps.iter()
702        )
703        .enumerate()
704        .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| {
705            // utf8
706            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
707            let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
708            let start_sv = ScalarValue::Int64(Some(s));
709            let nth_sv = ScalarValue::Int64(Some(n));
710            let flags_sv = ScalarValue::Utf8(Some(flag.to_string()));
711            let subexp_sv = ScalarValue::Int64(Some(subexp));
712            let expected = expected.get(pos).cloned();
713            let re = regexp_instr_with_scalar_values(&[
714                v_sv,
715                regex_sv,
716                start_sv.clone(),
717                nth_sv.clone(),
718                flags_sv,
719                subexp_sv.clone(),
720            ]);
721            match re {
722                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
723                    assert_eq!(v, expected, "regexp_instr scalar test failed");
724                }
725                _ => panic!("Unexpected result"),
726            }
727
728            // largeutf8
729            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
730            let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
731            let start_sv = ScalarValue::Int64(Some(s));
732            let nth_sv = ScalarValue::Int64(Some(n));
733            let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string()));
734            let subexp_sv = ScalarValue::Int64(Some(subexp));
735            let re = regexp_instr_with_scalar_values(&[
736                v_sv,
737                regex_sv,
738                start_sv.clone(),
739                nth_sv.clone(),
740                flags_sv,
741                subexp_sv.clone(),
742            ]);
743            match re {
744                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
745                    assert_eq!(v, expected, "regexp_instr scalar test failed");
746                }
747                _ => panic!("Unexpected result"),
748            }
749
750            // utf8view
751            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
752            let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
753            let start_sv = ScalarValue::Int64(Some(s));
754            let nth_sv = ScalarValue::Int64(Some(n));
755            let flags_sv = ScalarValue::Utf8View(Some(flag.to_string()));
756            let subexp_sv = ScalarValue::Int64(Some(subexp));
757            let re = regexp_instr_with_scalar_values(&[
758                v_sv,
759                regex_sv,
760                start_sv.clone(),
761                nth_sv.clone(),
762                flags_sv,
763                subexp_sv.clone(),
764            ]);
765            match re {
766                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
767                    assert_eq!(v, expected, "regexp_instr scalar test failed");
768                }
769                _ => panic!("Unexpected result"),
770            }
771        });
772    }
773
774    fn test_case_sensitive_regexp_instr_array<A>()
775    where
776        A: From<Vec<&'static str>> + Array + 'static,
777    {
778        let values = A::from(vec![
779            "hello world",
780            "abcdefg",
781            "xyz123xyz",
782            "no match here",
783            "",
784        ]);
785        let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
786
787        let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
788        let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
789        assert_eq!(re.as_ref(), &expected);
790    }
791
792    fn test_case_sensitive_regexp_instr_array_start<A>()
793    where
794        A: From<Vec<&'static str>> + Array + 'static,
795    {
796        let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
797        let regex = A::from(vec!["abc", "abc", "gg"]);
798        let start = Int64Array::from(vec![4, 5, 5]);
799        let expected = Int64Array::from(vec![4, 7, 0]);
800
801        let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
802            .unwrap();
803        assert_eq!(re.as_ref(), &expected);
804    }
805
806    fn test_case_sensitive_regexp_instr_array_nth<A>()
807    where
808        A: From<Vec<&'static str>> + Array + 'static,
809    {
810        let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
811        let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
812        let start = Int64Array::from(vec![1, 1, 1, 1]);
813        let nth = Int64Array::from(vec![1, 2, 3, 4]);
814        let expected = Int64Array::from(vec![1, 4, 7, 0]);
815
816        let re = regexp_instr_func(&[
817            Arc::new(values),
818            Arc::new(regex),
819            Arc::new(start),
820            Arc::new(nth),
821        ])
822        .unwrap();
823        assert_eq!(re.as_ref(), &expected);
824    }
825}