use std::sync::Arc;
use crate::array::*;
use crate::buffer::Buffer;
use crate::compute::util::apply_bin_op_to_option_bitmap;
use crate::datatypes::{ArrowNumericType, BooleanType, DataType, ToByteSlice};
use crate::error::{ArrowError, Result};
pub fn compare_op<T, F>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
op: F,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> bool,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}
let null_bit_buffer = apply_bin_op_to_option_bitmap(
left.data().null_bitmap(),
right.data().null_bitmap(),
|a, b| a & b,
)?;
let mut values = Vec::with_capacity(left.len());
for i in 0..left.len() {
values.push(op(left.value(i), right.value(i)));
}
let data = ArrayData::new(
DataType::Boolean,
left.len(),
None,
null_bit_buffer,
left.offset(),
vec![Buffer::from(values.to_byte_slice())],
vec![],
);
Ok(PrimitiveArray::<BooleanType>::from(Arc::new(data)))
}
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
fn simd_compare_op<T, F>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
op: F,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
F: Fn(T::Simd, T::Simd) -> T::SimdMask,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}
let null_bit_buffer = apply_bin_op_to_option_bitmap(
left.data().null_bitmap(),
right.data().null_bitmap(),
|a, b| a & b,
)?;
let lanes = T::lanes();
let mut result = BooleanBufferBuilder::new(left.len());
for i in (0..left.len()).step_by(lanes) {
let simd_left = T::load(left.value_slice(i, lanes));
let simd_right = T::load(right.value_slice(i, lanes));
let simd_result = op(simd_left, simd_right);
for i in 0..lanes {
result.append(T::mask_get(&simd_result, i))?;
}
}
let data = ArrayData::new(
DataType::Boolean,
left.len(),
None,
null_bit_buffer,
left.offset(),
vec![result.finish()],
vec![],
);
Ok(PrimitiveArray::<BooleanType>::from(Arc::new(data)))
}
pub fn eq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
return simd_compare_op(left, right, |a, b| T::eq(a, b));
#[allow(unreachable_code)]
compare_op(left, right, |a, b| a == b)
}
pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
return simd_compare_op(left, right, |a, b| T::ne(a, b));
#[allow(unreachable_code)]
compare_op(left, right, |a, b| a != b)
}
pub fn lt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
return simd_compare_op(left, right, |a, b| T::lt(a, b));
#[allow(unreachable_code)]
compare_op(left, right, |a, b| a < b)
}
pub fn lt_eq<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
return simd_compare_op(left, right, |a, b| T::le(a, b));
#[allow(unreachable_code)]
compare_op(left, right, |a, b| a <= b)
}
pub fn gt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
return simd_compare_op(left, right, |a, b| T::gt(a, b));
#[allow(unreachable_code)]
compare_op(left, right, |a, b| a > b)
}
pub fn gt_eq<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
return simd_compare_op(left, right, |a, b| T::ge(a, b));
#[allow(unreachable_code)]
compare_op(left, right, |a, b| a >= b)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::Int32Array;
#[test]
fn test_primitive_array_eq() {
let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
let c = eq(&a, &b).unwrap();
assert_eq!(false, c.value(0));
assert_eq!(false, c.value(1));
assert_eq!(true, c.value(2));
assert_eq!(false, c.value(3));
assert_eq!(false, c.value(4));
}
#[test]
fn test_primitive_array_neq() {
let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
let c = neq(&a, &b).unwrap();
assert_eq!(true, c.value(0));
assert_eq!(true, c.value(1));
assert_eq!(false, c.value(2));
assert_eq!(true, c.value(3));
assert_eq!(true, c.value(4));
}
#[test]
fn test_primitive_array_lt() {
let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
let c = lt(&a, &b).unwrap();
assert_eq!(false, c.value(0));
assert_eq!(false, c.value(1));
assert_eq!(false, c.value(2));
assert_eq!(true, c.value(3));
assert_eq!(true, c.value(4));
}
#[test]
fn test_primitive_array_lt_nulls() {
let a = Int32Array::from(vec![None, None, Some(1)]);
let b = Int32Array::from(vec![None, Some(1), None]);
let c = lt(&a, &b).unwrap();
assert_eq!(false, c.value(0));
assert_eq!(true, c.value(1));
assert_eq!(false, c.value(2));
}
#[test]
fn test_primitive_array_lt_eq() {
let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
let c = lt_eq(&a, &b).unwrap();
assert_eq!(false, c.value(0));
assert_eq!(false, c.value(1));
assert_eq!(true, c.value(2));
assert_eq!(true, c.value(3));
assert_eq!(true, c.value(4));
}
#[test]
fn test_primitive_array_lt_eq_nulls() {
let a = Int32Array::from(vec![None, None, Some(1)]);
let b = Int32Array::from(vec![None, Some(1), None]);
let c = lt_eq(&a, &b).unwrap();
assert_eq!(true, c.value(0));
assert_eq!(true, c.value(1));
assert_eq!(false, c.value(2));
}
#[test]
fn test_primitive_array_gt() {
let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
let c = gt(&a, &b).unwrap();
assert_eq!(true, c.value(0));
assert_eq!(true, c.value(1));
assert_eq!(false, c.value(2));
assert_eq!(false, c.value(3));
assert_eq!(false, c.value(4));
}
#[test]
fn test_primitive_array_gt_nulls() {
let a = Int32Array::from(vec![None, None, Some(1)]);
let b = Int32Array::from(vec![None, Some(1), None]);
let c = gt(&a, &b).unwrap();
assert_eq!(false, c.value(0));
assert_eq!(false, c.value(1));
assert_eq!(true, c.value(2));
}
#[test]
fn test_primitive_array_gt_eq() {
let a = Int32Array::from(vec![8, 8, 8, 8, 8]);
let b = Int32Array::from(vec![6, 7, 8, 9, 10]);
let c = gt_eq(&a, &b).unwrap();
assert_eq!(true, c.value(0));
assert_eq!(true, c.value(1));
assert_eq!(true, c.value(2));
assert_eq!(false, c.value(3));
assert_eq!(false, c.value(4));
}
#[test]
fn test_primitive_array_gt_eq_nulls() {
let a = Int32Array::from(vec![None, None, Some(1)]);
let b = Int32Array::from(vec![None, Some(1), None]);
let c = gt_eq(&a, &b).unwrap();
assert_eq!(true, c.value(0));
assert_eq!(false, c.value(1));
assert_eq!(true, c.value(2));
}
}