use crate::array::*;
use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::scalar::Scalar;
mod binary;
mod boolean;
mod primitive;
mod utf8;
mod simd;
pub use simd::{Simd8, Simd8Lanes};
pub use binary::compare as binary_compare;
pub use binary::compare_scalar as binary_compare_scalar;
pub use boolean::compare as boolean_compare;
pub use boolean::compare_scalar as boolean_compare_scalar;
pub use primitive::compare as primitive_compare;
pub use primitive::compare_scalar as primitive_compare_scalar;
pub(crate) use primitive::compare_values_op as primitive_compare_values_op;
pub use utf8::compare as utf8_compare;
pub use utf8::compare_scalar as utf8_compare_scalar;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Operator {
Lt,
LtEq,
Gt,
GtEq,
Eq,
Neq,
}
pub fn compare(lhs: &dyn Array, rhs: &dyn Array, operator: Operator) -> Result<BooleanArray> {
let data_type = lhs.data_type();
if data_type != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Comparison is only supported for arrays of the same logical type".to_string(),
));
}
match data_type {
DataType::Boolean => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
boolean::compare(lhs, rhs, operator)
}
DataType::Int8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<i8>(lhs, rhs, operator)
}
DataType::Int16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<i16>(lhs, rhs, operator)
}
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<i32>(lhs, rhs, operator)
}
DataType::Int64
| DataType::Timestamp(_, None)
| DataType::Date64
| DataType::Time64(_)
| DataType::Duration(_) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<i64>(lhs, rhs, operator)
}
DataType::UInt8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<u8>(lhs, rhs, operator)
}
DataType::UInt16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<u16>(lhs, rhs, operator)
}
DataType::UInt32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<u32>(lhs, rhs, operator)
}
DataType::UInt64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<u64>(lhs, rhs, operator)
}
DataType::Float16 => unreachable!(),
DataType::Float32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<f32>(lhs, rhs, operator)
}
DataType::Float64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<f64>(lhs, rhs, operator)
}
DataType::Utf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::compare::<i32>(lhs, rhs, operator)
}
DataType::LargeUtf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::compare::<i64>(lhs, rhs, operator)
}
DataType::Decimal(_, _) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<i128>(lhs, rhs, operator)
}
DataType::Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare::<i32>(lhs, rhs, operator)
}
DataType::LargeBinary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare::<i64>(lhs, rhs, operator)
}
_ => Err(ArrowError::NotYetImplemented(format!(
"Comparison between {:?} is not supported",
data_type
))),
}
}
pub fn compare_scalar(
lhs: &dyn Array,
rhs: &dyn Scalar,
operator: Operator,
) -> Result<BooleanArray> {
let data_type = lhs.data_type();
if data_type != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Comparison is only supported for the same logical type".to_string(),
));
}
Ok(match data_type {
DataType::Boolean => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
boolean::compare_scalar(lhs, rhs, operator)
}
DataType::Int8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i8>(lhs, rhs, operator)
}
DataType::Int16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i16>(lhs, rhs, operator)
}
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i32>(lhs, rhs, operator)
}
DataType::Int64
| DataType::Timestamp(_, None)
| DataType::Date64
| DataType::Time64(_)
| DataType::Duration(_) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i64>(lhs, rhs, operator)
}
DataType::UInt8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u8>(lhs, rhs, operator)
}
DataType::UInt16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u16>(lhs, rhs, operator)
}
DataType::UInt32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u32>(lhs, rhs, operator)
}
DataType::UInt64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u64>(lhs, rhs, operator)
}
DataType::Float16 => unreachable!(),
DataType::Float32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<f32>(lhs, rhs, operator)
}
DataType::Float64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<f64>(lhs, rhs, operator)
}
DataType::Decimal(_, _) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i128>(lhs, rhs, operator)
}
DataType::Utf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::compare_scalar::<i32>(lhs, rhs, operator)
}
DataType::LargeUtf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::compare_scalar::<i64>(lhs, rhs, operator)
}
DataType::Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare_scalar::<i32>(lhs, rhs, operator)
}
DataType::LargeBinary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare_scalar::<i64>(lhs, rhs, operator)
}
_ => {
return Err(ArrowError::NotYetImplemented(format!(
"Comparison between {:?} is not supported",
data_type
)))
}
})
}
pub fn can_compare(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(_)
| DataType::Int64
| DataType::Timestamp(_, None)
| DataType::Date64
| DataType::Time64(_)
| DataType::Duration(_)
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Decimal(_, _)
| DataType::Binary
| DataType::LargeBinary
)
}