use std::any::Any;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Formatter;
use std::hash::Hash;
use std::sync::Arc;
use arcref::ArcRef;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::dtype::DType;
use crate::expr::Expression;
use crate::expr::StatsCatalog;
use crate::expr::stats::Stat;
use crate::scalar_fn::ScalarFn;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnRef;
pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
fn id(&self) -> ScalarFnId;
fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
_ = options;
Ok(None)
}
fn deserialize(
&self,
_metadata: &[u8],
_session: &VortexSession,
) -> VortexResult<Self::Options> {
vortex_bail!("Expression {} is not deserializable", self.id());
}
fn arity(&self, options: &Self::Options) -> Arity;
fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
fn fmt_sql(
&self,
options: &Self::Options,
expr: &Expression,
f: &mut Formatter<'_>,
) -> fmt::Result;
fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
let _ = options;
Ok(args.to_vec())
}
fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
fn execute(
&self,
options: &Self::Options,
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef>;
fn reduce(
&self,
options: &Self::Options,
node: &dyn ReduceNode,
ctx: &dyn ReduceCtx,
) -> VortexResult<Option<ReduceNodeRef>> {
_ = options;
_ = node;
_ = ctx;
Ok(None)
}
fn simplify(
&self,
options: &Self::Options,
expr: &Expression,
ctx: &dyn SimplifyCtx,
) -> VortexResult<Option<Expression>> {
_ = options;
_ = expr;
_ = ctx;
Ok(None)
}
fn simplify_untyped(
&self,
options: &Self::Options,
expr: &Expression,
) -> VortexResult<Option<Expression>> {
_ = options;
_ = expr;
Ok(None)
}
fn stat_falsification(
&self,
options: &Self::Options,
expr: &Expression,
catalog: &dyn StatsCatalog,
) -> Option<Expression> {
_ = options;
_ = expr;
_ = catalog;
None
}
fn stat_expression(
&self,
options: &Self::Options,
expr: &Expression,
stat: Stat,
catalog: &dyn StatsCatalog,
) -> Option<Expression> {
_ = options;
_ = expr;
_ = stat;
_ = catalog;
None
}
fn validity(
&self,
options: &Self::Options,
expression: &Expression,
) -> VortexResult<Option<Expression>> {
_ = (options, expression);
Ok(None)
}
fn is_null_sensitive(&self, options: &Self::Options) -> bool {
_ = options;
true
}
fn is_fallible(&self, options: &Self::Options) -> bool {
_ = options;
true
}
}
pub trait ReduceCtx {
fn new_node(
&self,
scalar_fn: ScalarFnRef,
children: &[ReduceNodeRef],
) -> VortexResult<ReduceNodeRef>;
}
pub type ReduceNodeRef = Arc<dyn ReduceNode>;
pub trait ReduceNode {
fn as_any(&self) -> &dyn Any;
fn node_dtype(&self) -> VortexResult<DType>;
fn scalar_fn(&self) -> Option<&ScalarFnRef>;
fn child(&self, idx: usize) -> ReduceNodeRef;
fn child_count(&self) -> usize;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Arity {
Exact(usize),
Variadic { min: usize, max: Option<usize> },
}
impl Display for Arity {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Arity::Exact(n) => write!(f, "{}", n),
Arity::Variadic { min, max } => match max {
Some(max) if min == max => write!(f, "{}", min),
Some(max) => write!(f, "{}..{}", min, max),
None => write!(f, "{}+", min),
},
}
}
}
impl Arity {
pub fn matches(&self, arg_count: usize) -> bool {
match self {
Arity::Exact(m) => *m == arg_count,
Arity::Variadic { min, max } => {
if arg_count < *min {
return false;
}
if let Some(max) = max
&& arg_count > *max
{
return false;
}
true
}
}
}
}
pub trait SimplifyCtx {
fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
}
pub trait ExecutionArgs {
fn get(&self, index: usize) -> VortexResult<ArrayRef>;
fn num_inputs(&self) -> usize;
fn row_count(&self) -> usize;
}
pub struct VecExecutionArgs {
inputs: Vec<ArrayRef>,
row_count: usize,
}
impl VecExecutionArgs {
pub fn new(inputs: Vec<ArrayRef>, row_count: usize) -> Self {
Self { inputs, row_count }
}
}
impl ExecutionArgs for VecExecutionArgs {
fn get(&self, index: usize) -> VortexResult<ArrayRef> {
self.inputs.get(index).cloned().ok_or_else(|| {
vortex_err!(
"Input index {} out of bounds (num_inputs={})",
index,
self.inputs.len()
)
})
}
fn num_inputs(&self) -> usize {
self.inputs.len()
}
fn row_count(&self) -> usize {
self.row_count
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct EmptyOptions;
impl Display for EmptyOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "")
}
}
pub trait ScalarFnVTableExt: ScalarFnVTable {
fn bind(&self, options: Self::Options) -> ScalarFnRef {
ScalarFn::new(self.clone(), options).erased()
}
fn new_expr(
&self,
options: Self::Options,
children: impl IntoIterator<Item = Expression>,
) -> Expression {
Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
}
fn try_new_expr(
&self,
options: Self::Options,
children: impl IntoIterator<Item = Expression>,
) -> VortexResult<Expression> {
Expression::try_new(self.bind(options), children)
}
}
impl<V: ScalarFnVTable> ScalarFnVTableExt for V {}
pub type ChildName = ArcRef<str>;