toasty 0.4.0

An async ORM for Rust supporting SQL and NoSQL databases
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
use indexmap::IndexSet;
use toasty_core::stmt::{self, visit_mut};

use crate::engine::{
    Engine, HirStatement, SelectItems, eval,
    exec::{MergeIndex, MergeQualification, NestedChild, NestedLevel},
    hir, mir,
    plan::HirPlanner,
};

#[derive(Debug)]
struct NestedMergePlanner<'a> {
    engine: &'a Engine,
    hir: &'a HirStatement,
    mir: &'a mut mir::Store,
    inputs: IndexSet<mir::NodeId>,
    /// Statements that must execute before the merge but whose output is not needed
    deps: IndexSet<mir::NodeId>,
    /// Flat list of hash indexes to build, populated as HashLookup qualifications are planned.
    hash_indexes: Vec<MergeIndex>,
    /// Flat list of sorted indexes to build, populated as SortLookup qualifications are planned.
    sort_indexes: Vec<MergeIndex>,
    /// Statement stack, used to infer expression types
    stack: Vec<hir::StmtId>,
}

impl HirPlanner<'_> {
    /// Builds a nested merge operation for queries with sub-statement arguments
    /// in the returning clause.
    ///
    /// When a query has `Arg::Sub { returning: true, .. }` arguments
    /// (sub-statements used in the returning clause), those represent nested
    /// data that needs to be merged with their parent rows. This method
    /// constructs a `NestedMerge` execution plan that:
    ///
    /// 1. Identifies all batch-loaded inputs needed (parent and child queries)
    /// 2. Builds a tree structure mirroring the nesting hierarchy
    /// 3. For each level, captures:
    ///    - The source data (reference to batch-loaded results)
    ///    - How to filter child rows for each parent (qualification predicates)
    ///    - How to project the combined parent+children into the final shape
    ///
    /// The resulting `NestedMerge` will execute by:
    /// - Loading all batch data upfront - fetches all input data for all levels before processing
    /// - Processing each root row:
    ///   - For each nested child relationship, filters batch-loaded child data and recursively
    ///     merges matching rows with their own children
    ///   - Collects results into a list, or a single value if `single` is `true`
    ///   - Projects the final row with the current row and all nested children
    /// - Returning all merged rows with their nested data
    ///
    /// # Example
    ///
    /// For a query like:
    /// ```sql
    /// SELECT user.*, (SELECT * FROM todos WHERE user_id = user.id) as todos
    /// FROM users
    /// ```
    ///
    /// This builds a two-level merge where:
    /// - Root level: user rows from batch load
    /// - Nested level: todo rows filtered by user_id match, projected into a list
    ///
    /// Returns `None` if the statement has no sub-statements with `returning: true`.
    pub(super) fn plan_nested_merge(&mut self, stmt_id: hir::StmtId) -> Option<mir::NodeId> {
        let stmt_state = &self.hir[stmt_id];

        // Return if there is no nested merge to do
        let need_nested_merge = stmt_state.args.iter().any(|arg| {
            matches!(
                arg,
                hir::Arg::Sub {
                    returning: true,
                    ..
                }
            )
        });
        if !need_nested_merge {
            return None;
        }

        if stmt_state.stmt.as_ref().unwrap().is_insert() {
            // todo!("stmt_state={stmt_state:#?}");
            return None;
        }

        let nested_merge_planner = NestedMergePlanner {
            engine: self.engine,
            hir: self.hir,
            mir: &mut self.mir,
            inputs: IndexSet::new(),
            deps: IndexSet::new(),
            hash_indexes: vec![],
            sort_indexes: vec![],
            stack: vec![],
        };

        let node_id = nested_merge_planner.plan_nested_merge(stmt_id);
        Some(node_id)
    }
}

