mod association;
mod expr_or;
mod include;
mod insert;
mod lift_in_subquery;
mod lift_update_query;
mod paginate;
mod relation;
mod returning;
mod via_join;
#[cfg(test)]
mod tests;
use std::cell::Cell;
use hashbrown::HashSet;
use index_vec::IndexVec;
use toasty_core::{
Result, Schema,
driver::Capability,
schema::{
app::{self, ModelRoot},
db::ColumnId,
mapping,
},
stmt::{self, IntoExprTarget, VisitMut, visit_mut},
};
use crate::engine::{Engine, HirStatement, fold, hir, simplify::Simplify};
fn map_nullable_single(subquery: stmt::Expr, present: stmt::Expr) -> stmt::Expr {
stmt::Expr::Let(stmt::ExprLet {
bindings: vec![subquery],
body: Box::new(stmt::Expr::match_expr(
stmt::Expr::arg(0),
vec![stmt::MatchArm {
pattern: stmt::Value::Null,
expr: stmt::Expr::Value(stmt::Value::Null),
}],
present,
)),
})
}
impl Engine {
pub(super) fn lower_stmt(&self, stmt: stmt::Statement) -> Result<HirStatement> {
let schema = &self.schema;
let mut state = LoweringState {
hir: HirStatement::new(),
scopes: IndexVec::new(),
engine: self,
relations: vec![],
errors: vec![],
dependencies: HashSet::new(),
};
state.lower_stmt(stmt::ExprContext::new(schema), None, stmt);
if let Some(err) = state.errors.into_iter().next() {
return Err(err);
}
Ok(state.hir)
}
}
impl LoweringState<'_> {
fn lower_stmt(
&mut self,
expr_cx: stmt::ExprContext,
row_index: Option<usize>,
mut stmt: stmt::Statement,
) -> hir::StmtId {
association::RewriteVia::new(expr_cx).rewrite(&mut stmt);
lift_in_subquery::LiftInSubquery::new(expr_cx).rewrite(&mut stmt);
lift_update_query::LiftUpdateQuery::new().rewrite(&mut stmt);
Simplify::with_context(expr_cx, self.engine.capability).visit_mut(&mut stmt);
let stmt_id = self.hir.new_statement_info(self.dependencies.clone());
let scope_id = self.scopes.push(Scope { stmt_id, row_index });
let mut collect_dependencies = None;
LowerStatement {
state: self,
expr_cx,
scope_id,
cx: LoweringContext::Statement,
collect_dependencies: &mut collect_dependencies,
}
.visit_stmt_mut(&mut stmt);
self.engine.simplify_stmt(&mut stmt);
let stmt_info = &mut self.hir[stmt_id];
stmt_info.stmt = Some(Box::new(stmt));
self.scopes.pop();
debug_assert!(collect_dependencies.is_none());
stmt_id
}
}
struct LowerStatement<'a, 'b> {
state: &'a mut LoweringState<'b>,
expr_cx: stmt::ExprContext<'a>,
scope_id: ScopeId,
cx: LoweringContext<'a>,
collect_dependencies: &'a mut Option<HashSet<hir::StmtId>>,
}
#[derive(Debug)]
struct LoweringState<'a> {
engine: &'a Engine,
hir: HirStatement,
scopes: IndexVec<ScopeId, Scope>,
relations: Vec<app::FieldId>,
dependencies: HashSet<hir::StmtId>,
errors: Vec<crate::Error>,
}
#[derive(Debug, Clone, Copy)]
enum LoweringContext<'a> {
Insert(&'a [ColumnId], Option<usize>),
InsertRow(&'a stmt::Expr),
Returning(Option<usize>),
Statement,
}
#[derive(Debug)]
struct Scope {
stmt_id: hir::StmtId,
row_index: Option<usize>,
}
index_vec::define_index_type! {
struct ScopeId = u32;
}
enum CollectionOp<'a> {
Append(&'a mut stmt::Expr),
Remove(&'a mut stmt::Expr),
Pop,
RemoveAt(&'a mut stmt::Expr),
}
enum ArithmeticOp {
Add,
Subtract,
}
impl LowerStatement<'_, '_> {
fn lower_set_assignment(
&mut self,
out: &mut stmt::Assignments,
mapping: &mapping::Model,
projection: &stmt::Projection,
expr: &mut stmt::Expr,
) {
self.visit_expr_mut(expr);
let Some(field) = mapping.resolve_field_mapping(projection) else {
self.state
.errors
.push(crate::Error::invalid_statement(format!(
"invalid assignment projection: {projection:?}"
)));
return;
};
for (column, lowering_idx) in field.columns() {
let mut lowering_expr = mapping.model_to_table[lowering_idx].clone();
lowering_expr.substitute(AssignmentInput {
assignment_projection: projection.clone(),
value: expr,
});
self.visit_expr_mut(&mut lowering_expr);
out.set(column, lowering_expr);
}
}
fn lower_collection_op(
&mut self,
out: &mut stmt::Assignments,
mapping: &mapping::Model,
projection: &stmt::Projection,
op: CollectionOp,
) {
let Some(field) = mapping.resolve_field_mapping(projection) else {
self.state
.errors
.push(crate::Error::invalid_statement(format!(
"invalid assignment projection: {projection:?}"
)));
return;
};
let Some(prim) = field.as_primitive() else {
self.state
.errors
.push(crate::Error::invalid_statement(format!(
"collection operator on non-primitive field: {projection:?}"
)));
return;
};
let cap = self.capability();
let unsupported = match &op {
CollectionOp::Append(_) => None,
CollectionOp::Remove(_) if !cap.vec_remove => Some("stmt::remove"),
CollectionOp::Pop if !cap.vec_pop => Some("stmt::pop"),
CollectionOp::RemoveAt(_) if !cap.vec_remove_at => Some("stmt::remove_at"),
_ => None,
};
if let Some(op_name) = unsupported {
self.state
.errors
.push(crate::Error::invalid_statement(format!(
"{op_name} is not yet supported on this backend"
)));
return;
}
match op {
CollectionOp::Append(expr) => {
self.visit_expr_mut(expr);
out.append(prim.column, expr.take());
}
CollectionOp::Remove(expr) => {
self.visit_expr_mut(expr);
out.remove(prim.column, expr.take());
}
CollectionOp::Pop => {
out.pop(prim.column);
}
CollectionOp::RemoveAt(expr) => {
self.visit_expr_mut(expr);
out.remove_at(prim.column, expr.take());
}
}
}
fn lower_arithmetic_op(
&mut self,
out: &mut stmt::Assignments,
mapping: &mapping::Model,
projection: &stmt::Projection,
op: ArithmeticOp,
expr: &mut stmt::Expr,
) {
let Some(field) = mapping.resolve_field_mapping(projection) else {
self.state
.errors
.push(crate::Error::invalid_statement(format!(
"invalid assignment projection: {projection:?}"
)));
return;
};
let Some(prim) = field.as_primitive() else {
self.state
.errors
.push(crate::Error::invalid_statement(format!(
"arithmetic operator on non-primitive field: {projection:?}"
)));
return;
};
self.visit_expr_mut(expr);
let value = expr.take();
match op {
ArithmeticOp::Add => out.add(prim.column, value),
ArithmeticOp::Subtract => out.subtract(prim.column, value),
}
}
fn lower_owned_assignment(
&mut self,
out: &mut stmt::Assignments,
mapping: &mapping::Model,
projection: &stmt::Projection,
assignment: stmt::Assignment,
) {
match assignment {
stmt::Assignment::Set(mut expr) => {
self.lower_set_assignment(out, mapping, projection, &mut expr);
}
stmt::Assignment::Append(mut expr) => {
self.lower_collection_op(out, mapping, projection, CollectionOp::Append(&mut expr));
}
stmt::Assignment::Remove(mut expr) => {
self.lower_collection_op(out, mapping, projection, CollectionOp::Remove(&mut expr));
}
stmt::Assignment::Pop => {
self.lower_collection_op(out, mapping, projection, CollectionOp::Pop);
}
stmt::Assignment::RemoveAt(mut expr) => {
self.lower_collection_op(
out,
mapping,
projection,
CollectionOp::RemoveAt(&mut expr),
);
}
stmt::Assignment::Add(mut expr) => {
self.lower_arithmetic_op(out, mapping, projection, ArithmeticOp::Add, &mut expr);
}
stmt::Assignment::Subtract(mut expr) => {
self.lower_arithmetic_op(
out,
mapping,
projection,
ArithmeticOp::Subtract,
&mut expr,
);
}
stmt::Assignment::Batch(entries) => {
let folded = fold_batch(entries);
if matches!(folded, stmt::Assignment::Batch(_)) {
todo!("non-composable batch shape on {projection:?}: {folded:#?}")
}
self.lower_owned_assignment(out, mapping, projection, folded);
}
stmt::Assignment::Insert(_) => {
todo!("Insert assignment is not produced for table lowering; got {assignment:#?}")
}
}
}
}
fn fold_batch(entries: Vec<stmt::Assignment>) -> stmt::Assignment {
let mut iter = entries.into_iter();
let mut acc = iter.next().expect("batch is non-empty");
for next in iter {
acc = compose_assignment(acc, next);
}
acc
}
fn compose_assignment(acc: stmt::Assignment, next: stmt::Assignment) -> stmt::Assignment {
use stmt::Assignment::*;
match (acc, next) {
(_, Set(v)) => Set(v),
(Set(v), Add(x)) => Set(stmt::Expr::add(v, x)),
(Set(v), Subtract(x)) => Set(stmt::Expr::sub(v, x)),
(Add(a), Add(x)) => Add(stmt::Expr::add(a, x)),
(Add(a), Subtract(x)) => Add(stmt::Expr::sub(a, x)),
(Subtract(a), Add(x)) => Subtract(stmt::Expr::sub(a, x)),
(Subtract(a), Subtract(x)) => Subtract(stmt::Expr::add(a, x)),
(Append(a), Append(b)) => match try_concat_list_literals(a, b) {
Ok(merged) => Append(merged),
Err((a, b)) => Batch(vec![Append(a), Append(b)]),
},
(Batch(mut tail), next) => {
tail.push(next);
Batch(tail)
}
(acc, next) => Batch(vec![acc, next]),
}
}
fn try_concat_list_literals(
a: stmt::Expr,
b: stmt::Expr,
) -> Result<stmt::Expr, (stmt::Expr, stmt::Expr)> {
if !is_list_literal(&a) || !is_list_literal(&b) {
return Err((a, b));
}
let mut items = take_list_items(a).expect("checked is_list_literal");
items.extend(take_list_items(b).expect("checked is_list_literal"));
Ok(stmt::Expr::list_from_vec(items))
}
fn is_list_literal(e: &stmt::Expr) -> bool {
matches!(
e,
stmt::Expr::List(_) | stmt::Expr::Value(stmt::Value::List(_))
)
}
fn take_list_items(e: stmt::Expr) -> Option<Vec<stmt::Expr>> {
match e {
stmt::Expr::List(list) => Some(list.items),
stmt::Expr::Value(stmt::Value::List(values)) => {
Some(values.into_iter().map(stmt::Expr::from).collect())
}
_ => None,
}
}
impl LowerStatement<'_, '_> {
fn new_dependency(&mut self, stmt: impl Into<stmt::Statement>) -> hir::StmtId {
let row_index = match self.cx {
LoweringContext::Insert(_, row_index) => row_index,
LoweringContext::Returning(row_index) => row_index,
_ => None,
};
let stmt_id = self.state.lower_stmt(self.expr_cx, row_index, stmt.into());
if let Some(dependencies) = &mut self.collect_dependencies {
dependencies.insert(stmt_id);
}
self.curr_stmt_info().deps.insert(stmt_id);
stmt_id
}
fn collect_dependencies(
&mut self,
f: impl FnOnce(&mut LowerStatement<'_, '_>),
) -> HashSet<hir::StmtId> {
let old = self.collect_dependencies.replace(HashSet::new());
f(self);
std::mem::replace(self.collect_dependencies, old).unwrap()
}
fn track_dependency(&mut self, dependency: hir::StmtId) {
self.curr_stmt_info().deps.insert(dependency);
}
fn with_dependencies(
&mut self,
mut dependencies: HashSet<hir::StmtId>,
f: impl FnOnce(&mut LowerStatement<'_, '_>),
) {
dependencies.extend(&self.state.dependencies);
let old = std::mem::replace(&mut self.state.dependencies, dependencies);
f(self);
self.state.dependencies = old;
}
}
impl visit_mut::VisitMut for LowerStatement<'_, '_> {
fn visit_order_by_expr_mut(&mut self, node: &mut stmt::OrderByExpr) {
self.visit_expr_mut(&mut node.expr);
let mut lhs = node.expr.clone();
let mut rhs = node.expr.take();
self.lower_expr_binary_op(stmt::BinaryOp::Eq, &mut lhs, &mut rhs);
node.expr = lhs;
}
fn visit_assignments_mut(&mut self, i: &mut stmt::Assignments) {
let mut lowered = stmt::Assignments::default();
let mapping = self.mapping_unwrap();
let assignments = std::mem::take(i);
for (projection, assignment) in assignments {
self.lower_owned_assignment(&mut lowered, mapping, &projection, assignment);
}
*i = lowered;
}
fn visit_expr_set_op_mut(&mut self, i: &mut stmt::ExprSetOp) {
todo!("stmt={i:#?}");
}
fn visit_expr_binary_op_mut(&mut self, i: &mut stmt::ExprBinaryOp) {
if i.op.is_eq() || i.op.is_ne() {
self.rewrite_eq_operand(&mut i.lhs);
self.rewrite_eq_operand(&mut i.rhs);
}
stmt::visit_mut::visit_expr_binary_op_mut(self, i);
}
fn visit_expr_in_list_mut(&mut self, i: &mut stmt::ExprInList) {
self.rewrite_in_list_model_operand(i);
stmt::visit_mut::visit_expr_in_list_mut(self, i);
}
fn visit_expr_mut(&mut self, expr: &mut stmt::Expr) {
match expr {
stmt::Expr::BinaryOp(e) => {
self.visit_expr_binary_op_mut(e);
if let Some(lowered) = self.lower_expr_binary_op(e.op, &mut e.lhs, &mut e.rhs) {
*expr = lowered;
}
}
stmt::Expr::Or(e) if expr_or::is_variant_tautology_or(self.expr_cx.schema(), e) => {
*expr = true.into();
}
stmt::Expr::InList(e) => {
self.visit_expr_in_list_mut(e);
if let Some(lowered) = self.lower_expr_in_list(&mut e.expr, &mut e.list) {
*expr = lowered;
}
if let stmt::Expr::InList(e) = expr
&& self.supports_any_rewrite()
&& in_list_is_value_list(e)
{
let stmt::Expr::InList(e) = expr.take() else {
unreachable!()
};
*expr = stmt::Expr::any_op(*e.expr, stmt::BinaryOp::Eq, *e.list);
}
}
stmt::Expr::Not(e) if matches!(*e.expr, stmt::Expr::InList(_)) => {
self.visit_expr_not_mut(e);
if self.supports_any_rewrite()
&& let stmt::Expr::AnyOp(any) = e.expr.as_mut()
&& any.op == stmt::BinaryOp::Eq
{
let stmt::Expr::Not(not) = expr.take() else {
unreachable!()
};
let stmt::Expr::AnyOp(any) = *not.expr else {
unreachable!()
};
*expr = stmt::Expr::all_op(*any.lhs, stmt::BinaryOp::Ne, *any.rhs);
}
}
stmt::Expr::InSubquery(e) => {
if self.capability().sql {
self.visit_expr_in_subquery_mut(e);
self.lower_in_subquery_operands(
&mut e.expr,
e.query.returning_mut_unwrap().as_project_mut_unwrap(),
);
let returning = e.query.returning_mut_unwrap().as_project_mut_unwrap();
if !returning.is_record() {
*returning = stmt::Expr::record([returning.take()]);
}
} else {
self.visit_expr_mut(&mut e.expr);
let source_id = self.scope_stmt_id();
let target_id = self.scope_statement(|child| {
child.visit_stmt_query_mut(&mut e.query);
});
let target_stmt_info = &self.state.hir[target_id];
debug_assert!(
target_stmt_info
.args
.iter()
.all(|arg| matches!(arg, hir::Arg::Sub { .. })),
"TODO: sub-statement references parent scope"
);
debug_assert!(target_stmt_info.back_refs.is_empty(), "TODO");
self.track_dependency(target_id);
self.lower_in_subquery_operands(
&mut e.expr,
e.query.returning_mut_unwrap().as_project_mut_unwrap(),
);
let stmt::Expr::InSubquery(e) = expr.take() else {
panic!()
};
let arg =
self.new_sub_statement(source_id, target_id, Box::new((*e.query).into()));
*expr = stmt::ExprInList {
expr: e.expr,
list: Box::new(arg),
}
.into();
}
}
stmt::Expr::IsVariant(e) => {
let enum_model = self
.schema()
.app
.model(e.variant.model)
.as_embedded_enum_unwrap();
let has_data = enum_model.has_data_variants();
let disc_value = enum_model.variants[e.variant.index].discriminant.clone();
self.visit_expr_mut(&mut e.expr);
let lowered_expr = e.expr.take();
if has_data {
*expr = stmt::Expr::eq(
stmt::Expr::project(lowered_expr, [0usize]),
stmt::Expr::Value(disc_value),
);
} else {
*expr = stmt::Expr::eq(lowered_expr, stmt::Expr::Value(disc_value));
}
}
stmt::Expr::Reference(expr_reference) => {
match expr_reference {
stmt::ExprReference::Field { nesting: 0, index }
if matches!(self.cx, LoweringContext::Returning(_))
&& self.model_unwrap().fields[*index].ty.is_relation() =>
{
*expr = self.build_relation_subquery(*index, &[]);
}
stmt::ExprReference::Field { nesting, index } => {
*expr = self.lower_expr_field(*nesting, *index);
self.visit_expr_mut(expr);
}
stmt::ExprReference::Model { .. } => todo!(),
stmt::ExprReference::Column(expr_column) => {
if expr_column.nesting > 0 {
let source_id = self.scope_stmt_id();
let target_id = self.resolve_stmt_id(expr_column.nesting);
debug_assert_eq!(self.state.scopes.len(), self.scope_id + 1);
for scope in self.state.scopes.iter().rev() {
if scope.stmt_id == target_id {
break;
}
self.state.hir[scope.stmt_id].independent = false;
}
let position = self.new_ref(source_id, target_id, *expr_reference);
*expr = stmt::Expr::arg(position);
}
}
}
}
stmt::Expr::Stmt(_) => {
let stmt::Expr::Stmt(mut expr_stmt) = expr.take() else {
panic!()
};
debug_assert!(
self.cx.is_returning() || matches!(self.cx, LoweringContext::Statement),
"cx={:#?}",
self.cx,
);
let source_id = self.scope_stmt_id();
let target_id = self.scope_statement(|child| {
visit_mut::visit_expr_stmt_mut(child, &mut expr_stmt);
});
fold::fold_stmt(&mut *expr_stmt.stmt);
*expr = self.new_sub_statement(source_id, target_id, expr_stmt.stmt);
if self.state.hir[target_id].independent {
self.curr_stmt_info().deps.insert(target_id);
}
}
stmt::Expr::Exists(_) if !self.capability().sql => {
let stmt::Expr::Exists(mut expr_exists) = expr.take() else {
panic!()
};
let source_id = self.scope_stmt_id();
let target_id = self.scope_statement(|child| {
child.visit_stmt_query_mut(&mut expr_exists.subquery);
});
let mut stmt = stmt::Statement::Query(*expr_exists.subquery);
fold::fold_stmt(&mut stmt);
let arg = self.new_sub_statement(source_id, target_id, Box::new(stmt));
if self.state.hir[target_id].independent {
self.curr_stmt_info().deps.insert(target_id);
}
let mut subquery = stmt::Query::values(arg);
subquery.single = true;
*expr = stmt::Expr::Exists(stmt::ExprExists {
subquery: Box::new(subquery),
});
}
_ => {
stmt::visit_mut::visit_expr_mut(self, expr);
}
}
}
fn visit_insert_target_mut(&mut self, i: &mut stmt::InsertTarget) {
match i {
stmt::InsertTarget::Scope(_) => todo!("stmt={i:#?}"),
stmt::InsertTarget::Model(model_id) => {
let mapping = self.schema().mapping_for(model_id);
*i = stmt::InsertTable {
table: mapping.table,
columns: mapping.columns.clone(),
}
.into();
}
_ => todo!(),
}
}
fn visit_update_target_mut(&mut self, i: &mut stmt::UpdateTarget) {
match i {
stmt::UpdateTarget::Query(_) => todo!("update_target={i:#?}"),
stmt::UpdateTarget::Model(model_id) => {
let table_id = self.schema().table_id_for(model_id);
*i = stmt::UpdateTarget::table(table_id);
}
stmt::UpdateTarget::Table(_) => {}
}
}
fn visit_expr_stmt_mut(&mut self, i: &mut stmt::ExprStmt) {
stmt::visit_mut::visit_expr_stmt_mut(self, i);
}
fn visit_returning_mut(&mut self, i: &mut stmt::Returning) {
if let stmt::Returning::Model { include } = i {
let mut returning = self.mapping_unwrap().default_returning.clone();
let mut include_paths = std::mem::take(include);
let is_insert = self.cx.is_insert();
self.prepare_model_returning_for_context(&mut returning, &mut include_paths, is_insert);
self.process_top_level_includes(&mut returning, &include_paths, is_insert);
*i = stmt::Returning::Project(returning);
}
if matches!(&self.cx, LoweringContext::Insert(..))
&& let stmt::Returning::Expr(stmt::Expr::List(list)) = i
{
for (index, item) in list.items.iter_mut().enumerate() {
self.lower_returning_for_row(index).visit_expr_mut(item);
}
return;
}
stmt::visit_mut::visit_returning_mut(&mut self.lower_returning(), i);
}
fn visit_stmt_delete_mut(&mut self, stmt: &mut stmt::Delete) {
let mut lower = self.scope_expr(&stmt.from);
lower.plan_stmt_delete_relations(stmt);
lower.visit_filter_mut(&mut stmt.filter);
if let Some(expr) = &mut stmt.condition.expr {
lower.visit_expr_mut(expr);
}
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
}
lower.apply_lowering_filter_constraint(&mut stmt.filter);
self.visit_source_mut(&mut stmt.from);
}
fn visit_stmt_insert_mut(&mut self, stmt: &mut stmt::Insert) {
self.apply_insert_scope(&mut stmt.target, &mut stmt.source);
let mut lower = self.lower_insert(&stmt.target);
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
}
lower.preprocess_insert_values(&mut stmt.source, &mut stmt.returning);
lower.visit_stmt_query_mut(&mut stmt.source);
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
lower.constantize_insert_returning(returning, &stmt.source);
if stmt.source.single
&& let stmt::Returning::Expr(expr) = &returning
{
debug_assert!(!expr.is_list());
}
}
self.visit_insert_target_mut(&mut stmt.target);
}
fn visit_stmt_query_mut(&mut self, stmt: &mut stmt::Query) {
let mut lower = self.scope_expr(&stmt.body);
if let Some(with) = &mut stmt.with {
lower.visit_with_mut(with);
}
if let Some(order_by) = &mut stmt.order_by {
lower.visit_order_by_mut(order_by);
}
if let Some(limit) = &mut stmt.limit {
lower.visit_limit_mut(limit);
}
self.visit_expr_set_mut(&mut stmt.body);
self.rewrite_offset_after_as_filter(stmt);
}
fn visit_stmt_select_mut(&mut self, stmt: &mut stmt::Select) {
let mut lower = self.scope_expr(&stmt.source);
lower.visit_filter_mut(&mut stmt.filter);
lower.visit_returning_mut(&mut stmt.returning);
lower.apply_lowering_filter_constraint(&mut stmt.filter);
self.visit_source_mut(&mut stmt.source);
}
fn visit_stmt_update_mut(&mut self, stmt: &mut stmt::Update) {
let mut lower = self.scope_expr(&stmt.target);
let mut returning_changed = false;
if let Some(returning) = &mut stmt.returning
&& returning.is_changed()
{
returning_changed = true;
if let Some(model) = lower.model() {
let mapping = lower.mapping_unwrap();
let mut changed_bits = stmt::PathFieldSet::new();
for projection in stmt.assignments.keys() {
if let Some(mf) = mapping.resolve_field_mapping(projection) {
changed_bits |= mf.field_mask();
}
}
*returning = stmt::Returning::Project(build_update_returning(
model.id,
None,
&mapping.fields,
&changed_bits,
));
}
}
lower.plan_stmt_update_relations(
&mut stmt.assignments,
&stmt.filter,
&mut stmt.returning,
returning_changed,
);
lower.visit_assignments_mut(&mut stmt.assignments);
lower.visit_filter_mut(&mut stmt.filter);
if let Some(expr) = &mut stmt.condition.expr {
lower.visit_expr_mut(expr);
}
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
lower.constantize_update_returning(returning, &stmt.assignments);
}
self.visit_update_target_mut(&mut stmt.target);
}
fn visit_source_mut(&mut self, stmt: &mut stmt::Source) {
if let stmt::Source::Model(source_model) = stmt {
debug_assert!(source_model.via.is_none(), "TODO");
let table_id = self.schema().table_id_for(source_model.id);
*stmt = stmt::Source::table(table_id);
}
}
fn visit_values_mut(&mut self, stmt: &mut stmt::Values) {
if self.cx.is_insert()
&& let Some(mapping) = self.mapping()
{
for row in &mut stmt.rows {
let mut lowered = mapping.model_to_table.clone();
self.lower_insert_row(row)
.visit_expr_record_mut(&mut lowered);
*row = lowered.into();
}
return;
}
visit_mut::visit_values_mut(self, stmt);
}
}
impl<'a, 'b> LowerStatement<'a, 'b> {
fn rewrite_eq_operand(&self, operand: &mut stmt::Expr) {
if let stmt::Expr::Reference(expr_reference) = operand {
match &*expr_reference {
stmt::ExprReference::Model { nesting } => {
let nesting = *nesting;
let model = self
.expr_cx
.resolve_expr_reference(expr_reference)
.as_model_unwrap();
*operand = key_field_refs(nesting, model.primary_key.fields.iter().copied());
}
stmt::ExprReference::Field { nesting, .. } => {
let nesting = *nesting;
let field = self
.expr_cx
.resolve_expr_reference(expr_reference)
.as_field_unwrap();
match &field.ty {
app::FieldTy::Primitive(_) | app::FieldTy::Embedded(_) => {}
app::FieldTy::Has(_) | app::FieldTy::Via(_) => todo!(),
app::FieldTy::BelongsTo(rel) => {
*operand = key_field_refs(
nesting,
rel.foreign_key.fields.iter().map(|fk| fk.source),
);
}
}
}
_ => {}
}
}
}
fn rewrite_in_list_model_operand(&self, expr: &mut stmt::ExprInList) {
let (nesting, pk_field_id) = {
let stmt::Expr::Reference(expr_ref @ stmt::ExprReference::Model { nesting }) =
&*expr.expr
else {
return;
};
let nesting = *nesting;
let model = self
.expr_cx
.resolve_expr_reference(expr_ref)
.as_model_unwrap();
let [pk_field_id] = &model.primary_key.fields[..] else {
todo!()
};
(nesting, *pk_field_id)
};
let pk = self.expr_cx.schema().app.field(pk_field_id);
match &mut *expr.list {
stmt::Expr::List(expr_list) => {
for item in &mut expr_list.items {
match item {
stmt::Expr::Value(value) => {
assert!(value.is_a(&pk.ty.as_primitive_unwrap().ty));
}
_ => todo!("{item:#?}"),
}
}
}
stmt::Expr::Value(stmt::Value::List(values)) => {
for value in values {
assert!(value.is_a(&pk.ty.as_primitive_unwrap().ty));
}
}
_ => todo!("expr={expr:#?}"),
}
*expr.expr = stmt::Expr::ref_field(nesting, pk.id());
}
fn lower_expr_binary_op(
&mut self,
op: stmt::BinaryOp,
lhs: &mut stmt::Expr,
rhs: &mut stmt::Expr,
) -> Option<stmt::Expr> {
match (&mut *lhs, &mut *rhs) {
(stmt::Expr::Value(value), other) | (other, stmt::Expr::Value(value))
if value.is_null() =>
{
let other = other.take();
Some(match op {
stmt::BinaryOp::Eq => stmt::Expr::is_null(other),
stmt::BinaryOp::Ne => stmt::Expr::is_not_null(other),
_ => todo!(),
})
}
(stmt::Expr::Record(lhs_rec), stmt::Expr::Record(rhs_rec))
if (op.is_eq() || op.is_ne()) && lhs_rec.len() == rhs_rec.len() =>
{
Some(self.combine_record_op(
op,
std::mem::take(&mut lhs_rec.fields),
std::mem::take(&mut rhs_rec.fields),
))
}
(stmt::Expr::Record(rec), stmt::Expr::Value(stmt::Value::Record(val_rec)))
| (stmt::Expr::Value(stmt::Value::Record(val_rec)), stmt::Expr::Record(rec))
if (op.is_eq() || op.is_ne()) && rec.len() == val_rec.len() =>
{
let val_exprs = std::mem::take(&mut val_rec.fields)
.into_iter()
.map(stmt::Expr::Value)
.collect();
Some(self.combine_record_op(op, std::mem::take(&mut rec.fields), val_exprs))
}
(stmt::Expr::Cast(expr_cast), _) | (_, stmt::Expr::Cast(expr_cast)) => {
let target_ty = self.capability().native_type_for(&expr_cast.ty);
self.cast_expr(lhs, &target_ty);
self.cast_expr(rhs, &target_ty);
None
}
_ => None,
}
}
fn combine_record_op(
&mut self,
op: stmt::BinaryOp,
lhs_fields: Vec<stmt::Expr>,
rhs_fields: Vec<stmt::Expr>,
) -> stmt::Expr {
let comparisons: Vec<_> = lhs_fields
.into_iter()
.zip(rhs_fields)
.map(|(mut l, mut r)| {
self.lower_expr_binary_op(op, &mut l, &mut r)
.unwrap_or_else(|| stmt::Expr::binary_op(l, op, r))
})
.collect();
if op.is_eq() {
stmt::Expr::and_from_vec(comparisons)
} else {
stmt::Expr::or_from_vec(comparisons)
}
}
fn lower_in_subquery_operands(&mut self, lhs: &mut stmt::Expr, rhs: &mut stmt::Expr) {
if let (stmt::Expr::Record(lhs_rec), stmt::Expr::Record(rhs_rec)) = (&mut *lhs, &mut *rhs)
&& lhs_rec.len() == rhs_rec.len()
{
for (l, r) in lhs_rec.fields.iter_mut().zip(rhs_rec.fields.iter_mut()) {
let maybe_res = self.lower_expr_binary_op(stmt::BinaryOp::Eq, l, r);
assert!(maybe_res.is_none(), "TODO");
}
} else {
let maybe_res = self.lower_expr_binary_op(stmt::BinaryOp::Eq, lhs, rhs);
assert!(maybe_res.is_none(), "TODO");
}
}
fn lower_expr_in_list(
&mut self,
expr: &mut stmt::Expr,
list: &mut stmt::Expr,
) -> Option<stmt::Expr> {
match (&mut *expr, list) {
(expr, stmt::Expr::Map(expr_map)) => {
assert!(expr_map.base.is_arg(), "TODO");
let maybe_res =
self.lower_expr_binary_op(stmt::BinaryOp::Eq, expr, &mut expr_map.map);
assert!(maybe_res.is_none(), "TODO");
None
}
(stmt::Expr::Cast(expr_cast), list) => {
let target_ty = self.capability().native_type_for(&expr_cast.ty);
self.cast_expr(expr, &target_ty);
match list {
stmt::Expr::List(expr_list) => {
for item in &mut expr_list.items {
self.cast_expr(item, &target_ty);
}
}
stmt::Expr::Value(stmt::Value::List(items)) => {
for item in items {
*item = target_ty.cast(item.take()).expect("failed to cast value");
}
}
stmt::Expr::Arg(_) => {
let arg = list.take();
let cast = stmt::Expr::cast(stmt::Expr::arg(0), target_ty);
*list = stmt::Expr::map(arg, cast);
}
_ => todo!("expr={expr:#?}; list={list:#?}"),
}
None
}
(stmt::Expr::Record(lhs), stmt::Expr::List(list)) => {
for lhs in lhs {
assert!(lhs.is_column());
}
for item in &mut list.items {
assert!(item.is_value());
}
None
}
(stmt::Expr::Record(lhs), stmt::Expr::Value(stmt::Value::List(_))) => {
for lhs in lhs {
assert!(lhs.is_column());
}
None
}
(stmt::Expr::Reference(expr_reference), list) => {
assert!(expr_reference.is_column());
match list {
stmt::Expr::Value(stmt::Value::List(_)) => {}
stmt::Expr::List(list) => {
for item in &list.items {
assert!(item.is_value());
}
}
_ => panic!("invalid; should have been caught earlier"),
}
None
}
(expr, list) => todo!("expr={expr:#?}; list={list:#?}"),
}
}
fn apply_lowering_filter_constraint(&self, _filter: &mut stmt::Filter) {}
fn lower_expr_field(&self, nesting: usize, index: usize) -> stmt::Expr {
match self.cx {
LoweringContext::Statement | LoweringContext::Returning(_) => {
let mapping = self.mapping_at_unwrap(nesting);
mapping.table_to_model.lower_expr_reference(nesting, index)
}
LoweringContext::InsertRow(row) => {
if nesting > 0 {
let mapping = self.mapping_at_unwrap(nesting);
mapping.table_to_model.lower_expr_reference(nesting, index)
} else {
row.entry(index).unwrap().to_expr()
}
}
_ => todo!("cx={:#?}", self.cx),
}
}
fn new_ref(
&mut self,
source_id: hir::StmtId,
target_id: hir::StmtId,
mut expr_reference: stmt::ExprReference,
) -> usize {
let stmt::ExprReference::Column(expr_column) = &mut expr_reference else {
todo!()
};
let nesting = expr_column.nesting;
debug_assert!(nesting != 0, "expr_reference={expr_reference:#?}");
expr_column.nesting = 0;
let target = &mut self.state.hir[target_id];
target
.back_refs
.entry(source_id)
.or_default()
.exprs
.insert_full(expr_reference);
let source = &mut self.state.hir[source_id];
for (i, arg) in source.args.iter().enumerate() {
let hir::Arg::Ref {
target_expr_ref, ..
} = arg
else {
continue;
};
if *target_expr_ref == expr_reference {
return i;
}
}
let arg = source.args.len();
source.args.push(hir::Arg::Ref {
target_expr_ref: expr_reference,
stmt_id: target_id,
nesting,
data_load_input: Cell::new(None),
returning_input: Cell::new(None),
batch_load_index: if let Some(row_index) = self.state.scopes[self.scope_id].row_index {
debug_assert_eq!(1, nesting, "TODO");
Cell::new(Some(row_index))
} else {
Cell::new(None)
},
});
arg
}
fn new_statement_info(&mut self) -> hir::StmtId {
let mut deps = self.state.dependencies.clone();
deps.extend(&self.curr_stmt_info().deps);
self.state.hir.new_statement_info(deps)
}
fn new_sub_statement(
&mut self,
source_id: hir::StmtId,
target_id: hir::StmtId,
stmt: Box<stmt::Statement>,
) -> stmt::Expr {
self.state.hir[target_id].stmt = Some(stmt);
self.new_dependency_arg(source_id, target_id)
}
fn new_dependency_arg(&mut self, source_id: hir::StmtId, target_id: hir::StmtId) -> stmt::Expr {
let source = &mut self.state.hir[source_id];
let arg = source.args.len();
source.args.push(hir::Arg::Sub {
stmt_id: target_id,
returning: self.cx.is_returning(),
input: Cell::new(None),
batch_load_index: Cell::new(None),
});
stmt::Expr::arg(arg)
}
fn lower_sub_stmt(&mut self, stmt: stmt::Statement) -> stmt::Expr {
let source_id = self.scope_stmt_id();
let mut stmt = Box::new(stmt);
let target_id = self.scope_statement(|child| {
association::RewriteVia::new(child.expr_cx).rewrite(&mut stmt);
lift_in_subquery::LiftInSubquery::new(child.expr_cx).rewrite(&mut stmt);
Simplify::with_context(child.expr_cx, child.state.engine.capability)
.visit_mut(&mut *stmt);
child.visit_stmt_mut(&mut stmt);
child.state.engine.simplify_stmt(&mut *stmt);
});
let saved_cx = std::mem::replace(&mut self.cx, LoweringContext::Returning(None));
let arg = self.new_sub_statement(source_id, target_id, stmt);
self.cx = saved_cx;
if self.state.hir[target_id].independent {
self.curr_stmt_info().deps.insert(target_id);
}
arg
}
fn schema(&self) -> &'b Schema {
&self.state.engine.schema
}
fn capability(&self) -> &Capability {
self.state.engine.capability()
}
fn supports_any_rewrite(&self) -> bool {
let cap = self.capability();
cap.bind_list_param && cap.predicate_match_any
}
fn field(&self, id: impl Into<app::FieldId>) -> &'b app::Field {
self.schema().app.field(id.into())
}
fn model(&self) -> Option<&'a ModelRoot> {
self.expr_cx.target().as_model()
}
#[track_caller]
fn model_unwrap(&self) -> &'a ModelRoot {
self.expr_cx.target().as_model_unwrap()
}
fn mapping(&self) -> Option<&'b mapping::Model> {
self.model()
.map(|model| self.state.engine.schema.mapping_for(model))
}
#[track_caller]
fn mapping_unwrap(&self) -> &'b mapping::Model {
self.state.engine.schema.mapping_for(self.model_unwrap())
}
#[track_caller]
fn mapping_at_unwrap(&self, nesting: usize) -> &'b mapping::Model {
let model = self.expr_cx.target_at(nesting).as_model_unwrap();
self.state.engine.schema.mapping_for(model)
}
fn curr_stmt_info(&mut self) -> &mut hir::StatementInfo {
let stmt_id = self.scope_stmt_id();
&mut self.state.hir[stmt_id]
}
fn scope_stmt_id(&self) -> hir::StmtId {
self.state.scopes[self.scope_id].stmt_id
}
fn resolve_stmt_id(&self, nesting: usize) -> hir::StmtId {
debug_assert!(
self.scope_id >= nesting,
"invalid nesting; nesting={nesting:#?}; scopes={:#?}",
self.state.scopes
);
self.state.scopes[self.scope_id - nesting].stmt_id
}
fn scope_statement(&mut self, f: impl FnOnce(&mut LowerStatement<'_, '_>)) -> hir::StmtId {
let stmt_id = self.new_statement_info();
let row_index = match &self.cx {
LoweringContext::Insert(_, row_index) => *row_index,
LoweringContext::Returning(row_index) => *row_index,
_ => None,
};
let scope_id = self.state.scopes.push(Scope { stmt_id, row_index });
let mut dependencies = None;
let mut lower = LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id,
cx: LoweringContext::Statement,
collect_dependencies: &mut dependencies,
};
f(&mut lower);
debug_assert!(dependencies.is_none());
self.state.scopes.pop();
stmt_id
}
fn scope_expr<'child>(
&'child mut self,
target: impl IntoExprTarget<'child>,
) -> LowerStatement<'child, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx.scope(target),
scope_id: self.scope_id,
cx: self.cx,
collect_dependencies: self.collect_dependencies,
}
}
fn lower_insert<'child>(
&'child mut self,
target: &'child stmt::InsertTarget,
) -> LowerStatement<'child, 'b> {
let columns = match target {
stmt::InsertTarget::Scope(_) => {
panic!("InsertTarget::Scope should already have been lowered by this point")
}
stmt::InsertTarget::Model(model_id) => &self.schema().mapping_for(model_id).columns,
stmt::InsertTarget::Table(insert_table) => &insert_table.columns,
};
LowerStatement {
state: self.state,
expr_cx: self.expr_cx.scope(target),
scope_id: self.scope_id,
cx: LoweringContext::Insert(columns, None),
collect_dependencies: self.collect_dependencies,
}
}
fn lower_insert_with_row(&mut self, row: usize, f: impl FnOnce(&mut Self)) {
let LoweringContext::Insert(_, maybe_row) = &mut self.cx else {
todo!()
};
debug_assert!(maybe_row.is_none());
*maybe_row = Some(row);
f(self);
let LoweringContext::Insert(_, maybe_row) = &mut self.cx else {
todo!()
};
debug_assert_eq!(Some(row), *maybe_row);
*maybe_row = None;
}
fn lower_insert_row<'child>(
&'child mut self,
row: &'child stmt::Expr,
) -> LowerStatement<'child, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id: self.scope_id,
cx: LoweringContext::InsertRow(row),
collect_dependencies: self.collect_dependencies,
}
}
fn lower_returning(&mut self) -> LowerStatement<'_, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id: self.scope_id,
cx: LoweringContext::Returning(None),
collect_dependencies: self.collect_dependencies,
}
}
fn lower_returning_for_row<'child>(
&'child mut self,
row_index: usize,
) -> LowerStatement<'child, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id: self.scope_id,
cx: LoweringContext::Returning(Some(row_index)),
collect_dependencies: self.collect_dependencies,
}
}
fn cast_expr(&mut self, expr: &mut stmt::Expr, target_ty: &stmt::Type) {
assert!(!target_ty.is_list(), "TODO");
match expr {
stmt::Expr::Cast(expr_cast) => {
*expr = expr_cast.expr.take();
}
stmt::Expr::Value(value) => {
let casted = target_ty.cast(value.take()).expect("failed to cast value");
*value = casted;
}
stmt::Expr::Project(_) => {
todo!()
}
stmt::Expr::Arg(_) => {
let base = expr.take();
*expr = stmt::Expr::cast(base, target_ty.clone());
}
_ => todo!("cast_expr: cannot cast {expr:#?} to {target_ty:?}"),
}
}
}
impl LoweringContext<'_> {
fn is_insert(&self) -> bool {
matches!(self, LoweringContext::Insert { .. })
}
fn is_returning(&self) -> bool {
matches!(self, LoweringContext::Returning(_))
}
}
pub(super) fn key_field_refs(
nesting: usize,
mut fields: impl ExactSizeIterator<Item = app::FieldId>,
) -> stmt::Expr {
if fields.len() == 1 {
stmt::Expr::ref_field(nesting, fields.next().unwrap())
} else {
stmt::Expr::record(fields.map(|field| stmt::Expr::ref_field(nesting, field)))
}
}
struct AssignmentInput<'a> {
assignment_projection: stmt::Projection,
value: &'a stmt::Expr,
}
impl stmt::Input for AssignmentInput<'_> {
fn resolve_ref(
&mut self,
expr_reference: &stmt::ExprReference,
expr_projection: &stmt::Projection,
) -> Option<stmt::Expr> {
let stmt::ExprReference::Field { nesting: 0, index } = expr_reference else {
return None;
};
let assignment_steps = self.assignment_projection.as_slice();
if *index != assignment_steps[0] {
return None;
}
let remaining_steps = &assignment_steps[1..];
if expr_projection.as_slice() == remaining_steps {
Some(self.value.clone())
} else {
self.value.entry(expr_projection).map(|e| e.to_expr())
}
}
}
fn build_update_returning(
model_id: app::ModelId,
root_field_id: Option<app::FieldId>,
mapping_fields: &[mapping::Field],
changed_bits: &stmt::PathFieldSet,
) -> stmt::Expr {
let mut exprs = vec![];
let mut field_set = stmt::PathFieldSet::new();
for (i, mf) in mapping_fields.iter().enumerate() {
let intersection = changed_bits.clone() & mf.field_mask();
if intersection.is_empty() {
continue;
}
field_set.insert(i);
if mf.is_relation() {
exprs.push(stmt::Expr::null());
} else {
let root_field_id = root_field_id.unwrap_or(app::FieldId {
model: model_id,
index: i,
});
if intersection == mf.field_mask() {
let base = stmt::Expr::ref_self_field(root_field_id);
let expr = if mf.sub_projection().is_identity() {
base
} else {
stmt::Expr::project(base, mf.sub_projection().clone())
};
exprs.push(expr);
} else {
let emb_mapping = mf.as_struct().unwrap();
exprs.push(build_update_returning(
model_id,
Some(root_field_id),
&emb_mapping.fields,
&intersection,
));
}
}
}
stmt::Expr::cast(
stmt::ExprRecord::from_vec(exprs),
stmt::Type::SparseRecord(field_set),
)
}
fn in_list_is_value_list(e: &stmt::ExprInList) -> bool {
if matches!(*e.expr, stmt::Expr::Record(_)) {
return false;
}
let scalar = |v: &stmt::Value| !matches!(v, stmt::Value::Record(_) | stmt::Value::List(_));
match &*e.list {
stmt::Expr::Value(stmt::Value::List(items)) => items.iter().all(scalar),
stmt::Expr::List(list) => list
.items
.iter()
.all(|i| matches!(i, stmt::Expr::Value(v) if scalar(v))),
_ => false,
}
}