use arrow_buffer::BooleanBuffer;
use vortex_dtype::Nullability;
use vortex_error::{VortexResult, vortex_bail};
use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
use crate::vtable::ValidityHelper;
use crate::{Array, ArrayRef, IntoArray, register_kernel};
impl BetweenKernel for DecimalVTable {
fn between(
&self,
arr: &DecimalArray,
lower: &dyn Array,
upper: &dyn Array,
options: &BetweenOptions,
) -> VortexResult<Option<ArrayRef>> {
let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
return Ok(None);
};
let nullability =
arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
match_each_decimal_value_type!(arr.values_type(), |D| {
between_unpack::<D>(arr, lower, upper, nullability, options)
})
}
}
fn between_unpack<T: NativeDecimalType>(
arr: &DecimalArray,
lower: Scalar,
upper: Scalar,
nullability: Nullability,
options: &BetweenOptions,
) -> VortexResult<Option<ArrayRef>> {
let Some(lower_value) = lower
.as_decimal()
.decimal_value()
.and_then(|v| v.cast::<T>())
else {
vortex_bail!(
"invalid lower bound Scalar: {lower}, expected {:?}",
T::VALUES_TYPE
)
};
let Some(upper_value) = upper
.as_decimal()
.decimal_value()
.and_then(|v| v.cast::<T>())
else {
vortex_bail!(
"invalid upper bound Scalar: {upper}, expected {:?}",
T::VALUES_TYPE
)
};
let lower_op = match options.lower_strict {
StrictComparison::Strict => |a, b| a < b,
StrictComparison::NonStrict => |a, b| a <= b,
};
let upper_op = match options.upper_strict {
StrictComparison::Strict => |a, b| a < b,
StrictComparison::NonStrict => |a, b| a <= b,
};
Ok(Some(between_impl::<T>(
arr,
lower_value,
upper_value,
nullability,
lower_op,
upper_op,
)))
}
register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
fn between_impl<T: NativeDecimalType>(
arr: &DecimalArray,
lower: T,
upper: T,
nullability: Nullability,
lower_op: impl Fn(T, T) -> bool,
upper_op: impl Fn(T, T) -> bool,
) -> ArrayRef {
let buffer = arr.buffer::<T>();
BoolArray::from_bool_buffer(
BooleanBuffer::collect_bool(buffer.len(), |idx| {
let value = buffer[idx];
lower_op(lower, value) & upper_op(value, upper)
}),
arr.validity().clone().union_nullability(nullability),
)
.into_array()
}
#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_dtype::{DecimalDType, Nullability};
use vortex_scalar::{DecimalValue, Scalar};
use crate::arrays::{ConstantArray, DecimalArray};
use crate::compute::{BetweenOptions, StrictComparison, between};
use crate::validity::Validity;
use crate::{Array, ToCanonical};
#[test]
fn test_between() {
let values = buffer![100i128, 200i128, 300i128, 400i128];
let decimal_type = DecimalDType::new(3, 2);
let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
let lower = ConstantArray::new(
Scalar::decimal(
DecimalValue::I128(100i128),
decimal_type,
Nullability::NonNullable,
),
array.len(),
);
let upper = ConstantArray::new(
Scalar::decimal(
DecimalValue::I128(400i128),
decimal_type,
Nullability::NonNullable,
),
array.len(),
);
let between_strict = between(
array.as_ref(),
lower.as_ref(),
upper.as_ref(),
&BetweenOptions {
lower_strict: StrictComparison::Strict,
upper_strict: StrictComparison::NonStrict,
},
)
.unwrap();
assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
let between_strict = between(
array.as_ref(),
lower.as_ref(),
upper.as_ref(),
&BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::Strict,
},
)
.unwrap();
assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
}
fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
array.to_bool().boolean_buffer().iter().collect()
}
}