1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
//! Interprocedural callee-before-caller pass dispatch via #74 level_wave (#74 self-consumer).
//!
//! Closes the recursion thesis for #74 — `level_wave_program` ships to
//! user dialects (whole-schema migrations, BFS layering, breadth-first
//! graph rewrites) AND drives vyre's interprocedural pass dispatch
//! when callees must finish before callers start.
//!
//! # The self-use
//!
//! Every interprocedural pass that walks a call graph has the same
//! shape: visit each function in callee-before-caller order, run a
//! per-function body, barrier between depth waves. Without
//! `level_wave_program`, each backend hand-codes that loop on the
//! host (one dispatch per depth, host-side termination check). With
//! it, the entire BFS becomes one Program and one dispatch.
//!
//! # Algorithm
//!
//! ```text
//! 1. Caller computes per-function depth in the call graph (leaves at
//! depth 0, increasing toward main).
//! 2. Caller hands `step_body` (the per-function rewrite/analysis body)
//! plus the depth array to `build_callee_before_caller_program`.
//! 3. Returned Program runs the body for every function at depth `d`,
//! barriers, then advances to depth `d+1` — all in one dispatch.
//! ```
//!
//! P-DRIVER-10: every interprocedural callee-before-caller pass should
//! consume this rather than hand-rolling a host depth loop.
use ;
use level_wave_program;
/// Build a Program that visits every function in callee-before-caller
/// order using GPU-side level-wave dispatch.
///
/// `step_body`: per-function body. Reads/writes any caller-declared
/// buffer via `Expr::InvocationId { axis: 0 }` to address the function
/// being visited.
///
/// `depth_buf`: name of the buffer containing per-function depth in the
/// call graph (leaves at 0).
///
/// `max_depth`: number of waves (i.e., `max(depth) + 1`).
///
/// `function_count`: total functions in the dispatch grid.