Skip to main content

arrow_string/
regexp.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines kernel to extract substrings based on a regular
19//! expression of a \[Large\]StringArray
20
21use crate::like::StringArrayType;
22
23use arrow_array::builder::{
24    BooleanBufferBuilder, GenericStringBuilder, ListBuilder, StringViewBuilder,
25};
26use arrow_array::cast::AsArray;
27use arrow_array::*;
28use arrow_buffer::{BooleanBuffer, NullBuffer};
29use arrow_data::ArrayDataBuilder;
30use arrow_schema::{ArrowError, DataType, Field};
31use regex::Regex;
32
33use std::collections::HashMap;
34use std::sync::Arc;
35
36/// Return BooleanArray indicating which strings in an array match an array of
37/// regular expressions.
38///
39/// This is equivalent to the SQL `array ~ regex_array`, supporting
40/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
41///
42/// If `regex_array` element has an empty value, the corresponding result value is always true.
43///
44/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
45/// which allow special search modes, such as case-insensitive and multi-line mode.
46/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
47/// for more information.
48///
49/// # See Also
50/// * [`regexp_is_match_scalar`] for matching a single regular expression against an array of strings
51/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
52///
53/// # Example
54/// ```
55/// # use arrow_array::{StringArray, BooleanArray};
56/// # use arrow_string::regexp::regexp_is_match;
57/// // First array is the array of strings to match
58/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
59/// // Second array is the array of regular expressions to match against
60/// let regex_array = StringArray::from(vec!["^Foo", "^Foo", "Bar$", "Baz"]);
61/// // Third array is the array of flags to use for each regular expression, if desired
62/// // (the type must be provided to satisfy type inference for the third parameter)
63/// let flags_array: Option<&StringArray> = None;
64/// // The result is a BooleanArray indicating when each string in `array`
65/// // matches the corresponding regular expression in `regex_array`
66/// let result = regexp_is_match(&array, &regex_array, flags_array).unwrap();
67/// assert_eq!(result, BooleanArray::from(vec![true, false, true, true]));
68/// ```
69pub fn regexp_is_match<'a, S1, S2, S3>(
70    array: &'a S1,
71    regex_array: &'a S2,
72    flags_array: Option<&'a S3>,
73) -> Result<BooleanArray, ArrowError>
74where
75    &'a S1: StringArrayType<'a>,
76    &'a S2: StringArrayType<'a>,
77    &'a S3: StringArrayType<'a>,
78{
79    if array.len() != regex_array.len() {
80        return Err(ArrowError::ComputeError(
81            "Cannot perform comparison operation on arrays of different length".to_string(),
82        ));
83    }
84
85    let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());
86
87    let mut patterns: HashMap<String, Regex> = HashMap::new();
88    let mut result = BooleanBufferBuilder::new(array.len());
89
90    let complete_pattern = match flags_array {
91        Some(flags) => Box::new(
92            regex_array
93                .iter()
94                .zip(flags.iter())
95                .map(|(pattern, flags)| {
96                    pattern.map(|pattern| match flags {
97                        Some(flag) => format!("(?{flag}){pattern}"),
98                        None => pattern.to_string(),
99                    })
100                }),
101        ) as Box<dyn Iterator<Item = Option<String>>>,
102        None => Box::new(
103            regex_array
104                .iter()
105                .map(|pattern| pattern.map(|pattern| pattern.to_string())),
106        ),
107    };
108
109    array
110        .iter()
111        .zip(complete_pattern)
112        .map(|(value, pattern)| {
113            match (value, pattern) {
114                // Required for Postgres compatibility:
115                // SELECT 'foobarbequebaz' ~ ''); = true
116                (Some(_), Some(pattern)) if pattern == *"" => {
117                    result.append(true);
118                }
119                (Some(value), Some(pattern)) => {
120                    let existing_pattern = patterns.get(&pattern);
121                    let re = match existing_pattern {
122                        Some(re) => re,
123                        None => {
124                            let re = Regex::new(pattern.as_str()).map_err(|e| {
125                                ArrowError::ComputeError(format!(
126                                    "Regular expression did not compile: {e:?}"
127                                ))
128                            })?;
129                            patterns.entry(pattern).or_insert(re)
130                        }
131                    };
132                    result.append(re.is_match(value));
133                }
134                _ => result.append(false),
135            }
136            Ok(())
137        })
138        .collect::<Result<Vec<()>, ArrowError>>()?;
139
140    let data = unsafe {
141        ArrayDataBuilder::new(DataType::Boolean)
142            .len(array.len())
143            .buffers(vec![result.into()])
144            .nulls(nulls)
145            .build_unchecked()
146    };
147
148    Ok(BooleanArray::from(data))
149}
150
151/// Return BooleanArray indicating which strings in an array match a single regular expression.
152///
153/// This is equivalent to the SQL `array ~ regex_array`, supporting
154/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar.
155///
156/// See the documentation on [`regexp_is_match`] for more details on arguments
157///
158/// # See Also
159/// * [`regexp_is_match`] for matching an array of regular expression against an array of strings
160/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
161///
162/// # Example
163/// ```
164/// # use arrow_array::{StringArray, BooleanArray};
165/// # use arrow_string::regexp::regexp_is_match_scalar;
166/// // array of strings to match
167/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
168/// let regexp = "^Foo"; // regular expression to match against
169/// let flags: Option<&str> = None;  // flags can control the matching behavior
170/// // The result is a BooleanArray indicating when each string in `array`
171/// // matches the regular expression `regexp`
172/// let result = regexp_is_match_scalar(&array, regexp, None).unwrap();
173/// assert_eq!(result, BooleanArray::from(vec![true, false, true, false]));
174/// ```
175pub fn regexp_is_match_scalar<'a, S>(
176    array: &'a S,
177    regex: &str,
178    flag: Option<&str>,
179) -> Result<BooleanArray, ArrowError>
180where
181    &'a S: StringArrayType<'a>,
182{
183    let mut result = BooleanBufferBuilder::new(array.len());
184
185    let pattern = match flag {
186        Some(flag) => format!("(?{flag}){regex}"),
187        None => regex.to_string(),
188    };
189
190    if pattern.is_empty() {
191        result.append_n(array.len(), true);
192    } else {
193        let re = Regex::new(pattern.as_str()).map_err(|e| {
194            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
195        })?;
196        for i in 0..array.len() {
197            let value = array.value(i);
198            result.append(re.is_match(value));
199        }
200    }
201
202    let values = BooleanBuffer::from(result);
203    let nulls = array
204        .nulls()
205        .map(|n| n.inner().sliced())
206        .map(|b| NullBuffer::new(BooleanBuffer::new(b, 0, array.len())))
207        .filter(|n| n.null_count() > 0);
208    Ok(BooleanArray::new(values, nulls))
209}
210
211macro_rules! process_regexp_array_match {
212    ($array:expr, $regex_array:expr, $flags_array:expr, $list_builder:expr) => {
213        let mut patterns: HashMap<String, Regex> = HashMap::new();
214
215        let complete_pattern = match $flags_array {
216            Some(flags) => Box::new($regex_array.iter().zip(flags.iter()).map(
217                |(pattern, flags)| {
218                    pattern.map(|pattern| match flags {
219                        Some(value) => format!("(?{value}){pattern}"),
220                        None => pattern.to_string(),
221                    })
222                },
223            )) as Box<dyn Iterator<Item = Option<String>>>,
224            None => Box::new(
225                $regex_array
226                    .iter()
227                    .map(|pattern| pattern.map(|pattern| pattern.to_string())),
228            ),
229        };
230
231        $array
232            .iter()
233            .zip(complete_pattern)
234            .map(|(value, pattern)| {
235                match (value, pattern) {
236                    // Required for Postgres compatibility:
237                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
238                    (Some(_), Some(pattern)) if pattern == *"" => {
239                        $list_builder.values().append_value("");
240                        $list_builder.append(true);
241                    }
242                    (Some(value), Some(pattern)) => {
243                        let existing_pattern = patterns.get(&pattern);
244                        let re = match existing_pattern {
245                            Some(re) => re,
246                            None => {
247                                let re = Regex::new(pattern.as_str()).map_err(|e| {
248                                    ArrowError::ComputeError(format!(
249                                        "Regular expression did not compile: {e:?}"
250                                    ))
251                                })?;
252                                patterns.entry(pattern).or_insert(re)
253                            }
254                        };
255                        match re.captures(value) {
256                            Some(caps) => {
257                                let mut iter = caps.iter();
258                                if caps.len() > 1 {
259                                    iter.next();
260                                }
261                                for m in iter.flatten() {
262                                    $list_builder.values().append_value(m.as_str());
263                                }
264
265                                $list_builder.append(true);
266                            }
267                            None => $list_builder.append(false),
268                        }
269                    }
270                    _ => $list_builder.append(false),
271                }
272                Ok(())
273            })
274            .collect::<Result<Vec<()>, ArrowError>>()?;
275    };
276}
277
278fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
279    array: &GenericStringArray<OffsetSize>,
280    regex_array: &GenericStringArray<OffsetSize>,
281    flags_array: Option<&GenericStringArray<OffsetSize>>,
282) -> Result<ArrayRef, ArrowError> {
283    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
284    let mut list_builder = ListBuilder::new(builder);
285
286    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
287
288    Ok(Arc::new(list_builder.finish()))
289}
290
291fn regexp_array_match_utf8view(
292    array: &StringViewArray,
293    regex_array: &StringViewArray,
294    flags_array: Option<&StringViewArray>,
295) -> Result<ArrayRef, ArrowError> {
296    let builder = StringViewBuilder::with_capacity(0);
297    let mut list_builder = ListBuilder::new(builder);
298
299    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
300
301    Ok(Arc::new(list_builder.finish()))
302}
303
304fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
305    regex_array: &'a dyn Array,
306    flag_array: Option<&'a dyn Array>,
307) -> (Option<&'a str>, Option<&'a str>) {
308    let regex = regex_array.as_string::<OffsetSize>();
309    let regex = regex.is_valid(0).then(|| regex.value(0));
310
311    if let Some(flag_array) = flag_array {
312        let flag = flag_array.as_string::<OffsetSize>();
313        (regex, flag.is_valid(0).then(|| flag.value(0)))
314    } else {
315        (regex, None)
316    }
317}
318
319fn get_scalar_pattern_flag_utf8view<'a>(
320    regex_array: &'a dyn Array,
321    flag_array: Option<&'a dyn Array>,
322) -> (Option<&'a str>, Option<&'a str>) {
323    let regex = regex_array.as_string_view();
324    let regex = regex.is_valid(0).then(|| regex.value(0));
325
326    if let Some(flag_array) = flag_array {
327        let flag = flag_array.as_string_view();
328        (regex, flag.is_valid(0).then(|| flag.value(0)))
329    } else {
330        (regex, None)
331    }
332}
333
334macro_rules! process_regexp_match {
335    ($array:expr, $regex:expr, $list_builder:expr) => {
336        $array
337            .iter()
338            .map(|value| {
339                match value {
340                    // Required for Postgres compatibility:
341                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
342                    Some(_) if $regex.as_str().is_empty() => {
343                        $list_builder.values().append_value("");
344                        $list_builder.append(true);
345                    }
346                    Some(value) => match $regex.captures(value) {
347                        Some(caps) => {
348                            let mut iter = caps.iter();
349                            if caps.len() > 1 {
350                                iter.next();
351                            }
352                            for m in iter.flatten() {
353                                $list_builder.values().append_value(m.as_str());
354                            }
355                            $list_builder.append(true);
356                        }
357                        None => $list_builder.append(false),
358                    },
359                    None => $list_builder.append(false),
360                }
361                Ok(())
362            })
363            .collect::<Result<Vec<()>, ArrowError>>()?
364    };
365}
366
367fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
368    array: &GenericStringArray<OffsetSize>,
369    regex: &Regex,
370) -> Result<ArrayRef, ArrowError> {
371    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
372    let mut list_builder = ListBuilder::new(builder);
373
374    process_regexp_match!(array, regex, list_builder);
375
376    Ok(Arc::new(list_builder.finish()))
377}
378
379fn regexp_scalar_match_utf8view(
380    array: &StringViewArray,
381    regex: &Regex,
382) -> Result<ArrayRef, ArrowError> {
383    let builder = StringViewBuilder::with_capacity(0);
384    let mut list_builder = ListBuilder::new(builder);
385
386    process_regexp_match!(array, regex, list_builder);
387
388    Ok(Arc::new(list_builder.finish()))
389}
390
391/// Extract all groups matched by a regular expression for a given String array.
392///
393/// Modelled after the Postgres [regexp_match].
394///
395/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
396/// match of the corresponding index in `regex_array` to string in `array`
397///
398/// If there is no match, the list element is NULL.
399///
400/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
401/// then the list element is a single-element [`GenericStringArray`] containing the substring
402/// matching the whole pattern.
403///
404/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
405/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
406/// the n'th capturing parenthesized subexpression of the pattern.
407///
408/// The flags parameter is an optional text string containing zero or more single-letter flags
409/// that change the function's behavior.
410///
411/// # See Also
412/// * [`regexp_is_match`] for matching (rather than extracting) a regular expression against an array of strings
413///
414/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
415pub fn regexp_match(
416    array: &dyn Array,
417    regex_array: &dyn Datum,
418    flags_array: Option<&dyn Datum>,
419) -> Result<ArrayRef, ArrowError> {
420    let (rhs, is_rhs_scalar) = regex_array.get();
421
422    if array.data_type() != rhs.data_type() {
423        return Err(ArrowError::ComputeError(
424            "regexp_match() requires both array and pattern to be either Utf8, Utf8View or LargeUtf8"
425                .to_string(),
426        ));
427    }
428
429    let (flags, is_flags_scalar) = match flags_array {
430        Some(flags) => {
431            let (flags, is_flags_scalar) = flags.get();
432            (Some(flags), Some(is_flags_scalar))
433        }
434        None => (None, None),
435    };
436
437    if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
438        return Err(ArrowError::ComputeError(
439            "regexp_match() requires both pattern and flags to be either scalar or array"
440                .to_string(),
441        ));
442    }
443
444    if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
445        return Err(ArrowError::ComputeError(
446            "regexp_match() requires both pattern and flags to be either Utf8, Utf8View or LargeUtf8"
447                .to_string(),
448        ));
449    }
450
451    if is_rhs_scalar {
452        // Regex and flag is scalars
453        let (regex, flag) = match rhs.data_type() {
454            DataType::Utf8View => get_scalar_pattern_flag_utf8view(rhs, flags),
455            DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
456            DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
457            _ => {
458                return Err(ArrowError::ComputeError(
459                    "regexp_match() requires pattern to be either Utf8, Utf8View or LargeUtf8"
460                        .to_string(),
461                ));
462            }
463        };
464
465        if regex.is_none() {
466            return Ok(new_null_array(
467                &DataType::List(Arc::new(Field::new_list_field(
468                    array.data_type().clone(),
469                    true,
470                ))),
471                array.len(),
472            ));
473        }
474
475        let regex = regex.unwrap();
476
477        let pattern = if let Some(flag) = flag {
478            format!("(?{flag}){regex}")
479        } else {
480            regex.to_string()
481        };
482
483        let re = Regex::new(pattern.as_str()).map_err(|e| {
484            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
485        })?;
486
487        match array.data_type() {
488            DataType::Utf8View => regexp_scalar_match_utf8view(array.as_string_view(), &re),
489            DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re),
490            DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re),
491            _ => Err(ArrowError::ComputeError(
492                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
493                    .to_string(),
494            )),
495        }
496    } else {
497        match array.data_type() {
498            DataType::Utf8View => {
499                let regex_array = rhs.as_string_view();
500                let flags_array = flags.map(|flags| flags.as_string_view());
501                regexp_array_match_utf8view(array.as_string_view(), regex_array, flags_array)
502            }
503            DataType::Utf8 => {
504                let regex_array = rhs.as_string();
505                let flags_array = flags.map(|flags| flags.as_string());
506                regexp_array_match(array.as_string::<i32>(), regex_array, flags_array)
507            }
508            DataType::LargeUtf8 => {
509                let regex_array = rhs.as_string();
510                let flags_array = flags.map(|flags| flags.as_string());
511                regexp_array_match(array.as_string::<i64>(), regex_array, flags_array)
512            }
513            _ => Err(ArrowError::ComputeError(
514                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
515                    .to_string(),
516            )),
517        }
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    macro_rules! test_match_single_group {
526        ($test_name:ident, $values:expr, $patterns:expr, $arr_type:ty, $builder_type:ty, $expected:expr) => {
527            #[test]
528            fn $test_name() {
529                let array: $arr_type = <$arr_type>::from($values);
530                let pattern: $arr_type = <$arr_type>::from($patterns);
531
532                let actual = regexp_match(&array, &pattern, None).unwrap();
533
534                let elem_builder: $builder_type = <$builder_type>::new();
535                let mut expected_builder = ListBuilder::new(elem_builder);
536
537                for val in $expected {
538                    match val {
539                        Some(v) => {
540                            expected_builder.values().append_value(v);
541                            expected_builder.append(true);
542                        }
543                        None => expected_builder.append(false),
544                    }
545                }
546
547                let expected = expected_builder.finish();
548                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
549                assert_eq!(&expected, result);
550            }
551        };
552    }
553
554    test_match_single_group!(
555        match_single_group_string,
556        vec![
557            Some("abc-005-def"),
558            Some("X-7-5"),
559            Some("X545"),
560            None,
561            Some("foobarbequebaz"),
562            Some("foobarbequebaz"),
563        ],
564        vec![
565            r".*-(\d*)-.*",
566            r".*-(\d*)-.*",
567            r".*-(\d*)-.*",
568            r".*-(\d*)-.*",
569            r"(bar)(bequ1e)",
570            ""
571        ],
572        StringArray,
573        GenericStringBuilder<i32>,
574        [Some("005"), Some("7"), None, None, None, Some("")]
575    );
576    test_match_single_group!(
577        match_single_group_string_view,
578        vec![
579            Some("abc-005-def"),
580            Some("X-7-5"),
581            Some("X545"),
582            None,
583            Some("foobarbequebaz"),
584            Some("foobarbequebaz"),
585        ],
586        vec![
587            r".*-(\d*)-.*",
588            r".*-(\d*)-.*",
589            r".*-(\d*)-.*",
590            r".*-(\d*)-.*",
591            r"(bar)(bequ1e)",
592            ""
593        ],
594        StringViewArray,
595        StringViewBuilder,
596        [Some("005"), Some("7"), None, None, None, Some("")]
597    );
598
599    macro_rules! test_match_single_group_with_flags {
600        ($test_name:ident, $values:expr, $patterns:expr, $flags:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
601            #[test]
602            fn $test_name() {
603                let array: $array_type = <$array_type>::from($values);
604                let pattern: $array_type = <$array_type>::from($patterns);
605                let flags: $array_type = <$array_type>::from($flags);
606
607                let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
608
609                let elem_builder: $builder_type = <$builder_type>::new();
610                let mut expected_builder = ListBuilder::new(elem_builder);
611
612                for val in $expected {
613                    match val {
614                        Some(v) => {
615                            expected_builder.values().append_value(v);
616                            expected_builder.append(true);
617                        }
618                        None => {
619                            expected_builder.append(false);
620                        }
621                    }
622                }
623
624                let expected = expected_builder.finish();
625                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
626                assert_eq!(&expected, result);
627            }
628        };
629    }
630
631    test_match_single_group_with_flags!(
632        match_single_group_with_flags_string,
633        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
634        vec![r"x.*-(\d*)-.*"; 4],
635        vec!["i"; 4],
636        StringArray,
637        GenericStringBuilder<i32>,
638        [None, Some("7"), None, None]
639    );
640    test_match_single_group_with_flags!(
641        match_single_group_with_flags_stringview,
642        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
643        vec![r"x.*-(\d*)-.*"; 4],
644        vec!["i"; 4],
645        StringViewArray,
646        StringViewBuilder,
647        [None, Some("7"), None, None]
648    );
649
650    macro_rules! test_match_scalar_pattern {
651        ($test_name:ident, $values:expr, $pattern:expr, $flag:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
652            #[test]
653            fn $test_name() {
654                let array: $array_type = <$array_type>::from($values);
655
656                let pattern_scalar = Scalar::new(<$array_type>::from(vec![$pattern; 1]));
657                let flag_scalar = Scalar::new(<$array_type>::from(vec![$flag; 1]));
658
659                let actual = regexp_match(&array, &pattern_scalar, Some(&flag_scalar)).unwrap();
660
661                let elem_builder: $builder_type = <$builder_type>::new();
662                let mut expected_builder = ListBuilder::new(elem_builder);
663
664                for val in $expected {
665                    match val {
666                        Some(v) => {
667                            expected_builder.values().append_value(v);
668                            expected_builder.append(true);
669                        }
670                        None => expected_builder.append(false),
671                    }
672                }
673
674                let expected = expected_builder.finish();
675                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
676                assert_eq!(&expected, result);
677            }
678        };
679    }
680
681    test_match_scalar_pattern!(
682        match_scalar_pattern_string_with_flags,
683        vec![
684            Some("abc-005-def"),
685            Some("x-7-5"),
686            Some("X-0-Y"),
687            Some("X545"),
688            None
689        ],
690        r"x.*-(\d*)-.*",
691        Some("i"),
692        StringArray,
693        GenericStringBuilder<i32>,
694        [None, Some("7"), Some("0"), None, None]
695    );
696    test_match_scalar_pattern!(
697        match_scalar_pattern_stringview_with_flags,
698        vec![
699            Some("abc-005-def"),
700            Some("x-7-5"),
701            Some("X-0-Y"),
702            Some("X545"),
703            None
704        ],
705        r"x.*-(\d*)-.*",
706        Some("i"),
707        StringViewArray,
708        StringViewBuilder,
709        [None, Some("7"), Some("0"), None, None]
710    );
711
712    test_match_scalar_pattern!(
713        match_scalar_pattern_string_no_flags,
714        vec![
715            Some("abc-005-def"),
716            Some("x-7-5"),
717            Some("X-0-Y"),
718            Some("X545"),
719            None
720        ],
721        r"x.*-(\d*)-.*",
722        None::<&str>,
723        StringArray,
724        GenericStringBuilder<i32>,
725        [None, Some("7"), None, None, None]
726    );
727    test_match_scalar_pattern!(
728        match_scalar_pattern_stringview_no_flags,
729        vec![
730            Some("abc-005-def"),
731            Some("x-7-5"),
732            Some("X-0-Y"),
733            Some("X545"),
734            None
735        ],
736        r"x.*-(\d*)-.*",
737        None::<&str>,
738        StringViewArray,
739        StringViewBuilder,
740        [None, Some("7"), None, None, None]
741    );
742
743    macro_rules! test_match_scalar_no_pattern {
744        ($test_name:ident, $values:expr, $array_type:ty, $pattern_type:expr, $builder_type:ty, $expected:expr) => {
745            #[test]
746            fn $test_name() {
747                let array: $array_type = <$array_type>::from($values);
748                let pattern = Scalar::new(new_null_array(&$pattern_type, 1));
749
750                let actual = regexp_match(&array, &pattern, None).unwrap();
751
752                let elem_builder: $builder_type = <$builder_type>::new();
753                let mut expected_builder = ListBuilder::new(elem_builder);
754
755                for val in $expected {
756                    match val {
757                        Some(v) => {
758                            expected_builder.values().append_value(v);
759                            expected_builder.append(true);
760                        }
761                        None => expected_builder.append(false),
762                    }
763                }
764
765                let expected = expected_builder.finish();
766                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
767                assert_eq!(&expected, result);
768            }
769        };
770    }
771
772    test_match_scalar_no_pattern!(
773        match_scalar_no_pattern_string,
774        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
775        StringArray,
776        DataType::Utf8,
777        GenericStringBuilder<i32>,
778        [None::<&str>, None, None, None]
779    );
780    test_match_scalar_no_pattern!(
781        match_scalar_no_pattern_stringview,
782        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
783        StringViewArray,
784        DataType::Utf8View,
785        StringViewBuilder,
786        [None::<&str>, None, None, None]
787    );
788
789    macro_rules! test_match_single_group_not_skip {
790        ($test_name:ident, $values:expr, $pattern:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
791            #[test]
792            fn $test_name() {
793                let array: $array_type = <$array_type>::from($values);
794                let pattern: $array_type = <$array_type>::from(vec![$pattern]);
795
796                let actual = regexp_match(&array, &pattern, None).unwrap();
797
798                let elem_builder: $builder_type = <$builder_type>::new();
799                let mut expected_builder = ListBuilder::new(elem_builder);
800
801                for val in $expected {
802                    match val {
803                        Some(v) => {
804                            expected_builder.values().append_value(v);
805                            expected_builder.append(true);
806                        }
807                        None => expected_builder.append(false),
808                    }
809                }
810
811                let expected = expected_builder.finish();
812                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
813                assert_eq!(&expected, result);
814            }
815        };
816    }
817
818    test_match_single_group_not_skip!(
819        match_single_group_not_skip_string,
820        vec![Some("foo"), Some("bar")],
821        r"foo",
822        StringArray,
823        GenericStringBuilder<i32>,
824        [Some("foo")]
825    );
826    test_match_single_group_not_skip!(
827        match_single_group_not_skip_stringview,
828        vec![Some("foo"), Some("bar")],
829        r"foo",
830        StringViewArray,
831        StringViewBuilder,
832        [Some("foo")]
833    );
834
835    macro_rules! test_flag_utf8 {
836        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
837            #[test]
838            fn $test_name() {
839                let left = $left;
840                let right = $right;
841                let res = $op(&left, &right, None).unwrap();
842                let expected = $expected;
843                assert_eq!(expected.len(), res.len());
844                for i in 0..res.len() {
845                    let v = res.value(i);
846                    assert_eq!(v, expected[i]);
847                }
848            }
849        };
850        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
851            #[test]
852            fn $test_name() {
853                let left = $left;
854                let right = $right;
855                let flag = Some($flag);
856                let res = $op(&left, &right, flag.as_ref()).unwrap();
857                let expected = $expected;
858                assert_eq!(expected.len(), res.len());
859                for i in 0..res.len() {
860                    let v = res.value(i);
861                    assert_eq!(v, expected[i]);
862                }
863            }
864        };
865    }
866
867    macro_rules! test_flag_utf8_scalar {
868        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
869            #[test]
870            fn $test_name() {
871                let left = $left;
872                let res = $op(&left, $right, None).unwrap();
873                let expected = $expected;
874                assert_eq!(expected.len(), res.len());
875                for i in 0..res.len() {
876                    let v = res.value(i);
877                    assert_eq!(
878                        v,
879                        expected[i],
880                        "unexpected result when comparing {} at position {} to {} ",
881                        left.value(i),
882                        i,
883                        $right
884                    );
885                }
886            }
887        };
888        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
889            #[test]
890            fn $test_name() {
891                let left = $left;
892                let flag = Some($flag);
893                let res = $op(&left, $right, flag).unwrap();
894                let expected = $expected;
895                assert_eq!(expected.len(), res.len());
896                for i in 0..res.len() {
897                    let v = res.value(i);
898                    assert_eq!(
899                        v,
900                        expected[i],
901                        "unexpected result when comparing {} at position {} to {} ",
902                        left.value(i),
903                        i,
904                        $right
905                    );
906                }
907            }
908        };
909    }
910
911    test_flag_utf8!(
912        test_array_regexp_is_match_utf8,
913        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
914        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
915        regexp_is_match::<StringArray, StringArray, StringArray>,
916        [true, false, true, false, false, true]
917    );
918    test_flag_utf8!(
919        test_array_regexp_is_match_utf8_insensitive,
920        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
921        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
922        StringArray::from(vec!["i"; 6]),
923        regexp_is_match,
924        [true, true, true, true, false, true]
925    );
926
927    test_flag_utf8_scalar!(
928        test_array_regexp_is_match_utf8_scalar,
929        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
930        "^ar",
931        regexp_is_match_scalar,
932        [true, false, false, false]
933    );
934    test_flag_utf8_scalar!(
935        test_array_regexp_is_match_utf8_scalar_empty,
936        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
937        "",
938        regexp_is_match_scalar,
939        [true, true, true, true]
940    );
941    test_flag_utf8_scalar!(
942        test_array_regexp_is_match_utf8_scalar_insensitive,
943        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
944        "^ar",
945        "i",
946        regexp_is_match_scalar,
947        [true, true, false, false]
948    );
949
950    test_flag_utf8!(
951        tes_array_regexp_is_match,
952        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
953        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
954        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
955        [true, false, true, false, false, true]
956    );
957    test_flag_utf8!(
958        test_array_regexp_is_match_2,
959        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
960        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
961        regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
962        [true, false, true, false, false, true]
963    );
964    test_flag_utf8!(
965        test_array_regexp_is_match_insensitive,
966        StringViewArray::from(vec![
967            "Official Rust implementation of Apache Arrow",
968            "apache/arrow-rs",
969            "apache/arrow-rs",
970            "parquet",
971            "parquet",
972            "row",
973            "row",
974        ]),
975        StringViewArray::from(vec![
976            ".*rust implement.*",
977            "^ap",
978            "^AP",
979            "et$",
980            "ET$",
981            "foo",
982            ""
983        ]),
984        StringViewArray::from(vec!["i"; 7]),
985        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
986        [true, true, true, true, true, false, true]
987    );
988    test_flag_utf8!(
989        test_array_regexp_is_match_insensitive_2,
990        LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
991        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
992        StringArray::from(vec!["i"; 6]),
993        regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
994        [true, true, true, true, false, true]
995    );
996
997    test_flag_utf8_scalar!(
998        test_array_regexp_is_match_scalar,
999        StringViewArray::from(vec![
1000            "apache/arrow-rs",
1001            "APACHE/ARROW-RS",
1002            "parquet",
1003            "PARQUET",
1004        ]),
1005        "^ap",
1006        regexp_is_match_scalar::<StringViewArray>,
1007        [true, false, false, false]
1008    );
1009    test_flag_utf8_scalar!(
1010        test_array_regexp_is_match_scalar_empty,
1011        StringViewArray::from(vec![
1012            "apache/arrow-rs",
1013            "APACHE/ARROW-RS",
1014            "parquet",
1015            "PARQUET",
1016        ]),
1017        "",
1018        regexp_is_match_scalar::<StringViewArray>,
1019        [true, true, true, true]
1020    );
1021    test_flag_utf8_scalar!(
1022        test_array_regexp_is_match_scalar_insensitive,
1023        StringViewArray::from(vec![
1024            "apache/arrow-rs",
1025            "APACHE/ARROW-RS",
1026            "parquet",
1027            "PARQUET",
1028        ]),
1029        "^ap",
1030        "i",
1031        regexp_is_match_scalar::<StringViewArray>,
1032        [true, true, false, false]
1033    );
1034}