use toasty_core::{
schema::app::{BelongsTo, FieldId, FieldTy, ModelId},
stmt::{self, Expr, ExprContext, IntoExprTarget, ResolvedRef, Visit, VisitMut},
};
pub(super) struct LiftInSubquery<'a> {
cx: ExprContext<'a>,
}
impl<'a> LiftInSubquery<'a> {
pub(super) fn new(cx: ExprContext<'a>) -> Self {
Self { cx }
}
pub(super) fn rewrite(&mut self, stmt: &mut stmt::Statement) {
self.visit_mut(stmt);
}
fn scope<'scope>(&'scope self, target: impl IntoExprTarget<'scope>) -> LiftInSubquery<'scope> {
LiftInSubquery {
cx: self.cx.scope(target),
}
}
}
impl VisitMut for LiftInSubquery<'_> {
fn visit_expr_mut(&mut self, expr: &mut stmt::Expr) {
match expr {
stmt::Expr::InSubquery(e) => {
if let Some(lifted) = lift_in_subquery(&self.cx, &e.expr, &e.query) {
*expr = lifted;
}
}
stmt::Expr::BinaryOp(e) => {
if let Some(lifted) =
try_lift_relation_path_comparison(&self.cx, e.op, &e.lhs, &e.rhs)
{
*expr = lifted;
} else if let Some(lifted) =
try_lift_relation_path_comparison(&self.cx, e.op.commute(), &e.rhs, &e.lhs)
{
*expr = lifted;
}
}
_ => {}
}
stmt::visit_mut::visit_expr_mut(self, expr);
}
fn visit_stmt_delete_mut(&mut self, stmt: &mut stmt::Delete) {
self.visit_source_mut(&mut stmt.from);
let mut s = self.scope(&stmt.from);
s.visit_filter_mut(&mut stmt.filter);
if let Some(returning) = &mut stmt.returning {
s.visit_returning_mut(returning);
}
}
fn visit_stmt_insert_mut(&mut self, stmt: &mut stmt::Insert) {
self.visit_insert_target_mut(&mut stmt.target);
let mut s = self.scope(&stmt.target);
s.visit_stmt_query_mut(&mut stmt.source);
if let Some(returning) = &mut stmt.returning {
s.visit_returning_mut(returning);
}
}
fn visit_stmt_select_mut(&mut self, stmt: &mut stmt::Select) {
self.visit_source_mut(&mut stmt.source);
let mut s = self.scope(&stmt.source);
s.visit_filter_mut(&mut stmt.filter);
s.visit_returning_mut(&mut stmt.returning);
}
fn visit_stmt_update_mut(&mut self, stmt: &mut stmt::Update) {
self.visit_update_target_mut(&mut stmt.target);
let mut s = self.scope(&stmt.target);
s.visit_assignments_mut(&mut stmt.assignments);
s.visit_filter_mut(&mut stmt.filter);
if let Some(expr) = &mut stmt.condition.expr {
s.visit_expr_mut(expr);
}
if let Some(returning) = &mut stmt.returning {
s.visit_returning_mut(returning);
}
}
}
struct LiftBelongsTo<'a> {
cx: ExprContext<'a>,
belongs_to: &'a BelongsTo,
fk_field_matches: Vec<bool>,
fail: bool,
operands: Vec<stmt::Expr>,
}
pub(super) fn lift_in_subquery(
cx: &ExprContext,
expr: &stmt::Expr,
query: &stmt::Query,
) -> Option<stmt::Expr> {
let field = match expr {
stmt::Expr::Project(_) => {
todo!()
}
stmt::Expr::Reference(expr_reference @ stmt::ExprReference::Field { .. }) => {
cx.resolve_expr_reference(expr_reference).as_field_unwrap()
}
_ => {
return None;
}
};
match &field.ty {
FieldTy::BelongsTo(belongs_to) => lift_belongs_to_in_subquery(cx, belongs_to, query),
FieldTy::HasOne(has_one) => {
lift_has_n_in_subquery(has_one.target, has_one.pair(&cx.schema().app), query)
}
FieldTy::HasMany(has_many) => {
lift_has_n_in_subquery(has_many.target, has_many.pair(&cx.schema().app), query)
}
_ => None,
}
}
pub(super) fn try_lift_relation_path_comparison(
cx: &ExprContext,
op: stmt::BinaryOp,
project_side: &stmt::Expr,
other_side: &stmt::Expr,
) -> Option<stmt::Expr> {
let Expr::Project(project_expr) = project_side else {
return None;
};
let Expr::Reference(expr_ref) = &*project_expr.base else {
return None;
};
let ResolvedRef::Field(field) = cx.resolve_expr_reference(expr_ref) else {
return None;
};
let target_model_id = match &field.ty {
FieldTy::HasOne(rel) => rel.target,
FieldTy::BelongsTo(rel) => rel.target,
FieldTy::HasMany(rel) => rel.target,
_ => return None,
};
let (head_idx, tail) = project_expr.projection.as_slice().split_first()?;
let target_field = Expr::ref_self_field(FieldId {
model: target_model_id,
index: *head_idx,
});
let target_lhs = if tail.is_empty() {
target_field
} else {
Expr::project(target_field, stmt::Projection::from(tail))
};
let subquery = stmt::Query::new_select(
stmt::Source::from(target_model_id),
Expr::binary_op(target_lhs, op, other_side.clone()),
);
lift_in_subquery(cx, &project_expr.base, &subquery)
}
fn lift_belongs_to_in_subquery(
cx: &ExprContext,
belongs_to: &BelongsTo,
query: &stmt::Query,
) -> Option<stmt::Expr> {
if belongs_to.target != query.body.as_select_unwrap().source.model_id_unwrap() {
return None;
}
let select = query.body.as_select_unwrap();
assert_eq!(
belongs_to.foreign_key.fields.len(),
1,
"TODO: composite keys"
);
let mut lift = LiftBelongsTo {
cx: cx.scope(&select.source),
belongs_to,
fk_field_matches: vec![false; belongs_to.foreign_key.fields.len()],
operands: vec![],
fail: false,
};
lift.visit_filter(&select.filter);
if lift.fail {
let [fk_fields] = &belongs_to.foreign_key.fields[..] else {
todo!("composite keys")
};
let mut subquery = query.clone();
subquery.body.as_select_mut_unwrap().returning =
stmt::Returning::Project(stmt::Expr::ref_self_field(fk_fields.target));
Some(stmt::Expr::in_subquery(
stmt::Expr::ref_self_field(fk_fields.source),
subquery,
))
} else {
Some(if lift.operands.len() == 1 {
lift.operands.into_iter().next().unwrap()
} else {
stmt::ExprAnd {
operands: lift.operands,
}
.into()
})
}
}
fn lift_has_n_in_subquery(
target: ModelId,
pair: &BelongsTo,
query: &stmt::Query,
) -> Option<stmt::Expr> {
if target != query.body.as_select_unwrap().source.model_id_unwrap() {
return None;
}
let (self_field, child_field) = match &pair.foreign_key.fields[..] {
[fk_field] => (fk_field.target, fk_field.source),
_ => todo!("composite keys"),
};
let mut subquery = query.clone();
match &mut subquery.body {
stmt::ExprSet::Select(select) => {
select.returning = stmt::Returning::Project(stmt::Expr::ref_self_field(child_field));
}
_ => todo!(),
}
Some(
stmt::ExprInSubquery {
expr: Box::new(stmt::Expr::ref_self_field(self_field)),
query: Box::new(subquery),
}
.into(),
)
}
impl Visit for LiftBelongsTo<'_> {
fn visit_expr_binary_op(&mut self, i: &stmt::ExprBinaryOp) {
match (&*i.lhs, &*i.rhs) {
(stmt::Expr::Reference(expr_reference), other)
| (other, stmt::Expr::Reference(expr_reference)) => {
assert!(i.op.is_eq() || i.op.is_ne());
if i.op.is_eq() || i.op.is_ne() {
let field = self
.cx
.resolve_expr_reference(expr_reference)
.as_field_unwrap();
self.lift_fk_constraint(field.id, i.op, other);
} else {
self.fail = true;
}
}
_ => {
self.fail = true;
}
}
}
}
impl LiftBelongsTo<'_> {
fn lift_fk_constraint(&mut self, field: FieldId, op: stmt::BinaryOp, expr: &stmt::Expr) {
for (i, fk_field) in self.belongs_to.foreign_key.fields.iter().enumerate() {
if fk_field.target == field {
if self.fk_field_matches[i] {
todo!("not handled");
}
self.operands.push(stmt::Expr::binary_op(
stmt::Expr::ref_self_field(fk_field.source),
op,
expr.clone(),
));
self.fk_field_matches[i] = true;
return;
}
}
self.fail = true;
}
}