1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
use toasty_core::{
schema::app,
stmt::{self, ExprContext, IntoExprTarget, VisitMut},
};
/// Pre-lowering pass that rewrites `Source::Model { via: Some(_) }` into
/// an explicit WHERE filter on the surrounding statement.
///
/// `via` associations are an app-level construct. They appear when a
/// query is built from a relation traversal, e.g.
/// `user.todos().delete(...)`, and the lowering walk converts
/// `Source::Model` into `Source::Table` once the association has been
/// rewritten as a filter. This pass runs before the lowering walk so
/// that the walk only sees the rewritten form.
pub(super) struct RewriteVia<'a> {
cx: ExprContext<'a>,
}
impl<'a> RewriteVia<'a> {
pub(super) fn new(cx: ExprContext<'a>) -> Self {
Self { cx }
}
/// Walk a statement and apply the via-association rewrite to every
/// Delete, Insert, and Query node it contains.
pub(super) fn rewrite(&mut self, stmt: &mut stmt::Statement) {
self.visit_mut(stmt);
}
fn schema(&self) -> &'a toasty_core::Schema {
self.cx.schema()
}
fn scope<'scope>(&'scope self, target: impl IntoExprTarget<'scope>) -> RewriteVia<'scope> {
RewriteVia {
cx: self.cx.scope(target),
}
}
pub(super) fn rewrite_via_for_delete(&mut self, stmt: &mut stmt::Delete) {
if let stmt::Source::Model(model) = &mut stmt.from
&& let Some(via) = model.via.take()
{
// Create a new scope to indicate we are operating in the
// context of stmt.from
let mut s = self.scope(&stmt.from);
let filter = s.rewrite_association_as_filter(via);
stmt.filter = stmt::Filter::and(stmt.filter.take(), filter);
}
}
pub(super) fn rewrite_via_for_insert(&mut self, stmt: &mut stmt::Insert) {
if let stmt::InsertTarget::Scope(scope) = &mut stmt.target {
self.rewrite_via_for_query(scope);
}
}
pub(super) fn rewrite_via_for_query(&mut self, stmt: &mut stmt::Query) {
if let stmt::ExprSet::Select(select) = &mut stmt.body
&& let stmt::Source::Model(model) = &mut select.source
&& let Some(via) = model.via.take()
{
// Create a new scope to indicate we are operating in the
// context of stmt.target
let mut s = self.scope(&select.source);
let filter = s.rewrite_association_as_filter(via);
select.filter = stmt::Filter::and(select.filter.take(), filter);
}
}
pub(super) fn rewrite_association_as_filter(
&mut self,
association: stmt::Association,
) -> stmt::Filter {
assert!(
!association.path.projection.is_empty(),
"via path must have at least one step"
);
// Resolve every via in the path and unfold the chain into nested
// single-step `Source::Model { via }` wrappers. After this the path
// is one step and the terminal field is guaranteed not to be a via.
let mut association = self.unfold_path(association);
// Run the visitor's overridden `visit_stmt_query_mut` on the source
// so any `Source::Model { via: Some(_) }` introduced by unfolding is
// rewritten on its own merits before the outer single-step filter is
// built. The free-function walker would skip the override on the
// source query itself.
self.visit_stmt_query_mut(&mut association.source);
let Some(field) = self.schema().app.resolve_field_path(&association.path) else {
todo!()
};
match &field.ty {
app::FieldTy::BelongsTo(rel) => {
self.rewrite_association_belongs_to_as_filter(rel, association)
}
// Direct has-one / has-many: filter the target by its paired
// `BelongsTo` against the source query. Via relations were
// already unfolded, so only direct kinds reach this arm.
app::FieldTy::Has(has) => stmt::Expr::in_subquery(
stmt::Expr::ref_self_field(has.pair_id),
*association.source,
)
.into(),
_ => todo!("field={field:#?}"),
}
}
/// Entry point for path unfolding. Pulls the seed `source_model_id` off
/// the association's source query and delegates to the recursive
/// [`unfold_steps`](Self::unfold_steps) helper. Returns an association
/// whose path is a single step that does **not** name a via relation.
fn unfold_path(&self, association: stmt::Association) -> stmt::Association {
let stmt::Association { source, path } = association;
let source_model_id = source.body.as_select_unwrap().source.model_id_unwrap();
self.unfold_steps(source, source_model_id, path.projection.as_slice())
}
/// Walk `steps`, splicing each via relation's resolved path inline and
/// wrapping every intermediate step in a nested `Source::Model { via }`.
/// Returns the outer single-step association the caller filters against.
///
/// Via splicing allocates a `Vec<usize>` per via segment so the recursion
/// can borrow it as a slice. Paths are short (typically 1-3 steps) and
/// vias are rare, so this is cheap in practice.
fn unfold_steps(
&self,
source: Box<stmt::Query>,
source_model_id: app::ModelId,
steps: &[usize],
) -> stmt::Association {
let [first, rest @ ..] = steps else {
unreachable!("unfold_steps called with empty steps")
};
let field = &self
.schema()
.app
.model(source_model_id)
.as_root_unwrap()
.fields[*first];
// If this step names a via relation, splice the via's resolved path
// in place of the via field and continue. Handles via-of-via
// naturally because the recursion re-examines the spliced steps.
let via_path = match &field.ty {
app::FieldTy::Via(via) => Some(via.path.projection.as_slice()),
_ => None,
};
if let Some(via_steps) = via_path {
let mut spliced = Vec::with_capacity(via_steps.len() + rest.len());
spliced.extend_from_slice(via_steps);
spliced.extend_from_slice(rest);
return self.unfold_steps(source, source_model_id, &spliced);
}
// Base case: a single direct relation step stays on the outer
// association.
if rest.is_empty() {
return stmt::Association {
source,
path: stmt::Path::from_index(source_model_id, *first),
};
}
let next_model_id = match &field.ty {
app::FieldTy::Has(rel) => rel.target,
app::FieldTy::Via(rel) => rel.target,
app::FieldTy::BelongsTo(rel) => rel.target,
other => todo!("non-relation field in via path: {other:#?}"),
};
let inner = stmt::Association {
source,
path: stmt::Path::from_index(source_model_id, *first),
};
let new_source = Box::new(stmt::Query::new_select(
stmt::Source::Model(stmt::SourceModel {
id: next_model_id,
via: Some(inner),
}),
stmt::Expr::Value(stmt::Value::Bool(true)),
));
self.unfold_steps(new_source, next_model_id, rest)
}
fn rewrite_association_belongs_to_as_filter(
&mut self,
rel: &app::BelongsTo,
association: stmt::Association,
) -> stmt::Filter {
// The FK lives on the source model; the target model carries the
// referenced fields. Filter is `<fk.target...> IN (SELECT
// <fk.source...> FROM <source>)` — a single field reference on each
// side for single-column FKs, a record of references for composite
// FKs (lowered to a tuple-style IN by the SQL serializer).
let target = super::key_field_refs(0, rel.foreign_key.fields.iter().map(|fk| fk.target));
let returning = super::key_field_refs(0, rel.foreign_key.fields.iter().map(|fk| fk.source));
let mut source = *association.source;
source.body.as_select_mut_unwrap().returning = stmt::Returning::Project(returning);
stmt::Expr::in_subquery(target, source).into()
}
}
impl VisitMut for RewriteVia<'_> {
fn visit_stmt_delete_mut(&mut self, i: &mut stmt::Delete) {
self.rewrite_via_for_delete(i);
stmt::visit_mut::visit_stmt_delete_mut(self, i);
}
fn visit_stmt_insert_mut(&mut self, i: &mut stmt::Insert) {
self.rewrite_via_for_insert(i);
stmt::visit_mut::visit_stmt_insert_mut(self, i);
}
fn visit_stmt_query_mut(&mut self, i: &mut stmt::Query) {
self.rewrite_via_for_query(i);
stmt::visit_mut::visit_stmt_query_mut(self, i);
}
}