use indexmap::IndexSet;
use toasty_core::stmt::{self, visit_mut};
use crate::engine::{
Engine, HirStatement, SelectItems, eval,
exec::{MergeIndex, MergeQualification, NestedChild, NestedLevel},
hir, mir,
plan::HirPlanner,
};
#[derive(Debug)]
struct NestedMergePlanner<'a> {
engine: &'a Engine,
hir: &'a HirStatement,
mir: &'a mut mir::Store,
inputs: IndexSet<mir::NodeId>,
deps: IndexSet<mir::NodeId>,
hash_indexes: Vec<MergeIndex>,
sort_indexes: Vec<MergeIndex>,
stack: Vec<hir::StmtId>,
}
impl HirPlanner<'_> {
pub(super) fn plan_nested_merge(&mut self, stmt_id: hir::StmtId) -> Option<mir::NodeId> {
let stmt_state = &self.hir[stmt_id];
let need_nested_merge = stmt_state.args.iter().any(|arg| {
matches!(
arg,
hir::Arg::Sub {
returning: true,
..
}
)
});
if !need_nested_merge {
return None;
}
if stmt_state.stmt.as_ref().unwrap().is_insert() {
return None;
}
let nested_merge_planner = NestedMergePlanner {
engine: self.engine,
hir: self.hir,
mir: &mut self.mir,
inputs: IndexSet::new(),
deps: IndexSet::new(),
hash_indexes: vec![],
sort_indexes: vec![],
stack: vec![],
};
let node_id = nested_merge_planner.plan_nested_merge(stmt_id);
Some(node_id)
}
}
impl NestedMergePlanner<'_> {
fn plan_nested_merge(mut self, root: hir::StmtId) -> mir::NodeId {
self.stack.push(root);
let root = self.plan_nested_level(root, 0);
self.stack.pop();
self.mir.insert_with_deps(
mir::NestedMerge {
inputs: self.inputs,
root,
hash_indexes: self.hash_indexes,
sort_indexes: self.sort_indexes,
},
self.deps,
)
}
fn plan_nested_child(&mut self, stmt_id: hir::StmtId, depth: usize) -> NestedChild {
self.stack.push(stmt_id);
let level = self.plan_nested_level(stmt_id, depth);
let stmt_state = &self.hir[stmt_id];
let selection = stmt_state.load_data_select_items.get().unwrap();
let ret = match stmt_state.stmt.as_deref().unwrap() {
stmt::Statement::Query(query) => {
let filter_expr = self.build_filter_for_nested_child(stmt_id, selection, depth);
let filter_arg_tys = self.build_filter_arg_tys();
let qualification = match try_eq_lookup(&filter_expr, &filter_arg_tys, depth) {
Some((child_projections, lookup_key)) if query.single => {
let index = self.hash_indexes.len();
self.hash_indexes.push(MergeIndex {
source: level.source,
child_projections,
});
MergeQualification::HashLookup { index, lookup_key }
}
Some((child_projections, lookup_key)) => {
let index = self.sort_indexes.len();
self.sort_indexes.push(MergeIndex {
source: level.source,
child_projections,
});
MergeQualification::SortLookup { index, lookup_key }
}
None => {
MergeQualification::Scan(eval::Func::from_stmt(filter_expr, filter_arg_tys))
}
};
NestedChild {
level,
qualification,
single: query.single,
}
}
stmt::Statement::Insert(insert) => NestedChild {
level,
qualification: MergeQualification::All,
single: insert.source.single,
},
stmt => todo!("stmt={stmt:#?}"),
};
self.stack.pop();
ret
}
fn plan_nested_level(&mut self, stmt_id: hir::StmtId, depth: usize) -> NestedLevel {
let stmt_state = &self.hir[stmt_id];
let stmt = stmt_state.stmt.as_deref().unwrap();
let returning = stmt.returning_unwrap();
let source;
let mut nested = vec![];
let projection = match returning {
stmt::Returning::Expr(expr) => {
let (s, _) = self
.inputs
.insert_full(stmt_state.load_data_statement.get().unwrap());
source = s;
self.build_projection_from_expr(stmt_id, expr, depth, &mut nested)
}
_ => {
let node_id = stmt_state.output.get().unwrap();
let (s, _) = self.inputs.insert_full(node_id);
source = s;
let ty = match self.mir[node_id].ty().clone() {
stmt::Type::List(ty) => *ty,
ty => ty,
};
eval::Func::from_stmt(stmt::Expr::arg(0), vec![ty])
}
};
NestedLevel {
source,
projection,
nested,
}
}
fn build_filter_arg_tys(&self) -> Vec<stmt::Type> {
self.stack
.iter()
.map(|stmt_id| self.build_exec_statement_ty_for(*stmt_id))
.collect()
}
fn build_projection_arg_tys(&self, nested_children: &[NestedChild]) -> Vec<stmt::Type> {
let curr = self.stack.last().unwrap();
let mut projection_arg_tys = vec![self.build_exec_statement_ty_for(*curr)];
for nested in nested_children {
projection_arg_tys.push(if nested.single {
nested.level.projection.ret.clone()
} else {
stmt::Type::list(nested.level.projection.ret.clone())
});
}
projection_arg_tys
}
fn build_exec_statement_ty_for(&self, stmt_id: hir::StmtId) -> stmt::Type {
let stmt_state = &self.hir[stmt_id];
let stmt = stmt_state.stmt.as_deref().unwrap();
let cx = stmt::ExprContext::new_with_target(&*self.engine.schema, stmt);
let mut fields = vec![];
for select_item in stmt_state.load_data_select_items.get().unwrap() {
fields.push(select_item.infer_ty(&cx));
}
stmt::Type::Record(fields)
}
fn build_projection_from_expr(
&mut self,
stmt_id: hir::StmtId,
expr: &stmt::Expr,
depth: usize,
nested: &mut Vec<NestedChild>,
) -> eval::Func {
let hir = self.hir;
let selection = hir[stmt_id].load_data_select_items.get().unwrap();
let mut projection = expr.clone();
visit_mut::walk_expr_scoped_mut(&mut projection, 0, |expr, scope_depth| match expr {
stmt::Expr::Arg(expr_arg) if expr_arg.nesting == scope_depth => {
let position = expr_arg.position;
let stmt_state = &hir[stmt_id];
match &stmt_state.args[position] {
hir::Arg::Sub {
stmt_id: child_stmt_id,
..
} => {
let child_stmt_id = *child_stmt_id;
let child_stmt_state = &hir[child_stmt_id];
let child_stmt = child_stmt_state.stmt.as_deref().unwrap();
let child_returning = child_stmt.returning_unwrap();
match child_returning {
stmt::Returning::Value(returning_expr) if returning_expr.is_const() => {
match child_stmt {
stmt::Statement::Query(query) => {
if query.single {
let stmt::Expr::Value(v) = returning_expr else {
todo!()
};
assert!(!v.is_list());
}
}
stmt::Statement::Insert(insert) => {
if insert.source.single {
let stmt::Expr::Value(v) = returning_expr else {
todo!()
};
assert!(!v.is_list());
}
}
_ => {}
}
self.deps
.insert(child_stmt_state.load_data_statement.get().unwrap());
*expr = returning_expr.clone();
}
_ => {
let nested_child = self.plan_nested_child(child_stmt_id, depth + 1);
nested.push(nested_child);
*expr = stmt::Expr::arg(nested.len());
}
}
}
hir::Arg::Ref { .. } => todo!(),
}
false
}
stmt::Expr::Reference(expr_reference) => {
let expr_column = expr_reference.as_expr_column_unwrap();
debug_assert_eq!(0, expr_column.nesting);
let index = selection.get_index_of_expr_reference(*expr_column);
*expr = stmt::Expr::arg_project(0, [index]);
false
}
_ => true,
});
let projection_arg_tys = self.build_projection_arg_tys(nested);
eval::Func::from_stmt(projection, projection_arg_tys)
}
fn build_filter_for_nested_child(
&self,
stmt_id: hir::StmtId,
selection: &SelectItems,
depth: usize,
) -> stmt::Expr {
let stmt_state = &self.hir[stmt_id];
let stmt::Statement::Query(query) = stmt_state.stmt.as_deref().unwrap() else {
unreachable!()
};
let select = query.body.as_select_unwrap();
let mut filter = select.filter.clone();
visit_mut::for_each_expr_mut(&mut filter, |expr| match expr {
stmt::Expr::Arg(expr_arg) => {
let hir::Arg::Ref {
nesting,
stmt_id: target_id,
target_expr_ref,
..
} = &stmt_state.args[expr_arg.position]
else {
todo!()
};
debug_assert!(*nesting > 0);
let target_stmt = &self.hir[target_id];
let target_exec_statement_index = target_stmt
.load_data_select_items
.get()
.unwrap()
.get_index_of_expr_reference(*target_expr_ref);
*expr = stmt::Expr::arg_project(depth - *nesting, [target_exec_statement_index]);
}
stmt::Expr::Reference(expr_reference) => {
let index = selection.get_index_of_expr_reference(*expr_reference);
*expr = stmt::Expr::arg_project(depth, [index]);
}
_ => {}
});
filter.into_expr()
}
}
fn try_eq_lookup(
expr: &stmt::Expr,
arg_tys: &[stmt::Type],
depth: usize,
) -> Option<(Vec<stmt::Projection>, eval::Func)> {
let eq_terms: Vec<(&stmt::Expr, &stmt::Expr)> = match expr {
stmt::Expr::BinaryOp(op) if op.op == stmt::BinaryOp::Eq => {
vec![(&op.lhs, &op.rhs)]
}
stmt::Expr::And(and_expr) => {
let mut terms = vec![];
for operand in and_expr.operands.iter() {
match operand {
stmt::Expr::BinaryOp(op) if op.op == stmt::BinaryOp::Eq => {
terms.push((&*op.lhs, &*op.rhs));
}
_ => return None,
}
}
terms
}
_ => return None,
};
let mut child_projections = vec![];
let mut lookup_key_exprs = vec![];
for (lhs, rhs) in eq_terms {
let (child_proj, parent_expr) = extract_child_parent_eq(lhs, rhs, depth)?;
child_projections.push(child_proj);
lookup_key_exprs.push(parent_expr);
}
if child_projections.is_empty() {
return None;
}
let lookup_key_expr = if lookup_key_exprs.len() == 1 {
lookup_key_exprs.remove(0)
} else {
stmt::Expr::record_from_vec(lookup_key_exprs)
};
let lookup_key_arg_tys = arg_tys[..depth].to_vec();
let lookup_key = eval::Func::from_stmt(lookup_key_expr, lookup_key_arg_tys);
Some((child_projections, lookup_key))
}
fn extract_child_parent_eq(
lhs: &stmt::Expr,
rhs: &stmt::Expr,
depth: usize,
) -> Option<(stmt::Projection, stmt::Expr)> {
match (as_simple_arg_project(lhs), as_simple_arg_project(rhs)) {
(Some((l_pos, l_proj)), Some((r_pos, _))) if l_pos == depth && r_pos < depth => {
Some((l_proj.clone(), rhs.clone()))
}
(Some((l_pos, _)), Some((r_pos, r_proj))) if r_pos == depth && l_pos < depth => {
Some((r_proj.clone(), lhs.clone()))
}
_ => None,
}
}
fn as_simple_arg_project(expr: &stmt::Expr) -> Option<(usize, &stmt::Projection)> {
match expr {
stmt::Expr::Project(proj) => match proj.base.as_ref() {
stmt::Expr::Arg(arg) if arg.nesting == 0 => Some((arg.position, &proj.projection)),
_ => None,
},
_ => None,
}
}