use std::any::Any;
use std::sync::LazyLock;
use arcref::ArcRef;
use vortex_dtype::DType;
use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
use vortex_scalar::Scalar;
use crate::arrays::ConstantArray;
use crate::compute::{
BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
boolean, compare,
};
use crate::vtable::VTable;
use crate::{Array, ArrayRef, Canonical, IntoArray};
static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
for kernel in inventory::iter::<BetweenKernelRef> {
compute.register_kernel(kernel.0.clone());
}
compute
});
pub(crate) fn warm_up_vtable() -> usize {
BETWEEN_FN.kernels().len()
}
pub fn between(
arr: &dyn Array,
lower: &dyn Array,
upper: &dyn Array,
options: &BetweenOptions,
) -> VortexResult<ArrayRef> {
BETWEEN_FN
.invoke(&InvocationArgs {
inputs: &[arr.into(), lower.into(), upper.into()],
options,
})?
.unwrap_array()
}
pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
inventory::collect!(BetweenKernelRef);
pub trait BetweenKernel: VTable {
fn between(
&self,
arr: &Self::Array,
lower: &dyn Array,
upper: &dyn Array,
options: &BetweenOptions,
) -> VortexResult<Option<ArrayRef>>;
}
#[derive(Debug)]
pub struct BetweenKernelAdapter<V: VTable>(pub V);
impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
pub const fn lift(&'static self) -> BetweenKernelRef {
BetweenKernelRef(ArcRef::new_ref(self))
}
}
impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
let inputs = BetweenArgs::try_from(args)?;
let Some(array) = inputs.array.as_opt::<V>() else {
return Ok(None);
};
Ok(
V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
.map(|array| array.into()),
)
}
}
struct Between;
impl ComputeFnVTable for Between {
fn invoke(
&self,
args: &InvocationArgs,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Output> {
let BetweenArgs {
array,
lower,
upper,
options,
} = BetweenArgs::try_from(args)?;
let return_dtype = self.return_dtype(args)?;
if array.is_empty() {
return Ok(Canonical::empty(&return_dtype).into_array().into());
}
if (lower.is_invalid(0) || upper.is_invalid(0))
&& let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
&& (c_lower.is_null() || c_upper.is_null())
{
return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
.into_array()
.into());
}
if lower.as_constant().is_some_and(|v| v.is_null())
|| upper.as_constant().is_some_and(|v| v.is_null())
{
return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
.into_array()
.into());
}
for kernel in kernels {
if let Some(output) = kernel.invoke(args)? {
return Ok(output);
}
}
if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
return Ok(output);
}
Ok(boolean(
&compare(lower, array, options.lower_strict.to_operator())?,
&compare(array, upper, options.upper_strict.to_operator())?,
BooleanOperator::And,
)?
.into())
}
fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
let BetweenArgs {
array,
lower,
upper,
options: _,
} = BetweenArgs::try_from(args)?;
if !array.dtype().eq_ignore_nullability(lower.dtype()) {
vortex_bail!(
"Array and lower bound types do not match: {:?} != {:?}",
array.dtype(),
lower.dtype()
);
}
if !array.dtype().eq_ignore_nullability(upper.dtype()) {
vortex_bail!(
"Array and upper bound types do not match: {:?} != {:?}",
array.dtype(),
upper.dtype()
);
}
Ok(DType::Bool(
array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
))
}
fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
let BetweenArgs {
array,
lower,
upper,
options: _,
} = BetweenArgs::try_from(args)?;
if array.len() != lower.len() || array.len() != upper.len() {
vortex_bail!(
"Array lengths do not match: array:{} lower:{} upper:{}",
array.len(),
lower.len(),
upper.len()
);
}
Ok(array.len())
}
fn is_elementwise(&self) -> bool {
true
}
}
struct BetweenArgs<'a> {
array: &'a dyn Array,
lower: &'a dyn Array,
upper: &'a dyn Array,
options: &'a BetweenOptions,
}
impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
type Error = VortexError;
fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
if value.inputs.len() != 3 {
vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
}
let array = value.inputs[0]
.array()
.ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
let lower = value.inputs[1]
.array()
.ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
let upper = value.inputs[2]
.array()
.ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
let options = value
.options
.as_any()
.downcast_ref::<BetweenOptions>()
.vortex_expect("Expected options to be an operator");
Ok(BetweenArgs {
array,
lower,
upper,
options,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BetweenOptions {
pub lower_strict: StrictComparison,
pub upper_strict: StrictComparison,
}
impl Options for BetweenOptions {
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum StrictComparison {
Strict,
NonStrict,
}
impl StrictComparison {
pub const fn to_operator(&self) -> Operator {
match self {
StrictComparison::Strict => Operator::Lt,
StrictComparison::NonStrict => Operator::Lte,
}
}
}
#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_dtype::{Nullability, PType};
use super::*;
use crate::ToCanonical;
use crate::compute::conformance::search_sorted::rstest;
use crate::test_harness::to_int_indices;
#[rstest]
#[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
#[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
#[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
#[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
fn test_bounds(
#[case] lower_strict: StrictComparison,
#[case] upper_strict: StrictComparison,
#[case] expected: Vec<u64>,
) {
let lower = buffer![0, 0, 0, 0, 2].into_array();
let array = buffer![1, 0, 1, 0, 1].into_array();
let upper = buffer![2, 1, 1, 0, 0].into_array();
let matches = between(
array.as_ref(),
lower.as_ref(),
upper.as_ref(),
&BetweenOptions {
lower_strict,
upper_strict,
},
)
.unwrap()
.to_bool();
let indices = to_int_indices(matches).unwrap();
assert_eq!(indices, expected);
}
#[test]
fn test_constants() {
let lower = buffer![0, 0, 2, 0, 2].into_array();
let array = buffer![1, 0, 1, 0, 1].into_array();
let upper = ConstantArray::new(
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
5,
);
let matches = between(
array.as_ref(),
lower.as_ref(),
upper.as_ref(),
&BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::NonStrict,
},
)
.unwrap()
.to_bool();
let indices = to_int_indices(matches).unwrap();
assert!(indices.is_empty());
let upper = ConstantArray::new(Scalar::from(2), 5);
let matches = between(
array.as_ref(),
lower.as_ref(),
upper.as_ref(),
&BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::NonStrict,
},
)
.unwrap()
.to_bool();
let indices = to_int_indices(matches).unwrap();
assert_eq!(indices, vec![0, 1, 3]);
let lower = ConstantArray::new(Scalar::from(0), 5);
let matches = between(
array.as_ref(),
lower.as_ref(),
upper.as_ref(),
&BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::NonStrict,
},
)
.unwrap()
.to_bool();
let indices = to_int_indices(matches).unwrap();
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
}
}