use std::sync::LazyLock;
use arcref::ArcRef;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use crate::ArrayRef;
use crate::DynArray;
use crate::IntoArray as _;
use crate::compute::ComputeFn;
use crate::compute::ComputeFnVTable;
use crate::compute::InvocationArgs;
use crate::compute::Kernel;
use crate::compute::Output;
use crate::compute::UnaryArgs;
use crate::dtype::DType;
use crate::expr::stats::Precision;
use crate::expr::stats::Stat;
use crate::expr::stats::StatsProviderExt;
use crate::scalar::Scalar;
use crate::scalar::ScalarValue;
use crate::vtable::VTable;
static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
for kernel in inventory::iter::<NaNCountKernelRef> {
compute.register_kernel(kernel.0.clone());
}
compute
});
pub(crate) fn warm_up_vtable() -> usize {
NAN_COUNT_FN.kernels().len()
}
pub fn nan_count(array: &ArrayRef) -> VortexResult<usize> {
Ok(NAN_COUNT_FN
.invoke(&InvocationArgs {
inputs: &[array.into()],
options: &(),
})?
.unwrap_scalar()?
.as_primitive()
.as_::<usize>()
.vortex_expect("NaN count should not return null"))
}
struct NaNCount;
impl ComputeFnVTable for NaNCount {
fn invoke(
&self,
args: &InvocationArgs,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Output> {
let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
let array = array.to_array();
let nan_count = nan_count_impl(&array, kernels)?;
array.statistics().set(
Stat::NaNCount,
Precision::Exact(ScalarValue::from(nan_count as u64)),
);
Ok(Scalar::from(nan_count as u64).into())
}
fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
Stat::NaNCount
.dtype(array.dtype())
.ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
}
fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
Ok(1)
}
fn is_elementwise(&self) -> bool {
false
}
}
pub trait NaNCountKernel: VTable {
fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
}
pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
inventory::collect!(NaNCountKernelRef);
#[derive(Debug)]
pub struct NaNCountKernelAdapter<V: VTable>(pub V);
impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
pub const fn lift(&'static self) -> NaNCountKernelRef {
NaNCountKernelRef(ArcRef::new_ref(self))
}
}
impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
let Some(array) = array.as_opt::<V>() else {
return Ok(None);
};
let nan_count = V::nan_count(&self.0, array)?;
Ok(Some(Scalar::from(nan_count as u64).into()))
}
}
fn nan_count_impl(array: &ArrayRef, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
if array.is_empty() || array.valid_count()? == 0 {
return Ok(0);
}
if let Some(nan_count) = array
.statistics()
.get_as::<usize>(Stat::NaNCount)
.and_then(Precision::as_exact)
{
return Ok(nan_count);
}
let args = InvocationArgs {
inputs: &[array.into()],
options: &(),
};
for kernel in kernels {
if let Some(output) = kernel.invoke(&args)? {
return output
.unwrap_scalar()?
.as_primitive()
.as_::<usize>()
.ok_or_else(|| vortex_err!("NaN count should not return null"));
}
}
if !array.is_canonical() {
let canonical = array.to_canonical()?.into_array();
return nan_count(&canonical);
}
vortex_bail!(
"No NaN count kernel found for array type: {}",
array.dtype()
)
}