use std::any::Any;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::sync::Arc;
use itertools::Itertools;
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
use crate::arrays::ConstantArray;
use crate::compute::Operator as Op;
use crate::operator::{LengthBounds, Operator, OperatorEq, OperatorHash, OperatorId, OperatorRef};
use crate::pipeline::bits::BitView;
use crate::pipeline::vec::Selection;
use crate::pipeline::view::ViewMut;
use crate::pipeline::{
BindContext, Element, Kernel, KernelContext, PipelinedOperator, RowSelection, VectorId,
};
#[derive(Debug)]
pub struct CompareOperator {
children: [OperatorRef; 2],
op: Op,
dtype: DType,
}
impl CompareOperator {
pub fn try_new(lhs: OperatorRef, rhs: OperatorRef, op: Op) -> VortexResult<CompareOperator> {
if lhs.dtype() != rhs.dtype() {
vortex_bail!(
"Cannot compare arrays with different dtypes: {} and {}",
lhs.dtype(),
rhs.dtype()
);
}
let lhs_const = lhs.as_any().downcast_ref::<ConstantArray>();
let rhs_const = rhs.as_any().downcast_ref::<ConstantArray>();
if lhs_const.is_some() && rhs_const.is_some() {
}
let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
let dtype = DType::Bool(nullability);
Ok(CompareOperator {
children: [lhs, rhs],
op,
dtype,
})
}
pub fn op(&self) -> Op {
self.op
}
}
impl OperatorHash for CompareOperator {
fn operator_hash<H: Hasher>(&self, state: &mut H) {
self.op.hash(state);
self.dtype.hash(state);
self.children.iter().for_each(|c| c.operator_hash(state));
}
}
impl OperatorEq for CompareOperator {
fn operator_eq(&self, other: &Self) -> bool {
self.op == other.op
&& self.dtype == other.dtype
&& self
.children
.iter()
.zip(other.children.iter())
.all(|(a, b)| a.operator_eq(b))
}
}
impl Operator for CompareOperator {
fn id(&self) -> OperatorId {
OperatorId::from("vortex.compare")
}
fn as_any(&self) -> &dyn Any {
self
}
fn dtype(&self) -> &DType {
&self.dtype
}
fn bounds(&self) -> LengthBounds {
self.children[0].bounds() & self.children[1].bounds()
}
fn children(&self) -> &[OperatorRef] {
&self.children
}
fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
let (lhs, rhs) = children
.into_iter()
.tuples()
.next()
.vortex_expect("missing");
Ok(Arc::new(CompareOperator {
children: [lhs, rhs],
op: self.op,
dtype: self.dtype.clone(),
}))
}
fn as_pipelined(&self) -> Option<&dyn PipelinedOperator> {
if let Some((left, right)) = self.children[0]
.as_pipelined()
.zip(self.children[1].as_pipelined())
&& left.row_selection() != right.row_selection()
{
return None;
}
Some(self)
}
}
macro_rules! match_each_compare_op {
($self:expr, | $enc:ident | $body:block) => {{
match $self {
Op::Eq => {
type $enc = Eq;
$body
}
Op::NotEq => {
type $enc = NotEq;
$body
}
Op::Gt => {
type $enc = Gt;
$body
}
Op::Gte => {
type $enc = Gte;
$body
}
Op::Lt => {
type $enc = Lt;
$body
}
Op::Lte => {
type $enc = Lte;
$body
}
}
}};
}
impl PipelinedOperator for CompareOperator {
fn row_selection(&self) -> RowSelection {
self.children[0]
.as_pipelined()
.map(|p| p.row_selection())
.unwrap_or(RowSelection::All)
}
#[allow(clippy::cognitive_complexity)]
fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
debug_assert_eq!(self.children[0].dtype(), self.children[1].dtype());
let DType::Primitive(ptype, _) = self.children[0].dtype() else {
vortex_bail!(
"Unsupported type for comparison: {}",
self.children[0].dtype()
)
};
let lhs_const = self.children[0].as_any().downcast_ref::<ConstantArray>();
if let Some(lhs_const) = lhs_const {
return match_each_native_ptype!(ptype, |T| {
match_each_compare_op!(self.op.swap(), |Op| {
Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
lhs: ctx.children()[1],
rhs: lhs_const
.scalar()
.as_primitive()
.typed_value::<T>()
.vortex_expect("scalar value not of type T"),
_phantom: PhantomData,
}) as Box<dyn Kernel>)
})
});
}
let rhs_const = self.children[1].as_any().downcast_ref::<ConstantArray>();
if let Some(rhs_const) = rhs_const {
return match_each_native_ptype!(ptype, |T| {
match_each_compare_op!(self.op, |Op| {
Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
lhs: ctx.children()[0],
rhs: rhs_const
.scalar()
.as_primitive()
.typed_value::<T>()
.vortex_expect("scalar value not of type T"),
_phantom: PhantomData,
}) as Box<dyn Kernel>)
})
});
}
match_each_native_ptype!(ptype, |T| {
match_each_compare_op!(self.op, |Op| {
Ok(Box::new(ComparePrimitiveKernel::<T, Op> {
lhs: ctx.children()[0],
rhs: ctx.children()[1],
_phantom: PhantomData,
}) as Box<dyn Kernel>)
})
})
}
fn vector_children(&self) -> Vec<usize> {
vec![0, 1]
}
fn batch_children(&self) -> Vec<usize> {
vec![]
}
}
pub struct ComparePrimitiveKernel<T, Op> {
lhs: VectorId,
rhs: VectorId,
_phantom: PhantomData<(T, Op)>,
}
impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel for ComparePrimitiveKernel<T, Op> {
fn step(
&self,
ctx: &KernelContext,
_chunk_idx: usize,
selection: &BitView,
out: &mut ViewMut,
) -> VortexResult<()> {
let lhs_vec = ctx.vector(self.lhs);
let lhs = lhs_vec.as_array::<T>();
let rhs_vec = ctx.vector(self.rhs);
let rhs = rhs_vec.as_array::<T>();
let bools = out.as_array_mut::<bool>();
match (lhs_vec.selection(), rhs_vec.selection()) {
(Selection::Prefix, Selection::Prefix) => {
for i in 0..selection.true_count() {
bools[i] = Op::compare(&lhs[i], &rhs[i]);
}
out.set_selection(Selection::Prefix)
}
(Selection::Mask, Selection::Mask) => {
let mut pos = 0;
selection.iter_ones(|idx| {
bools[pos] = Op::compare(&lhs[idx], &rhs[idx]);
pos += 1;
});
out.set_selection(Selection::Prefix)
}
(Selection::Mask, Selection::Prefix) => {
let mut pos = 0;
selection.iter_ones(|idx| {
bools[pos] = Op::compare(&lhs[idx], &rhs[pos]);
pos += 1;
});
out.set_selection(Selection::Prefix)
}
(Selection::Prefix, Selection::Mask) => {
let mut pos = 0;
selection.iter_ones(|idx| {
bools[pos] = Op::compare(&lhs[pos], &rhs[idx]);
pos += 1;
});
out.set_selection(Selection::Prefix)
}
}
Ok(())
}
}
struct ScalarComparePrimitiveKernel<T: Element + NativePType, Op: CompareOp<T>> {
lhs: VectorId,
rhs: T,
_phantom: PhantomData<Op>,
}
impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel
for ScalarComparePrimitiveKernel<T, Op>
{
fn step(
&self,
ctx: &KernelContext,
_chunk_idx: usize,
selection: &BitView,
out: &mut ViewMut,
) -> VortexResult<()> {
let lhs_vec = ctx.vector(self.lhs);
let lhs = lhs_vec.as_array::<T>();
let bools = out.as_array_mut::<bool>();
match lhs_vec.selection() {
Selection::Prefix => {
for i in 0..selection.true_count() {
bools[i] = Op::compare(&lhs[i], &self.rhs);
}
out.set_selection(Selection::Prefix)
}
Selection::Mask => {
selection.iter_ones(|idx| {
bools[idx] = Op::compare(&lhs[idx], &self.rhs);
});
out.set_selection(Selection::Mask)
}
}
Ok(())
}
}
pub(crate) trait CompareOp<T> {
fn compare(lhs: &T, rhs: &T) -> bool;
}
pub struct Eq;
impl<T: PartialEq> CompareOp<T> for Eq {
#[inline(always)]
fn compare(lhs: &T, rhs: &T) -> bool {
lhs == rhs
}
}
pub struct NotEq;
impl<T: PartialEq> CompareOp<T> for NotEq {
#[inline(always)]
fn compare(lhs: &T, rhs: &T) -> bool {
lhs != rhs
}
}
pub struct Gt;
impl<T: PartialOrd> CompareOp<T> for Gt {
#[inline(always)]
fn compare(lhs: &T, rhs: &T) -> bool {
lhs > rhs
}
}
pub struct Gte;
impl<T: PartialOrd> CompareOp<T> for Gte {
#[inline(always)]
fn compare(lhs: &T, rhs: &T) -> bool {
lhs >= rhs
}
}
pub struct Lt;
impl<T: PartialOrd> CompareOp<T> for Lt {
#[inline(always)]
fn compare(lhs: &T, rhs: &T) -> bool {
lhs < rhs
}
}
pub struct Lte;
impl<T: PartialOrd> CompareOp<T> for Lte {
#[inline(always)]
fn compare(lhs: &T, rhs: &T) -> bool {
lhs <= rhs
}
}