impl NestedMergePlanner<'_> {
    fn plan_nested_merge(mut self, root: hir::StmtId) -> mir::NodeId {
        self.stack.push(root);
        let root = self.plan_nested_level(root, 0);
        self.stack.pop();

        self.mir.insert_with_deps(
            mir::NestedMerge {
                inputs: self.inputs,
                root,
                hash_indexes: self.hash_indexes,
                sort_indexes: self.sort_indexes,
            },
            self.deps,
        )
    }

    fn plan_nested_child(&mut self, stmt_id: hir::StmtId, depth: usize) -> NestedChild {
        self.stack.push(stmt_id);

        let level = self.plan_nested_level(stmt_id, depth);
        let stmt_state = &self.hir[stmt_id];
        let selection = stmt_state.load_data_select_items.get().unwrap();

        let ret = match stmt_state.stmt.as_deref().unwrap() {
            stmt::Statement::Query(query) => {
                let filter_expr = self.build_filter_for_nested_child(stmt_id, selection, depth);

                let filter_arg_tys = self.build_filter_arg_tys();
                let qualification = match try_eq_lookup(&filter_expr, &filter_arg_tys, depth) {
                    Some((child_projections, lookup_key)) if query.single => {
                        // has_one / belongs_to: unique key → HashIndex (O(1) lookup).
                        let index = self.hash_indexes.len();
                        self.hash_indexes.push(MergeIndex {
                            source: level.source,
                            child_projections,
                        });
                        MergeQualification::HashLookup { index, lookup_key }
                    }
                    Some((child_projections, lookup_key)) => {
                        // has_many: duplicate keys → SortedIndex (O(log M + k) lookup).
                        let index = self.sort_indexes.len();
                        self.sort_indexes.push(MergeIndex {
                            source: level.source,
                            child_projections,
                        });
                        MergeQualification::SortLookup { index, lookup_key }
                    }
                    // Filter does not reduce to a pure equality conjunction, so we
                    // cannot drive an index lookup. Fall back to a linear scan.
                    // See `try_eq_lookup` for discussion of how this could be
                    // improved to use an index with a residual post-filter.
                    None => {
                        MergeQualification::Scan(eval::Func::from_stmt(filter_expr, filter_arg_tys))
                    }
                };

                NestedChild {
                    level,
                    qualification,
                    single: query.single,
                }
            }
            stmt::Statement::Insert(insert) => NestedChild {
                level,
                qualification: MergeQualification::All,
                single: insert.source.single,
            },
            stmt => todo!("stmt={stmt:#?}"),
        };

        self.stack.pop();

        ret
    }

    fn plan_nested_level(&mut self, stmt_id: hir::StmtId, depth: usize) -> NestedLevel {
        let stmt_state = &self.hir[stmt_id];
        let stmt = stmt_state.stmt.as_deref().unwrap();
        let returning = stmt.returning_unwrap();

        let source;
        let mut nested = vec![];

        // Map the returning clause to projection expression
        let projection = match returning {
            stmt::Returning::Expr(expr) => {
                let (s, _) = self
                    .inputs
                    .insert_full(stmt_state.load_data_statement.get().unwrap());

                source = s;
                self.build_projection_from_expr(stmt_id, expr, depth, &mut nested)
            }
            _ => {
                let node_id = stmt_state.output.get().unwrap();

                let (s, _) = self.inputs.insert_full(node_id);
                source = s;

                // Flatten list (bit of a hack)
                let ty = match self.mir[node_id].ty().clone() {
                    stmt::Type::List(ty) => *ty,
                    ty => ty,
                };

                eval::Func::from_stmt(stmt::Expr::arg(0), vec![ty])
            }
        };

        NestedLevel {
            source,
            projection,
            nested,
        }
    }

    fn build_filter_arg_tys(&self) -> Vec<stmt::Type> {
        self.stack
            .iter()
            .map(|stmt_id| self.build_exec_statement_ty_for(*stmt_id))
            .collect()
    }

    fn build_projection_arg_tys(&self, nested_children: &[NestedChild]) -> Vec<stmt::Type> {
        let curr = self.stack.last().unwrap();
        let mut projection_arg_tys = vec![self.build_exec_statement_ty_for(*curr)];

        for nested in nested_children {
            projection_arg_tys.push(if nested.single {
                nested.level.projection.ret.clone()
            } else {
                stmt::Type::list(nested.level.projection.ret.clone())
            });
        }

        projection_arg_tys
    }

    fn build_exec_statement_ty_for(&self, stmt_id: hir::StmtId) -> stmt::Type {
        let stmt_state = &self.hir[stmt_id];
        let stmt = stmt_state.stmt.as_deref().unwrap();

        let cx = stmt::ExprContext::new_with_target(&*self.engine.schema, stmt);

        let mut fields = vec![];

        for select_item in stmt_state.load_data_select_items.get().unwrap() {
            fields.push(select_item.infer_ty(&cx));
        }

        stmt::Type::Record(fields)
    }

    /// Rewrites a projection expression, replacing statement-level `Arg` and
    /// `Reference` nodes with nested merge arg references.
    ///
    /// Uses `walk_expr_scoped_mut` to automatically track scope depth through
    /// Let/Map scopes so that only statement-level args (where
    /// `nesting == scope_depth`) are rewritten — inner Let/Map bindings are
    /// left alone.
    fn build_projection_from_expr(
        &mut self,
        stmt_id: hir::StmtId,
        expr: &stmt::Expr,
        depth: usize,
        nested: &mut Vec<NestedChild>,
    ) -> eval::Func {
        // Copy the shared hir reference out of self so the closure can access
        // hir data without conflicting with the &mut self capture.
        let hir = self.hir;
        let selection = hir[stmt_id].load_data_select_items.get().unwrap();
        let mut projection = expr.clone();

        visit_mut::walk_expr_scoped_mut(&mut projection, 0, |expr, scope_depth| match expr {
            stmt::Expr::Arg(expr_arg) if expr_arg.nesting == scope_depth => {
                let position = expr_arg.position;
                let stmt_state = &hir[stmt_id];

                match &stmt_state.args[position] {
                    hir::Arg::Sub {
                        stmt_id: child_stmt_id,
                        ..
                    } => {
                        let child_stmt_id = *child_stmt_id;
                        let child_stmt_state = &hir[child_stmt_id];
                        let child_stmt = child_stmt_state.stmt.as_deref().unwrap();
                        let child_returning = child_stmt.returning_unwrap();

                        match child_returning {
                            stmt::Returning::Value(returning_expr) if returning_expr.is_const() => {
                                match child_stmt {
                                    stmt::Statement::Query(query) => {
                                        if query.single {
                                            let stmt::Expr::Value(v) = returning_expr else {
                                                todo!()
                                            };
                                            assert!(!v.is_list());
                                        }
                                    }
                                    stmt::Statement::Insert(insert) => {
                                        if insert.source.single {
                                            let stmt::Expr::Value(v) = returning_expr else {
                                                todo!()
                                            };
                                            assert!(!v.is_list());
                                        }
                                    }
                                    _ => {}
                                }

                                self.deps
                                    .insert(child_stmt_state.load_data_statement.get().unwrap());
                                *expr = returning_expr.clone();
                            }
                            _ => {
                                let nested_child = self.plan_nested_child(child_stmt_id, depth + 1);
                                nested.push(nested_child);

                                *expr = stmt::Expr::arg(nested.len());
                            }
                        }
                    }
                    hir::Arg::Ref { .. } => todo!(),
                }
                false
            }
            stmt::Expr::Reference(expr_reference) => {
                let expr_column = expr_reference.as_expr_column_unwrap();
                debug_assert_eq!(0, expr_column.nesting);
                let index = selection.get_index_of_expr_reference(*expr_column);
                *expr = stmt::Expr::arg_project(0, [index]);
                false
            }
            _ => true,
        });

        let projection_arg_tys = self.build_projection_arg_tys(nested);
        eval::Func::from_stmt(projection, projection_arg_tys)
    }

    fn build_filter_for_nested_child(
        &self,
        stmt_id: hir::StmtId,
        selection: &SelectItems,
        depth: usize,
    ) -> stmt::Expr {
        let stmt_state = &self.hir[stmt_id];
        let stmt::Statement::Query(query) = stmt_state.stmt.as_deref().unwrap() else {
            unreachable!()
        };
        let select = query.body.as_select_unwrap();

        // Extract the qualification. For now, we will just re-run the
        // entire where clause, but that can be improved later.
        let mut filter = select.filter.clone();

        visit_mut::for_each_expr_mut(&mut filter, |expr| match expr {
            stmt::Expr::Arg(expr_arg) => {
                let hir::Arg::Ref {
                    nesting,
                    stmt_id: target_id,
                    target_expr_ref,
                    ..
                } = &stmt_state.args[expr_arg.position]
                else {
                    todo!()
                };

                debug_assert!(*nesting > 0);

                // This is a bit of a roundabout way to get the data. We may
                // want to find a better way to track the info for more direct
                // access.
                let target_stmt = &self.hir[target_id];

                let target_exec_statement_index = target_stmt
                    .load_data_select_items
                    .get()
                    .unwrap()
                    .get_index_of_expr_reference(*target_expr_ref);

                *expr = stmt::Expr::arg_project(depth - *nesting, [target_exec_statement_index]);
            }
            stmt::Expr::Reference(expr_reference) => {
                let index = selection.get_index_of_expr_reference(*expr_reference);
                *expr = stmt::Expr::arg_project(depth, [index]);
            }
            _ => {}
        });

        filter.into_expr()
    }
}

