use std::fmt::Formatter;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::arrays::ScalarFn;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::expr::Expression;
use vortex_array::scalar_fn::Arity;
use vortex_array::scalar_fn::ChildName;
use vortex_array::scalar_fn::EmptyOptions;
use vortex_array::scalar_fn::ExecutionArgs;
use vortex_array::scalar_fn::ScalarFnId;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
#[derive(Clone)]
pub struct RowCount;
impl ScalarFnVTable for RowCount {
type Options = EmptyOptions;
fn id(&self) -> ScalarFnId {
ScalarFnId::from("vortex.row_count")
}
fn arity(&self, _options: &Self::Options) -> Arity {
Arity::Exact(0)
}
fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
unreachable!("RowCount has arity 0")
}
fn fmt_sql(
&self,
_options: &Self::Options,
_expr: &Expression,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
write!(f, "row_count()")
}
fn return_dtype(&self, _options: &Self::Options, _args: &[DType]) -> VortexResult<DType> {
Ok(DType::Primitive(PType::U64, Nullability::NonNullable))
}
fn execute(
&self,
_options: &Self::Options,
_args: &dyn ExecutionArgs,
_ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
vortex_bail!("RowCount must be substituted before evaluation")
}
fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
false
}
fn is_fallible(&self, _options: &Self::Options) -> bool {
false
}
}
pub fn contains_row_count(array: &ArrayRef) -> bool {
if array.is::<ExactScalarFn<RowCount>>() {
return true;
}
match array.as_opt::<ScalarFn>() {
Some(view) => view.iter_children().any(contains_row_count),
None => false,
}
}
pub fn substitute_row_count(array: ArrayRef, replacement: &ArrayRef) -> VortexResult<ArrayRef> {
if array.is::<ExactScalarFn<RowCount>>() {
vortex_ensure!(
replacement.len() == array.len(),
"RowCount replacement length {} does not match scope length {}",
replacement.len(),
array.len(),
);
vortex_ensure!(
replacement.dtype() == array.dtype(),
"RowCount replacement dtype {} does not match scope dtype {}",
replacement.dtype(),
array.dtype(),
);
return Ok(replacement.clone());
}
if !array.is::<ScalarFn>() {
return Ok(array);
}
let nchildren = array.nchildren();
let mut array = array;
for slot_idx in 0..nchildren {
let (taken, child) = unsafe { array.take_slot_unchecked(slot_idx)? };
let new_child = substitute_row_count(child, replacement)?;
array = unsafe { taken.put_slot_unchecked(slot_idx, new_child)? };
}
Ok(array)
}
#[cfg(test)]
mod tests {
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::internal::row_count::RowCount;
use crate::scalar_fn::vtable::ScalarFnVTableExt;
#[test]
fn row_count_helper_dtype() {
let expr = RowCount.new_expr(EmptyOptions, []);
assert_eq!(
expr.return_dtype(&DType::Primitive(PType::I32, Nullability::Nullable))
.unwrap(),
DType::Primitive(PType::U64, Nullability::NonNullable),
);
}
}