use std::fmt;
use egg::{Id, Language};
use crate::array::*;
use crate::planner::{Expr, RecExpr};
use crate::types::{ConvertError, DataValue};
pub struct Evaluator<'a> {
expr: &'a RecExpr,
id: Id,
}
impl fmt::Display for Evaluator<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let recexpr = self.node().build_recexpr(|id| self.expr[id].clone());
write!(f, "{recexpr}")
}
}
impl<'a> Evaluator<'a> {
pub fn new(expr: &'a RecExpr) -> Self {
Self {
expr,
id: Id::from(expr.as_ref().len() - 1),
}
}
fn node(&self) -> &Expr {
&self.expr[self.id]
}
fn next(&self, id: Id) -> Self {
Self {
expr: self.expr,
id,
}
}
pub fn eval_list(&self, chunk: &DataChunk) -> Result<DataChunk, ConvertError> {
let list = self.node().as_list();
if list.is_empty() {
return Ok(DataChunk::no_column(chunk.cardinality()));
}
list.iter().map(|id| self.next(*id).eval(chunk)).collect()
}
pub fn eval(&self, chunk: &DataChunk) -> Result<ArrayImpl, ConvertError> {
use Expr::*;
match self.node() {
ColumnIndex(idx) => Ok(chunk.array_at(idx.0 as _).clone()),
Constant(v) => {
let mut builder =
ArrayBuilderImpl::with_capacity(chunk.cardinality(), &v.data_type());
builder.push_n(chunk.cardinality(), v);
Ok(builder.finish())
}
Cast([ty, a]) => {
let array = self.next(*a).eval(chunk)?;
array.cast(self.next(*ty).node().as_type())
}
IsNull(a) => {
let array = self.next(*a).eval(chunk)?;
Ok(ArrayImpl::new_bool(
array.get_valid_bitmap().iter().map(|v| !v).collect(),
))
}
Asc(a) | Desc(a) | Nested(a) => self.next(*a).eval(chunk),
RowCount => Ok(ArrayImpl::new_null(
(0..chunk.cardinality()).map(|_| ()).collect(),
)),
Count(a) | Sum(a) | Min(a) | Max(a) | First(a) | Last(a) => self.next(*a).eval(chunk),
e => {
if let Some((op, a, b)) = e.binary_op() {
let left = self.next(a).eval(chunk)?;
let right = self.next(b).eval(chunk)?;
left.binary_op(&op, &right)
} else if let Some((op, a)) = e.unary_op() {
let array = self.next(a).eval(chunk)?;
array.unary_op(&op)
} else {
panic!("can not evaluate expression: {self}");
}
}
}
}
pub fn init_agg_states<B: FromIterator<DataValue>>(&self) -> B {
(self.node().as_list().iter())
.map(|id| self.next(*id).init_agg_state())
.collect()
}
fn init_agg_state(&self) -> DataValue {
use Expr::*;
match self.node() {
RowCount | Count(_) => DataValue::Int32(0),
Sum(_) | Min(_) | Max(_) | First(_) | Last(_) => DataValue::Null,
t => panic!("not aggregation: {t}"),
}
}
pub fn eval_agg_list(
&self,
states: &mut [DataValue],
chunk: &DataChunk,
) -> Result<(), ConvertError> {
let list = self.node().as_list();
for (state, id) in states.iter_mut().zip(list) {
*state = self.next(*id).eval_agg(state.clone(), chunk)?;
}
Ok(())
}
pub fn agg_list_append(
&self,
states: &mut [DataValue],
values: impl Iterator<Item = DataValue>,
) {
let list = self.node().as_list();
for ((state, id), value) in states.iter_mut().zip(list).zip(values) {
*state = self.next(*id).agg_append(state.clone(), value);
}
}
fn eval_agg(&self, state: DataValue, chunk: &DataChunk) -> Result<DataValue, ConvertError> {
impl DataValue {
fn add(self, other: Self) -> Self {
if self.is_null() {
other
} else {
self + other
}
}
fn or(self, other: Self) -> Self {
if self.is_null() {
other
} else {
self
}
}
}
use Expr::*;
match self.node() {
RowCount => Ok(state.add(DataValue::Int32(chunk.cardinality() as _))),
Count(a) => Ok(state.add(DataValue::Int32(self.next(*a).eval(chunk)?.count() as _))),
Sum(a) => Ok(state.add(self.next(*a).eval(chunk)?.sum())),
Min(a) => Ok(state.min(self.next(*a).eval(chunk)?.min_())),
Max(a) => Ok(state.max(self.next(*a).eval(chunk)?.max_())),
First(a) => Ok(state.or(self.next(*a).eval(chunk)?.first())),
Last(a) => Ok(self.next(*a).eval(chunk)?.last().or(state)),
t => panic!("not aggregation: {t}"),
}
}
fn agg_append(&self, state: DataValue, value: DataValue) -> DataValue {
use Expr::*;
match self.node() {
RowCount => state.add(DataValue::Int32(1)),
Count(_) => state.add(DataValue::Int32(!value.is_null() as _)),
Sum(_) => state.add(value),
Min(_) => state.min(value),
Max(_) => state.max(value),
First(_) => state.or(value),
Last(_) => value,
t => panic!("not aggregation: {t}"),
}
}
pub fn orders(&self) -> Vec<bool> {
(self.node().as_list().iter())
.map(|id| match self.next(*id).node() {
Expr::Asc(_) => false,
Expr::Desc(_) => true,
_ => panic!("not order"),
})
.collect()
}
}