use crate::SparkError;
use arrow::array::builder::BooleanBuilder;
use arrow::array::types::Int32Type;
use arrow::array::{Array, BooleanArray, DictionaryArray, RecordBatch, StringArray};
use arrow::compute::take;
use arrow::datatypes::{DataType, Schema};
use datafusion::common::{internal_err, Result};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr_common::physical_expr::DynEq;
use datafusion::physical_plan::ColumnarValue;
use regex::Regex;
use std::any::Any;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug)]
pub struct RLike {
child: Arc<dyn PhysicalExpr>,
pattern_str: String,
pattern: Regex,
}
impl Hash for RLike {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(self.pattern_str.as_bytes());
}
}
impl DynEq for RLike {
fn dyn_eq(&self, other: &dyn Any) -> bool {
if let Some(other) = other.downcast_ref::<Self>() {
self.pattern_str == other.pattern_str
} else {
false
}
}
}
impl RLike {
pub fn try_new(child: Arc<dyn PhysicalExpr>, pattern: &str) -> Result<Self> {
Ok(Self {
child,
pattern_str: pattern.to_string(),
pattern: Regex::new(pattern).map_err(|e| {
SparkError::Internal(format!("Failed to compile pattern {pattern}: {e}"))
})?,
})
}
fn is_match(&self, inputs: &StringArray) -> BooleanArray {
let mut builder = BooleanBuilder::with_capacity(inputs.len());
if inputs.is_nullable() {
for i in 0..inputs.len() {
if inputs.is_null(i) {
builder.append_null();
} else {
builder.append_value(self.pattern.is_match(inputs.value(i)));
}
}
} else {
for i in 0..inputs.len() {
builder.append_value(self.pattern.is_match(inputs.value(i)));
}
}
builder.finish()
}
}
impl Display for RLike {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RLike [child: {}, pattern: {}] ",
self.child, self.pattern_str
)
}
}
impl PhysicalExpr for RLike {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
self.child.nullable(input_schema)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
match self.child.evaluate(batch)? {
ColumnarValue::Array(array) if array.as_any().is::<DictionaryArray<Int32Type>>() => {
let dict_array = array
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.expect("dict array");
let dict_values = dict_array
.values()
.as_any()
.downcast_ref::<StringArray>()
.expect("strings");
let new_values = self.is_match(dict_values);
let result = take(&new_values, dict_array.keys(), None)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Array(array) => {
let inputs = array
.as_any()
.downcast_ref::<StringArray>()
.expect("string array");
let array = self.is_match(inputs);
Ok(ColumnarValue::Array(Arc::new(array)))
}
ColumnarValue::Scalar(_) => {
internal_err!("non scalar regexp patterns are not supported")
}
}
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.child]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
assert!(children.len() == 1);
Ok(Arc::new(RLike::try_new(
Arc::clone(&children[0]),
&self.pattern_str,
)?))
}
fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
unimplemented!()
}
}