mod insert;
mod paginate;
mod relation;
mod returning;
use std::{cell::Cell, collections::HashSet};
use index_vec::IndexVec;
use toasty_core::{
Result, Schema,
driver::Capability,
schema::{
app::{self, FieldTy, ModelRoot},
db::ColumnId,
mapping,
},
stmt::{self, IntoExprTarget, VisitMut, visit_mut},
};
use crate::engine::{Engine, HirStatement, hir, simplify::Simplify};
impl Engine {
pub(super) fn lower_stmt(&self, stmt: stmt::Statement) -> Result<HirStatement> {
let schema = &self.schema;
let mut state = LoweringState {
hir: HirStatement::new(),
scopes: IndexVec::new(),
engine: self,
relations: vec![],
errors: vec![],
dependencies: HashSet::new(),
};
state.lower_stmt(stmt::ExprContext::new(schema), None, stmt);
if let Some(err) = state.errors.into_iter().next() {
return Err(err);
}
Ok(state.hir)
}
}
impl LoweringState<'_> {
fn lower_stmt(
&mut self,
expr_cx: stmt::ExprContext,
row_index: Option<usize>,
mut stmt: stmt::Statement,
) -> hir::StmtId {
Simplify::with_context(expr_cx).visit_mut(&mut stmt);
let stmt_id = self.hir.new_statement_info(self.dependencies.clone());
let scope_id = self.scopes.push(Scope { stmt_id, row_index });
let mut collect_dependencies = None;
LowerStatement {
state: self,
expr_cx,
scope_id,
cx: LoweringContext::Statement,
collect_dependencies: &mut collect_dependencies,
}
.visit_stmt_mut(&mut stmt);
self.engine.simplify_stmt(&mut stmt);
let stmt_info = &mut self.hir[stmt_id];
stmt_info.stmt = Some(Box::new(stmt));
self.scopes.pop();
debug_assert!(collect_dependencies.is_none());
stmt_id
}
}
struct LowerStatement<'a, 'b> {
state: &'a mut LoweringState<'b>,
expr_cx: stmt::ExprContext<'a>,
scope_id: ScopeId,
cx: LoweringContext<'a>,
collect_dependencies: &'a mut Option<HashSet<hir::StmtId>>,
}
#[derive(Debug)]
struct LoweringState<'a> {
engine: &'a Engine,
hir: HirStatement,
scopes: IndexVec<ScopeId, Scope>,
relations: Vec<app::FieldId>,
dependencies: HashSet<hir::StmtId>,
errors: Vec<crate::Error>,
}
#[derive(Debug, Clone, Copy)]
enum LoweringContext<'a> {
Insert(&'a [ColumnId], Option<usize>),
InsertRow(&'a stmt::Expr),
Returning(Option<usize>),
Statement,
}
#[derive(Debug)]
struct Scope {
stmt_id: hir::StmtId,
row_index: Option<usize>,
}
index_vec::define_index_type! {
struct ScopeId = u32;
}
impl LowerStatement<'_, '_> {
fn new_dependency(&mut self, stmt: impl Into<stmt::Statement>) -> hir::StmtId {
let row_index = match self.cx {
LoweringContext::Insert(_, row_index) => row_index,
LoweringContext::Returning(row_index) => row_index,
_ => None,
};
let stmt_id = self.state.lower_stmt(self.expr_cx, row_index, stmt.into());
if let Some(dependencies) = &mut self.collect_dependencies {
dependencies.insert(stmt_id);
}
self.curr_stmt_info().deps.insert(stmt_id);
stmt_id
}
fn collect_dependencies(
&mut self,
f: impl FnOnce(&mut LowerStatement<'_, '_>),
) -> HashSet<hir::StmtId> {
let old = self.collect_dependencies.replace(HashSet::new());
f(self);
std::mem::replace(self.collect_dependencies, old).unwrap()
}
fn track_dependency(&mut self, dependency: hir::StmtId) {
self.curr_stmt_info().deps.insert(dependency);
}
fn with_dependencies(
&mut self,
mut dependencies: HashSet<hir::StmtId>,
f: impl FnOnce(&mut LowerStatement<'_, '_>),
) {
dependencies.extend(&self.state.dependencies);
let old = std::mem::replace(&mut self.state.dependencies, dependencies);
f(self);
self.state.dependencies = old;
}
}
impl visit_mut::VisitMut for LowerStatement<'_, '_> {
fn visit_order_by_expr_mut(&mut self, node: &mut stmt::OrderByExpr) {
self.visit_expr_mut(&mut node.expr);
let mut lhs = node.expr.clone();
let mut rhs = node.expr.take();
self.lower_expr_binary_op(stmt::BinaryOp::Eq, &mut lhs, &mut rhs);
node.expr = lhs;
}
fn visit_assignments_mut(&mut self, i: &mut stmt::Assignments) {
let mut assignments = stmt::Assignments::default();
let mapping = self.mapping_unwrap();
for (projection, assignment) in &*i {
let stmt::Assignment::Set(expr) = assignment else {
todo!("only SET supported; got {assignment:#?}");
};
let mut lowered_expr = expr.clone();
self.visit_expr_mut(&mut lowered_expr);
let Some(field) = mapping.resolve_field_mapping(projection) else {
self.state
.errors
.push(crate::Error::invalid_statement(format!(
"invalid assignment projection: {:?}",
projection
)));
continue;
};
for (column, lowering_idx) in field.columns() {
let mut lowering_expr = mapping.model_to_table[lowering_idx].clone();
let input = AssignmentInput {
assignment_projection: projection.clone(),
value: &lowered_expr,
};
lowering_expr.substitute(input);
self.visit_expr_mut(&mut lowering_expr);
assignments.set(column, lowering_expr);
}
}
*i = assignments;
}
fn visit_expr_set_op_mut(&mut self, i: &mut stmt::ExprSetOp) {
todo!("stmt={i:#?}");
}
fn visit_expr_mut(&mut self, expr: &mut stmt::Expr) {
match expr {
stmt::Expr::BinaryOp(e) => {
self.visit_expr_binary_op_mut(e);
if let Some(lowered) = self.lower_expr_binary_op(e.op, &mut e.lhs, &mut e.rhs) {
*expr = lowered;
}
}
stmt::Expr::InList(e) => {
self.visit_expr_in_list_mut(e);
if let Some(lowered) = self.lower_expr_in_list(&mut e.expr, &mut e.list) {
*expr = lowered;
}
}
stmt::Expr::InSubquery(e) => {
if self.capability().sql {
self.visit_expr_in_subquery_mut(e);
let maybe_res = self.lower_expr_binary_op(
stmt::BinaryOp::Eq,
&mut e.expr,
e.query.returning_mut_unwrap().as_expr_mut_unwrap(),
);
assert!(maybe_res.is_none(), "TODO");
let returning = e.query.returning_mut_unwrap().as_expr_mut_unwrap();
if !returning.is_record() {
*returning = stmt::Expr::record([returning.take()]);
}
} else {
self.visit_expr_mut(&mut e.expr);
let source_id = self.scope_stmt_id();
let target_id = self.scope_statement(|child| {
child.visit_stmt_query_mut(&mut e.query);
});
let target_stmt_info = &self.state.hir[target_id];
debug_assert!(target_stmt_info.args.is_empty(), "TODO");
debug_assert!(target_stmt_info.back_refs.is_empty(), "TODO");
self.track_dependency(target_id);
let maybe_res = self.lower_expr_binary_op(
stmt::BinaryOp::Eq,
&mut e.expr,
e.query.returning_mut_unwrap().as_expr_mut_unwrap(),
);
assert!(maybe_res.is_none(), "TODO");
let stmt::Expr::InSubquery(e) = expr.take() else {
panic!()
};
let arg =
self.new_sub_statement(source_id, target_id, Box::new((*e.query).into()));
*expr = stmt::ExprInList {
expr: e.expr,
list: Box::new(arg),
}
.into();
}
}
stmt::Expr::IsVariant(e) => {
let enum_model = self
.schema()
.app
.model(e.variant.model)
.as_embedded_enum_unwrap();
let has_data = enum_model.has_data_variants();
let disc_value = enum_model.variants[e.variant.index].discriminant.clone();
self.visit_expr_mut(&mut e.expr);
let lowered_expr = e.expr.take();
if has_data {
*expr = stmt::Expr::eq(
stmt::Expr::project(lowered_expr, [0usize]),
stmt::Expr::Value(disc_value),
);
} else {
*expr = stmt::Expr::eq(lowered_expr, stmt::Expr::Value(disc_value));
}
}
stmt::Expr::Reference(expr_reference) => {
match expr_reference {
stmt::ExprReference::Field { nesting, index } => {
*expr = self.lower_expr_field(*nesting, *index);
self.visit_expr_mut(expr);
}
stmt::ExprReference::Model { .. } => todo!(),
stmt::ExprReference::Column(expr_column) => {
if expr_column.nesting > 0 {
let source_id = self.scope_stmt_id();
let target_id = self.resolve_stmt_id(expr_column.nesting);
debug_assert_eq!(self.state.scopes.len(), self.scope_id + 1);
for scope in self.state.scopes.iter().rev() {
if scope.stmt_id == target_id {
break;
}
self.state.hir[scope.stmt_id].independent = false;
}
let position = self.new_ref(source_id, target_id, *expr_reference);
*expr = stmt::Expr::arg(position);
}
}
}
}
stmt::Expr::Stmt(_) => {
let stmt::Expr::Stmt(mut expr_stmt) = expr.take() else {
panic!()
};
debug_assert!(
self.cx.is_returning() || matches!(self.cx, LoweringContext::Statement),
"cx={:#?}",
self.cx,
);
let source_id = self.scope_stmt_id();
let target_id = self.scope_statement(|child| {
visit_mut::visit_expr_stmt_mut(child, &mut expr_stmt);
});
self.state.engine.simplify_stmt(&mut *expr_stmt.stmt);
*expr = self.new_sub_statement(source_id, target_id, expr_stmt.stmt);
if self.state.hir[target_id].independent {
self.curr_stmt_info().deps.insert(target_id);
}
}
stmt::Expr::Exists(_) if !self.capability().sql => {
let stmt::Expr::Exists(mut expr_exists) = expr.take() else {
panic!()
};
let source_id = self.scope_stmt_id();
let target_id = self.scope_statement(|child| {
child.visit_stmt_query_mut(&mut expr_exists.subquery);
});
let mut stmt = stmt::Statement::Query(*expr_exists.subquery);
self.state.engine.simplify_stmt(&mut stmt);
let arg = self.new_sub_statement(source_id, target_id, Box::new(stmt));
if self.state.hir[target_id].independent {
self.curr_stmt_info().deps.insert(target_id);
}
let mut subquery = stmt::Query::values(arg);
subquery.single = true;
*expr = stmt::Expr::Exists(stmt::ExprExists {
subquery: Box::new(subquery),
});
}
_ => {
stmt::visit_mut::visit_expr_mut(self, expr);
}
}
}
fn visit_insert_target_mut(&mut self, i: &mut stmt::InsertTarget) {
match i {
stmt::InsertTarget::Scope(_) => todo!("stmt={i:#?}"),
stmt::InsertTarget::Model(model_id) => {
let mapping = self.schema().mapping_for(model_id);
*i = stmt::InsertTable {
table: mapping.table,
columns: mapping.columns.clone(),
}
.into();
}
_ => todo!(),
}
}
fn visit_update_target_mut(&mut self, i: &mut stmt::UpdateTarget) {
match i {
stmt::UpdateTarget::Query(_) => todo!("update_target={i:#?}"),
stmt::UpdateTarget::Model(model_id) => {
let table_id = self.schema().table_id_for(model_id);
*i = stmt::UpdateTarget::table(table_id);
}
stmt::UpdateTarget::Table(_) => {}
}
}
fn visit_expr_stmt_mut(&mut self, i: &mut stmt::ExprStmt) {
stmt::visit_mut::visit_expr_stmt_mut(self, i);
}
fn visit_returning_mut(&mut self, i: &mut stmt::Returning) {
if let stmt::Returning::Model { include } = i {
let include = std::mem::take(include);
let mut returning = self.mapping_unwrap().table_to_model.lower_returning_model();
for path in &include {
self.build_include_subquery(&mut returning, path);
}
*i = stmt::Returning::Expr(returning);
}
if matches!(&self.cx, LoweringContext::Insert(..))
&& let stmt::Returning::Value(stmt::Expr::List(list)) = i
{
for (index, item) in list.items.iter_mut().enumerate() {
self.lower_returning_for_row(index).visit_expr_mut(item);
}
return;
}
stmt::visit_mut::visit_returning_mut(&mut self.lower_returning(), i);
}
fn visit_stmt_delete_mut(&mut self, stmt: &mut stmt::Delete) {
let mut lower = self.scope_expr(&stmt.from);
lower.plan_stmt_delete_relations(stmt);
lower.visit_filter_mut(&mut stmt.filter);
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
}
lower.apply_lowering_filter_constraint(&mut stmt.filter);
self.visit_source_mut(&mut stmt.from);
}
fn visit_stmt_insert_mut(&mut self, stmt: &mut stmt::Insert) {
self.apply_insert_scope(&mut stmt.target, &mut stmt.source);
let mut lower = self.lower_insert(&stmt.target);
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
}
lower.preprocess_insert_values(&mut stmt.source, &mut stmt.returning);
lower.visit_stmt_query_mut(&mut stmt.source);
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
lower.constantize_insert_returning(returning, &stmt.source);
if stmt.source.single
&& let stmt::Returning::Value(expr) = &returning
{
debug_assert!(!expr.is_list());
}
}
self.visit_insert_target_mut(&mut stmt.target);
}
fn visit_stmt_query_mut(&mut self, stmt: &mut stmt::Query) {
let mut lower = self.scope_expr(&stmt.body);
if let Some(with) = &mut stmt.with {
lower.visit_with_mut(with);
}
if let Some(order_by) = &mut stmt.order_by {
lower.visit_order_by_mut(order_by);
}
if let Some(limit) = &mut stmt.limit {
lower.visit_limit_mut(limit);
}
self.visit_expr_set_mut(&mut stmt.body);
self.rewrite_offset_after_as_filter(stmt);
}
fn visit_stmt_select_mut(&mut self, stmt: &mut stmt::Select) {
let mut lower = self.scope_expr(&stmt.source);
lower.visit_filter_mut(&mut stmt.filter);
lower.visit_returning_mut(&mut stmt.returning);
lower.apply_lowering_filter_constraint(&mut stmt.filter);
self.visit_source_mut(&mut stmt.source);
}
fn visit_stmt_update_mut(&mut self, stmt: &mut stmt::Update) {
let mut lower = self.scope_expr(&stmt.target);
let mut returning_changed = false;
if let Some(returning) = &mut stmt.returning
&& returning.is_changed()
{
returning_changed = true;
if let Some(model) = lower.model() {
let mapping = lower.mapping_unwrap();
let mut changed_bits = stmt::PathFieldSet::new();
for projection in stmt.assignments.keys() {
if let Some(mf) = mapping.resolve_field_mapping(projection) {
changed_bits |= mf.field_mask();
}
}
*returning = stmt::Returning::Expr(build_update_returning(
model.id,
None,
&mapping.fields,
&changed_bits,
));
}
}
lower.plan_stmt_update_relations(
&mut stmt.assignments,
&stmt.filter,
&mut stmt.returning,
returning_changed,
);
lower.visit_assignments_mut(&mut stmt.assignments);
lower.visit_filter_mut(&mut stmt.filter);
if let Some(expr) = &mut stmt.condition.expr {
lower.visit_expr_mut(expr);
}
if let Some(returning) = &mut stmt.returning {
lower.visit_returning_mut(returning);
lower.constantize_update_returning(returning, &stmt.assignments);
}
self.visit_update_target_mut(&mut stmt.target);
}
fn visit_source_mut(&mut self, stmt: &mut stmt::Source) {
if let stmt::Source::Model(source_model) = stmt {
debug_assert!(source_model.via.is_none(), "TODO");
let table_id = self.schema().table_id_for(source_model.id);
*stmt = stmt::Source::table(table_id);
}
}
fn visit_values_mut(&mut self, stmt: &mut stmt::Values) {
if self.cx.is_insert()
&& let Some(mapping) = self.mapping()
{
for row in &mut stmt.rows {
let mut lowered = mapping.model_to_table.clone();
self.lower_insert_row(row)
.visit_expr_record_mut(&mut lowered);
*row = lowered.into();
}
return;
}
visit_mut::visit_values_mut(self, stmt);
}
}
impl<'a, 'b> LowerStatement<'a, 'b> {
fn lower_expr_binary_op(
&mut self,
op: stmt::BinaryOp,
lhs: &mut stmt::Expr,
rhs: &mut stmt::Expr,
) -> Option<stmt::Expr> {
match (&mut *lhs, &mut *rhs) {
(stmt::Expr::Value(value), other) | (other, stmt::Expr::Value(value))
if value.is_null() =>
{
let other = other.take();
Some(match op {
stmt::BinaryOp::Eq => stmt::Expr::is_null(other),
stmt::BinaryOp::Ne => stmt::Expr::is_not_null(other),
_ => todo!(),
})
}
(stmt::Expr::Cast(expr_cast), _) | (_, stmt::Expr::Cast(expr_cast)) => {
let target_ty = self.capability().native_type_for(&expr_cast.ty);
self.cast_expr(lhs, &target_ty);
self.cast_expr(rhs, &target_ty);
None
}
_ => None,
}
}
fn lower_expr_in_list(
&mut self,
expr: &mut stmt::Expr,
list: &mut stmt::Expr,
) -> Option<stmt::Expr> {
match (&mut *expr, list) {
(expr, stmt::Expr::Map(expr_map)) => {
assert!(expr_map.base.is_arg(), "TODO");
let maybe_res =
self.lower_expr_binary_op(stmt::BinaryOp::Eq, expr, &mut expr_map.map);
assert!(maybe_res.is_none(), "TODO");
None
}
(stmt::Expr::Cast(expr_cast), list) => {
let target_ty = self.capability().native_type_for(&expr_cast.ty);
self.cast_expr(expr, &target_ty);
match list {
stmt::Expr::List(expr_list) => {
for item in &mut expr_list.items {
self.cast_expr(item, &target_ty);
}
}
stmt::Expr::Value(stmt::Value::List(items)) => {
for item in items {
*item = target_ty.cast(item.take()).expect("failed to cast value");
}
}
stmt::Expr::Arg(_) => {
let arg = list.take();
let cast = stmt::Expr::cast(stmt::Expr::arg(0), target_ty);
*list = stmt::Expr::map(arg, cast);
}
_ => todo!("expr={expr:#?}; list={list:#?}"),
}
None
}
(stmt::Expr::Record(lhs), stmt::Expr::List(list)) => {
for lhs in lhs {
assert!(lhs.is_column());
}
for item in &mut list.items {
assert!(item.is_value());
}
None
}
(stmt::Expr::Record(lhs), stmt::Expr::Value(stmt::Value::List(_))) => {
for lhs in lhs {
assert!(lhs.is_column());
}
None
}
(stmt::Expr::Reference(expr_reference), list) => {
assert!(expr_reference.is_column());
match list {
stmt::Expr::Value(stmt::Value::List(_)) => {}
stmt::Expr::List(list) => {
for item in &list.items {
assert!(item.is_value());
}
}
_ => panic!("invalid; should have been caught earlier"),
}
None
}
(expr, list) => todo!("expr={expr:#?}; list={list:#?}"),
}
}
fn apply_lowering_filter_constraint(&self, _filter: &mut stmt::Filter) {}
fn lower_expr_field(&self, nesting: usize, index: usize) -> stmt::Expr {
match self.cx {
LoweringContext::Statement | LoweringContext::Returning(_) => {
let mapping = self.mapping_at_unwrap(nesting);
mapping.table_to_model.lower_expr_reference(nesting, index)
}
LoweringContext::InsertRow(row) => {
if nesting > 0 {
let mapping = self.mapping_at_unwrap(nesting);
mapping.table_to_model.lower_expr_reference(nesting, index)
} else {
row.entry(index).unwrap().to_expr()
}
}
_ => todo!("cx={:#?}", self.cx),
}
}
fn build_include_subquery(&mut self, returning: &mut stmt::Expr, path: &stmt::Path) {
let projection = &path.projection[..];
let (field_index, rest) = match projection {
[] => panic!("Empty include path"),
[first, rest @ ..] => (first, rest),
};
let field = &self.model_unwrap().fields[*field_index];
let (mut stmt, target_model_id) = match &field.ty {
FieldTy::HasMany(rel) => (
stmt::Query::new_select(
rel.target,
stmt::Expr::eq(
stmt::Expr::ref_parent_model(),
stmt::Expr::ref_self_field(rel.pair),
),
),
rel.target,
),
FieldTy::BelongsTo(rel) => {
let source_fk;
let target_pk;
if let [fk_field] = &rel.foreign_key.fields[..] {
source_fk = stmt::Expr::ref_parent_field(fk_field.source);
target_pk = stmt::Expr::ref_self_field(fk_field.target);
} else {
let mut source_fk_fields = vec![];
let mut target_pk_fields = vec![];
for fk_field in &rel.foreign_key.fields {
source_fk_fields.push(stmt::Expr::ref_parent_field(fk_field.source));
target_pk_fields.push(stmt::Expr::ref_parent_field(fk_field.source));
}
source_fk = stmt::Expr::record_from_vec(source_fk_fields);
target_pk = stmt::Expr::record_from_vec(target_pk_fields);
}
let mut query =
stmt::Query::new_select(rel.target, stmt::Expr::eq(source_fk, target_pk));
query.single = true;
(query, rel.target)
}
FieldTy::HasOne(rel) => {
let mut query = stmt::Query::new_select(
rel.target,
stmt::Expr::eq(
stmt::Expr::ref_parent_model(),
stmt::Expr::ref_self_field(rel.pair),
),
);
query.single = true;
(query, rel.target)
}
_ => todo!(),
};
if !rest.is_empty() {
let remaining_path = stmt::Path {
root: stmt::PathRoot::Model(target_model_id),
projection: stmt::Projection::from(rest),
};
stmt.include(remaining_path);
}
Simplify::with_context(self.expr_cx).visit_stmt_query_mut(&mut stmt);
let mut sub_expr = stmt::Expr::stmt(stmt);
if field.nullable() && !field.ty.is_has_many() {
sub_expr = stmt::Expr::Let(stmt::ExprLet {
bindings: vec![sub_expr],
body: Box::new(stmt::Expr::match_expr(
stmt::Expr::arg(0),
vec![stmt::MatchArm {
pattern: stmt::Value::Null,
expr: stmt::Expr::from(0i64),
}],
stmt::Expr::arg(0),
)),
});
}
returning.entry_mut(*field_index).insert(sub_expr);
}
fn new_ref(
&mut self,
source_id: hir::StmtId,
target_id: hir::StmtId,
mut expr_reference: stmt::ExprReference,
) -> usize {
let stmt::ExprReference::Column(expr_column) = &mut expr_reference else {
todo!()
};
let nesting = expr_column.nesting;
debug_assert!(nesting != 0, "expr_reference={expr_reference:#?}");
expr_column.nesting = 0;
let target = &mut self.state.hir[target_id];
target
.back_refs
.entry(source_id)
.or_default()
.exprs
.insert_full(expr_reference);
let source = &mut self.state.hir[source_id];
for (i, arg) in source.args.iter().enumerate() {
let hir::Arg::Ref {
target_expr_ref, ..
} = arg
else {
continue;
};
if *target_expr_ref == expr_reference {
return i;
}
}
let arg = source.args.len();
source.args.push(hir::Arg::Ref {
target_expr_ref: expr_reference,
stmt_id: target_id,
nesting,
data_load_input: Cell::new(None),
returning_input: Cell::new(None),
batch_load_index: if let Some(row_index) = self.state.scopes[self.scope_id].row_index {
debug_assert_eq!(1, nesting, "TODO");
Cell::new(Some(row_index))
} else {
Cell::new(None)
},
});
arg
}
fn new_statement_info(&mut self) -> hir::StmtId {
let mut deps = self.state.dependencies.clone();
deps.extend(&self.curr_stmt_info().deps);
self.state.hir.new_statement_info(deps)
}
fn new_sub_statement(
&mut self,
source_id: hir::StmtId,
target_id: hir::StmtId,
stmt: Box<stmt::Statement>,
) -> stmt::Expr {
self.state.hir[target_id].stmt = Some(stmt);
self.new_dependency_arg(source_id, target_id)
}
fn new_dependency_arg(&mut self, source_id: hir::StmtId, target_id: hir::StmtId) -> stmt::Expr {
let source = &mut self.state.hir[source_id];
let arg = source.args.len();
source.args.push(hir::Arg::Sub {
stmt_id: target_id,
returning: self.cx.is_returning(),
input: Cell::new(None),
batch_load_index: Cell::new(None),
});
stmt::Expr::arg(arg)
}
fn schema(&self) -> &'b Schema {
&self.state.engine.schema
}
fn capability(&self) -> &Capability {
self.state.engine.capability()
}
fn field(&self, id: impl Into<app::FieldId>) -> &'b app::Field {
self.schema().app.field(id.into())
}
fn model(&self) -> Option<&'a ModelRoot> {
self.expr_cx.target().as_model()
}
#[track_caller]
fn model_unwrap(&self) -> &'a ModelRoot {
self.expr_cx.target().as_model_unwrap()
}
fn mapping(&self) -> Option<&'b mapping::Model> {
self.model()
.map(|model| self.state.engine.schema.mapping_for(model))
}
#[track_caller]
fn mapping_unwrap(&self) -> &'b mapping::Model {
self.state.engine.schema.mapping_for(self.model_unwrap())
}
#[track_caller]
fn mapping_at_unwrap(&self, nesting: usize) -> &'b mapping::Model {
let model = self.expr_cx.target_at(nesting).as_model_unwrap();
self.state.engine.schema.mapping_for(model)
}
fn curr_stmt_info(&mut self) -> &mut hir::StatementInfo {
let stmt_id = self.scope_stmt_id();
&mut self.state.hir[stmt_id]
}
fn scope_stmt_id(&self) -> hir::StmtId {
self.state.scopes[self.scope_id].stmt_id
}
fn resolve_stmt_id(&self, nesting: usize) -> hir::StmtId {
debug_assert!(
self.scope_id >= nesting,
"invalid nesting; nesting={nesting:#?}; scopes={:#?}",
self.state.scopes
);
self.state.scopes[self.scope_id - nesting].stmt_id
}
fn scope_statement(&mut self, f: impl FnOnce(&mut LowerStatement<'_, '_>)) -> hir::StmtId {
let stmt_id = self.new_statement_info();
let row_index = match &self.cx {
LoweringContext::Insert(_, row_index) => *row_index,
LoweringContext::Returning(row_index) => *row_index,
_ => None,
};
let scope_id = self.state.scopes.push(Scope { stmt_id, row_index });
let mut dependencies = None;
let mut lower = LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id,
cx: LoweringContext::Statement,
collect_dependencies: &mut dependencies,
};
f(&mut lower);
debug_assert!(dependencies.is_none());
self.state.scopes.pop();
stmt_id
}
fn scope_expr<'child>(
&'child mut self,
target: impl IntoExprTarget<'child>,
) -> LowerStatement<'child, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx.scope(target),
scope_id: self.scope_id,
cx: self.cx,
collect_dependencies: self.collect_dependencies,
}
}
fn lower_insert<'child>(
&'child mut self,
target: &'child stmt::InsertTarget,
) -> LowerStatement<'child, 'b> {
let columns = match target {
stmt::InsertTarget::Scope(_) => {
panic!("InsertTarget::Scope should already have been lowered by this point")
}
stmt::InsertTarget::Model(model_id) => &self.schema().mapping_for(model_id).columns,
stmt::InsertTarget::Table(insert_table) => &insert_table.columns,
};
LowerStatement {
state: self.state,
expr_cx: self.expr_cx.scope(target),
scope_id: self.scope_id,
cx: LoweringContext::Insert(columns, None),
collect_dependencies: self.collect_dependencies,
}
}
fn lower_insert_with_row(&mut self, row: usize, f: impl FnOnce(&mut Self)) {
let LoweringContext::Insert(_, maybe_row) = &mut self.cx else {
todo!()
};
debug_assert!(maybe_row.is_none());
*maybe_row = Some(row);
f(self);
let LoweringContext::Insert(_, maybe_row) = &mut self.cx else {
todo!()
};
debug_assert_eq!(Some(row), *maybe_row);
*maybe_row = None;
}
fn lower_insert_row<'child>(
&'child mut self,
row: &'child stmt::Expr,
) -> LowerStatement<'child, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id: self.scope_id,
cx: LoweringContext::InsertRow(row),
collect_dependencies: self.collect_dependencies,
}
}
fn lower_returning(&mut self) -> LowerStatement<'_, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id: self.scope_id,
cx: LoweringContext::Returning(None),
collect_dependencies: self.collect_dependencies,
}
}
fn lower_returning_for_row<'child>(
&'child mut self,
row_index: usize,
) -> LowerStatement<'child, 'b> {
LowerStatement {
state: self.state,
expr_cx: self.expr_cx,
scope_id: self.scope_id,
cx: LoweringContext::Returning(Some(row_index)),
collect_dependencies: self.collect_dependencies,
}
}
fn cast_expr(&mut self, expr: &mut stmt::Expr, target_ty: &stmt::Type) {
assert!(!target_ty.is_list(), "TODO");
match expr {
stmt::Expr::Cast(expr_cast) => {
*expr = expr_cast.expr.take();
}
stmt::Expr::Value(value) => {
let casted = target_ty.cast(value.take()).expect("failed to cast value");
*value = casted;
}
stmt::Expr::Project(_) => {
todo!()
}
stmt::Expr::Arg(_) => {
let base = expr.take();
*expr = stmt::Expr::cast(base, target_ty.clone());
}
_ => todo!("cast_expr: cannot cast {expr:#?} to {target_ty:?}"),
}
}
}
impl LoweringContext<'_> {
fn is_insert(&self) -> bool {
matches!(self, LoweringContext::Insert { .. })
}
fn is_returning(&self) -> bool {
matches!(self, LoweringContext::Returning(_))
}
}
struct AssignmentInput<'a> {
assignment_projection: stmt::Projection,
value: &'a stmt::Expr,
}
impl stmt::Input for AssignmentInput<'_> {
fn resolve_ref(
&mut self,
expr_reference: &stmt::ExprReference,
expr_projection: &stmt::Projection,
) -> Option<stmt::Expr> {
let stmt::ExprReference::Field { nesting: 0, index } = expr_reference else {
return None;
};
let assignment_steps = self.assignment_projection.as_slice();
if *index != assignment_steps[0] {
return None;
}
let remaining_steps = &assignment_steps[1..];
if expr_projection.as_slice() == remaining_steps {
Some(self.value.clone())
} else {
self.value.entry(expr_projection).map(|e| e.to_expr())
}
}
}
fn build_update_returning(
model_id: app::ModelId,
root_field_id: Option<app::FieldId>,
mapping_fields: &[mapping::Field],
changed_bits: &stmt::PathFieldSet,
) -> stmt::Expr {
let mut exprs = vec![];
let mut field_set = stmt::PathFieldSet::new();
for (i, mf) in mapping_fields.iter().enumerate() {
let intersection = changed_bits.clone() & mf.field_mask();
if intersection.is_empty() {
continue;
}
field_set.insert(i);
if mf.is_relation() {
exprs.push(stmt::Expr::null());
} else {
let root_field_id = root_field_id.unwrap_or(app::FieldId {
model: model_id,
index: i,
});
if intersection == mf.field_mask() {
let base = stmt::Expr::ref_self_field(root_field_id);
let expr = if mf.sub_projection().is_identity() {
base
} else {
stmt::Expr::project(base, mf.sub_projection().clone())
};
exprs.push(expr);
} else {
let emb_mapping = mf.as_struct().unwrap();
exprs.push(build_update_returning(
model_id,
Some(root_field_id),
&emb_mapping.fields,
&intersection,
));
}
}
}
stmt::Expr::cast(
stmt::ExprRecord::from_vec(exprs),
stmt::Type::SparseRecord(field_set),
)
}