datafusion_functions/regex/
regexpcount.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::regex::{compile_and_cache_regex, compile_regex};
19use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType};
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::datatypes::{
22    DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
23};
24use arrow::error::ArrowError;
25use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
26use datafusion_expr::{
27    ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
28    TypeSignature::Uniform, Volatility,
29};
30use datafusion_macros::user_doc;
31use itertools::izip;
32use regex::Regex;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36#[user_doc(
37    doc_section(label = "Regular Expression Functions"),
38    description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.",
39    syntax_example = "regexp_count(str, regexp[, start, flags])",
40    sql_example = r#"```sql
41> select regexp_count('abcAbAbc', 'abc', 2, 'i');
42+---------------------------------------------------------------+
43| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) |
44+---------------------------------------------------------------+
45| 1                                                             |
46+---------------------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    standard_argument(name = "regexp", prefix = "Regular"),
50    argument(
51        name = "start",
52        description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function."
53    ),
54    argument(
55        name = "flags",
56        description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
57  - **i**: case-insensitive: letters match both upper and lower case
58  - **m**: multi-line mode: ^ and $ match begin/end of line
59  - **s**: allow . to match \n
60  - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
61  - **U**: swap the meaning of x* and x*?"#
62    )
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct RegexpCountFunc {
66    signature: Signature,
67}
68
69impl Default for RegexpCountFunc {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl RegexpCountFunc {
76    pub fn new() -> Self {
77        Self {
78            signature: Signature::one_of(
79                vec![
80                    Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
81                    Exact(vec![Utf8View, Utf8View, Int64]),
82                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
83                    Exact(vec![Utf8, Utf8, Int64]),
84                    Exact(vec![Utf8View, Utf8View, Int64, Utf8View]),
85                    Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]),
86                    Exact(vec![Utf8, Utf8, Int64, Utf8]),
87                ],
88                Volatility::Immutable,
89            ),
90        }
91    }
92}
93
94impl ScalarUDFImpl for RegexpCountFunc {
95    fn as_any(&self) -> &dyn std::any::Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "regexp_count"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
108        Ok(Int64)
109    }
110
111    fn invoke_with_args(
112        &self,
113        args: datafusion_expr::ScalarFunctionArgs,
114    ) -> Result<ColumnarValue> {
115        let args = &args.args;
116
117        let len = args
118            .iter()
119            .fold(Option::<usize>::None, |acc, arg| match arg {
120                ColumnarValue::Scalar(_) => acc,
121                ColumnarValue::Array(a) => Some(a.len()),
122            });
123
124        let is_scalar = len.is_none();
125        let inferred_length = len.unwrap_or(1);
126        let args = args
127            .iter()
128            .map(|arg| arg.to_array(inferred_length))
129            .collect::<Result<Vec<_>>>()?;
130
131        let result = regexp_count_func(&args);
132        if is_scalar {
133            // If all inputs are scalar, keeps output as scalar
134            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
135            result.map(ColumnarValue::Scalar)
136        } else {
137            result.map(ColumnarValue::Array)
138        }
139    }
140
141    fn documentation(&self) -> Option<&Documentation> {
142        self.doc()
143    }
144}
145
146pub fn regexp_count_func(args: &[ArrayRef]) -> Result<ArrayRef> {
147    let args_len = args.len();
148    if !(2..=4).contains(&args_len) {
149        return exec_err!(
150            "regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."
151        );
152    }
153
154    let values = &args[0];
155    match values.data_type() {
156        Utf8 | LargeUtf8 | Utf8View => (),
157        other => {
158            return internal_err!(
159                "Unsupported data type {other:?} for function regexp_count"
160            );
161        }
162    }
163
164    regexp_count(
165        values,
166        &args[1],
167        if args_len > 2 { Some(&args[2]) } else { None },
168        if args_len > 3 { Some(&args[3]) } else { None },
169    )
170    .map_err(|e| e.into())
171}
172
173/// `arrow-rs` style implementation of `regexp_count` function.
174/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern
175/// within a string array. It supports optional start positions and flags for case insensitivity.
176///
177/// The function accepts a variable number of arguments:
178/// - `values`: The array of strings to search within.
179/// - `regex_array`: The array of regular expression patterns to search for.
180/// - `start_array` (optional): The array of start positions for the search.
181/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity).
182///
183/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions,
184/// and flags. It uses a cache to store compiled regular expressions for efficiency.
185///
186/// # Errors
187/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile.
188fn regexp_count(
189    values: &dyn Array,
190    regex_array: &dyn Datum,
191    start_array: Option<&dyn Datum>,
192    flags_array: Option<&dyn Datum>,
193) -> Result<ArrayRef, ArrowError> {
194    let (regex_array, is_regex_scalar) = regex_array.get();
195    let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| {
196        let (start, is_start_scalar) = start.get();
197        (Some(start), is_start_scalar)
198    });
199    let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| {
200        let (flags, is_flags_scalar) = flags.get();
201        (Some(flags), is_flags_scalar)
202    });
203
204    match (values.data_type(), regex_array.data_type(), flags_array) {
205        (Utf8, Utf8, None) => regexp_count_inner(
206            &values.as_string::<i32>(),
207            &regex_array.as_string::<i32>(),
208            is_regex_scalar,
209            start_array.map(|start| start.as_primitive::<Int64Type>()),
210            is_start_scalar,
211            None,
212            is_flags_scalar,
213        ),
214        (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner(
215            &values.as_string::<i32>(),
216            &regex_array.as_string::<i32>(),
217            is_regex_scalar,
218            start_array.map(|start| start.as_primitive::<Int64Type>()),
219            is_start_scalar,
220            Some(&flags_array.as_string::<i32>()),
221            is_flags_scalar,
222        ),
223        (LargeUtf8, LargeUtf8, None) => regexp_count_inner(
224            &values.as_string::<i64>(),
225            &regex_array.as_string::<i64>(),
226            is_regex_scalar,
227            start_array.map(|start| start.as_primitive::<Int64Type>()),
228            is_start_scalar,
229            None,
230            is_flags_scalar,
231        ),
232        (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner(
233            &values.as_string::<i64>(),
234            &regex_array.as_string::<i64>(),
235            is_regex_scalar,
236            start_array.map(|start| start.as_primitive::<Int64Type>()),
237            is_start_scalar,
238            Some(&flags_array.as_string::<i64>()),
239            is_flags_scalar,
240        ),
241        (Utf8View, Utf8View, None) => regexp_count_inner(
242            &values.as_string_view(),
243            &regex_array.as_string_view(),
244            is_regex_scalar,
245            start_array.map(|start| start.as_primitive::<Int64Type>()),
246            is_start_scalar,
247            None,
248            is_flags_scalar,
249        ),
250        (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner(
251            &values.as_string_view(),
252            &regex_array.as_string_view(),
253            is_regex_scalar,
254            start_array.map(|start| start.as_primitive::<Int64Type>()),
255            is_start_scalar,
256            Some(&flags_array.as_string_view()),
257            is_flags_scalar,
258        ),
259        _ => Err(ArrowError::ComputeError(
260            "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
261        )),
262    }
263}
264
265fn regexp_count_inner<'a, S>(
266    values: &S,
267    regex_array: &S,
268    is_regex_scalar: bool,
269    start_array: Option<&Int64Array>,
270    is_start_scalar: bool,
271    flags_array: Option<&S>,
272    is_flags_scalar: bool,
273) -> Result<ArrayRef, ArrowError>
274where
275    S: StringArrayType<'a>,
276{
277    let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
278        (Some(regex_array.value(0)), true)
279    } else {
280        (None, false)
281    };
282
283    let (start_array, start_scalar, is_start_scalar) =
284        if let Some(start_array) = start_array {
285            if is_start_scalar || start_array.len() == 1 {
286                (None, Some(start_array.value(0)), true)
287            } else {
288                (Some(start_array), None, false)
289            }
290        } else {
291            (None, Some(1), true)
292        };
293
294    let (flags_array, flags_scalar, is_flags_scalar) =
295        if let Some(flags_array) = flags_array {
296            if is_flags_scalar || flags_array.len() == 1 {
297                (None, Some(flags_array.value(0)), true)
298            } else {
299                (Some(flags_array), None, false)
300            }
301        } else {
302            (None, None, true)
303        };
304
305    let mut regex_cache = HashMap::new();
306
307    match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
308        (true, true, true) => {
309            let regex = match regex_scalar {
310                None | Some("") => {
311                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
312                }
313                Some(regex) => regex,
314            };
315
316            let pattern = compile_regex(regex, flags_scalar)?;
317
318            Ok(Arc::new(
319                values
320                    .iter()
321                    .map(|value| count_matches(value, &pattern, start_scalar))
322                    .collect::<Result<Int64Array, ArrowError>>()?,
323            ))
324        }
325        (true, true, false) => {
326            let regex = match regex_scalar {
327                None | Some("") => {
328                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
329                }
330                Some(regex) => regex,
331            };
332
333            let flags_array = flags_array.unwrap();
334            if values.len() != flags_array.len() {
335                return Err(ArrowError::ComputeError(format!(
336                    "flags_array must be the same length as values array; got {} and {}",
337                    flags_array.len(),
338                    values.len(),
339                )));
340            }
341
342            Ok(Arc::new(
343                values
344                    .iter()
345                    .zip(flags_array.iter())
346                    .map(|(value, flags)| {
347                        let pattern =
348                            compile_and_cache_regex(regex, flags, &mut regex_cache)?;
349                        count_matches(value, pattern, start_scalar)
350                    })
351                    .collect::<Result<Int64Array, ArrowError>>()?,
352            ))
353        }
354        (true, false, true) => {
355            let regex = match regex_scalar {
356                None | Some("") => {
357                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
358                }
359                Some(regex) => regex,
360            };
361
362            let pattern = compile_regex(regex, flags_scalar)?;
363
364            let start_array = start_array.unwrap();
365
366            Ok(Arc::new(
367                values
368                    .iter()
369                    .zip(start_array.iter())
370                    .map(|(value, start)| count_matches(value, &pattern, start))
371                    .collect::<Result<Int64Array, ArrowError>>()?,
372            ))
373        }
374        (true, false, false) => {
375            let regex = match regex_scalar {
376                None | Some("") => {
377                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
378                }
379                Some(regex) => regex,
380            };
381
382            let flags_array = flags_array.unwrap();
383            if values.len() != flags_array.len() {
384                return Err(ArrowError::ComputeError(format!(
385                    "flags_array must be the same length as values array; got {} and {}",
386                    flags_array.len(),
387                    values.len(),
388                )));
389            }
390
391            Ok(Arc::new(
392                izip!(
393                    values.iter(),
394                    start_array.unwrap().iter(),
395                    flags_array.iter()
396                )
397                .map(|(value, start, flags)| {
398                    let pattern =
399                        compile_and_cache_regex(regex, flags, &mut regex_cache)?;
400
401                    count_matches(value, pattern, start)
402                })
403                .collect::<Result<Int64Array, ArrowError>>()?,
404            ))
405        }
406        (false, true, true) => {
407            if values.len() != regex_array.len() {
408                return Err(ArrowError::ComputeError(format!(
409                    "regex_array must be the same length as values array; got {} and {}",
410                    regex_array.len(),
411                    values.len(),
412                )));
413            }
414
415            Ok(Arc::new(
416                values
417                    .iter()
418                    .zip(regex_array.iter())
419                    .map(|(value, regex)| {
420                        let regex = match regex {
421                            None | Some("") => return Ok(0),
422                            Some(regex) => regex,
423                        };
424
425                        let pattern = compile_and_cache_regex(
426                            regex,
427                            flags_scalar,
428                            &mut regex_cache,
429                        )?;
430                        count_matches(value, pattern, start_scalar)
431                    })
432                    .collect::<Result<Int64Array, ArrowError>>()?,
433            ))
434        }
435        (false, true, false) => {
436            if values.len() != regex_array.len() {
437                return Err(ArrowError::ComputeError(format!(
438                    "regex_array must be the same length as values array; got {} and {}",
439                    regex_array.len(),
440                    values.len(),
441                )));
442            }
443
444            let flags_array = flags_array.unwrap();
445            if values.len() != flags_array.len() {
446                return Err(ArrowError::ComputeError(format!(
447                    "flags_array must be the same length as values array; got {} and {}",
448                    flags_array.len(),
449                    values.len(),
450                )));
451            }
452
453            Ok(Arc::new(
454                izip!(values.iter(), regex_array.iter(), flags_array.iter())
455                    .map(|(value, regex, flags)| {
456                        let regex = match regex {
457                            None | Some("") => return Ok(0),
458                            Some(regex) => regex,
459                        };
460
461                        let pattern =
462                            compile_and_cache_regex(regex, flags, &mut regex_cache)?;
463
464                        count_matches(value, pattern, start_scalar)
465                    })
466                    .collect::<Result<Int64Array, ArrowError>>()?,
467            ))
468        }
469        (false, false, true) => {
470            if values.len() != regex_array.len() {
471                return Err(ArrowError::ComputeError(format!(
472                    "regex_array must be the same length as values array; got {} and {}",
473                    regex_array.len(),
474                    values.len(),
475                )));
476            }
477
478            let start_array = start_array.unwrap();
479            if values.len() != start_array.len() {
480                return Err(ArrowError::ComputeError(format!(
481                    "start_array must be the same length as values array; got {} and {}",
482                    start_array.len(),
483                    values.len(),
484                )));
485            }
486
487            Ok(Arc::new(
488                izip!(values.iter(), regex_array.iter(), start_array.iter())
489                    .map(|(value, regex, start)| {
490                        let regex = match regex {
491                            None | Some("") => return Ok(0),
492                            Some(regex) => regex,
493                        };
494
495                        let pattern = compile_and_cache_regex(
496                            regex,
497                            flags_scalar,
498                            &mut regex_cache,
499                        )?;
500                        count_matches(value, pattern, start)
501                    })
502                    .collect::<Result<Int64Array, ArrowError>>()?,
503            ))
504        }
505        (false, false, false) => {
506            if values.len() != regex_array.len() {
507                return Err(ArrowError::ComputeError(format!(
508                    "regex_array must be the same length as values array; got {} and {}",
509                    regex_array.len(),
510                    values.len(),
511                )));
512            }
513
514            let start_array = start_array.unwrap();
515            if values.len() != start_array.len() {
516                return Err(ArrowError::ComputeError(format!(
517                    "start_array must be the same length as values array; got {} and {}",
518                    start_array.len(),
519                    values.len(),
520                )));
521            }
522
523            let flags_array = flags_array.unwrap();
524            if values.len() != flags_array.len() {
525                return Err(ArrowError::ComputeError(format!(
526                    "flags_array must be the same length as values array; got {} and {}",
527                    flags_array.len(),
528                    values.len(),
529                )));
530            }
531
532            Ok(Arc::new(
533                izip!(
534                    values.iter(),
535                    regex_array.iter(),
536                    start_array.iter(),
537                    flags_array.iter()
538                )
539                .map(|(value, regex, start, flags)| {
540                    let regex = match regex {
541                        None | Some("") => return Ok(0),
542                        Some(regex) => regex,
543                    };
544
545                    let pattern =
546                        compile_and_cache_regex(regex, flags, &mut regex_cache)?;
547                    count_matches(value, pattern, start)
548                })
549                .collect::<Result<Int64Array, ArrowError>>()?,
550            ))
551        }
552    }
553}
554
555fn count_matches(
556    value: Option<&str>,
557    pattern: &Regex,
558    start: Option<i64>,
559) -> Result<i64, ArrowError> {
560    let value = match value {
561        None | Some("") => return Ok(0),
562        Some(value) => value,
563    };
564
565    if let Some(start) = start {
566        if start < 1 {
567            return Err(ArrowError::ComputeError(
568                "regexp_count() requires start to be 1 based".to_string(),
569            ));
570        }
571
572        // Find the byte offset for the start position (1-based character index)
573        let byte_offset = value
574            .char_indices()
575            .nth((start as usize).saturating_sub(1))
576            .map(|(idx, _)| idx)
577            .unwrap_or(value.len());
578
579        // Use string slicing instead of collecting chars into a new String
580        let find_slice = &value[byte_offset..];
581        let count = pattern.find_iter(find_slice).count();
582        Ok(count as i64)
583    } else {
584        let count = pattern.find_iter(value).count();
585        Ok(count as i64)
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use arrow::array::{GenericStringArray, StringViewArray};
593    use arrow::datatypes::Field;
594    use datafusion_common::config::ConfigOptions;
595    use datafusion_expr::ScalarFunctionArgs;
596
597    #[test]
598    fn test_regexp_count() {
599        test_case_sensitive_regexp_count_scalar();
600        test_case_sensitive_regexp_count_scalar_start();
601        test_case_insensitive_regexp_count_scalar_flags();
602        test_case_sensitive_regexp_count_start_scalar_complex();
603
604        test_case_sensitive_regexp_count_array::<GenericStringArray<i32>>();
605        test_case_sensitive_regexp_count_array::<GenericStringArray<i64>>();
606        test_case_sensitive_regexp_count_array::<StringViewArray>();
607
608        test_case_sensitive_regexp_count_array_start::<GenericStringArray<i32>>();
609        test_case_sensitive_regexp_count_array_start::<GenericStringArray<i64>>();
610        test_case_sensitive_regexp_count_array_start::<StringViewArray>();
611
612        test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i32>>();
613        test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i64>>();
614        test_case_insensitive_regexp_count_array_flags::<StringViewArray>();
615
616        test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
617        test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
618        test_case_sensitive_regexp_count_array_complex::<StringViewArray>();
619
620        test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
621    }
622
623    fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
624        let args_values = args
625            .iter()
626            .map(|sv| ColumnarValue::Scalar(sv.clone()))
627            .collect();
628
629        let arg_fields = args
630            .iter()
631            .enumerate()
632            .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into())
633            .collect::<Vec<_>>();
634
635        RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
636            args: args_values,
637            arg_fields,
638            number_rows: args.len(),
639            return_field: Field::new("f", Int64, true).into(),
640            config_options: Arc::new(ConfigOptions::default()),
641        })
642    }
643
644    fn test_case_sensitive_regexp_count_scalar() {
645        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
646        let regex = "abc";
647        let expected: Vec<i64> = vec![0, 1, 2, 1, 3];
648
649        values.iter().enumerate().for_each(|(pos, &v)| {
650            // utf8
651            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
652            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
653            let expected = expected.get(pos).cloned();
654            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
655            match re {
656                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
657                    assert_eq!(v, expected, "regexp_count scalar test failed");
658                }
659                _ => panic!("Unexpected result"),
660            }
661
662            // largeutf8
663            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
664            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
665            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
666            match re {
667                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
668                    assert_eq!(v, expected, "regexp_count scalar test failed");
669                }
670                _ => panic!("Unexpected result"),
671            }
672
673            // utf8view
674            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
675            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
676            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
677            match re {
678                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
679                    assert_eq!(v, expected, "regexp_count scalar test failed");
680                }
681                _ => panic!("Unexpected result"),
682            }
683        });
684    }
685
686    fn test_case_sensitive_regexp_count_scalar_start() {
687        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
688        let regex = "abc";
689        let start = 2;
690        let expected: Vec<i64> = vec![0, 1, 1, 0, 2];
691
692        values.iter().enumerate().for_each(|(pos, &v)| {
693            // utf8
694            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
695            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
696            let start_sv = ScalarValue::Int64(Some(start));
697            let expected = expected.get(pos).cloned();
698            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
699            match re {
700                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
701                    assert_eq!(v, expected, "regexp_count scalar test failed");
702                }
703                _ => panic!("Unexpected result"),
704            }
705
706            // largeutf8
707            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
708            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
709            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
710            match re {
711                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
712                    assert_eq!(v, expected, "regexp_count scalar test failed");
713                }
714                _ => panic!("Unexpected result"),
715            }
716
717            // utf8view
718            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
719            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
720            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
721            match re {
722                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
723                    assert_eq!(v, expected, "regexp_count scalar test failed");
724                }
725                _ => panic!("Unexpected result"),
726            }
727        });
728    }
729
730    fn test_case_insensitive_regexp_count_scalar_flags() {
731        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
732        let regex = "abc";
733        let start = 1;
734        let flags = "i";
735        let expected: Vec<i64> = vec![0, 1, 2, 2, 3];
736
737        values.iter().enumerate().for_each(|(pos, &v)| {
738            // utf8
739            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
740            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
741            let start_sv = ScalarValue::Int64(Some(start));
742            let flags_sv = ScalarValue::Utf8(Some(flags.to_string()));
743            let expected = expected.get(pos).cloned();
744
745            let re = regexp_count_with_scalar_values(&[
746                v_sv,
747                regex_sv,
748                start_sv.clone(),
749                flags_sv.clone(),
750            ]);
751            match re {
752                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
753                    assert_eq!(v, expected, "regexp_count scalar test failed");
754                }
755                _ => panic!("Unexpected result"),
756            }
757
758            // largeutf8
759            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
760            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
761            let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string()));
762
763            let re = regexp_count_with_scalar_values(&[
764                v_sv,
765                regex_sv,
766                start_sv.clone(),
767                flags_sv.clone(),
768            ]);
769            match re {
770                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
771                    assert_eq!(v, expected, "regexp_count scalar test failed");
772                }
773                _ => panic!("Unexpected result"),
774            }
775
776            // utf8view
777            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
778            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
779            let flags_sv = ScalarValue::Utf8View(Some(flags.to_string()));
780
781            let re = regexp_count_with_scalar_values(&[
782                v_sv,
783                regex_sv,
784                start_sv.clone(),
785                flags_sv.clone(),
786            ]);
787            match re {
788                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
789                    assert_eq!(v, expected, "regexp_count scalar test failed");
790                }
791                _ => panic!("Unexpected result"),
792            }
793        });
794    }
795
796    fn test_case_sensitive_regexp_count_array<A>()
797    where
798        A: From<Vec<&'static str>> + Array + 'static,
799    {
800        let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
801        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
802
803        let expected = Int64Array::from(vec![0, 1, 2, 2, 2]);
804
805        let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
806        assert_eq!(re.as_ref(), &expected);
807    }
808
809    fn test_case_sensitive_regexp_count_array_start<A>()
810    where
811        A: From<Vec<&'static str>> + Array + 'static,
812    {
813        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
814        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
815        let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
816
817        let expected = Int64Array::from(vec![0, 0, 1, 1, 0]);
818
819        let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
820            .unwrap();
821        assert_eq!(re.as_ref(), &expected);
822    }
823
824    fn test_case_insensitive_regexp_count_array_flags<A>()
825    where
826        A: From<Vec<&'static str>> + Array + 'static,
827    {
828        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
829        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
830        let start = Int64Array::from(vec![1]);
831        let flags = A::from(vec!["", "i", "", "", "i"]);
832
833        let expected = Int64Array::from(vec![0, 1, 2, 2, 3]);
834
835        let re = regexp_count_func(&[
836            Arc::new(values),
837            Arc::new(regex),
838            Arc::new(start),
839            Arc::new(flags),
840        ])
841        .unwrap();
842        assert_eq!(re.as_ref(), &expected);
843    }
844
845    fn test_case_sensitive_regexp_count_start_scalar_complex() {
846        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
847        let regex = ["", "abc", "a", "bc", "ab"];
848        let start = 5;
849        let flags = ["", "i", "", "", "i"];
850        let expected: Vec<i64> = vec![0, 0, 0, 1, 1];
851
852        values.iter().enumerate().for_each(|(pos, &v)| {
853            // utf8
854            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
855            let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| (*s).to_string()));
856            let start_sv = ScalarValue::Int64(Some(start));
857            let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| (*f).to_string()));
858            let expected = expected.get(pos).cloned();
859            let re = regexp_count_with_scalar_values(&[
860                v_sv,
861                regex_sv,
862                start_sv.clone(),
863                flags_sv.clone(),
864            ]);
865            match re {
866                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
867                    assert_eq!(v, expected, "regexp_count scalar test failed");
868                }
869                _ => panic!("Unexpected result"),
870            }
871
872            // largeutf8
873            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
874            let regex_sv =
875                ScalarValue::LargeUtf8(regex.get(pos).map(|s| (*s).to_string()));
876            let flags_sv =
877                ScalarValue::LargeUtf8(flags.get(pos).map(|f| (*f).to_string()));
878            let re = regexp_count_with_scalar_values(&[
879                v_sv,
880                regex_sv,
881                start_sv.clone(),
882                flags_sv.clone(),
883            ]);
884            match re {
885                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
886                    assert_eq!(v, expected, "regexp_count scalar test failed");
887                }
888                _ => panic!("Unexpected result"),
889            }
890
891            // utf8view
892            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
893            let regex_sv =
894                ScalarValue::Utf8View(regex.get(pos).map(|s| (*s).to_string()));
895            let flags_sv =
896                ScalarValue::Utf8View(flags.get(pos).map(|f| (*f).to_string()));
897            let re = regexp_count_with_scalar_values(&[
898                v_sv,
899                regex_sv,
900                start_sv.clone(),
901                flags_sv.clone(),
902            ]);
903            match re {
904                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
905                    assert_eq!(v, expected, "regexp_count scalar test failed");
906                }
907                _ => panic!("Unexpected result"),
908            }
909        });
910    }
911
912    fn test_case_sensitive_regexp_count_array_complex<A>()
913    where
914        A: From<Vec<&'static str>> + Array + 'static,
915    {
916        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
917        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
918        let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
919        let flags = A::from(vec!["", "i", "", "", "i"]);
920
921        let expected = Int64Array::from(vec![0, 1, 1, 1, 1]);
922
923        let re = regexp_count_func(&[
924            Arc::new(values),
925            Arc::new(regex),
926            Arc::new(start),
927            Arc::new(flags),
928        ])
929        .unwrap();
930        assert_eq!(re.as_ref(), &expected);
931    }
932
933    fn test_case_regexp_count_cache_check<A>()
934    where
935        A: From<Vec<&'static str>> + Array + 'static,
936    {
937        let values = A::from(vec!["aaa", "Aaa", "aaa"]);
938        let regex = A::from(vec!["aaa", "aaa", "aaa"]);
939        let start = Int64Array::from(vec![1, 1, 1]);
940        let flags = A::from(vec!["", "i", ""]);
941
942        let expected = Int64Array::from(vec![1, 1, 1]);
943
944        let re = regexp_count_func(&[
945            Arc::new(values),
946            Arc::new(regex),
947            Arc::new(start),
948            Arc::new(flags),
949        ])
950        .unwrap();
951        assert_eq!(re.as_ref(), &expected);
952    }
953}