use crate::binary_predicate::BinaryPredicate;
use arrow_array::cast::AsArray;
use arrow_array::*;
use arrow_schema::*;
use arrow_select::take::take;
#[derive(Debug)]
pub(crate) enum Op {
Contains,
StartsWith,
EndsWith,
}
impl TryFrom<crate::like::Op> for Op {
type Error = ArrowError;
fn try_from(value: crate::like::Op) -> Result<Self, Self::Error> {
match value {
crate::like::Op::Contains => Ok(Op::Contains),
crate::like::Op::StartsWith => Ok(Op::StartsWith),
crate::like::Op::EndsWith => Ok(Op::EndsWith),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid binary operation: {value}"
))),
}
}
}
impl std::fmt::Display for Op {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Op::Contains => write!(f, "CONTAINS"),
Op::StartsWith => write!(f, "STARTS_WITH"),
Op::EndsWith => write!(f, "ENDS_WITH"),
}
}
}
pub(crate) fn binary_apply<'a, 'i, T: BinaryArrayType<'a> + 'a>(
op: Op,
l: T,
l_s: bool,
l_v: Option<&'a dyn AnyDictionaryArray>,
r: T,
r_s: bool,
r_v: Option<&'a dyn AnyDictionaryArray>,
) -> Result<BooleanArray, ArrowError> {
let l_len = l_v.map(|l| l.len()).unwrap_or(l.len());
if r_s {
let idx = match r_v {
Some(dict) if dict.null_count() != 0 => return Ok(BooleanArray::new_null(l_len)),
Some(dict) => dict.normalized_keys()[0],
None => 0,
};
if r.is_null(idx) {
return Ok(BooleanArray::new_null(l_len));
}
op_scalar::<T>(op, l, l_v, r.value(idx))
} else {
match (l_s, l_v, r_v) {
(true, None, None) => {
let v = l.is_valid(0).then(|| l.value(0));
op_binary(op, std::iter::repeat(v), r.iter())
}
(true, Some(l_v), None) => {
let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
op_binary(op, std::iter::repeat(v), r.iter())
}
(true, None, Some(r_v)) => {
let v = l.is_valid(0).then(|| l.value(0));
op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
}
(true, Some(l_v), Some(r_v)) => {
let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
}
(false, None, None) => op_binary(op, l.iter(), r.iter()),
(false, Some(l_v), None) => op_binary(op, vectored_iter(l, l_v), r.iter()),
(false, None, Some(r_v)) => op_binary(op, l.iter(), vectored_iter(r, r_v)),
(false, Some(l_v), Some(r_v)) => {
op_binary(op, vectored_iter(l, l_v), vectored_iter(r, r_v))
}
}
}
}
#[inline(never)]
fn op_scalar<'a, T: BinaryArrayType<'a>>(
op: Op,
l: T,
l_v: Option<&dyn AnyDictionaryArray>,
r: &[u8],
) -> Result<BooleanArray, ArrowError> {
let r = match op {
Op::Contains => BinaryPredicate::contains(r).evaluate_array(l, false),
Op::StartsWith => BinaryPredicate::StartsWith(r).evaluate_array(l, false),
Op::EndsWith => BinaryPredicate::EndsWith(r).evaluate_array(l, false),
};
Ok(match l_v {
Some(v) => take(&r, v.keys(), None)?.as_boolean().clone(),
None => r,
})
}
fn vectored_iter<'a, T: BinaryArrayType<'a> + 'a>(
a: T,
a_v: &'a dyn AnyDictionaryArray,
) -> impl Iterator<Item = Option<&'a [u8]>> + 'a {
let nulls = a_v.nulls();
let keys = a_v.normalized_keys();
keys.into_iter().enumerate().map(move |(idx, key)| {
if nulls.map(|n| n.is_null(idx)).unwrap_or_default() || a.is_null(key) {
return None;
}
Some(a.value(key))
})
}
#[inline(never)]
fn op_binary<'a>(
op: Op,
l: impl Iterator<Item = Option<&'a [u8]>>,
r: impl Iterator<Item = Option<&'a [u8]>>,
) -> Result<BooleanArray, ArrowError> {
match op {
Op::Contains => Ok(l
.zip(r)
.map(|(l, r)| Some(bytes_contains(l?, r?)))
.collect()),
Op::StartsWith => Ok(l
.zip(r)
.map(|(l, r)| Some(BinaryPredicate::StartsWith(r?).evaluate(l?)))
.collect()),
Op::EndsWith => Ok(l
.zip(r)
.map(|(l, r)| Some(BinaryPredicate::EndsWith(r?).evaluate(l?)))
.collect()),
}
}
fn bytes_contains(haystack: &[u8], needle: &[u8]) -> bool {
memchr::memmem::find(haystack, needle).is_some()
}