use std::fmt::Display;
use std::fmt::Formatter;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::aggregate_fn::AggregateFnRef;
use crate::arrays::ConstantArray;
use crate::dtype::DType;
use crate::expr::Expression;
use crate::expr::stats::Stat;
use crate::expr::stats::StatsProvider;
use crate::scalar::Scalar;
use crate::scalar_fn::Arity;
use crate::scalar_fn::ChildName;
use crate::scalar_fn::ExecutionArgs;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct StatOptions {
aggregate_fn: AggregateFnRef,
}
impl StatOptions {
pub fn new(aggregate_fn: AggregateFnRef) -> Self {
Self { aggregate_fn }
}
pub fn aggregate_fn(&self) -> &AggregateFnRef {
&self.aggregate_fn
}
}
impl Display for StatOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.aggregate_fn, f)
}
}
#[derive(Clone)]
pub struct StatFn;
impl ScalarFnVTable for StatFn {
type Options = StatOptions;
fn id(&self) -> ScalarFnId {
ScalarFnId::new("vortex.stat")
}
fn arity(&self, _options: &Self::Options) -> Arity {
Arity::Exact(1)
}
fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
match child_idx {
0 => ChildName::from("input"),
_ => unreachable!("Invalid child index {} for Stat expression", child_idx),
}
}
fn fmt_sql(
&self,
options: &Self::Options,
expr: &Expression,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
write!(f, "stat(")?;
expr.child(0).fmt_sql(f)?;
write!(f, ", {})", options.aggregate_fn())
}
fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
stat_dtype(options.aggregate_fn(), &arg_dtypes[0])
}
fn execute(
&self,
options: &Self::Options,
args: &dyn ExecutionArgs,
_ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let input = args.get(0)?;
let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?;
stat_array(&input, options.aggregate_fn(), dtype, args.row_count())
}
}
fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult<DType> {
let Some(dtype) = aggregate_fn.return_dtype(input_dtype) else {
vortex_bail!(
"Aggregate function {} does not support input dtype {}",
aggregate_fn,
input_dtype
);
};
Ok(dtype.as_nullable())
}
fn stat_array(
array: &ArrayRef,
aggregate_fn: &AggregateFnRef,
dtype: DType,
len: usize,
) -> VortexResult<ArrayRef> {
let value = if let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) {
array
.statistics()
.with_typed_stats_set(|stats| stats.get(stat))
.map(|stat| stat.into_inner())
.and_then(Scalar::into_value)
} else {
tracing::trace!(
"No legacy Stat slot for aggregate {}; stat expression will resolve to null",
aggregate_fn
);
None
};
let scalar = Scalar::try_new(dtype, value)?;
Ok(ConstantArray::new(scalar, len).into_array())
}