use std::iter::zip;
#[cfg(feature = "extract_groups")]
use arrow::array::{Array, StructArray};
use arrow::array::{MutablePlString, Utf8ViewArray};
use polars_core::prelude::arity::{try_binary_mut_with_options, try_unary_mut_with_options};
use regex::Regex;
use super::*;
#[cfg(feature = "extract_groups")]
fn extract_groups_array(
arr: &Utf8ViewArray,
reg: &Regex,
names: &[&str],
dtype: ArrowDataType,
) -> PolarsResult<ArrayRef> {
let mut builders = (0..names.len())
.map(|_| MutablePlString::with_capacity(arr.len()))
.collect::<Vec<_>>();
let mut locs = reg.capture_locations();
for opt_v in arr {
if let Some(s) = opt_v {
if reg.captures_read(&mut locs, s).is_some() {
for (i, builder) in builders.iter_mut().enumerate() {
builder.push(locs.get(i + 1).map(|(start, stop)| &s[start..stop]));
}
continue;
}
}
builders.iter_mut().for_each(|arr| arr.push_null());
}
let values = builders.into_iter().map(|a| a.freeze().boxed()).collect();
Ok(StructArray::new(dtype, arr.len(), values, arr.validity().cloned()).boxed())
}
#[cfg(feature = "extract_groups")]
pub(super) fn extract_groups(
ca: &StringChunked,
pat: &str,
dtype: &DataType,
) -> PolarsResult<Series> {
let reg = polars_utils::regex_cache::compile_regex(pat)?;
let n_fields = reg.captures_len();
if n_fields == 1 {
return StructChunked::from_series(ca.name().clone(), ca.len(), [].iter())
.map(|ca| ca.into_series());
}
let arrow_dtype = dtype.try_to_arrow(CompatLevel::newest())?;
let DataType::Struct(fields) = dtype else {
unreachable!() };
let names = fields
.iter()
.map(|fld| fld.name.as_str())
.collect::<Vec<_>>();
let chunks = ca
.downcast_iter()
.map(|array| extract_groups_array(array, ®, &names, arrow_dtype.clone()))
.collect::<PolarsResult<Vec<_>>>()?;
Series::try_from((ca.name().clone(), chunks))
}
fn extract_group_reg_lit(
arr: &Utf8ViewArray,
reg: &Regex,
group_index: usize,
) -> PolarsResult<Utf8ViewArray> {
let mut builder = MutablePlString::with_capacity(arr.len());
let mut locs = reg.capture_locations();
for opt_v in arr {
if let Some(s) = opt_v {
if reg.captures_read(&mut locs, s).is_some() {
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
continue;
}
}
builder.push_null();
}
Ok(builder.into())
}
fn extract_group_array_lit(
s: &str,
pat: &Utf8ViewArray,
group_index: usize,
) -> PolarsResult<Utf8ViewArray> {
let mut builder = MutablePlString::with_capacity(pat.len());
for opt_pat in pat {
if let Some(pat) = opt_pat {
let reg = polars_utils::regex_cache::compile_regex(pat)?;
let mut locs = reg.capture_locations();
if reg.captures_read(&mut locs, s).is_some() {
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
continue;
}
}
builder.push_null();
}
Ok(builder.into())
}
fn extract_group_binary(
arr: &Utf8ViewArray,
pat: &Utf8ViewArray,
group_index: usize,
) -> PolarsResult<Utf8ViewArray> {
let mut builder = MutablePlString::with_capacity(arr.len());
for (opt_s, opt_pat) in zip(arr, pat) {
match (opt_s, opt_pat) {
(Some(s), Some(pat)) => {
let reg = polars_utils::regex_cache::compile_regex(pat)?;
let mut locs = reg.capture_locations();
if reg.captures_read(&mut locs, s).is_some() {
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
continue;
}
builder.push_null()
},
_ => builder.push_null(),
}
}
Ok(builder.into())
}
pub(super) fn extract_group(
ca: &StringChunked,
pat: &StringChunked,
group_index: usize,
) -> PolarsResult<StringChunked> {
match (ca.len(), pat.len()) {
(_, 1) => {
if let Some(pat) = pat.get(0) {
let reg = polars_utils::regex_cache::compile_regex(pat)?;
try_unary_mut_with_options(ca, |arr| extract_group_reg_lit(arr, ®, group_index))
} else {
Ok(StringChunked::full_null(ca.name().clone(), ca.len()))
}
},
(1, _) => {
if let Some(s) = ca.get(0) {
try_unary_mut_with_options(pat, |pat| extract_group_array_lit(s, pat, group_index))
} else {
Ok(StringChunked::full_null(ca.name().clone(), pat.len()))
}
},
(len_ca, len_pat) if len_ca == len_pat => try_binary_mut_with_options(
ca,
pat,
|ca, pat| extract_group_binary(ca, pat, group_index),
ca.name().clone(),
),
_ => {
polars_bail!(ComputeError: "ca(len: {}) and pat(len: {}) should either broadcast or have the same length", ca.len(), pat.len())
},
}
}