/// Try to extract index lookup key fields from a transformed filter expression.
///
/// Recognizes patterns of the form:
/// - Single equality: `arg_project(depth, [cf]) == arg_project(pos < depth, [pf])`
/// - Composite AND:   `eq1 AND eq2 AND ...` where each `eqi` has the above form
///
/// On success returns `(child_projections, lookup_key)` where:
/// - `child_projections[i]` is the projection into the child record for key field `i`
/// - `lookup_key` is an `eval::Func` that evaluates against the ancestor `RowStack`
///   and returns the lookup key (scalar for single-field, `Value::Record` for composite)
///
/// # Limitations
///
/// This function only succeeds when the *entire* filter is a pure equality conjunction.
/// Any more complex filter (e.g. `a = b AND c > d`, or an `OR`) causes it to return
/// `None` and the caller falls back to a full `Scan`, even if part of the filter could
/// drive an index lookup with the remainder applied as a post-filter.
///
/// A more complete approach — similar to `IndexMatch` in the index planner — would
/// extract whichever equality terms can key an index, build the lookup from those,
/// and re-evaluate the full original predicate against each candidate row returned by
/// the index. That would turn O(N×M) into O(log M + k) even for compound filters.
/// For now we keep this conservative: only use an index when the whole filter matches.
fn try_eq_lookup(
    expr: &stmt::Expr,
    arg_tys: &[stmt::Type],
    depth: usize,
) -> Option<(Vec<stmt::Projection>, eval::Func)> {
    // Collect equality terms: single BinaryOp(Eq) or AND of BinaryOp(Eq)s.
    let eq_terms: Vec<(&stmt::Expr, &stmt::Expr)> = match expr {
        stmt::Expr::BinaryOp(op) if op.op == stmt::BinaryOp::Eq => {
            vec![(&op.lhs, &op.rhs)]
        }
        stmt::Expr::And(and_expr) => {
            let mut terms = vec![];
            for operand in and_expr.operands.iter() {
                match operand {
                    stmt::Expr::BinaryOp(op) if op.op == stmt::BinaryOp::Eq => {
                        terms.push((&*op.lhs, &*op.rhs));
                    }
                    _ => return None,
                }
            }
            terms
        }
        _ => return None,
    };

    let mut child_projections = vec![];
    let mut lookup_key_exprs = vec![];

    for (lhs, rhs) in eq_terms {
        let (child_proj, parent_expr) = extract_child_parent_eq(lhs, rhs, depth)?;
        child_projections.push(child_proj);
        lookup_key_exprs.push(parent_expr);
    }

    if child_projections.is_empty() {
        return None;
    }

    // Build the parent key expression. For a single field, use the scalar
    // directly. For multiple fields, wrap in a record (evaluates to Value::Record).
    let lookup_key_expr = if lookup_key_exprs.len() == 1 {
        lookup_key_exprs.remove(0)
    } else {
        stmt::Expr::record_from_vec(lookup_key_exprs)
    };

    // Parent key args are the ancestor stack types only (not including the
    // current child row at position `depth`).
    let lookup_key_arg_tys = arg_tys[..depth].to_vec();
    let lookup_key = eval::Func::from_stmt(lookup_key_expr, lookup_key_arg_tys);

    Some((child_projections, lookup_key))
}

