use egg::{define_language, CostFunction, Id, Symbol};
use crate::binder_v2::copy::ExtSource;
use crate::binder_v2::{BoundDrop, CreateTable};
use crate::catalog::{ColumnRefId, TableRefId};
use crate::parser::{BinaryOperator, UnaryOperator};
use crate::types::{ColumnIndex, DataTypeKind, DataValue};
mod cost;
mod explain;
mod rules;
pub use explain::Explain;
pub use rules::{ColumnIndexResolver, ExprAnalysis, TypeError, TypeSchemaAnalysis};
type EGraph = egg::EGraph<Expr, ExprAnalysis>;
type Rewrite = egg::Rewrite<Expr, ExprAnalysis>;
type Pattern = egg::Pattern<Expr>;
pub type RecExpr = egg::RecExpr<Expr>;
define_language! {
pub enum Expr {
Constant(DataValue), Type(DataTypeKind), Column(ColumnRefId), Table(TableRefId), ColumnIndex(ColumnIndex), ExtSource(ExtSource),
"`" = Nested(Id), "list" = List(Box<[Id]>),
"+" = Add([Id; 2]),
"-" = Sub([Id; 2]),
"*" = Mul([Id; 2]),
"/" = Div([Id; 2]),
"%" = Mod([Id; 2]),
"||" = StringConcat([Id; 2]),
">" = Gt([Id; 2]),
"<" = Lt([Id; 2]),
">=" = GtEq([Id; 2]),
"<=" = LtEq([Id; 2]),
"=" = Eq([Id; 2]),
"<>" = NotEq([Id; 2]),
"and" = And([Id; 2]),
"or" = Or([Id; 2]),
"xor" = Xor([Id; 2]),
"like" = Like([Id; 2]),
"-" = Neg(Id),
"not" = Not(Id),
"isnull" = IsNull(Id),
"if" = If([Id; 3]),
"max" = Max(Id),
"min" = Min(Id),
"sum" = Sum(Id),
"avg" = Avg(Id),
"count" = Count(Id),
"rowcount" = RowCount,
"first" = First(Id),
"last" = Last(Id),
"exists" = Exists(Id),
"in" = In([Id; 2]),
"cast" = Cast([Id; 2]),
"scan" = Scan([Id; 2]), "values" = Values(Box<[Id]>), "proj" = Proj([Id; 2]), "filter" = Filter([Id; 2]), "order" = Order([Id; 2]), "asc" = Asc(Id), "desc" = Desc(Id), "limit" = Limit([Id; 3]), "topn" = TopN([Id; 4]), "join" = Join([Id; 4]), "hashjoin" = HashJoin([Id; 5]), "inner" = Inner,
"left_outer" = LeftOuter,
"right_outer" = RightOuter,
"full_outer" = FullOuter,
"agg" = Agg([Id; 3]), CreateTable(CreateTable),
Drop(BoundDrop),
"insert" = Insert([Id; 3]), "delete" = Delete([Id; 2]), "copy_from" = CopyFrom([Id; 2]), "copy_to" = CopyTo([Id; 2]), "explain" = Explain(Id),
"prune" = Prune([Id; 2]), "empty" = Empty(Box<[Id]>),
Symbol(Symbol),
}
}
impl Expr {
pub const fn true_() -> Self {
Self::Constant(DataValue::Bool(true))
}
pub const fn null() -> Self {
Self::Constant(DataValue::Null)
}
pub const fn zero() -> Self {
Self::Constant(DataValue::Int32(0))
}
pub fn as_const(&self) -> DataValue {
let Self::Constant(v) = self else { panic!("not a constant: {self}") };
v.clone()
}
pub fn as_list(&self) -> &[Id] {
let Self::List(l) = self else { panic!("not a list: {self}") };
l
}
pub fn as_column(&self) -> ColumnRefId {
let Self::Column(c) = self else { panic!("not a columnn: {self}") };
*c
}
pub fn as_table(&self) -> TableRefId {
let Self::Table(t) = self else { panic!("not a table: {self}") };
*t
}
pub fn as_type(&self) -> &DataTypeKind {
let Self::Type(t) = self else { panic!("not a type: {self}") };
t
}
pub fn as_ext_source(&self) -> ExtSource {
let Self::ExtSource(v) = self else { panic!("not an external source: {self}") };
v.clone()
}
pub const fn binary_op(&self) -> Option<(BinaryOperator, Id, Id)> {
use BinaryOperator as Op;
#[allow(clippy::match_ref_pats)]
Some(match self {
&Self::Add([a, b]) => (Op::Plus, a, b),
&Self::Sub([a, b]) => (Op::Minus, a, b),
&Self::Mul([a, b]) => (Op::Multiply, a, b),
&Self::Div([a, b]) => (Op::Divide, a, b),
&Self::Mod([a, b]) => (Op::Modulo, a, b),
&Self::StringConcat([a, b]) => (Op::StringConcat, a, b),
&Self::Gt([a, b]) => (Op::Gt, a, b),
&Self::Lt([a, b]) => (Op::Lt, a, b),
&Self::GtEq([a, b]) => (Op::GtEq, a, b),
&Self::LtEq([a, b]) => (Op::LtEq, a, b),
&Self::Eq([a, b]) => (Op::Eq, a, b),
&Self::NotEq([a, b]) => (Op::NotEq, a, b),
&Self::And([a, b]) => (Op::And, a, b),
&Self::Or([a, b]) => (Op::Or, a, b),
&Self::Xor([a, b]) => (Op::Xor, a, b),
_ => return None,
})
}
pub const fn unary_op(&self) -> Option<(UnaryOperator, Id)> {
use UnaryOperator as Op;
#[allow(clippy::match_ref_pats)]
Some(match self {
&Self::Neg(a) => (Op::Minus, a),
&Self::Not(a) => (Op::Not, a),
_ => return None,
})
}
}
trait ExprExt {
fn as_list(&self) -> &[Id];
}
impl<D> ExprExt for egg::EClass<Expr, D> {
fn as_list(&self) -> &[Id] {
self.iter()
.find_map(|e| match e {
Expr::List(list) => Some(list),
_ => None,
})
.expect("not list")
}
}
pub fn optimize(expr: &RecExpr) -> RecExpr {
let runner = egg::Runner::default()
.with_expr(expr)
.run(&*rules::STAGE1_RULES);
let extractor = egg::Extractor::new(&runner.egraph, cost::NoPrune);
let (_, mut expr) = extractor.find_best(runner.roots[0]);
let mut best_cost = f32::MAX;
for _ in 0..3 {
let runner = egg::Runner::default()
.with_expr(&expr)
.with_iter_limit(6)
.run(&*rules::STAGE2_RULES);
let cost_fn = cost::CostFn {
egraph: &runner.egraph,
};
let extractor = egg::Extractor::new(&runner.egraph, cost_fn);
let cost;
(cost, expr) = extractor.find_best(runner.roots[0]);
if cost >= best_cost {
break;
}
best_cost = cost;
}
let runner = egg::Runner::default()
.with_expr(&expr)
.run(&*rules::STAGE3_RULES);
let cost_fn = cost::CostFn {
egraph: &runner.egraph,
};
let extractor = egg::Extractor::new(&runner.egraph, cost_fn);
(_, expr) = extractor.find_best(runner.roots[0]);
expr
}
pub fn costs(expr: &RecExpr) -> Vec<f32> {
let mut egraph = EGraph::default();
egraph.add_expr(expr);
let mut cost_fn = cost::CostFn { egraph: &egraph };
let mut costs = vec![0.0; expr.as_ref().len()];
for (i, node) in expr.as_ref().iter().enumerate() {
let cost = cost_fn.cost(node, |i| costs[usize::from(i)]);
costs[i] = cost;
}
costs
}