use crate::datatypes::DataType;
use crate::error::{ArrowError, Result};
use crate::scalar::{Scalar, Utf8Scalar};
use crate::{array::*, bitmap::Bitmap};
use super::{super::utils::combine_validities, Operator};
fn compare_op<O, F>(lhs: &Utf8Array<O>, rhs: &Utf8Array<O>, op: F) -> Result<BooleanArray>
where
O: Offset,
F: Fn(&str, &str) -> bool,
{
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Cannot perform comparison operation on arrays of different length".to_string(),
));
}
let validity = combine_validities(lhs.validity(), rhs.validity());
let values = lhs
.values_iter()
.zip(rhs.values_iter())
.map(|(lhs, rhs)| op(lhs, rhs));
let values = Bitmap::from_trusted_len_iter(values);
Ok(BooleanArray::from_data(DataType::Boolean, values, validity))
}
fn compare_op_scalar<O, F>(lhs: &Utf8Array<O>, rhs: &str, op: F) -> BooleanArray
where
O: Offset,
F: Fn(&str, &str) -> bool,
{
let validity = lhs.validity().cloned();
let values = lhs.values_iter().map(|lhs| op(lhs, rhs));
let values = Bitmap::from_trusted_len_iter(values);
BooleanArray::from_data(DataType::Boolean, values, validity)
}
fn eq<O: Offset>(lhs: &Utf8Array<O>, rhs: &Utf8Array<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a == b)
}
fn eq_scalar<O: Offset>(lhs: &Utf8Array<O>, rhs: &str) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a == b)
}
fn neq<O: Offset>(lhs: &Utf8Array<O>, rhs: &Utf8Array<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a != b)
}
fn neq_scalar<O: Offset>(lhs: &Utf8Array<O>, rhs: &str) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a != b)
}
fn lt<O: Offset>(lhs: &Utf8Array<O>, rhs: &Utf8Array<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a < b)
}
fn lt_scalar<O: Offset>(lhs: &Utf8Array<O>, rhs: &str) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a < b)
}
fn lt_eq<O: Offset>(lhs: &Utf8Array<O>, rhs: &Utf8Array<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a <= b)
}
fn lt_eq_scalar<O: Offset>(lhs: &Utf8Array<O>, rhs: &str) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a <= b)
}
fn gt<O: Offset>(lhs: &Utf8Array<O>, rhs: &Utf8Array<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a > b)
}
fn gt_scalar<O: Offset>(lhs: &Utf8Array<O>, rhs: &str) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a > b)
}
fn gt_eq<O: Offset>(lhs: &Utf8Array<O>, rhs: &Utf8Array<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a >= b)
}
fn gt_eq_scalar<O: Offset>(lhs: &Utf8Array<O>, rhs: &str) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a >= b)
}
pub fn compare<O: Offset>(
lhs: &Utf8Array<O>,
rhs: &Utf8Array<O>,
op: Operator,
) -> Result<BooleanArray> {
match op {
Operator::Eq => eq(lhs, rhs),
Operator::Neq => neq(lhs, rhs),
Operator::Gt => gt(lhs, rhs),
Operator::GtEq => gt_eq(lhs, rhs),
Operator::Lt => lt(lhs, rhs),
Operator::LtEq => lt_eq(lhs, rhs),
}
}
pub fn compare_scalar<O: Offset>(
lhs: &Utf8Array<O>,
rhs: &Utf8Scalar<O>,
op: Operator,
) -> BooleanArray {
if !rhs.is_valid() {
return BooleanArray::new_null(DataType::Boolean, lhs.len());
}
compare_scalar_non_null(lhs, rhs.value(), op)
}
pub fn compare_scalar_non_null<O: Offset>(
lhs: &Utf8Array<O>,
rhs: &str,
op: Operator,
) -> BooleanArray {
match op {
Operator::Eq => eq_scalar(lhs, rhs),
Operator::Neq => neq_scalar(lhs, rhs),
Operator::Gt => gt_scalar(lhs, rhs),
Operator::GtEq => gt_eq_scalar(lhs, rhs),
Operator::Lt => lt_scalar(lhs, rhs),
Operator::LtEq => lt_eq_scalar(lhs, rhs),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_generic<O: Offset, F: Fn(&Utf8Array<O>, &Utf8Array<O>) -> Result<BooleanArray>>(
lhs: Vec<&str>,
rhs: Vec<&str>,
op: F,
expected: Vec<bool>,
) {
let lhs = Utf8Array::<O>::from_slice(lhs);
let rhs = Utf8Array::<O>::from_slice(rhs);
let expected = BooleanArray::from_slice(expected);
assert_eq!(op(&lhs, &rhs).unwrap(), expected);
}
fn test_generic_scalar<O: Offset, F: Fn(&Utf8Array<O>, &str) -> BooleanArray>(
lhs: Vec<&str>,
rhs: &str,
op: F,
expected: Vec<bool>,
) {
let lhs = Utf8Array::<O>::from_slice(lhs);
let expected = BooleanArray::from_slice(expected);
assert_eq!(op(&lhs, rhs), expected);
}
#[test]
fn test_gt_eq() {
test_generic::<i32, _>(
vec!["arrow", "datafusion", "flight", "parquet"],
vec!["flight", "flight", "flight", "flight"],
gt_eq,
vec![false, false, true, true],
)
}
#[test]
fn test_gt_eq_scalar() {
test_generic_scalar::<i32, _>(
vec!["arrow", "datafusion", "flight", "parquet"],
"flight",
gt_eq_scalar,
vec![false, false, true, true],
)
}
#[test]
fn test_eq() {
test_generic::<i32, _>(
vec!["arrow", "arrow", "arrow", "arrow"],
vec!["arrow", "parquet", "datafusion", "flight"],
eq,
vec![true, false, false, false],
)
}
#[test]
fn test_eq_scalar() {
test_generic_scalar::<i32, _>(
vec!["arrow", "parquet", "datafusion", "flight"],
"arrow",
eq_scalar,
vec![true, false, false, false],
)
}
#[test]
fn test_neq() {
test_generic::<i32, _>(
vec!["arrow", "arrow", "arrow", "arrow"],
vec!["arrow", "parquet", "datafusion", "flight"],
neq,
vec![false, true, true, true],
)
}
#[test]
fn test_neq_scalar() {
test_generic_scalar::<i32, _>(
vec!["arrow", "parquet", "datafusion", "flight"],
"arrow",
neq_scalar,
vec![false, true, true, true],
)
}
}