Skip to main content

vyre_runtime/megakernel/
builder.rs

1//! IR program builders  -  construct the megakernel `Program` from vyre IR.
2//!
3//! Two flavours:
4//! - **Interpreted** (`build_program_sharded`)  -  If-tree opcode dispatch.
5//! - **JIT** (`build_program_jit`)  -  payload processor fused directly.
6
7use std::sync::Arc;
8
9use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
10
11use super::handlers::{claimed_slot_bindings, claimed_slot_body, load_miss_body, OpcodeHandler};
12use super::io::{
13    io_word, IO_DESTINATION_CAPABILITY_TABLE, IO_QUEUE_DMA_TAG, IO_SLOT_COUNT, IO_SLOT_WORDS,
14    IO_SOURCE_CAPABILITY_TABLE,
15};
16use super::ir_util::atomic_load_relaxed;
17use super::protocol::*;
18use super::workspace_adapter::MegakernelWorkspaceAdapter;
19mod cache;
20mod jit;
21mod priority;
22pub use jit::{build_program_jit, build_program_jit_slots, persistent_body_jit};
23pub use priority::{
24    build_program_priority, build_program_priority_slots, persistent_body_priority,
25    persistent_body_priority_slots,
26};
27
28/// Build the default megakernel IR (256 lanes × 1 workgroup, no custom opcodes).
29#[must_use]
30pub fn build_program() -> Program {
31    build_program_sharded(256, &[])
32}
33
34/// Build the megakernel IR with a custom workgroup size and optional
35/// custom opcodes.
36///
37/// Buffers are declared with concrete `with_count(...)` sizes so the
38/// backend readback layer allocates the right static staging size  -  a
39/// `count=0` default reads back 4 bytes regardless of how much the
40/// kernel wrote.
41#[must_use]
42pub fn build_program_sharded(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
43    build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
44}
45
46/// Build the megakernel IR for an explicit number of ring slots.
47///
48/// This is the production sharded ABI: `slot_count` sizes the ring buffer,
49/// while `workgroup_size_x` controls lanes per workgroup. Dispatch must launch
50/// `slot_count / workgroup_size_x` workgroups so every slot has an owning lane.
51#[must_use]
52pub fn build_program_sharded_slots(
53    workgroup_size_x: u32,
54    slot_count: u32,
55    opcodes: &[OpcodeHandler],
56) -> Program {
57    build_program_sharded_slots_with_io(workgroup_size_x, slot_count, opcodes, false)
58}
59
60/// Build the sharded megakernel IR as a shared immutable template.
61///
62/// Empty opcode sets use the thread-local template cache directly, allowing
63/// compile paths to avoid cloning the cached Program before wrapping it in
64/// `Arc` again.
65#[must_use]
66pub fn build_program_sharded_slots_shared(
67    workgroup_size_x: u32,
68    slot_count: u32,
69    opcodes: &[OpcodeHandler],
70) -> Arc<Program> {
71    if opcodes.is_empty() {
72        return cache::cached_empty_sharded_program_shared(workgroup_size_x, slot_count, false);
73    }
74    Arc::new(build_program_sharded_slots(
75        workgroup_size_x,
76        slot_count,
77        opcodes,
78    ))
79}
80
81/// Build the sharded megakernel IR with a consumer-owned resident workspace.
82#[must_use]
83pub fn build_program_sharded_with_workspace_adapter(
84    workgroup_size_x: u32,
85    slot_count: u32,
86    opcodes: &[OpcodeHandler],
87    adapter: &impl MegakernelWorkspaceAdapter,
88) -> Program {
89    wrap_persistent_megakernel_program_with_buffers(
90        default_buffers_with_workspace_adapter(slot_count, adapter),
91        workgroup_size_x,
92        persistent_body_with_workspace_adapter(workgroup_size_x, opcodes, adapter),
93    )
94}
95
96/// Build a finite one-pass sharded megakernel IR for host-submitted batches.
97///
98/// Unlike [`build_program_sharded_slots`], this program does not wrap the body
99/// in `Node::forever`; each lane attempts to drain its owning slot once and the
100/// dispatch returns. Use this for synchronous batch APIs that need a completion
101/// report from the same queue submission.
102#[must_use]
103pub fn build_program_sharded_once_slots(
104    workgroup_size_x: u32,
105    slot_count: u32,
106    opcodes: &[OpcodeHandler],
107) -> Program {
108    if opcodes.is_empty() {
109        return cache::cached_empty_sharded_once_program(workgroup_size_x, slot_count);
110    }
111    wrap_megakernel_program(
112        workgroup_size_x,
113        slot_count,
114        persistent_body_with_io(workgroup_size_x, opcodes, false),
115    )
116}
117
118/// Shared-Arc variant of [`build_program_sharded_once_slots`] for hot runtime
119/// dispatchers that must not clone the megakernel template every launch.
120#[must_use]
121pub fn build_program_sharded_once_slots_shared(
122    workgroup_size_x: u32,
123    slot_count: u32,
124    opcodes: &[OpcodeHandler],
125) -> Arc<Program> {
126    if opcodes.is_empty() {
127        return cache::cached_empty_sharded_once_program_shared(workgroup_size_x, slot_count);
128    }
129    Arc::new(build_program_sharded_once_slots(
130        workgroup_size_x,
131        slot_count,
132        opcodes,
133    ))
134}
135
136/// Build a finite one-pass megakernel that reports completion through the
137/// control buffer only.
138///
139/// Ring, debug, and IO buffers remain read-write device buffers, but their
140/// host readback ranges are empty. This is the hot dispatcher path: completion
141/// is already accumulated into control, so reading back the full ring/debug/IO
142/// surfaces is redundant launch latency.
143#[must_use]
144pub fn build_program_sharded_once_slots_control_report_shared(
145    workgroup_size_x: u32,
146    slot_count: u32,
147    opcodes: &[OpcodeHandler],
148) -> Arc<Program> {
149    if opcodes.is_empty() {
150        return cache::cached_empty_sharded_once_control_report_program_shared(
151            workgroup_size_x,
152            slot_count,
153        );
154    }
155    let mut buffers = default_buffers(slot_count);
156    for buffer in buffers.iter_mut().skip(1) {
157        buffer.output_byte_range = Some(0..0);
158    }
159    Arc::new(optimize_megakernel_program(Program::wrapped(
160        buffers,
161        [workgroup_size_x, 1, 1],
162        persistent_body_with_io(workgroup_size_x, opcodes, false),
163    )))
164}
165
166/// Build the megakernel IR without the IO polling sidecar.
167///
168/// This is the dispatch path for host-provided [`super::MegakernelWorkItem`]
169/// queues. It keeps the executable kernel free of `AsyncLoad` nodes until the
170/// runtime scheduler owns a concrete async-lowering pass.
171#[must_use]
172pub fn build_program_sharded_no_io(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
173    build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
174}
175
176/// Build the megakernel IR with the experimental IO polling sidecar.
177///
178/// The returned Program contains `AsyncLoad` nodes and must be lowered through
179/// a runtime scheduler pass before reaching a concrete backend lowering path.
180#[must_use]
181pub fn build_program_sharded_with_io_polling(
182    workgroup_size_x: u32,
183    opcodes: &[OpcodeHandler],
184) -> Program {
185    build_program_sharded_slots_with_io(workgroup_size_x, workgroup_size_x.max(1), opcodes, true)
186}
187
188/// Build the megakernel IR with a self-loading load-miss handler.
189///
190/// The persistent loop is extended with an [`opcode::LOAD_MISS`] handler.
191/// When the GPU sees this opcode it scans the IO queue for an empty slot,
192/// writes a DMA-read request, and polls until the host/runtime marks it
193/// complete. The `arg0` field of the slot is the consumer's opaque
194/// resource identifier; vyre does not interpret it.
195#[must_use]
196pub fn build_program_with_self_loading_miss_handler(
197    workgroup_size_x: u32,
198    slot_count: u32,
199    opcodes: &[OpcodeHandler],
200) -> Program {
201    match try_build_program_with_self_loading_miss_handler(workgroup_size_x, slot_count, opcodes) {
202        Ok(program) => program,
203        Err(error) => panic!("{error}"),
204    }
205}
206
207/// Fallible variant of [`build_program_with_self_loading_miss_handler`].
208pub fn try_build_program_with_self_loading_miss_handler(
209    workgroup_size_x: u32,
210    slot_count: u32,
211    opcodes: &[OpcodeHandler],
212) -> Result<Program, String> {
213    let mut extended = Vec::new();
214    let extended_len = opcodes.len().checked_add(1).ok_or_else(|| {
215        "megakernel self-loading opcode extension count overflowed usize. Fix: split opcode handler sets before building the megakernel."
216            .to_string()
217    })?;
218    vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut extended, extended_len).map_err(|error| {
219        format!(
220            "megakernel self-loading opcode extension allocation failed: {error}. Fix: split opcode handler sets before building the megakernel."
221        )
222    })?;
223    extended.extend_from_slice(opcodes);
224    extended.push(OpcodeHandler {
225        opcode: super::protocol::opcode::LOAD_MISS,
226        body: load_miss_body(),
227    });
228    Ok(wrap_persistent_megakernel_program(
229        workgroup_size_x,
230        slot_count,
231        persistent_body_with_io(workgroup_size_x, &extended, false),
232    ))
233}
234
235fn build_program_sharded_slots_with_io(
236    workgroup_size_x: u32,
237    slot_count: u32,
238    opcodes: &[OpcodeHandler],
239    include_io_polling: bool,
240) -> Program {
241    if opcodes.is_empty() {
242        return cache::cached_empty_sharded_program(
243            workgroup_size_x,
244            slot_count,
245            include_io_polling,
246        );
247    }
248    wrap_persistent_megakernel_program(
249        workgroup_size_x,
250        slot_count,
251        persistent_body_with_io(workgroup_size_x, opcodes, include_io_polling),
252    )
253}
254
255fn wrap_persistent_megakernel_program(
256    workgroup_size_x: u32,
257    slot_count: u32,
258    body: Vec<Node>,
259) -> Program {
260    wrap_megakernel_program(workgroup_size_x, slot_count, vec![Node::forever(body)])
261}
262
263fn wrap_persistent_megakernel_program_with_buffers(
264    buffers: Vec<BufferDecl>,
265    workgroup_size_x: u32,
266    body: Vec<Node>,
267) -> Program {
268    optimize_megakernel_program(Program::wrapped(
269        buffers,
270        [workgroup_size_x, 1, 1],
271        vec![Node::forever(body)],
272    ))
273}
274
275fn wrap_megakernel_program(workgroup_size_x: u32, slot_count: u32, body: Vec<Node>) -> Program {
276    optimize_megakernel_program(Program::wrapped(
277        default_buffers(slot_count),
278        [workgroup_size_x, 1, 1],
279        body,
280    ))
281}
282
283fn optimize_megakernel_program(program: Program) -> Program {
284    let (program, _) = super::planner::try_elide_value_flow_barriers(program).unwrap_or_else(
285        |error| {
286            panic!(
287                "megakernel program barrier optimization failed: {error}. Fix: reduce fused program size before builder optimization."
288            )
289        },
290    );
291    vyre_foundation::optimizer::pre_lowering::optimize(program)
292}
293
294/// Reserve sizes for the megakernel's four host-visible buffers. All
295/// four go through the static-readback path so every buffer needs
296/// a concrete `count` (u32 elements). The numbers mirror the wire
297/// layout in `protocol.rs`:
298///
299/// - **control**: 128 u32 words covers SHUTDOWN, DONE_COUNT, EPOCH,
300///   METRICS_BASE..METRICS_BASE+METRICS_SLOTS, OBSERVABLE_BASE, and
301///   the 32-entry tenant-mask table.
302/// - **ring_buffer**: `slot_count` slots × `SLOT_WORDS`.
303///   `slot_count` must match host-published ring bytes and dispatch geometry.
304/// - **debug_log**: cursor word + `debug::RECORD_CAPACITY` × 4-word records.
305/// - **io_queue**: 64 slots × 8 words (source, destination,
306///   offset_low, offset_high, size, status, tag, pad).
307fn default_buffers(slot_count: u32) -> Vec<BufferDecl> {
308    let ring_slots = slot_count.max(1);
309    let control = BufferDecl::read_write("control", 0, DataType::U32).with_count(CONTROL_MIN_WORDS);
310    let ring_buffer = BufferDecl::read_write("ring_buffer", 1, DataType::U32).with_count(
311        ring_slots.checked_mul(SLOT_WORDS).unwrap_or_else(|| {
312            panic!(
313                "megakernel ring buffer word count overflowed u32. Fix: reduce slot_count or SLOT_WORDS before building default megakernel buffers."
314            )
315        }),
316    );
317    let debug_log =
318        BufferDecl::read_write("debug_log", 2, DataType::U32).with_count(debug::BUFFER_WORDS);
319    let io_queue = BufferDecl::read_write("io_queue", 3, DataType::U32).with_count(64 * 8);
320    vec![control, ring_buffer, debug_log, io_queue]
321}
322
323fn default_buffers_with_workspace_adapter(
324    slot_count: u32,
325    adapter: &impl MegakernelWorkspaceAdapter,
326) -> Vec<BufferDecl> {
327    let mut buffers = default_buffers(slot_count);
328    buffers.push(adapter.buffer_decl());
329    buffers
330}
331
332/// The body that runs once per iteration per lane. Exposed for tests
333/// and downstream crates that splice additional opcodes.
334#[must_use]
335pub fn persistent_body(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Vec<Node> {
336    persistent_body_with_io(workgroup_size_x, opcodes, false)
337}
338
339/// Fallible persistent body builder with explicit staging-allocation reporting.
340pub fn try_persistent_body(
341    workgroup_size_x: u32,
342    opcodes: &[OpcodeHandler],
343) -> Result<Vec<Node>, String> {
344    try_persistent_body_with_io(workgroup_size_x, opcodes, false)
345}
346
347fn persistent_body_with_io(
348    workgroup_size_x: u32,
349    opcodes: &[OpcodeHandler],
350    include_io_polling: bool,
351) -> Vec<Node> {
352    match try_persistent_body_with_io(workgroup_size_x, opcodes, include_io_polling) {
353        Ok(body) => body,
354        Err(error) => panic!("{error}"),
355    }
356}
357
358fn try_persistent_body_with_io(
359    workgroup_size_x: u32,
360    opcodes: &[OpcodeHandler],
361    include_io_polling: bool,
362) -> Result<Vec<Node>, String> {
363    let mut body = persistent_lane_prologue(workgroup_size_x);
364    let additional_nodes = if include_io_polling { 3 } else { 2 };
365    let body_capacity = body.len().checked_add(additional_nodes).ok_or_else(|| {
366        "megakernel persistent body node reservation overflowed usize. Fix: reduce fused IO/body staging before building the megakernel."
367            .to_string()
368    })?;
369    vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity).map_err(|error| {
370        format!(
371            "megakernel persistent body node reservation failed: {error}. Fix: reduce fused IO/body staging before building the megakernel."
372        )
373    })?;
374    body.push(direct_slot_base_binding());
375    body.push(Node::Block(execute_slot_body(opcodes)));
376    if include_io_polling {
377        body.push(Node::Block(process_io_requests()));
378    }
379    Ok(body)
380}
381
382fn persistent_lane_prologue(workgroup_size_x: u32) -> Vec<Node> {
383    vec![
384        Node::let_bind(
385            "shutdown_flag",
386            atomic_load_relaxed("control", Expr::u32(control::SHUTDOWN)),
387        ),
388        Node::if_then(
389            Expr::ne(Expr::var("shutdown_flag"), Expr::u32(0)),
390            vec![Node::Return],
391        ),
392        Node::let_bind("lane_id", lane_id_expr(workgroup_size_x)),
393    ]
394}
395
396fn direct_slot_base_binding() -> Node {
397    Node::let_bind(
398        "slot_base",
399        Expr::mul(Expr::var("lane_id"), Expr::u32(SLOT_WORDS)),
400    )
401}
402
403fn slot_tenant_id_load() -> Expr {
404    Expr::load(
405        "ring_buffer",
406        Expr::add(Expr::var("slot_base"), Expr::u32(TENANT_WORD)),
407    )
408}
409
410fn tenant_authorized_body(tenant_id: Expr, authorized_body: Vec<Node>) -> Vec<Node> {
411    vec![
412        Node::let_bind("tenant_id", tenant_id),
413        Node::let_bind(
414            "tenant_base",
415            atomic_load_relaxed("control", Expr::u32(control::TENANT_BASE)),
416        ),
417        Node::let_bind(
418            "tenant_mask",
419            atomic_load_relaxed(
420                "control",
421                Expr::add(Expr::var("tenant_base"), Expr::var("tenant_id")),
422            ),
423        ),
424        Node::if_then(
425            Expr::ne(Expr::var("tenant_mask"), Expr::u32(0)),
426            authorized_body,
427        ),
428    ]
429}
430
431fn lane_id_expr(workgroup_size_x: u32) -> Expr {
432    Expr::add(
433        Expr::mul(Expr::workgroup_x(), Expr::u32(workgroup_size_x)),
434        Expr::local_x(),
435    )
436}
437
438fn persistent_body_with_workspace_adapter(
439    workgroup_size_x: u32,
440    opcodes: &[OpcodeHandler],
441    adapter: &impl MegakernelWorkspaceAdapter,
442) -> Vec<Node> {
443    let mut body = adapter.bootstrap_nodes();
444    body.extend(adapter.guard_nodes());
445    body.extend(adapter.dispatch_nodes());
446    body.extend(persistent_body_with_io(workgroup_size_x, opcodes, false));
447    body
448}
449
450
451fn process_io_requests() -> Vec<Node> {
452    let nodes = vec![Node::loop_for(
453        "io_idx",
454        Expr::u32(0),
455        Expr::u32(IO_SLOT_COUNT),
456        vec![
457            Node::let_bind(
458                "io_base",
459                Expr::mul(Expr::var("io_idx"), Expr::u32(IO_SLOT_WORDS)),
460            ),
461            Node::let_bind(
462                "io_status_idx",
463                Expr::add(Expr::var("io_base"), Expr::u32(io_word::STATUS)),
464            ),
465            // CAS PUBLISHED -> CLAIMED
466            Node::let_bind(
467                "prev_io_status",
468                Expr::atomic_compare_exchange(
469                    "io_queue",
470                    Expr::var("io_status_idx"),
471                    Expr::u32(slot::PUBLISHED),
472                    Expr::u32(slot::CLAIMED),
473                ),
474            ),
475            Node::if_then(
476                Expr::eq(Expr::var("prev_io_status"), Expr::u32(slot::PUBLISHED)),
477                vec![
478                    Node::let_bind(
479                        "io_src_handle",
480                        Expr::load(
481                            "io_queue",
482                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::SRC_HANDLE)),
483                        ),
484                    ),
485                    Node::let_bind(
486                        "io_dst_handle",
487                        Expr::load(
488                            "io_queue",
489                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::DST_HANDLE)),
490                        ),
491                    ),
492                    Node::AsyncLoad {
493                        source: IO_SOURCE_CAPABILITY_TABLE.into(),
494                        destination: IO_DESTINATION_CAPABILITY_TABLE.into(),
495                        offset: Box::new(Expr::load(
496                            "io_queue",
497                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::OFFSET_LO)),
498                        )),
499                        size: Box::new(Expr::load(
500                            "io_queue",
501                            Expr::add(Expr::var("io_base"), Expr::u32(io_word::BYTE_COUNT)),
502                        )),
503                        tag: IO_QUEUE_DMA_TAG.into(),
504                    },
505                    // Mark as DONE
506                    Node::store(
507                        "io_queue",
508                        Expr::var("io_status_idx"),
509                        Expr::u32(slot::DONE),
510                    ),
511                ],
512            ),
513        ],
514    )];
515
516    nodes
517}
518
519fn execute_slot_body(opcodes: &[OpcodeHandler]) -> Vec<Node> {
520    vec![
521        Node::let_bind(
522            "status_index",
523            Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
524        ),
525        Node::let_bind(
526            "observed_status",
527            atomic_load_relaxed("ring_buffer", Expr::var("status_index")),
528        ),
529        Node::if_then(
530            Expr::eq(Expr::var("observed_status"), Expr::u32(slot::PUBLISHED)),
531            tenant_authorized_claim_body(slot_tenant_id_load(), claimed_slot_body(opcodes)),
532        ),
533    ]
534}
535
536fn tenant_authorized_claim_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
537    tenant_authorized_body(
538        tenant_id,
539        vec![
540            // CAS PUBLISHED -> CLAIMED after authorization. This keeps
541            // disabled tenants visible to the host instead of converting
542            // their slots into stuck CLAIMED work.
543            Node::let_bind(
544                "prev_status",
545                Expr::atomic_compare_exchange(
546                    "ring_buffer",
547                    Expr::var("status_index"),
548                    Expr::u32(slot::PUBLISHED),
549                    Expr::u32(slot::CLAIMED),
550                ),
551            ),
552            Node::if_then(
553                Expr::eq(Expr::var("prev_status"), Expr::u32(slot::PUBLISHED)),
554                claimed_body,
555            ),
556        ],
557    )
558}
559
560fn execute_already_claimed_slot_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
561    let mut body = vec![Node::let_bind(
562        "status_index",
563        Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
564    )];
565    body.extend(tenant_authorized_body(tenant_id, claimed_body));
566    body
567}
568
569#[cfg(test)]
570mod tests;
571