use std::fmt;
use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Formatter;
use std::hash::Hash;
use std::ops::Deref;
use std::sync::Arc;
use itertools::Itertools;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use crate::dtype::DType;
use crate::expr::StatsCatalog;
use crate::expr::display::DisplayTreeExpr;
use crate::expr::stats::Stat;
use crate::scalar_fn::ScalarFnRef;
use crate::scalar_fn::fns::root::Root;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Expression {
scalar_fn: ScalarFnRef,
children: Arc<Vec<Expression>>,
}
impl Deref for Expression {
type Target = ScalarFnRef;
fn deref(&self) -> &Self::Target {
&self.scalar_fn
}
}
impl Expression {
pub fn try_new(
scalar_fn: ScalarFnRef,
children: impl IntoIterator<Item = Expression>,
) -> VortexResult<Self> {
let children = Vec::from_iter(children);
vortex_ensure!(
scalar_fn.signature().arity().matches(children.len()),
"Expression arity mismatch: expected {} children but got {}",
scalar_fn.signature().arity(),
children.len()
);
Ok(Self {
scalar_fn,
children: children.into(),
})
}
pub fn scalar_fn(&self) -> &ScalarFnRef {
&self.scalar_fn
}
pub fn children(&self) -> &Arc<Vec<Expression>> {
&self.children
}
pub fn child(&self, n: usize) -> &Expression {
&self.children[n]
}
pub fn with_children(
mut self,
children: impl IntoIterator<Item = Expression>,
) -> VortexResult<Self> {
let children = Vec::from_iter(children);
vortex_ensure!(
self.signature().arity().matches(children.len()),
"Expression arity mismatch: expected {} children but got {}",
self.signature().arity(),
children.len()
);
self.children = Arc::new(children);
Ok(self)
}
pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
if self.is::<Root>() {
return Ok(scope.clone());
}
let dtypes: Vec<_> = self
.children
.iter()
.map(|c| c.return_dtype(scope))
.try_collect()?;
self.scalar_fn.return_dtype(&dtypes)
}
pub fn validity(&self) -> VortexResult<Expression> {
self.scalar_fn.validity(self)
}
pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
self.scalar_fn().stat_falsification(self, catalog)
}
pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option<Expression> {
self.scalar_fn().stat_expression(self, stat, catalog)
}
pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
self.stat_expression(Stat::Min, catalog)
}
pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option<Expression> {
self.stat_expression(Stat::Max, catalog)
}
pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.scalar_fn().fmt_sql(self, f)
}
pub fn display_tree(&self) -> impl Display {
DisplayTreeExpr(self)
}
}
impl Display for Expression {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.fmt_sql(f)
}
}
impl Drop for Expression {
fn drop(&mut self) {
if let Some(children) = Arc::get_mut(&mut self.children) {
let mut children_to_drop = std::mem::take(children);
while let Some(mut child) = children_to_drop.pop() {
if let Some(expr_children) = Arc::get_mut(&mut child.children) {
children_to_drop.append(expr_children);
}
}
}
}
}