Skip to main content

vyre_runtime/megakernel/
builder.rs

1//! IR program builders  -  construct the megakernel `Program` from vyre IR.
2//!
3//! Two flavours:
4//! - **Interpreted** (`build_program_sharded`)  -  If-tree opcode dispatch.
5//! - **JIT** (`build_program_jit`)  -  payload processor fused directly.
6
7use std::sync::Arc;
8
9use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
10
11use super::handlers::{claimed_slot_bindings, claimed_slot_body, load_miss_body, OpcodeHandler};
12use super::io::{
13    io_word, IO_DESTINATION_CAPABILITY_TABLE, IO_QUEUE_DMA_TAG, IO_SLOT_COUNT, IO_SLOT_WORDS,
14    IO_SOURCE_CAPABILITY_TABLE,
15};
16use super::ir_util::atomic_load_relaxed;
17use super::protocol::*;
18use super::workspace_adapter::MegakernelWorkspaceAdapter;
19mod cache;
20mod jit;
21mod priority;
22pub use jit::{build_program_jit, build_program_jit_slots, persistent_body_jit};
23pub use priority::{
24    build_program_priority, build_program_priority_slots, persistent_body_priority,
25    persistent_body_priority_slots,
26};
27
28/// Build the default megakernel IR (256 lanes × 1 workgroup, no custom opcodes).
29#[must_use]
30pub fn build_program() -> Program {
31    build_program_sharded(256, &[])
32}
33
34/// Build the megakernel IR with a custom workgroup size and optional
35/// custom opcodes.
36///
37/// Buffers are declared with concrete `with_count(...)` sizes so the
38/// backend readback layer allocates the right static staging size  -  a
39/// `count=0` default reads back 4 bytes regardless of how much the
40/// kernel wrote.
41#[must_use]
42pub fn build_program_sharded(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
43    build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
44}
45
46/// Build the megakernel IR for an explicit number of ring slots.
47///
48/// This is the production sharded ABI: `slot_count` sizes the ring buffer,
49/// while `workgroup_size_x` controls lanes per workgroup. Dispatch must launch
50/// `slot_count / workgroup_size_x` workgroups so every slot has an owning lane.
51#[must_use]
52pub fn build_program_sharded_slots(
53    workgroup_size_x: u32,
54    slot_count: u32,
55    opcodes: &[OpcodeHandler],
56) -> Program {
57    build_program_sharded_slots_with_io(workgroup_size_x, slot_count, opcodes, false)
58}
59
60/// Build the sharded megakernel IR as a shared immutable template.
61///
62/// Empty opcode sets use the thread-local template cache directly, allowing
63/// compile paths to avoid cloning the cached Program before wrapping it in
64/// `Arc` again.
65#[must_use]
66pub fn build_program_sharded_slots_shared(
67    workgroup_size_x: u32,
68    slot_count: u32,
69    opcodes: &[OpcodeHandler],
70) -> Arc<Program> {
71    if opcodes.is_empty() {
72        return cache::cached_empty_sharded_program_shared(workgroup_size_x, slot_count, false);
73    }
74    Arc::new(build_program_sharded_slots(
75        workgroup_size_x,
76        slot_count,
77        opcodes,
78    ))
79}
80
81/// Build the sharded megakernel IR with a consumer-owned resident workspace.
82#[must_use]
83pub fn build_program_sharded_with_workspace_adapter(
84    workgroup_size_x: u32,
85    slot_count: u32,
86    opcodes: &[OpcodeHandler],
87    adapter: &impl MegakernelWorkspaceAdapter,
88) -> Program {
89    wrap_persistent_megakernel_program_with_buffers(
90        default_buffers_with_workspace_adapter(slot_count, adapter),
91        workgroup_size_x,
92        persistent_body_with_workspace_adapter(workgroup_size_x, opcodes, adapter),
93    )
94}
95
96/// Build a finite one-pass sharded megakernel IR for host-submitted batches.
97///
98/// Unlike [`build_program_sharded_slots`], this program does not wrap the body
99/// in `Node::forever`; each lane attempts to drain its owning slot once and the
100/// dispatch returns. Use this for synchronous batch APIs that need a completion
101/// report from the same queue submission.
102#[must_use]
103pub fn build_program_sharded_once_slots(
104    workgroup_size_x: u32,
105    slot_count: u32,
106    opcodes: &[OpcodeHandler],
107) -> Program {
108    if opcodes.is_empty() {
109        return cache::cached_empty_sharded_once_program(workgroup_size_x, slot_count);
110    }
111    wrap_megakernel_program(
112        workgroup_size_x,
113        slot_count,
114        persistent_body_with_io(workgroup_size_x, opcodes, false),
115    )
116}
117
118/// Shared-Arc variant of [`build_program_sharded_once_slots`] for hot runtime
119/// dispatchers that must not clone the megakernel template every launch.
120#[must_use]
121pub fn build_program_sharded_once_slots_shared(
122    workgroup_size_x: u32,
123    slot_count: u32,
124    opcodes: &[OpcodeHandler],
125) -> Arc<Program> {
126    if opcodes.is_empty() {
127        return cache::cached_empty_sharded_once_program_shared(workgroup_size_x, slot_count);
128    }
129    Arc::new(build_program_sharded_once_slots(
130        workgroup_size_x,
131        slot_count,
132        opcodes,
133    ))
134}
135
136/// Build a finite one-pass megakernel that reports completion through the
137/// control buffer only.
138///
139/// Ring, debug, and IO buffers remain read-write device buffers, but their
140/// host readback ranges are empty. This is the hot dispatcher path: completion
141/// is already accumulated into control, so reading back the full ring/debug/IO
142/// surfaces is redundant launch latency.
143#[must_use]
144pub fn build_program_sharded_once_slots_control_report_shared(
145    workgroup_size_x: u32,
146    slot_count: u32,
147    opcodes: &[OpcodeHandler],
148) -> Arc<Program> {
149    if opcodes.is_empty() {
150        return cache::cached_empty_sharded_once_control_report_program_shared(
151            workgroup_size_x,
152            slot_count,
153        );
154    }
155    let mut buffers = default_buffers(slot_count);
156    for buffer in buffers.iter_mut().skip(1) {
157        buffer.output_byte_range = Some(0..0);
158    }
159    Arc::new(optimize_megakernel_program(Program::wrapped(
160        buffers,
161        [workgroup_size_x, 1, 1],
162        persistent_body_with_io(workgroup_size_x, opcodes, false),
163    )))
164}
165
166/// Build the megakernel IR without the IO polling sidecar.
167///
168/// This is the dispatch path for host-provided [`super::MegakernelWorkItem`]
169/// queues. It keeps the executable kernel free of `AsyncLoad` nodes until the
170/// runtime scheduler owns a concrete async-lowering pass.
171#[must_use]
172pub fn build_program_sharded_no_io(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
173    build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
174}
175
176/// Build the megakernel IR with the experimental IO polling sidecar.
177///
178/// The returned Program contains `AsyncLoad` nodes and must be lowered through
179/// a runtime scheduler pass before reaching a concrete backend lowering path.
180#[must_use]
181pub fn build_program_sharded_with_io_polling(
182    workgroup_size_x: u32,
183    opcodes: &[OpcodeHandler],
184) -> Program {
185    build_program_sharded_slots_with_io(workgroup_size_x, workgroup_size_x.max(1), opcodes, true)
186}
187
188/// Build the megakernel IR with a self-loading load-miss handler.
189///
190/// The persistent loop is extended with an [`opcode::LOAD_MISS`] handler.
191/// When the GPU sees this opcode it scans the IO queue for an empty slot,
192/// writes a DMA-read request, and polls until the host/runtime marks it
193/// complete. The `arg0` field of the slot is the consumer's opaque
194/// resource identifier; vyre does not interpret it.
195#[must_use]
196#[cfg(any(test, feature = "legacy-infallible"))]
197pub fn build_program_with_self_loading_miss_handler(
198    workgroup_size_x: u32,
199    slot_count: u32,
200    opcodes: &[OpcodeHandler],
201) -> Program {
202    match try_build_program_with_self_loading_miss_handler(workgroup_size_x, slot_count, opcodes) {
203        Ok(program) => program,
204        Err(error) => panic!("{error}"),
205    }
206}
207
208/// Fallible variant of [`build_program_with_self_loading_miss_handler`].
209pub fn try_build_program_with_self_loading_miss_handler(
210    workgroup_size_x: u32,
211    slot_count: u32,
212    opcodes: &[OpcodeHandler],
213) -> Result<Program, String> {
214    let mut extended = Vec::new();
215    let extended_len = opcodes.len().checked_add(1).ok_or_else(|| {
216        "megakernel self-loading opcode extension count overflowed usize. Fix: split opcode handler sets before building the megakernel."
217            .to_string()
218    })?;
219    vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut extended, extended_len).map_err(|error| {
220        format!(
221            "megakernel self-loading opcode extension allocation failed: {error}. Fix: split opcode handler sets before building the megakernel."
222        )
223    })?;
224    extended.extend_from_slice(opcodes);
225    extended.push(OpcodeHandler {
226        opcode: super::protocol::opcode::LOAD_MISS,
227        body: load_miss_body(),
228    });
229    Ok(wrap_persistent_megakernel_program(
230        workgroup_size_x,
231        slot_count,
232        persistent_body_with_io(workgroup_size_x, &extended, false),
233    ))
234}
235
236fn build_program_sharded_slots_with_io(
237    workgroup_size_x: u32,
238    slot_count: u32,
239    opcodes: &[OpcodeHandler],
240    include_io_polling: bool,
241) -> Program {
242    if opcodes.is_empty() {
243        return cache::cached_empty_sharded_program(
244            workgroup_size_x,
245            slot_count,
246            include_io_polling,
247        );
248    }
249    wrap_persistent_megakernel_program(
250        workgroup_size_x,
251        slot_count,
252        persistent_body_with_io(workgroup_size_x, opcodes, include_io_polling),
253    )
254}
255
256fn wrap_persistent_megakernel_program(
257    workgroup_size_x: u32,
258    slot_count: u32,
259    body: Vec<Node>,
260) -> Program {
261    wrap_megakernel_program(workgroup_size_x, slot_count, vec![Node::forever(body)])
262}
263
264fn wrap_persistent_megakernel_program_with_buffers(
265    buffers: Vec<BufferDecl>,
266    workgroup_size_x: u32,
267    body: Vec<Node>,
268) -> Program {
269    optimize_megakernel_program(Program::wrapped(
270        buffers,
271        [workgroup_size_x, 1, 1],
272        vec![Node::forever(body)],
273    ))
274}
275
276fn wrap_megakernel_program(workgroup_size_x: u32, slot_count: u32, body: Vec<Node>) -> Program {
277    optimize_megakernel_program(Program::wrapped(
278        default_buffers(slot_count),
279        [workgroup_size_x, 1, 1],
280        body,
281    ))
282}
283
284fn optimize_megakernel_program(program: Program) -> Program {
285    let fallback = program.clone();
286    let program = match super::planner::try_elide_value_flow_barriers(program) {
287        Ok((program, _)) => program,
288        Err(_) => fallback,
289    };
290    vyre_foundation::optimizer::pre_lowering::optimize(program)
291}
292
293/// Reserve sizes for the megakernel's four host-visible buffers. All
294/// four go through the static-readback path so every buffer needs
295/// a concrete `count` (u32 elements). The numbers mirror the wire
296/// layout in `protocol.rs`:
297///
298/// - **control**: 128 u32 words covers SHUTDOWN, DONE_COUNT, EPOCH,
299///   METRICS_BASE..METRICS_BASE+METRICS_SLOTS, OBSERVABLE_BASE, and
300///   the 32-entry tenant-mask table.
301/// - **ring_buffer**: `slot_count` slots × `SLOT_WORDS`.
302///   `slot_count` must match host-published ring bytes and dispatch geometry.
303/// - **debug_log**: cursor word + `debug::RECORD_CAPACITY` × 4-word records.
304/// - **io_queue**: 64 slots × 8 words (source, destination,
305///   offset_low, offset_high, size, status, tag, pad).
306fn default_buffers(slot_count: u32) -> Vec<BufferDecl> {
307    let ring_slots = slot_count.max(1);
308    let control = BufferDecl::read_write("control", 0, DataType::U32).with_count(CONTROL_MIN_WORDS);
309    let ring_buffer = BufferDecl::read_write("ring_buffer", 1, DataType::U32)
310        .with_count(ring_slots.saturating_mul(SLOT_WORDS));
311    let debug_log =
312        BufferDecl::read_write("debug_log", 2, DataType::U32).with_count(debug::BUFFER_WORDS);
313    let io_queue = BufferDecl::read_write("io_queue", 3, DataType::U32).with_count(64 * 8);
314    vec![control, ring_buffer, debug_log, io_queue]
315}
316
317fn default_buffers_with_workspace_adapter(
318    slot_count: u32,
319    adapter: &impl MegakernelWorkspaceAdapter,
320) -> Vec<BufferDecl> {
321    let mut buffers = default_buffers(slot_count);
322    buffers.push(adapter.buffer_decl());
323    buffers
324}
325
326/// The body that runs once per iteration per lane. Exposed for tests
327/// and downstream crates that splice additional opcodes.
328#[must_use]
329pub fn persistent_body(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Vec<Node> {
330    persistent_body_with_io(workgroup_size_x, opcodes, false)
331}
332
333/// Fallible persistent body builder with explicit staging-allocation reporting.
334pub fn try_persistent_body(
335    workgroup_size_x: u32,
336    opcodes: &[OpcodeHandler],
337) -> Result<Vec<Node>, String> {
338    try_persistent_body_with_io(workgroup_size_x, opcodes, false)
339}
340
341fn persistent_body_with_io(
342    workgroup_size_x: u32,
343    opcodes: &[OpcodeHandler],
344    include_io_polling: bool,
345) -> Vec<Node> {
346    let mut body = persistent_lane_prologue(workgroup_size_x);
347    let additional_nodes = if include_io_polling { 3 } else { 2 };
348    if let Some(body_capacity) = body.len().checked_add(additional_nodes) {
349        let _ = vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity);
350    }
351    body.push(direct_slot_base_binding());
352    body.push(Node::Block(execute_slot_body(opcodes)));
353    if include_io_polling {
354        body.push(Node::Block(process_io_requests()));
355    }
356    body
357}
358
359fn try_persistent_body_with_io(
360    workgroup_size_x: u32,
361    opcodes: &[OpcodeHandler],
362    include_io_polling: bool,
363) -> Result<Vec<Node>, String> {
364    let mut body = persistent_lane_prologue(workgroup_size_x);
365    let additional_nodes = if include_io_polling { 3 } else { 2 };
366    let body_capacity = body.len().checked_add(additional_nodes).ok_or_else(|| {
367        "megakernel persistent body node reservation overflowed usize. Fix: reduce fused IO/body staging before building the megakernel."
368            .to_string()
369    })?;
370    vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity).map_err(|error| {
371        format!(
372            "megakernel persistent body node reservation failed: {error}. Fix: reduce fused IO/body staging before building the megakernel."
373        )
374    })?;
375    body.push(direct_slot_base_binding());
376    body.push(Node::Block(execute_slot_body(opcodes)));
377    if include_io_polling {
378        body.push(Node::Block(process_io_requests()));
379    }
380    Ok(body)
381}
382
383fn persistent_lane_prologue(workgroup_size_x: u32) -> Vec<Node> {
384    vec![
385        Node::let_bind(
386            "shutdown_flag",
387            atomic_load_relaxed("control", Expr::u32(control::SHUTDOWN)),
388        ),
389        Node::if_then(
390            Expr::ne(Expr::var("shutdown_flag"), Expr::u32(0)),
391            vec![Node::Return],
392        ),
393        Node::let_bind("lane_id", lane_id_expr(workgroup_size_x)),
394    ]
395}
396
397fn direct_slot_base_binding() -> Node {
398    Node::let_bind(
399        "slot_base",
400        Expr::mul(Expr::var("lane_id"), Expr::u32(SLOT_WORDS)),
401    )
402}
403
404fn slot_tenant_id_load() -> Expr {
405    Expr::load(
406        "ring_buffer",
407        Expr::add(Expr::var("slot_base"), Expr::u32(TENANT_WORD)),
408    )
409}
410
411fn tenant_authorized_body(tenant_id: Expr, authorized_body: Vec<Node>) -> Vec<Node> {
412    vec![
413        Node::let_bind("tenant_id", tenant_id),
414        Node::let_bind(
415            "tenant_base",
416            atomic_load_relaxed("control", Expr::u32(control::TENANT_BASE)),
417        ),
418        Node::let_bind(
419            "tenant_mask",
420            atomic_load_relaxed(
421                "control",
422                Expr::add(Expr::var("tenant_base"), Expr::var("tenant_id")),
423            ),
424        ),
425        Node::if_then(
426            Expr::ne(Expr::var("tenant_mask"), Expr::u32(0)),
427            authorized_body,
428        ),
429    ]
430}
431
432fn lane_id_expr(workgroup_size_x: u32) -> Expr {
433    Expr::add(
434        Expr::mul(Expr::workgroup_x(), Expr::u32(workgroup_size_x)),
435        Expr::local_x(),
436    )
437}
438
439fn persistent_body_with_workspace_adapter(
440    workgroup_size_x: u32,
441    opcodes: &[OpcodeHandler],
442    adapter: &impl MegakernelWorkspaceAdapter,
443) -> Vec<Node> {
444    let mut body = adapter.bootstrap_nodes();
445    body.extend(adapter.guard_nodes());
446    body.extend(adapter.dispatch_nodes());
447    body.extend(persistent_body_with_io(workgroup_size_x, opcodes, false));
448    body
449}
450
451fn process_io_requests() -> Vec<Node> {
452    let nodes = vec![Node::loop_for(
453        "io_idx",
454        Expr::u32(0),
455        Expr::u32(IO_SLOT_COUNT),
456        vec![
457            Node::let_bind(
458                "io_base",
459                Expr::mul(Expr::var("io_idx"), Expr::u32(IO_SLOT_WORDS)),
460            ),
461            Node::let_bind(
462                "io_status_idx",
463                Expr::add(Expr::var("io_base"), Expr::u32(io_word::STATUS)),
464            ),
465            // CAS PUBLISHED -> CLAIMED
466            Node::let_bind(
467                "prev_io_status",
468                Expr::atomic_compare_exchange(
469                    "io_queue",
470                    Expr::var("io_status_idx"),
471                    Expr::u32(slot::PUBLISHED),
472                    Expr::u32(slot::CLAIMED),
473                ),
474            ),
475            Node::if_then(
476                Expr::eq(Expr::var("prev_io_status"), Expr::u32(slot::PUBLISHED)),
477                vec![
478                    Node::let_bind(
479                        "io_src_handle",
480                        Expr::load(
481                            "io_queue",
482                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::SRC_HANDLE)),
483                        ),
484                    ),
485                    Node::let_bind(
486                        "io_dst_handle",
487                        Expr::load(
488                            "io_queue",
489                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::DST_HANDLE)),
490                        ),
491                    ),
492                    Node::AsyncLoad {
493                        source: IO_SOURCE_CAPABILITY_TABLE.into(),
494                        destination: IO_DESTINATION_CAPABILITY_TABLE.into(),
495                        offset: Box::new(Expr::load(
496                            "io_queue",
497                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::OFFSET_LO)),
498                        )),
499                        size: Box::new(Expr::load(
500                            "io_queue",
501                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::BYTE_COUNT)),
502                        )),
503                        tag: IO_QUEUE_DMA_TAG.into(),
504                    },
505                    // Mark as DONE
506                    Node::store(
507                        "io_queue",
508                        Expr::var("io_status_idx"),
509                        Expr::u32(slot::DONE),
510                    ),
511                ],
512            ),
513        ],
514    )];
515
516    nodes
517}
518
519fn execute_slot_body(opcodes: &[OpcodeHandler]) -> Vec<Node> {
520    vec![
521        Node::let_bind(
522            "status_index",
523            Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
524        ),
525        Node::let_bind(
526            "observed_status",
527            atomic_load_relaxed("ring_buffer", Expr::var("status_index")),
528        ),
529        Node::if_then(
530            Expr::eq(Expr::var("observed_status"), Expr::u32(slot::PUBLISHED)),
531            tenant_authorized_claim_body(slot_tenant_id_load(), claimed_slot_body(opcodes)),
532        ),
533    ]
534}
535
536fn tenant_authorized_claim_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
537    tenant_authorized_body(
538        tenant_id,
539        vec![
540            // CAS PUBLISHED -> CLAIMED after authorization. This keeps
541            // disabled tenants visible to the host instead of converting
542            // their slots into stuck CLAIMED work.
543            Node::let_bind(
544                "prev_status",
545                Expr::atomic_compare_exchange(
546                    "ring_buffer",
547                    Expr::var("status_index"),
548                    Expr::u32(slot::PUBLISHED),
549                    Expr::u32(slot::CLAIMED),
550                ),
551            ),
552            Node::if_then(
553                Expr::eq(Expr::var("prev_status"), Expr::u32(slot::PUBLISHED)),
554                claimed_body,
555            ),
556        ],
557    )
558}
559
560fn execute_already_claimed_slot_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
561    let mut body = vec![Node::let_bind(
562        "status_index",
563        Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
564    )];
565    body.extend(tenant_authorized_body(tenant_id, claimed_body));
566    body
567}
568
569#[cfg(test)]
570mod tests;