/// For an equality `lhs == rhs`, determine which side is the child (at `depth`)
/// and which is the parent (at some position < `depth`).
///
/// Both sides must be simple `arg_project(pos, projection)` expressions.
/// Returns `(child_projection, parent_expr)` or `None` if the pattern doesn't match.
fn extract_child_parent_eq(
    lhs: &stmt::Expr,
    rhs: &stmt::Expr,
    depth: usize,
) -> Option<(stmt::Projection, stmt::Expr)> {
    match (as_simple_arg_project(lhs), as_simple_arg_project(rhs)) {
        (Some((l_pos, l_proj)), Some((r_pos, _))) if l_pos == depth && r_pos < depth => {
            Some((l_proj.clone(), rhs.clone()))
        }
        (Some((l_pos, _)), Some((r_pos, r_proj))) if r_pos == depth && l_pos < depth => {
            Some((r_proj.clone(), lhs.clone()))
        }
        _ => None,
    }
}

/// Match `Project(Arg { position, nesting: 0 }, projection)` and return
/// `(position, &projection)`. Returns `None` for any other expression shape.
fn as_simple_arg_project(expr: &stmt::Expr) -> Option<(usize, &stmt::Projection)> {
    match expr {
        stmt::Expr::Project(proj) => match proj.base.as_ref() {
            stmt::Expr::Arg(arg) if arg.nesting == 0 => Some((arg.position, &proj.projection)),
            _ => None,
        },
        _ => None,
    }
}