use std::any::Any;
use std::sync::LazyLock;
use arcref::ArcRef;
use vortex_error::VortexError;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use crate::ArrayRef;
use crate::DynArray;
use crate::IntoArray as _;
use crate::arrays::ConstantVTable;
use crate::arrays::NullVTable;
use crate::compute::ComputeFn;
use crate::compute::ComputeFnVTable;
use crate::compute::InvocationArgs;
use crate::compute::Kernel;
use crate::compute::Options;
use crate::compute::Output;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::expr::stats::Precision;
use crate::expr::stats::Stat;
use crate::expr::stats::StatsProvider;
use crate::expr::stats::StatsProviderExt;
use crate::scalar::Scalar;
use crate::vtable::VTable;
static IS_CONSTANT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
let compute = ComputeFn::new("is_constant".into(), ArcRef::new_ref(&IsConstant));
for kernel in inventory::iter::<IsConstantKernelRef> {
compute.register_kernel(kernel.0.clone());
}
compute
});
pub(crate) fn warm_up_vtable() -> usize {
IS_CONSTANT_FN.kernels().len()
}
pub fn is_constant(array: &ArrayRef) -> VortexResult<Option<bool>> {
let opts = IsConstantOpts::default();
is_constant_opts(array, &opts)
}
pub fn is_constant_opts(array: &ArrayRef, options: &IsConstantOpts) -> VortexResult<Option<bool>> {
Ok(IS_CONSTANT_FN
.invoke(&InvocationArgs {
inputs: &[array.into()],
options,
})?
.unwrap_scalar()?
.as_bool()
.value())
}
struct IsConstant;
impl ComputeFnVTable for IsConstant {
fn invoke(
&self,
args: &InvocationArgs,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Output> {
let IsConstantArgs { array, options } = IsConstantArgs::try_from(args)?;
let array = array.to_array();
if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
let scalar: Scalar = Some(value).into();
return Ok(scalar.into());
}
let value = is_constant_impl(&array, options, kernels)?;
if options.cost == Cost::Canonicalize {
assert!(
value.is_some(),
"is constant in array {array} canonicalize returned None"
);
}
if let Some(value) = value {
array
.statistics()
.set(Stat::IsConstant, Precision::Exact(value.into()));
}
let scalar: Scalar = value.into();
Ok(scalar.into())
}
fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult<DType> {
Ok(DType::Bool(Nullability::Nullable))
}
fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
Ok(1)
}
fn is_elementwise(&self) -> bool {
false
}
}
fn is_constant_impl(
array: &ArrayRef,
options: &IsConstantOpts,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Option<bool>> {
match array.len() {
0 => return Ok(Some(false)),
1 => return Ok(Some(true)),
_ => {}
}
if array.is::<ConstantVTable>() || array.is::<NullVTable>() {
return Ok(Some(true));
}
let all_invalid = array.all_invalid()?;
if all_invalid {
return Ok(Some(true));
}
let all_valid = array.all_valid()?;
if !all_valid && !all_invalid {
return Ok(Some(false));
}
let min = array.statistics().get(Stat::Min);
let max = array.statistics().get(Stat::Max);
if let Some((min, max)) = min.zip(max) {
if min.is_exact()
&& min == max
&& (Stat::NaNCount.dtype(array.dtype()).is_none()
|| array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
{
return Ok(Some(true));
}
}
assert!(
all_valid,
"All values must be valid as an invariant of the VTable."
);
let args = InvocationArgs {
inputs: &[array.into()],
options,
};
for kernel in kernels {
if let Some(output) = kernel.invoke(&args)? {
return Ok(output.unwrap_scalar()?.as_bool().value());
}
}
tracing::debug!(
"No is_constant implementation found for {}",
array.encoding_id()
);
if options.cost == Cost::Canonicalize && !array.is_canonical() {
let array = array.to_canonical()?.into_array();
let is_constant = is_constant_opts(&array, options)?;
return Ok(is_constant);
}
Ok(None)
}
pub struct IsConstantKernelRef(ArcRef<dyn Kernel>);
inventory::collect!(IsConstantKernelRef);
pub trait IsConstantKernel: VTable {
fn is_constant(&self, array: &Self::Array, opts: &IsConstantOpts)
-> VortexResult<Option<bool>>;
}
#[derive(Debug)]
pub struct IsConstantKernelAdapter<V: VTable>(pub V);
impl<V: VTable + IsConstantKernel> IsConstantKernelAdapter<V> {
pub const fn lift(&'static self) -> IsConstantKernelRef {
IsConstantKernelRef(ArcRef::new_ref(self))
}
}
impl<V: VTable + IsConstantKernel> Kernel for IsConstantKernelAdapter<V> {
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
let args = IsConstantArgs::try_from(args)?;
let Some(array) = args.array.as_opt::<V>() else {
return Ok(None);
};
let is_constant = V::is_constant(&self.0, array, args.options)?;
let scalar: Scalar = is_constant.into();
Ok(Some(scalar.into()))
}
}
struct IsConstantArgs<'a> {
array: &'a dyn DynArray,
options: &'a IsConstantOpts,
}
impl<'a> TryFrom<&InvocationArgs<'a>> for IsConstantArgs<'a> {
type Error = VortexError;
fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
if value.inputs.len() != 1 {
vortex_bail!("Expected 1 input, found {}", value.inputs.len());
}
let array = value.inputs[0]
.array()
.ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
let options = value
.options
.as_any()
.downcast_ref::<IsConstantOpts>()
.ok_or_else(|| vortex_err!("Expected options to be of type IsConstantOpts"))?;
Ok(Self { array, options })
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Cost {
Negligible,
Specialized,
Canonicalize,
}
#[derive(Clone, Debug)]
pub struct IsConstantOpts {
pub cost: Cost,
}
impl Default for IsConstantOpts {
fn default() -> Self {
Self {
cost: Cost::Canonicalize,
}
}
}
impl Options for IsConstantOpts {
fn as_any(&self) -> &dyn Any {
self
}
}
impl IsConstantOpts {
pub fn is_negligible_cost(&self) -> bool {
self.cost == Cost::Negligible
}
}
#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use crate::IntoArray as _;
use crate::arrays::PrimitiveArray;
use crate::compute::is_constant;
use crate::expr::stats::Stat;
#[test]
fn is_constant_min_max_no_nan() {
let arr = buffer![0, 1].into_array();
arr.statistics()
.compute_all(&[Stat::Min, Stat::Max])
.unwrap();
assert!(!is_constant(&arr).unwrap().unwrap_or_default());
let arr = buffer![0, 0].into_array();
arr.statistics()
.compute_all(&[Stat::Min, Stat::Max])
.unwrap();
assert!(is_constant(&arr).unwrap().unwrap_or_default());
let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]).into_array();
assert!(is_constant(&arr).unwrap().unwrap_or_default());
}
#[test]
fn is_constant_min_max_with_nan() {
let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]).into_array();
arr.statistics()
.compute_all(&[Stat::Min, Stat::Max])
.unwrap();
assert!(!is_constant(&arr).unwrap().unwrap_or_default());
let arr =
PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)])
.into_array();
arr.statistics()
.compute_all(&[Stat::Min, Stat::Max])
.unwrap();
assert!(is_constant(&arr).unwrap().unwrap_or_default());
}
}