use arrow_array::{BooleanArray, GenericStringArray, OffsetSizeTrait};
use arrow_schema::ArrowError;
use regex::{Regex, RegexBuilder};
pub enum Predicate<'a> {
Eq(&'a str),
Contains(&'a str),
StartsWith(&'a str),
EndsWith(&'a str),
IEqAscii(&'a str),
IStartsWithAscii(&'a str),
IEndsWithAscii(&'a str),
Regex(Regex),
}
impl<'a> Predicate<'a> {
pub fn like(pattern: &'a str) -> Result<Self, ArrowError> {
if !pattern.contains(is_like_pattern) {
Ok(Self::Eq(pattern))
} else if pattern.ends_with('%')
&& !pattern.ends_with("\\%")
&& !pattern[..pattern.len() - 1].contains(is_like_pattern)
{
Ok(Self::StartsWith(&pattern[..pattern.len() - 1]))
} else if pattern.starts_with('%') && !pattern[1..].contains(is_like_pattern) {
Ok(Self::EndsWith(&pattern[1..]))
} else if pattern.starts_with('%')
&& pattern.ends_with('%')
&& !pattern.ends_with("\\%")
&& !pattern[1..pattern.len() - 1].contains(is_like_pattern)
{
Ok(Self::Contains(&pattern[1..pattern.len() - 1]))
} else {
Ok(Self::Regex(regex_like(pattern, false)?))
}
}
pub fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> {
if is_ascii && pattern.is_ascii() {
if !pattern.contains(is_like_pattern) {
return Ok(Self::IEqAscii(pattern));
} else if pattern.ends_with('%')
&& !pattern.ends_with("\\%")
&& !pattern[..pattern.len() - 1].contains(is_like_pattern)
{
return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1]));
} else if pattern.starts_with('%') && !pattern[1..].contains(is_like_pattern) {
return Ok(Self::IEndsWithAscii(&pattern[1..]));
}
}
Ok(Self::Regex(regex_like(pattern, true)?))
}
pub fn evaluate(&self, haystack: &str) -> bool {
match self {
Predicate::Eq(v) => *v == haystack,
Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
Predicate::Contains(v) => haystack.contains(v),
Predicate::StartsWith(v) => haystack.starts_with(v),
Predicate::IStartsWithAscii(v) => starts_with_ignore_ascii_case(haystack, v),
Predicate::EndsWith(v) => haystack.ends_with(v),
Predicate::IEndsWithAscii(v) => ends_with_ignore_ascii_case(haystack, v),
Predicate::Regex(v) => v.is_match(haystack),
}
}
#[inline(never)]
pub fn evaluate_array<O: OffsetSizeTrait>(
&self,
array: &GenericStringArray<O>,
negate: bool,
) -> BooleanArray {
match self {
Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
(haystack.len() == v.len() && haystack == *v) != negate
}),
Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
haystack.eq_ignore_ascii_case(v) != negate
}),
Predicate::Contains(v) => {
BooleanArray::from_unary(array, |haystack| haystack.contains(v) != negate)
}
Predicate::StartsWith(v) => {
BooleanArray::from_unary(array, |haystack| haystack.starts_with(v) != negate)
}
Predicate::IStartsWithAscii(v) => BooleanArray::from_unary(array, |haystack| {
starts_with_ignore_ascii_case(haystack, v) != negate
}),
Predicate::EndsWith(v) => {
BooleanArray::from_unary(array, |haystack| haystack.ends_with(v) != negate)
}
Predicate::IEndsWithAscii(v) => BooleanArray::from_unary(array, |haystack| {
ends_with_ignore_ascii_case(haystack, v) != negate
}),
Predicate::Regex(v) => {
BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
}
}
}
}
fn starts_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool {
let end = haystack.len().min(needle.len());
haystack.is_char_boundary(end) && needle.eq_ignore_ascii_case(&haystack[..end])
}
fn ends_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool {
let start = haystack.len().saturating_sub(needle.len());
haystack.is_char_boundary(start) && needle.eq_ignore_ascii_case(&haystack[start..])
}
fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> {
let mut result = String::with_capacity(pattern.len() * 2);
result.push('^');
let mut chars_iter = pattern.chars().peekable();
while let Some(c) = chars_iter.next() {
if c == '\\' {
let next = chars_iter.peek();
match next {
Some(next) if is_like_pattern(*next) => {
result.push(*next);
chars_iter.next();
}
_ => {
result.push('\\');
result.push('\\');
}
}
} else if regex_syntax::is_meta_character(c) {
result.push('\\');
result.push(c);
} else if c == '%' {
result.push_str(".*");
} else if c == '_' {
result.push('.');
} else {
result.push(c);
}
}
result.push('$');
RegexBuilder::new(&result)
.case_insensitive(case_insensitive)
.dot_matches_new_line(true)
.build()
.map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Unable to build regex from LIKE pattern: {e}"
))
})
}
fn is_like_pattern(c: char) -> bool {
c == '%' || c == '_'
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replace_like_wildcards() {
let a_eq = "_%";
let expected = "^..*$";
let r = regex_like(a_eq, false).unwrap();
assert_eq!(r.to_string(), expected);
}
#[test]
fn test_replace_like_wildcards_leave_like_meta_chars() {
let a_eq = "\\%\\_";
let expected = "^%_$";
let r = regex_like(a_eq, false).unwrap();
assert_eq!(r.to_string(), expected);
}
#[test]
fn test_replace_like_wildcards_with_multiple_escape_chars() {
let a_eq = "\\\\%";
let expected = "^\\\\%$";
let r = regex_like(a_eq, false).unwrap();
assert_eq!(r.to_string(), expected);
}
#[test]
fn test_replace_like_wildcards_escape_regex_meta_char() {
let a_eq = ".";
let expected = "^\\.$";
let r = regex_like(a_eq, false).unwrap();
assert_eq!(r.to_string(), expected);
}
}