1use std::sync::Arc;
8
9use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
10
11use super::handlers::{claimed_slot_bindings, claimed_slot_body, load_miss_body, OpcodeHandler};
12use super::io::{
13 io_word, IO_DESTINATION_CAPABILITY_TABLE, IO_QUEUE_DMA_TAG, IO_SLOT_COUNT, IO_SLOT_WORDS,
14 IO_SOURCE_CAPABILITY_TABLE,
15};
16use super::ir_util::atomic_load_relaxed;
17use super::protocol::*;
18use super::workspace_adapter::MegakernelWorkspaceAdapter;
19mod cache;
20mod jit;
21mod priority;
22pub use jit::{build_program_jit, build_program_jit_slots, persistent_body_jit};
23pub use priority::{
24 build_program_priority, build_program_priority_slots, persistent_body_priority,
25 persistent_body_priority_slots,
26};
27
28#[must_use]
30pub fn build_program() -> Program {
31 build_program_sharded(256, &[])
32}
33
34#[must_use]
42pub fn build_program_sharded(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
43 build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
44}
45
46#[must_use]
52pub fn build_program_sharded_slots(
53 workgroup_size_x: u32,
54 slot_count: u32,
55 opcodes: &[OpcodeHandler],
56) -> Program {
57 build_program_sharded_slots_with_io(workgroup_size_x, slot_count, opcodes, false)
58}
59
60#[must_use]
66pub fn build_program_sharded_slots_shared(
67 workgroup_size_x: u32,
68 slot_count: u32,
69 opcodes: &[OpcodeHandler],
70) -> Arc<Program> {
71 if opcodes.is_empty() {
72 return cache::cached_empty_sharded_program_shared(workgroup_size_x, slot_count, false);
73 }
74 Arc::new(build_program_sharded_slots(
75 workgroup_size_x,
76 slot_count,
77 opcodes,
78 ))
79}
80
81#[must_use]
83pub fn build_program_sharded_with_workspace_adapter(
84 workgroup_size_x: u32,
85 slot_count: u32,
86 opcodes: &[OpcodeHandler],
87 adapter: &impl MegakernelWorkspaceAdapter,
88) -> Program {
89 wrap_persistent_megakernel_program_with_buffers(
90 default_buffers_with_workspace_adapter(slot_count, adapter),
91 workgroup_size_x,
92 persistent_body_with_workspace_adapter(workgroup_size_x, opcodes, adapter),
93 )
94}
95
96#[must_use]
103pub fn build_program_sharded_once_slots(
104 workgroup_size_x: u32,
105 slot_count: u32,
106 opcodes: &[OpcodeHandler],
107) -> Program {
108 if opcodes.is_empty() {
109 return cache::cached_empty_sharded_once_program(workgroup_size_x, slot_count);
110 }
111 wrap_megakernel_program(
112 workgroup_size_x,
113 slot_count,
114 persistent_body_with_io(workgroup_size_x, opcodes, false),
115 )
116}
117
118#[must_use]
121pub fn build_program_sharded_once_slots_shared(
122 workgroup_size_x: u32,
123 slot_count: u32,
124 opcodes: &[OpcodeHandler],
125) -> Arc<Program> {
126 if opcodes.is_empty() {
127 return cache::cached_empty_sharded_once_program_shared(workgroup_size_x, slot_count);
128 }
129 Arc::new(build_program_sharded_once_slots(
130 workgroup_size_x,
131 slot_count,
132 opcodes,
133 ))
134}
135
136#[must_use]
144pub fn build_program_sharded_once_slots_control_report_shared(
145 workgroup_size_x: u32,
146 slot_count: u32,
147 opcodes: &[OpcodeHandler],
148) -> Arc<Program> {
149 if opcodes.is_empty() {
150 return cache::cached_empty_sharded_once_control_report_program_shared(
151 workgroup_size_x,
152 slot_count,
153 );
154 }
155 let mut buffers = default_buffers(slot_count);
156 for buffer in buffers.iter_mut().skip(1) {
157 buffer.output_byte_range = Some(0..0);
158 }
159 Arc::new(optimize_megakernel_program(Program::wrapped(
160 buffers,
161 [workgroup_size_x, 1, 1],
162 persistent_body_with_io(workgroup_size_x, opcodes, false),
163 )))
164}
165
166#[must_use]
172pub fn build_program_sharded_no_io(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Program {
173 build_program_sharded_slots(workgroup_size_x, workgroup_size_x.max(1), opcodes)
174}
175
176#[must_use]
181pub fn build_program_sharded_with_io_polling(
182 workgroup_size_x: u32,
183 opcodes: &[OpcodeHandler],
184) -> Program {
185 build_program_sharded_slots_with_io(workgroup_size_x, workgroup_size_x.max(1), opcodes, true)
186}
187
188#[must_use]
196pub fn build_program_with_self_loading_miss_handler(
197 workgroup_size_x: u32,
198 slot_count: u32,
199 opcodes: &[OpcodeHandler],
200) -> Program {
201 match try_build_program_with_self_loading_miss_handler(workgroup_size_x, slot_count, opcodes) {
202 Ok(program) => program,
203 Err(error) => panic!("{error}"),
204 }
205}
206
207pub fn try_build_program_with_self_loading_miss_handler(
209 workgroup_size_x: u32,
210 slot_count: u32,
211 opcodes: &[OpcodeHandler],
212) -> Result<Program, String> {
213 let mut extended = Vec::new();
214 let extended_len = opcodes.len().checked_add(1).ok_or_else(|| {
215 "megakernel self-loading opcode extension count overflowed usize. Fix: split opcode handler sets before building the megakernel."
216 .to_string()
217 })?;
218 vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut extended, extended_len).map_err(|error| {
219 format!(
220 "megakernel self-loading opcode extension allocation failed: {error}. Fix: split opcode handler sets before building the megakernel."
221 )
222 })?;
223 extended.extend_from_slice(opcodes);
224 extended.push(OpcodeHandler {
225 opcode: super::protocol::opcode::LOAD_MISS,
226 body: load_miss_body(),
227 });
228 Ok(wrap_persistent_megakernel_program(
229 workgroup_size_x,
230 slot_count,
231 persistent_body_with_io(workgroup_size_x, &extended, false),
232 ))
233}
234
235fn build_program_sharded_slots_with_io(
236 workgroup_size_x: u32,
237 slot_count: u32,
238 opcodes: &[OpcodeHandler],
239 include_io_polling: bool,
240) -> Program {
241 if opcodes.is_empty() {
242 return cache::cached_empty_sharded_program(
243 workgroup_size_x,
244 slot_count,
245 include_io_polling,
246 );
247 }
248 wrap_persistent_megakernel_program(
249 workgroup_size_x,
250 slot_count,
251 persistent_body_with_io(workgroup_size_x, opcodes, include_io_polling),
252 )
253}
254
255fn wrap_persistent_megakernel_program(
256 workgroup_size_x: u32,
257 slot_count: u32,
258 body: Vec<Node>,
259) -> Program {
260 wrap_megakernel_program(workgroup_size_x, slot_count, vec![Node::forever(body)])
261}
262
263fn wrap_persistent_megakernel_program_with_buffers(
264 buffers: Vec<BufferDecl>,
265 workgroup_size_x: u32,
266 body: Vec<Node>,
267) -> Program {
268 optimize_megakernel_program(Program::wrapped(
269 buffers,
270 [workgroup_size_x, 1, 1],
271 vec![Node::forever(body)],
272 ))
273}
274
275fn wrap_megakernel_program(workgroup_size_x: u32, slot_count: u32, body: Vec<Node>) -> Program {
276 optimize_megakernel_program(Program::wrapped(
277 default_buffers(slot_count),
278 [workgroup_size_x, 1, 1],
279 body,
280 ))
281}
282
283fn optimize_megakernel_program(program: Program) -> Program {
284 let (program, _) = super::planner::try_elide_value_flow_barriers(program).unwrap_or_else(
285 |error| {
286 panic!(
287 "megakernel program barrier optimization failed: {error}. Fix: reduce fused program size before builder optimization."
288 )
289 },
290 );
291 vyre_foundation::optimizer::pre_lowering::optimize(program)
292}
293
294fn default_buffers(slot_count: u32) -> Vec<BufferDecl> {
308 let ring_slots = slot_count.max(1);
309 let control = BufferDecl::read_write("control", 0, DataType::U32).with_count(CONTROL_MIN_WORDS);
310 let ring_buffer = BufferDecl::read_write("ring_buffer", 1, DataType::U32).with_count(
311 ring_slots.checked_mul(SLOT_WORDS).unwrap_or_else(|| {
312 panic!(
313 "megakernel ring buffer word count overflowed u32. Fix: reduce slot_count or SLOT_WORDS before building default megakernel buffers."
314 )
315 }),
316 );
317 let debug_log =
318 BufferDecl::read_write("debug_log", 2, DataType::U32).with_count(debug::BUFFER_WORDS);
319 let io_queue = BufferDecl::read_write("io_queue", 3, DataType::U32).with_count(64 * 8);
320 vec![control, ring_buffer, debug_log, io_queue]
321}
322
323fn default_buffers_with_workspace_adapter(
324 slot_count: u32,
325 adapter: &impl MegakernelWorkspaceAdapter,
326) -> Vec<BufferDecl> {
327 let mut buffers = default_buffers(slot_count);
328 buffers.push(adapter.buffer_decl());
329 buffers
330}
331
332#[must_use]
335pub fn persistent_body(workgroup_size_x: u32, opcodes: &[OpcodeHandler]) -> Vec<Node> {
336 persistent_body_with_io(workgroup_size_x, opcodes, false)
337}
338
339pub fn try_persistent_body(
341 workgroup_size_x: u32,
342 opcodes: &[OpcodeHandler],
343) -> Result<Vec<Node>, String> {
344 try_persistent_body_with_io(workgroup_size_x, opcodes, false)
345}
346
347fn persistent_body_with_io(
348 workgroup_size_x: u32,
349 opcodes: &[OpcodeHandler],
350 include_io_polling: bool,
351) -> Vec<Node> {
352 match try_persistent_body_with_io(workgroup_size_x, opcodes, include_io_polling) {
353 Ok(body) => body,
354 Err(error) => panic!("{error}"),
355 }
356}
357
358fn try_persistent_body_with_io(
359 workgroup_size_x: u32,
360 opcodes: &[OpcodeHandler],
361 include_io_polling: bool,
362) -> Result<Vec<Node>, String> {
363 let mut body = persistent_lane_prologue(workgroup_size_x);
364 let additional_nodes = if include_io_polling { 3 } else { 2 };
365 let body_capacity = body.len().checked_add(additional_nodes).ok_or_else(|| {
366 "megakernel persistent body node reservation overflowed usize. Fix: reduce fused IO/body staging before building the megakernel."
367 .to_string()
368 })?;
369 vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut body, body_capacity).map_err(|error| {
370 format!(
371 "megakernel persistent body node reservation failed: {error}. Fix: reduce fused IO/body staging before building the megakernel."
372 )
373 })?;
374 body.push(direct_slot_base_binding());
375 body.push(Node::Block(execute_slot_body(opcodes)));
376 if include_io_polling {
377 body.push(Node::Block(process_io_requests()));
378 }
379 Ok(body)
380}
381
382fn persistent_lane_prologue(workgroup_size_x: u32) -> Vec<Node> {
383 vec![
384 Node::let_bind(
385 "shutdown_flag",
386 atomic_load_relaxed("control", Expr::u32(control::SHUTDOWN)),
387 ),
388 Node::if_then(
389 Expr::ne(Expr::var("shutdown_flag"), Expr::u32(0)),
390 vec![Node::Return],
391 ),
392 Node::let_bind("lane_id", lane_id_expr(workgroup_size_x)),
393 ]
394}
395
396fn direct_slot_base_binding() -> Node {
397 Node::let_bind(
398 "slot_base",
399 Expr::mul(Expr::var("lane_id"), Expr::u32(SLOT_WORDS)),
400 )
401}
402
403fn slot_tenant_id_load() -> Expr {
404 Expr::load(
405 "ring_buffer",
406 Expr::add(Expr::var("slot_base"), Expr::u32(TENANT_WORD)),
407 )
408}
409
410fn tenant_authorized_body(tenant_id: Expr, authorized_body: Vec<Node>) -> Vec<Node> {
411 vec![
412 Node::let_bind("tenant_id", tenant_id),
413 Node::let_bind(
414 "tenant_base",
415 atomic_load_relaxed("control", Expr::u32(control::TENANT_BASE)),
416 ),
417 Node::let_bind(
418 "tenant_mask",
419 atomic_load_relaxed(
420 "control",
421 Expr::add(Expr::var("tenant_base"), Expr::var("tenant_id")),
422 ),
423 ),
424 Node::if_then(
425 Expr::ne(Expr::var("tenant_mask"), Expr::u32(0)),
426 authorized_body,
427 ),
428 ]
429}
430
431fn lane_id_expr(workgroup_size_x: u32) -> Expr {
432 Expr::add(
433 Expr::mul(Expr::workgroup_x(), Expr::u32(workgroup_size_x)),
434 Expr::local_x(),
435 )
436}
437
438fn persistent_body_with_workspace_adapter(
439 workgroup_size_x: u32,
440 opcodes: &[OpcodeHandler],
441 adapter: &impl MegakernelWorkspaceAdapter,
442) -> Vec<Node> {
443 let mut body = adapter.bootstrap_nodes();
444 body.extend(adapter.guard_nodes());
445 body.extend(adapter.dispatch_nodes());
446 body.extend(persistent_body_with_io(workgroup_size_x, opcodes, false));
447 body
448}
449
450
451fn process_io_requests() -> Vec<Node> {
452 let nodes = vec![Node::loop_for(
453 "io_idx",
454 Expr::u32(0),
455 Expr::u32(IO_SLOT_COUNT),
456 vec![
457 Node::let_bind(
458 "io_base",
459 Expr::mul(Expr::var("io_idx"), Expr::u32(IO_SLOT_WORDS)),
460 ),
461 Node::let_bind(
462 "io_status_idx",
463 Expr::add(Expr::var("io_base"), Expr::u32(io_word::STATUS)),
464 ),
465 Node::let_bind(
467 "prev_io_status",
468 Expr::atomic_compare_exchange(
469 "io_queue",
470 Expr::var("io_status_idx"),
471 Expr::u32(slot::PUBLISHED),
472 Expr::u32(slot::CLAIMED),
473 ),
474 ),
475 Node::if_then(
476 Expr::eq(Expr::var("prev_io_status"), Expr::u32(slot::PUBLISHED)),
477 vec![
478 Node::let_bind(
479 "io_src_handle",
480 Expr::load(
481 "io_queue",
482 Expr::add(Expr::var("io_base"), Expr::u32(io_word::SRC_HANDLE)),
483 ),
484 ),
485 Node::let_bind(
486 "io_dst_handle",
487 Expr::load(
488 "io_queue",
489 Expr::add(Expr::var("io_base"), Expr::u32(io_word::DST_HANDLE)),
490 ),
491 ),
492 Node::AsyncLoad {
493 source: IO_SOURCE_CAPABILITY_TABLE.into(),
494 destination: IO_DESTINATION_CAPABILITY_TABLE.into(),
495 offset: Box::new(Expr::load(
496 "io_queue",
497 Expr::add(Expr::var("io_base"), Expr::u32(io_word::OFFSET_LO)),
498 )),
499 size: Box::new(Expr::load(
500 "io_queue",
501 Expr::add(Expr::var("io_base"), Expr::u32(io_word::BYTE_COUNT)),
502 )),
503 tag: IO_QUEUE_DMA_TAG.into(),
504 },
505 Node::store(
507 "io_queue",
508 Expr::var("io_status_idx"),
509 Expr::u32(slot::DONE),
510 ),
511 ],
512 ),
513 ],
514 )];
515
516 nodes
517}
518
519fn execute_slot_body(opcodes: &[OpcodeHandler]) -> Vec<Node> {
520 vec![
521 Node::let_bind(
522 "status_index",
523 Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
524 ),
525 Node::let_bind(
526 "observed_status",
527 atomic_load_relaxed("ring_buffer", Expr::var("status_index")),
528 ),
529 Node::if_then(
530 Expr::eq(Expr::var("observed_status"), Expr::u32(slot::PUBLISHED)),
531 tenant_authorized_claim_body(slot_tenant_id_load(), claimed_slot_body(opcodes)),
532 ),
533 ]
534}
535
536fn tenant_authorized_claim_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
537 tenant_authorized_body(
538 tenant_id,
539 vec![
540 Node::let_bind(
544 "prev_status",
545 Expr::atomic_compare_exchange(
546 "ring_buffer",
547 Expr::var("status_index"),
548 Expr::u32(slot::PUBLISHED),
549 Expr::u32(slot::CLAIMED),
550 ),
551 ),
552 Node::if_then(
553 Expr::eq(Expr::var("prev_status"), Expr::u32(slot::PUBLISHED)),
554 claimed_body,
555 ),
556 ],
557 )
558}
559
560fn execute_already_claimed_slot_body(tenant_id: Expr, claimed_body: Vec<Node>) -> Vec<Node> {
561 let mut body = vec![Node::let_bind(
562 "status_index",
563 Expr::add(Expr::var("slot_base"), Expr::u32(STATUS_WORD)),
564 )];
565 body.extend(tenant_authorized_body(tenant_id, claimed_body));
566 body
567}
568
569#[cfg(test)]
570mod tests;
571