use arrow_array::{Array, ArrayAccessor, BinaryViewArray, BooleanArray};
use arrow_buffer::BooleanBuffer;
use memchr::memmem::Finder;
use std::iter::zip;
pub enum BinaryPredicate<'a> {
Contains(Finder<'a>),
StartsWith(&'a [u8]),
EndsWith(&'a [u8]),
}
impl<'a> BinaryPredicate<'a> {
pub fn contains(needle: &'a [u8]) -> Self {
Self::Contains(Finder::new(needle))
}
pub fn evaluate(&self, haystack: &[u8]) -> bool {
match self {
Self::Contains(finder) => finder.find(haystack).is_some(),
Self::StartsWith(v) => starts_with(haystack, v, equals_kernel),
Self::EndsWith(v) => ends_with(haystack, v, equals_kernel),
}
}
#[inline(never)]
pub fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
where
T: ArrayAccessor<Item = &'i [u8]>,
{
match self {
Self::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
finder.find(haystack).is_some() != negate
}),
Self::StartsWith(v) => {
if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
let nulls = view_array.logical_nulls();
let values = BooleanBuffer::from(
view_array
.prefix_bytes_iter(v.len())
.map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_kernel) != negate
})
}
}
Self::EndsWith(v) => {
if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
let nulls = view_array.logical_nulls();
let values = BooleanBuffer::from(
view_array
.suffix_bytes_iter(v.len())
.map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_kernel) != negate
})
}
}
}
}
}
fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
}
fn starts_with(
haystack: &[u8],
needle: &[u8],
byte_eq_kernel: impl Fn((&u8, &u8)) -> bool,
) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(haystack, needle).all(byte_eq_kernel)
}
}
fn ends_with(haystack: &[u8], needle: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(haystack.iter().rev(), needle.iter().rev()).all(byte_eq_kernel)
}
}
fn equals_kernel((n, h): (&u8, &u8)) -> bool {
n == h
}
#[cfg(test)]
mod tests {
use super::BinaryPredicate;
#[test]
fn test_contains() {
assert!(BinaryPredicate::contains(b"hay").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"haystack").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"h").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"k").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"stack").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"sta").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"stack").evaluate(b"hay\0stack"));
assert!(BinaryPredicate::contains(b"\0s").evaluate(b"hay\0stack"));
assert!(BinaryPredicate::contains(b"\0").evaluate(b"hay\0stack"));
assert!(BinaryPredicate::contains(b"a").evaluate(b"a"));
assert!(!BinaryPredicate::contains(b"hy").evaluate(b"haystack"));
assert!(!BinaryPredicate::contains(b"stackx").evaluate(b"haystack"));
assert!(!BinaryPredicate::contains(b"x").evaluate(b"haystack"));
assert!(!BinaryPredicate::contains(b"haystack haystack").evaluate(b"haystack"));
}
#[test]
fn test_starts_with() {
assert!(BinaryPredicate::StartsWith(b"hay").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"h\0aystack"));
assert!(BinaryPredicate::StartsWith(b"haystack").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"ha").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"h").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"stack").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"haystacks").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"HAY").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"hay").evaluate(b"h\0aystack"));
}
#[test]
fn test_ends_with() {
assert!(BinaryPredicate::EndsWith(b"stack").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"hayst\0ack"));
assert!(BinaryPredicate::EndsWith(b"haystack").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"ck").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"k").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"hay").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"STACK").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"haystacks").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"xhaystack").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"stack").evaluate(b"hayst\0ack"));
}
}