use std::mem;
use indexmap::{IndexMap, IndexSet};
use toasty_core::stmt::{self, Condition, visit_mut};
use toasty_core::driver::operation::QueryPkLimit;
use crate::{
Result,
engine::{
SelectItem, SelectItems, eval, exec,
hir::{self},
index::{self, IndexPlan},
mir,
plan::HirPlanner,
},
};
#[derive(Debug)]
struct LoadData {
inputs: IndexSet<mir::NodeId>,
select_items: SelectItems,
batch_load_args: IndexSet<usize>,
}
type Returning = Option<stmt::Returning>;
#[derive(Debug)]
struct ReturningInfo {
clause: Option<stmt::Returning>,
inputs: IndexSet<mir::NodeId>,
}
struct PaginationInfo {
page_size: i64,
cursor_column_indices: Vec<usize>,
}
struct PlanStatement<'a, 'b> {
planner: &'a mut HirPlanner<'b>,
stmt_id: hir::StmtId,
stmt_info: &'b hir::StatementInfo,
load_data: LoadData,
remaining_deps: Vec<hir::StmtId>,
}
impl HirPlanner<'_> {
pub(super) fn plan_statement(&mut self, stmt_id: hir::StmtId) -> Result<()> {
let stmt_info = &self.hir[stmt_id];
if stmt_info.load_data_statement.get().is_some() {
return Ok(());
}
for &dep_stmt_id in &stmt_info.deps {
if self.hir[dep_stmt_id].independent {
self.plan_statement(dep_stmt_id)?;
}
}
let stmt = stmt_info.stmt.as_deref().unwrap().clone();
let mut planner = PlanStatement {
planner: self,
stmt_id,
stmt_info,
load_data: LoadData {
inputs: IndexSet::new(),
select_items: SelectItems::new(),
batch_load_args: IndexSet::new(),
},
remaining_deps: stmt_info.deps.iter().cloned().collect(),
};
planner.plan(stmt)?;
Ok(())
}
}
impl<'a, 'b> PlanStatement<'a, 'b> {
fn plan(&mut self, mut stmt: stmt::Statement) -> Result<()> {
let mut returning = stmt.take_returning();
if returning.is_none()
&& let stmt::Statement::Query(query) = &mut stmt
&& let stmt::ExprSet::Values(values) = &mut query.body
{
returning = Some(stmt::Returning::Value(if query.single {
assert_eq!(1, values.rows.len());
values.rows.drain(..).next().unwrap()
} else {
stmt::Expr::list(std::mem::take(&mut values.rows))
}));
}
match &mut stmt {
stmt::Statement::Query(stmt) => stmt.single = false,
stmt::Statement::Insert(stmt) => stmt.source.single = false,
_ => {}
}
self.extract_columns_from_returning(&returning);
self.extract_data_load_args(&mut stmt);
self.collect_back_ref_columns();
if !self.load_data.batch_load_args.is_empty() {
debug_assert!(stmt.is_query());
self.rewrite_stmt_for_batch_load(&mut stmt);
} else if let stmt::Statement::Insert(insert) = &mut stmt {
self.rewrite_stmt_insert_arg_dependencies(insert);
} else if let stmt::Statement::Update(update) = &mut stmt {
self.rewrite_stmt_update_arg_dependencies(update);
}
let load_data_node_id = self.plan_data_loading(stmt, &mut returning)?;
self.stmt_info
.load_data_statement
.set(Some(load_data_node_id));
self.process_back_ref_projections(load_data_node_id);
self.stmt_info
.load_data_select_items
.set(mem::take(&mut self.load_data.select_items))
.unwrap();
self.plan_child_statements()?;
let returning_info = ReturningInfo {
inputs: self.extract_inputs_from_returning(&mut returning, load_data_node_id),
clause: returning,
};
let output_node_id = self.plan_output_node(load_data_node_id, returning_info);
self.stmt_info.output.set(Some(output_node_id));
Ok(())
}
fn extract_inputs_from_returning(
&mut self,
returning: &mut Returning,
load_data_node_id: mir::NodeId,
) -> IndexSet<mir::NodeId> {
let mut inputs = IndexSet::new();
let is_returning_projection = matches!(returning, Some(stmt::Returning::Expr(..)));
debug_assert!(
is_returning_projection || matches!(returning, None | Some(stmt::Returning::Value(..)))
);
visit_mut::for_each_expr_mut(returning, |expr| {
match expr {
stmt::Expr::Arg(expr_arg) => {
match &self.stmt_info.args[expr_arg.position] {
hir::Arg::Ref {
stmt_id: target_id,
returning_input,
batch_load_index,
target_expr_ref,
..
} => {
let target_stmt_info = &self.planner.hir[target_id];
let back_ref = &target_stmt_info.back_refs[&self.stmt_id];
let column = back_ref.exprs.get_index_of(target_expr_ref).unwrap();
if returning_input.get().is_none() {
let node_id = back_ref.node_id.get().unwrap();
let (index, _) = inputs.insert_full(node_id);
returning_input.set(Some(index));
}
let index = returning_input.get().unwrap();
let row = batch_load_index.get().unwrap();
let nesting = if self.stmt().is_insert() && is_returning_projection {
1
} else {
0
};
*expr = stmt::Expr::project(
stmt::ExprArg {
position: index,
nesting,
},
[row, column],
);
}
hir::Arg::Sub {
stmt_id: target_id, ..
} => {
assert!(
!(self.stmt().is_insert() && is_returning_projection),
"TODO"
);
let target_stmt_info = &self.planner.hir[target_id];
let target_node_id = target_stmt_info.output.get().expect("bug");
let (index, _) = inputs.insert_full(target_node_id);
*expr = stmt::Expr::arg(index);
}
}
}
stmt::Expr::Project(expr_project) if !is_returning_projection => {
if let stmt::Expr::Reference(expr_reference) = &*expr_project.base {
let [row] = expr_project.projection.as_slice() else {
todo!("expr_projec{expr_project:#?}")
};
let column = self.load_data_expr_reference_position(expr_reference);
let (position, _) = inputs.insert_full(load_data_node_id);
*expr = stmt::Expr::arg_project(position, [*row, column]);
}
}
stmt::Expr::Reference(expr_reference) if is_returning_projection => {
let column = self.load_data_expr_reference_position(expr_reference);
let (position, _) = inputs.insert_full(load_data_node_id);
*expr = stmt::Expr::arg_project(position, [column]);
}
stmt::Expr::Func(stmt::ExprFunc::Count(stmt::FuncCount { arg: None, .. }))
if is_returning_projection =>
{
let index = self
.stmt_info
.load_data_select_items
.get()
.unwrap()
.get_index_of_count_star();
let (position, _) = inputs.insert_full(load_data_node_id);
*expr = stmt::Expr::arg_project(position, [index]);
}
_ => {}
}
});
inputs
}
fn load_data_expr_reference_position(&self, expr_reference: &stmt::ExprReference) -> usize {
assert!(
expr_reference.is_column(),
"TODO: expr_reference = {expr_reference:#?}"
);
let Some(column) = self
.stmt_info
.load_data_select_items
.get()
.unwrap()
.try_get_index_of_expr_reference(*expr_reference)
else {
panic!(
"expr_reference={expr_reference:#?}; data_load.select_items={:#?}",
self.load_data.select_items
)
};
column
}
fn extract_columns_from_returning(&mut self, returning: &Returning) {
stmt::visit::for_each_expr(returning, |expr| match expr {
stmt::Expr::Reference(expr_reference) => {
assert!(
expr_reference.is_column(),
"TODO: expr_reference = {expr_reference:#?}"
);
self.load_data.select_items.insert((*expr_reference).into());
}
stmt::Expr::Func(stmt::ExprFunc::Count(stmt::FuncCount { arg: None, .. })) => {
self.load_data.select_items.insert(SelectItem::CountStar);
}
_ => {}
})
}
fn extract_data_load_args(&mut self, stmt: &mut stmt::Statement) {
if let Some(filter) = stmt.filter() {
stmt::visit::for_each_expr(filter, |expr| {
self.extract_data_load_args_from_expr(expr, None);
});
}
if let stmt::Statement::Insert(insert) = stmt {
let stmt::ExprSet::Values(values) = &insert.source.body else {
todo!()
};
for (i, row) in values.rows.iter().enumerate() {
stmt::visit::for_each_expr(row, |expr| {
self.extract_data_load_args_from_expr(expr, Some(i));
});
}
}
if let stmt::Statement::Update(update) = stmt {
for (_, assignment) in update.assignments.iter() {
stmt::visit::for_each_expr(assignment, |expr| {
self.extract_data_load_args_from_expr(expr, None);
});
}
}
}
fn extract_data_load_args_from_expr(&mut self, expr: &stmt::Expr, insert_row: Option<usize>) {
if let stmt::Expr::Arg(expr_arg) = expr {
match &self.stmt_info.args[expr_arg.position] {
hir::Arg::Sub {
stmt_id: target_id,
returning,
input,
batch_load_index,
..
} => {
debug_assert!(!returning, "the argument was found in a filter");
let target = &self.planner.hir[target_id];
let Some(node_id) = target.output.get() else {
panic!(
"bug: expected target statement to be planned; curr={:#?}; target={:#?}",
self.stmt_info, target
);
};
let (index, _) = self.load_data.inputs.insert_full(node_id);
batch_load_index.set(insert_row);
input.set(Some(index));
}
hir::Arg::Ref {
stmt_id: target_id,
data_load_input,
batch_load_index,
..
} => {
if data_load_input.get().is_some() {
return;
}
let target_stmt_info = &self.planner.hir[target_id];
let back_ref = &target_stmt_info.back_refs[&self.stmt_id];
let node_id = back_ref.node_id.get().unwrap();
let (index, _) = self.load_data.inputs.insert_full(node_id);
data_load_input.set(Some(index));
if target_stmt_info.stmt().is_query() {
debug_assert!(insert_row.is_none());
let (batch_load_table_ref_index, _) =
self.load_data.batch_load_args.insert_full(index);
batch_load_index.set(Some(batch_load_table_ref_index));
} else if let Some(row) = insert_row {
debug_assert!(target_stmt_info.stmt().is_insert());
if batch_load_index.get().is_none() {
batch_load_index.set(Some(row));
}
} else {
debug_assert!(
batch_load_index.get().is_some(),
"stmt={:#?}; target={:#?}; batch_load_index={:#?}",
self.stmt_info,
target_stmt_info,
batch_load_index.get()
);
}
}
}
}
}
fn collect_back_ref_columns(&mut self) {
for back_ref in self.stmt_info.back_refs.values() {
for expr in &back_ref.exprs {
self.load_data.select_items.insert((*expr).into());
}
}
}
fn rewrite_stmt_for_batch_load(&mut self, stmt: &mut stmt::Statement) {
if self.planner.engine.capability().sql {
self.rewrite_stmt_query_for_batch_load_sql(stmt);
} else {
self.rewrite_stmt_query_for_batch_load_nosql(stmt);
}
}
fn rewrite_stmt_query_for_batch_load_sql(&mut self, stmt: &mut stmt::Statement) {
let mut filter = stmt
.filter_mut()
.map(|filter| filter.take())
.unwrap_or_default();
visit_mut::for_each_expr_mut(&mut filter, |expr| {
match expr {
stmt::Expr::Reference(stmt::ExprReference::Column(expr_column)) => {
debug_assert_eq!(0, expr_column.nesting);
expr_column.nesting += 1;
}
stmt::Expr::Arg(expr_arg) => {
let hir::Arg::Ref {
stmt_id: target_id,
target_expr_ref,
batch_load_index: batch_load_table_ref_index,
..
} = &self.stmt_info.args[expr_arg.position]
else {
todo!()
};
let back_ref = &self.planner.hir[target_id].back_refs[&self.stmt_id];
let column = back_ref.exprs.get_index_of(target_expr_ref).unwrap();
*expr = stmt::Expr::column(stmt::ExprColumn {
nesting: 0,
table: batch_load_table_ref_index.get().unwrap(),
column,
});
}
_ => {}
}
});
let tables: Vec<stmt::TableRef> = self
.load_data
.batch_load_args
.iter()
.map(|position| stmt::TableRef::Arg(stmt::ExprArg::new(*position)))
.collect();
assert!(tables.len() <= 1, "TODO: handle more complicated cases");
let sub_query = stmt::Select {
returning: stmt::Returning::Expr(stmt::Expr::record([1])),
source: stmt::Source::Table(stmt::SourceTable {
tables,
from: vec![stmt::TableWithJoins {
relation: stmt::TableFactor::Table(stmt::SourceTableId(0)),
joins: vec![],
}],
}),
filter,
};
stmt.filter_mut_unwrap().set(stmt::Expr::exists(sub_query));
}
fn rewrite_stmt_query_for_batch_load_nosql(&mut self, stmt: &mut stmt::Statement) {
let mut filter = stmt.filter_expr_mut();
visit_mut::for_each_expr_mut(&mut filter, |expr| match expr {
stmt::Expr::Reference(stmt::ExprReference::Column(expr_column)) => {
debug_assert_eq!(0, expr_column.nesting);
}
stmt::Expr::Arg(expr_arg) => {
let hir::Arg::Ref {
stmt_id: target_id,
target_expr_ref,
..
} = &self.stmt_info.args[expr_arg.position]
else {
todo!()
};
let back_ref = &self.planner.hir[target_id].back_refs[&self.stmt_id];
let column = back_ref.exprs.get_index_of(target_expr_ref).unwrap();
*expr = stmt::Expr::arg(column);
}
_ => {}
});
assert!(
self.load_data.batch_load_args.len() == 1,
"TODO: handle more complicated cases"
);
let input = self.load_data.batch_load_args[0];
if let Some(filter) = filter {
let expr = filter.take();
*filter = stmt::Expr::any(stmt::Expr::map(stmt::Expr::arg(input), expr));
}
}
fn rewrite_stmt_insert_arg_dependencies(&mut self, stmt: &mut stmt::Insert) {
let stmt::ExprSet::Values(values) = &mut stmt.source.body else {
todo!()
};
for row in &mut values.rows {
self.rewrite_arg_dependencies(row);
}
}
fn rewrite_stmt_update_arg_dependencies(&mut self, stmt: &mut stmt::Update) {
for (_, assignment) in stmt.assignments.iter_mut() {
let expr = match assignment {
stmt::Assignment::Set(expr)
| stmt::Assignment::Insert(expr)
| stmt::Assignment::Remove(expr) => expr,
stmt::Assignment::Batch(_) => {
todo!("batch assignments in arg dependency rewriting")
}
};
self.rewrite_arg_dependencies(expr);
}
}
fn rewrite_arg_dependencies(&mut self, expr: &mut stmt::Expr) {
visit_mut::for_each_expr_mut(expr, |expr| {
if let stmt::Expr::Arg(expr_arg) = expr {
match &self.stmt_info.args[expr_arg.position] {
hir::Arg::Ref {
stmt_id: target_id,
target_expr_ref,
data_load_input,
batch_load_index,
..
} => {
debug_assert!(!self.load_data.inputs.is_empty(), "{:#?}", self.load_data);
let back_ref = &self.planner.hir[target_id].back_refs[&self.stmt_id];
let column = back_ref.exprs.get_index_of(target_expr_ref).unwrap();
*expr = stmt::Expr::arg_project(
data_load_input.get().unwrap(),
[batch_load_index.get().unwrap(), column],
);
}
hir::Arg::Sub { input, .. } => {
debug_assert!(
!self.load_data.inputs.is_empty(),
"{:#?} | is this needed?",
self.load_data
);
*expr = stmt::Expr::arg(input.get().unwrap());
}
}
}
});
}
fn plan_data_loading(
&mut self,
stmt: stmt::Statement,
returning: &mut Returning,
) -> Result<mir::NodeId> {
if self.load_data.select_items.contains(&SelectItem::CountStar)
&& !self.planner.engine.capability().sql
{
return Err(toasty_core::Error::unsupported_feature(
"count() queries are only supported with SQL drivers",
));
}
if let Some(node_id) = self.plan_const_or_empty_statement(&stmt, returning) {
debug_assert!(
stmt.is_query() || stmt.assignments().map(|a| a.is_empty()).unwrap_or(false),
"planned a mutable statement as const; stmt={:#?}",
stmt
);
Ok(node_id)
} else if self.planner.engine.capability().sql || stmt.is_insert() {
Ok(self.plan_data_loading_sql(stmt))
} else {
self.plan_data_loading_nosql(stmt)
}
}
fn plan_const_or_empty_statement(
&mut self,
stmt: &stmt::Statement,
returning: &mut Returning,
) -> Option<mir::NodeId> {
if stmt.is_const() {
let stmt::Value::List(rows) = stmt.eval_const().unwrap() else {
todo!()
};
return Some(
self.insert_const(
rows,
self.load_data
.select_items
.infer_record_list_ty(&self.planner.engine.expr_cx_for(stmt)),
),
);
}
if stmt.assignments().map(|a| a.is_empty()).unwrap_or(false) {
if returning.is_some() {
return Some(self.insert_const(
vec![stmt::Value::empty_sparse_record()],
stmt::Type::list(stmt::Type::empty_sparse_record()),
));
} else {
return Some(self.insert_const(
Vec::<stmt::Value>::new(),
stmt::Type::list(stmt::Type::empty_sparse_record()),
));
}
}
None
}
fn plan_data_loading_sql(&mut self, mut stmt: stmt::Statement) -> mir::NodeId {
let const_returning = self.extract_insert_returning_as_const(&stmt);
let pagination_info = self.plan_pagination_sql(&stmt);
if !self.load_data.select_items.is_empty() {
stmt.set_returning(
stmt::Expr::record(
self.load_data
.select_items
.iter()
.map(|item| item.to_expr()),
)
.into(),
);
}
let input_args: Vec<_> = self
.load_data
.inputs
.iter()
.map(|input| self.planner.mir.ty(*input).clone())
.collect();
let ty = self.planner.engine.infer_ty(&stmt, &input_args[..]);
let pagination_config = pagination_info.map(|info| self.build_extract_cursor(info, &ty));
let node = if stmt.condition().is_some() {
if let stmt::Statement::Update(stmt) = stmt {
assert!(stmt.returning.is_none(), "TODO: stmt={stmt:#?}");
if self.planner.engine.capability().cte_with_update {
mir::Operation::ExecStatement(Box::new(
self.plan_conditional_sql_query_as_cte(stmt, ty),
))
} else {
mir::Operation::ReadModifyWrite(Box::new(
self.plan_conditional_sql_query_as_rmw(stmt, ty),
))
}
} else {
todo!("stmt={stmt:#?}");
}
} else {
debug_assert!(
stmt.returning()
.and_then(|returning| returning.as_expr())
.map(|expr| expr.is_record())
.unwrap_or(true),
"stmt={stmt:#?}"
);
mir::Operation::ExecStatement(Box::new(mir::ExecStatement {
inputs: mem::take(&mut self.load_data.inputs),
stmt,
ty,
conditional_update_with_no_returning: false,
pagination: pagination_config.clone(),
}))
};
let mut exec_statement_node = self.insert_mir_with_deps(node);
if let Some((const_value, const_ty)) = const_returning {
exec_statement_node = self.planner.mir.insert_with_deps(
mir::Const {
value: const_value,
ty: const_ty,
},
[exec_statement_node],
);
}
exec_statement_node
}
fn extract_insert_returning_as_const(
&mut self,
stmt: &stmt::Statement,
) -> Option<(stmt::Value, stmt::Type)> {
let stmt::Statement::Insert(insert) = stmt else {
return None;
};
if self.load_data.select_items.is_empty() {
return None;
}
let target = insert.target.as_table_unwrap();
let values = insert.source.body.as_values()?;
let mut indices = vec![];
for select_item in &self.load_data.select_items {
let expr_ref = select_item.as_expr_reference_unwrap();
let expr_col = expr_ref.as_expr_column_unwrap();
debug_assert!(expr_col.nesting == 0, "expr_column={expr_col:#?}");
let Some(index) = target
.columns
.iter()
.enumerate()
.find(|(_, column_id)| column_id.index == expr_col.column)
.map(|(index, _)| index)
else {
todo!("insert returning referencing parent statement");
};
indices.push(index);
}
let mut result = Vec::with_capacity(values.rows.len());
for row in &values.rows {
let mut fields = Vec::with_capacity(indices.len());
for &index in &indices {
let value = row.entry(index)?.eval_const().ok()?;
fields.push(value);
}
result.push(stmt::Value::record_from_vec(fields));
}
let ty = self
.load_data
.select_items
.infer_record_list_ty(&self.planner.engine.expr_cx_for(stmt));
Some((stmt::Value::List(result), ty))
}
fn plan_pagination_sql(&mut self, stmt: &stmt::Statement) -> Option<PaginationInfo> {
let stmt::Statement::Query(query) = stmt else {
return None;
};
let stmt::Limit::Cursor(cursor) = query.limit.as_ref()? else {
return None;
};
let page_size = match &cursor.page_size {
stmt::Expr::Value(stmt::Value::I64(n)) => *n,
_ => return None,
};
let order_by = query.order_by.as_ref()?;
let mut cursor_column_indices = Vec::new();
for order_expr in &order_by.exprs {
if let Some(expr_ref) = order_expr.expr.as_expr_reference().copied() {
let (index, _) = self
.load_data
.select_items
.insert_full(SelectItem::from(expr_ref));
cursor_column_indices.push(index);
} else {
return None;
}
}
Some(PaginationInfo {
page_size,
cursor_column_indices,
})
}
fn build_extract_cursor(
&self,
info: PaginationInfo,
ty: &stmt::Type,
) -> exec::PaginationConfig {
let row_ty = if let stmt::Type::List(item_ty) = ty {
(**item_ty).clone()
} else {
stmt::Type::Unit
};
let extract_cursor = stmt::Expr::record(
info.cursor_column_indices
.into_iter()
.map(|index| stmt::Expr::arg_project(0, [index])),
);
let extract_cursor_func = eval::Func::from_stmt(extract_cursor, vec![row_ty]);
exec::PaginationConfig {
page_size: info.page_size,
extract_cursor: Some(extract_cursor_func),
}
}
fn plan_conditional_sql_query_as_cte(
&mut self,
stmt: stmt::Update,
ty: stmt::Type,
) -> mir::ExecStatement {
let Some(condition) = stmt.condition.expr else {
panic!("conditional update without condition");
};
let Some(filter) = stmt.filter.expr else {
panic!("conditional update without filter");
};
let stmt::UpdateTarget::Table(target) = stmt.target.clone() else {
panic!("conditional update without table");
};
let mut ctes = vec![];
ctes.push(stmt::Cte {
query: stmt::Query::builder(target)
.filter(filter.clone())
.returning(vec![
stmt::Expr::count_star(),
stmt::FuncCount {
arg: None,
filter: Some(Box::new(condition)),
}
.into(),
])
.build(),
});
let returning_len = match &stmt.returning {
Some(stmt::Returning::Expr(expr)) => {
let stmt::Expr::Record(expr_record) = expr else {
panic!("returning must be a record");
};
expr_record.fields.len()
}
Some(_) => todo!(),
None => 0,
};
ctes.push(stmt::Cte {
query: stmt::Query::new(stmt::Update {
target: stmt.target,
assignments: stmt.assignments,
filter: stmt::Filter::new(stmt::Expr::and(
filter,
stmt::Expr::stmt(stmt::Select {
source: stmt::TableRef::Cte {
nesting: 2,
index: 0,
}
.into(),
filter: true.into(),
returning: stmt::Returning::Expr(stmt::Expr::record_from_vec(vec![
stmt::Expr::eq(
stmt::ExprColumn {
nesting: 0,
table: 0,
column: 0,
},
stmt::ExprColumn {
nesting: 0,
table: 0,
column: 1,
},
),
])),
}),
)),
condition: Condition::default(),
returning: Some(
stmt.returning
.unwrap_or_else(|| {
stmt::Returning::Expr(stmt::Expr::record_from_vec(vec![
stmt::Expr::from("hello"),
]))
}),
),
}),
});
let mut columns = vec![
stmt::Expr::column(stmt::ExprColumn {
nesting: 0,
table: 0,
column: 0,
}),
stmt::Expr::column(stmt::ExprColumn {
nesting: 0,
table: 0,
column: 1,
}),
];
for i in 0..returning_len {
columns.push(stmt::Expr::column(stmt::ExprColumn {
nesting: 0,
table: 1,
column: i,
}));
}
let stmt = stmt::Query::builder(stmt::Select {
source: stmt::Source::table_with_joins(
vec![
stmt::TableRef::Cte {
nesting: 0,
index: 0,
},
stmt::TableRef::Cte {
nesting: 0,
index: 1,
},
],
stmt::TableWithJoins {
relation: stmt::TableFactor::Table(stmt::SourceTableId(0)),
joins: vec![stmt::Join {
table: stmt::SourceTableId(1),
constraint: stmt::JoinOp::Left(stmt::Expr::from(true)),
}],
},
),
filter: stmt::Filter::new(true),
returning: stmt::Returning::Expr(stmt::Expr::record_from_vec(columns)),
})
.with(ctes)
.build()
.into();
mir::ExecStatement {
inputs: mem::take(&mut self.load_data.inputs),
stmt,
ty,
conditional_update_with_no_returning: true,
pagination: None,
}
}
fn plan_conditional_sql_query_as_rmw(
&mut self,
stmt: stmt::Update,
ty: stmt::Type,
) -> mir::ReadModifyWrite {
assert!(stmt.returning.is_none(), "TODO: support returning");
let Some(condition) = stmt.condition.expr else {
panic!("conditional update without condition");
};
let Some(filter) = stmt.filter.expr else {
panic!("conditional update without filter");
};
let stmt::UpdateTarget::Table(target) = stmt.target.clone() else {
panic!("conditional update without table");
};
let read = stmt::Query::builder(target)
.filter(filter.clone())
.returning(vec![
stmt::Expr::count_star(),
stmt::FuncCount {
arg: None,
filter: Some(Box::new(condition)),
}
.into(),
])
.locks(if self.planner.engine.capability().select_for_update {
vec![stmt::Lock::Update]
} else {
vec![]
})
.build();
let write = stmt::Update {
target: stmt.target,
assignments: stmt.assignments,
filter: stmt::Filter::new(filter),
condition: stmt::Condition::default(),
returning: None,
};
mir::ReadModifyWrite {
inputs: mem::take(&mut self.load_data.inputs),
read,
write: write.into(),
ty,
}
}
fn plan_data_loading_nosql(&mut self, stmt: stmt::Statement) -> Result<mir::NodeId> {
if stmt.is_insert() {
debug_assert!(self.load_data.select_items.is_empty());
}
let mut index_plan = self.planner.engine.plan_index_path(&stmt)?;
let post_filter = self.prepare_post_filter(&stmt, &mut index_plan);
let ty = if self.load_data.select_items.is_empty() {
if stmt.is_query() {
stmt::Type::list(stmt::Type::Record(vec![]))
} else {
stmt::Type::Unit
}
} else {
self.load_data
.select_items
.infer_record_list_ty(&self.planner.engine.expr_cx_for(&stmt))
};
let node_id = if index_plan.index.primary_key {
self.plan_primary_key_execution(stmt, &mut index_plan, &ty)
} else {
self.plan_secondary_index_execution(stmt, &mut index_plan, &ty)
};
Ok(self.apply_post_filter(node_id, post_filter, ty))
}
fn plan_primary_key_execution(
&mut self,
stmt: stmt::Statement,
index_plan: &mut index::IndexPlan,
ty: &stmt::Type,
) -> mir::NodeId {
if let Some(mut key_expr) = index_plan.key_values.take() {
let (args, _input_nodes) = self.rewrite_expr_for_mir(&mut key_expr);
let key_ty =
stmt::Type::list(self.planner.engine.index_key_record_ty(index_plan.index));
let keys = eval::Func::from_stmt_typed(key_expr, args, key_ty);
let get_by_key_input = self.build_get_by_key_input(keys, self.index_key_ty(index_plan));
self.build_key_operation(&stmt, index_plan, get_by_key_input, ty)
} else {
let input = if self.load_data.inputs.is_empty() {
None
} else if self.load_data.inputs.len() == 1 {
Some(self.load_data.inputs[0])
} else {
todo!()
};
if stmt.is_query() {
let limit = extract_query_pk_limit(&stmt);
let order = extract_query_pk_order(&stmt);
self.insert_mir_with_deps(mir::QueryPk {
input,
table: index_plan.table_id(),
index: None, columns: self.load_data.select_items.extract_expr_references(),
pk_filter: index_plan.index_filter.take(),
row_filter: index_plan.result_filter.take(),
ty: ty.clone(),
limit,
order,
})
} else {
let index_key_ty = self.index_key_ty(index_plan);
let mut columns = self.load_data.select_items.extract_expr_references();
assert!(columns.is_empty());
for index_col in &index_plan.index.columns {
columns.insert(stmt::ExprReference::Column(stmt::ExprColumn {
nesting: 0,
table: 0,
column: index_col.column.index,
}));
}
let query_pk_node = self.insert_mir_with_deps(mir::QueryPk {
input,
table: index_plan.table_id(),
index: None, columns,
pk_filter: index_plan.index_filter.take(),
row_filter: index_plan.result_filter.take(),
ty: index_key_ty,
limit: None,
order: None,
});
self.build_key_operation(&stmt, index_plan, query_pk_node, ty)
}
}
}
fn plan_secondary_index_execution(
&mut self,
stmt: stmt::Statement,
index_plan: &mut index::IndexPlan,
ty: &stmt::Type,
) -> mir::NodeId {
let inputs = mem::take(&mut self.load_data.inputs);
assert!(index_plan.post_filter.is_none(), "TODO");
assert!(inputs.len() <= 1, "TODO: inputs={:#?}", inputs);
if stmt.is_query() && !index_plan.index.unique {
let input = if inputs.is_empty() {
None
} else {
Some(inputs[0])
};
let limit = extract_query_pk_limit(&stmt);
let order = extract_query_pk_order(&stmt);
return self.insert_mir_with_deps(mir::QueryPk {
input,
table: index_plan.index.on,
index: Some(index_plan.index.id), columns: self.load_data.select_items.extract_expr_references(), pk_filter: index_plan.index_filter.take(),
row_filter: index_plan.result_filter.take(),
ty: ty.clone(), limit,
order,
});
}
let index_key_ty = self.index_key_ty(index_plan);
let get_by_key_input = self.insert_mir_with_deps(mir::FindPkByIndex {
inputs,
table: index_plan.index.on,
index: index_plan.index.id,
filter: index_plan.index_filter.take(),
ty: index_key_ty,
});
self.build_key_operation(&stmt, index_plan, get_by_key_input, ty)
}
fn prepare_post_filter(
&mut self,
stmt: &stmt::Statement,
index_plan: &mut index::IndexPlan,
) -> Option<stmt::Expr> {
let mut post_filter = index_plan.post_filter.clone();
if stmt.is_query()
&& (index_plan.has_pk_keys || !index_plan.index.primary_key)
&& let Some(result_filter) = index_plan.result_filter.take()
{
post_filter = Some(match post_filter {
Some(post_filter) => stmt::Expr::and(result_filter, post_filter),
None => result_filter,
});
}
debug_assert!(
post_filter.is_none() || stmt.is_query(),
"stmt={:#?}; post_filter={post_filter:#?}",
stmt
);
if let Some(post_filter) = &mut post_filter {
visit_mut::for_each_expr_mut(post_filter, |expr| match expr {
stmt::Expr::Reference(expr_reference) => {
let (index, _) = self
.load_data
.select_items
.insert_full((*expr_reference).into());
*expr = stmt::Expr::arg_project(0, [index]);
}
stmt::Expr::Arg(_) => todo!("expr={expr:#?}"),
_ => {}
});
}
post_filter
}
fn apply_post_filter(
&mut self,
mut node_id: mir::NodeId,
post_filter: Option<stmt::Expr>,
ty: stmt::Type,
) -> mir::NodeId {
if let Some(post_filter) = post_filter {
let item_ty = ty.as_list_unwrap();
node_id = self.planner.mir.insert(mir::Filter {
input: node_id,
filter: eval::Func::from_stmt(post_filter, vec![item_ty.clone()]),
ty,
});
}
node_id
}
fn build_get_by_key_input(
&mut self,
keys: eval::Func,
index_key_ty: stmt::Type,
) -> mir::NodeId {
if keys.is_const() {
self.insert_const(keys.eval_const(), index_key_ty)
} else if keys.is_identity() {
debug_assert_eq!(1, self.load_data.inputs.len(), "TODO");
self.load_data.inputs[0]
} else {
let ty = stmt::Type::list(keys.ret.clone());
self.planner.mir.insert(mir::Project {
input: self.load_data.inputs[0],
projection: keys,
ty,
})
}
}
fn build_key_operation(
&mut self,
stmt: &stmt::Statement,
index_plan: &mut index::IndexPlan,
get_by_key_input: mir::NodeId,
ty: &stmt::Type,
) -> mir::NodeId {
match stmt {
stmt::Statement::Query(_) => {
debug_assert!(ty.is_list(), "ty={ty:#?}");
self.insert_mir_with_deps(mir::GetByKey {
input: get_by_key_input,
table: index_plan.table_id(),
columns: self.load_data.select_items.extract_expr_references(),
ty: ty.clone(),
})
}
stmt::Statement::Delete(delete_stmt) => self.insert_mir_with_deps(mir::DeleteByKey {
input: get_by_key_input,
table: index_plan.table_id(),
filter: index_plan.result_filter.take(),
condition: delete_stmt.condition.expr.clone(),
ty: stmt::Type::Unit,
}),
stmt::Statement::Update(update_stmt) => {
let guarded_input = self.apply_guard(get_by_key_input, index_plan);
self.insert_mir_with_deps(mir::UpdateByKey {
input: guarded_input,
table: index_plan.table_id(),
assignments: update_stmt.assignments.clone(),
filter: index_plan.result_filter.take(),
condition: update_stmt.condition.expr.clone(),
ty: ty.clone(),
})
}
_ => todo!("stmt={stmt:#?}"),
}
}
fn apply_guard(
&mut self,
input: mir::NodeId,
index_plan: &mut index::IndexPlan,
) -> mir::NodeId {
let Some(mut pre_filter_expr) = index_plan.pre_filter.take() else {
return input;
};
let (args, guard_inputs) = self.rewrite_expr_for_mir(&mut pre_filter_expr);
let guard = eval::Func::from_stmt(pre_filter_expr, args);
let ty = self.planner.mir[input].ty().clone();
self.planner.mir.insert(mir::Guard {
input,
guard_inputs,
guard,
ty,
})
}
fn rewrite_expr_for_mir(
&self,
expr: &mut stmt::Expr,
) -> (Vec<stmt::Type>, IndexSet<mir::NodeId>) {
let mut arg_map: IndexMap<usize, (stmt::Type, mir::NodeId)> = IndexMap::new();
visit_mut::for_each_expr_mut(expr, |expr| {
if let stmt::Expr::Arg(expr_arg) = expr {
let hir_pos = expr_arg.position;
let new_pos = match arg_map.get_index_of(&hir_pos) {
Some(idx) => idx,
None => {
let input_idx = match &self.stmt_info.args[hir_pos] {
hir::Arg::Sub { input, .. } => input.get().unwrap(),
_ => todo!("rewrite_expr_for_mir with non-Sub arg"),
};
let node_id = self.load_data.inputs[input_idx];
let ty = self.planner.mir[node_id].ty().clone();
let (idx, _) = arg_map.insert_full(hir_pos, (ty, node_id));
idx
}
};
expr_arg.position = new_pos;
}
});
let mut types = Vec::with_capacity(arg_map.len());
let mut nodes = IndexSet::with_capacity(arg_map.len());
for (_, (ty, node_id)) in arg_map {
types.push(ty);
nodes.insert(node_id);
}
(types, nodes)
}
fn process_back_ref_projections(&mut self, exec_stmt_node_id: mir::NodeId) {
for back_ref in self.stmt_info.back_refs.values() {
let projection = stmt::Expr::record(back_ref.exprs.iter().map(|expr_reference| {
let index = self
.load_data
.select_items
.get_index_of_expr_reference(*expr_reference);
stmt::Expr::arg_project(0, [index])
}));
let arg_ty = match self.planner.mir[exec_stmt_node_id].ty() {
stmt::Type::List(ty) => (**ty).clone(),
ty => ty.clone(),
};
let projection = eval::Func::from_stmt(projection, vec![arg_ty]);
let ty = stmt::Type::list(projection.ret.clone());
let project_node_id = self.planner.mir.insert(mir::Project {
input: exec_stmt_node_id,
projection,
ty,
});
back_ref.node_id.set(Some(project_node_id));
}
}
fn plan_child_statements(&mut self) -> Result<()> {
for &dep_stmt_id in &self.stmt_info.deps {
if !self.planner.hir[dep_stmt_id].independent {
self.planner.plan_statement(dep_stmt_id)?;
}
}
for arg in &self.stmt_info.args {
let hir::Arg::Sub { stmt_id, .. } = arg else {
continue;
};
self.planner.plan_statement(*stmt_id)?;
}
Ok(())
}
fn plan_output_node(
&mut self,
data_load_node_id: mir::NodeId,
returning: ReturningInfo,
) -> mir::NodeId {
if let Some(node_id) = self.planner.plan_nested_merge(self.stmt_id) {
return node_id;
}
let returning_arg_tys = returning
.inputs
.iter()
.map(|input| self.planner.mir[input].ty().clone())
.collect();
if let Some(clause) = returning.clause {
match clause {
stmt::Returning::Value(expr) => {
if let Ok(value) = expr.eval_const() {
let ty = value.infer_ty();
self.planner
.mir
.insert_with_deps(mir::Const { value, ty }, [data_load_node_id])
} else {
let eval = eval::Func::from_stmt(expr, returning_arg_tys);
let node_id = self.insert_mir_with_deps(mir::Eval {
inputs: returning.inputs,
eval,
metadata: None,
});
if !self.stmt().is_query() {
self.planner.mir[node_id].deps.insert(data_load_node_id);
}
node_id
}
}
stmt::Returning::Expr(projection) => {
if let Some(position) = returning.inputs.get_index_of(&data_load_node_id) {
self.insert_mir_with_deps(mir::Eval {
inputs: returning.inputs,
eval: eval::Func::from_stmt(
stmt::Expr::map(stmt::Expr::arg(position), projection),
returning_arg_tys,
),
metadata: Some(position),
})
} else {
let projection = eval::Func::from_stmt(projection, vec![]);
let ty = stmt::Type::list(projection.ret.clone());
let node = mir::Project {
input: data_load_node_id,
projection,
ty,
};
self.insert_mir_with_deps(node)
}
}
returning => panic!("unexpected `stmt::Returning` kind; returning={returning:#?}"),
}
} else {
self.apply_dependencies_to_node(data_load_node_id);
data_load_node_id
}
}
#[track_caller]
fn insert_const(&mut self, value: impl Into<stmt::Value>, ty: stmt::Type) -> mir::NodeId {
let value = value.into();
debug_assert!(
ty.is_list(),
"const types must be of type `stmt::Type::List`"
);
debug_assert!(
value.is_a(&ty),
"const type mismatch; expected={ty:#?}; actual={value:#?}",
);
self.planner.mir.insert(mir::Const { value, ty })
}
fn insert_mir_with_deps(&mut self, node: impl Into<mir::Node>) -> mir::NodeId {
let node_id = self.planner.mir.insert(node);
self.apply_dependencies_to_node(node_id);
node_id
}
fn apply_dependencies_to_node(&mut self, node_id: mir::NodeId) {
let node = &mut self.planner.mir[node_id];
self.remaining_deps.retain(|stmt_id| {
if let Some(dep_id) = self.planner.hir[stmt_id].output.get() {
node.deps.insert(dep_id);
false
} else {
true
}
});
}
fn index_key_ty(&self, index_plan: &IndexPlan) -> stmt::Type {
stmt::Type::list(self.planner.engine.index_key_record_ty(index_plan.index))
}
fn stmt(&self) -> &stmt::Statement {
self.stmt_info.stmt.as_deref().unwrap()
}
}
fn extract_query_pk_limit(stmt: &stmt::Statement) -> Option<QueryPkLimit> {
let query = stmt.as_query()?;
match query.limit.as_ref()? {
stmt::Limit::Cursor(c) => {
let page_size = as_i64_literal(&c.page_size);
let after = c.after.as_ref().and_then(|e| match e {
stmt::Expr::Value(v) => Some(v.clone()),
_ => None,
});
Some(QueryPkLimit::Cursor { page_size, after })
}
stmt::Limit::Offset(lo) => {
let limit = as_i64_literal(&lo.limit);
let offset = lo.offset.as_ref().map(as_i64_literal);
Some(QueryPkLimit::Offset { limit, offset })
}
}
}
fn as_i64_literal(expr: &stmt::Expr) -> i64 {
match expr {
stmt::Expr::Value(stmt::Value::I64(n)) => *n,
_ => panic!("limit/offset must be an i64 literal; got {expr:#?}"),
}
}
fn extract_query_pk_order(stmt: &stmt::Statement) -> Option<stmt::Direction> {
let query = stmt.as_query()?;
query.order_by.as_ref().and_then(|ob| {
ob.exprs.first().map(|e| match e.order {
Some(stmt::Direction::Desc) => stmt::Direction::Desc,
_ => stmt::Direction::Asc,
})
})
}