Skip to main content

vyre_runtime/megakernel/
handlers.rs

1//! Built-in opcode handler bodies  -  STORE_U32, ATOMIC_ADD, PRINTF, SHUTDOWN.
2//!
3//! Each function returns a `Vec<Node>` that executes when the opcode
4//! matches in the claimed-slot dispatch. Variables `arg0`, `arg1`,
5//! `arg2`, `slot_base` are in scope.
6
7use vyre_foundation::ir::{Expr, Node};
8
9use super::ir_util::{atomic_load_relaxed, atomic_store_relaxed};
10use super::protocol::{control, debug, opcode, ARGS_PER_SLOT};
11
12/// Caller-supplied opcode extension wired into the megakernel at
13/// bootstrap. The `body` executes when `opcode` matches; within the
14/// body, variables `slot_base`, `opcode`, `arg0..arg2` and the
15/// buffers `control`/`ring_buffer`/`debug_log` are in scope.
16#[derive(Debug, Clone)]
17pub struct OpcodeHandler {
18    /// Discriminant matched against `ring_buffer[slot_base + OPCODE_WORD]`.
19    pub opcode: u32,
20    /// IR nodes executed when the match lands.
21    pub body: Vec<Node>,
22}
23
24/// Wrap a body in `if opcode == discriminant { body }`.
25pub(crate) fn opcode_if(op: u32, body: Vec<Node>) -> Node {
26    Node::if_then(Expr::eq(Expr::var("opcode"), Expr::u32(op)), body)
27}
28
29pub(crate) fn store_u32_body() -> Vec<Node> {
30    vec![atomic_store_relaxed(
31        "store_u32_prev",
32        "control",
33        Expr::var("arg1"),
34        Expr::var("arg0"),
35    )]
36}
37
38pub(crate) fn atomic_add_body() -> Vec<Node> {
39    vec![Node::let_bind(
40        "atomic_add_prev",
41        Expr::atomic_add("control", Expr::var("arg1"), Expr::var("arg0")),
42    )]
43}
44
45pub(crate) fn shutdown_body() -> Vec<Node> {
46    vec![Node::let_bind(
47        "shutdown_prev",
48        Expr::atomic_exchange("control", Expr::u32(control::SHUTDOWN), Expr::u32(1)),
49    )]
50}
51
52pub(crate) fn printf_body() -> Vec<Node> {
53    // Reserve 4 u32 words at debug_log[cursor..cursor+4] atomically,
54    // then write (fmt_id=arg0, arg1, arg2, slot_base) into them.
55    // atomic_add returns the pre-increment value  -  our reservation
56    // base.
57    vec![
58        Node::let_bind(
59            "printf_base",
60            Expr::add(
61                Expr::atomic_add(
62                    "debug_log",
63                    Expr::u32(debug::CURSOR_WORD),
64                    Expr::u32(debug::RECORD_WORDS),
65                ),
66                Expr::u32(debug::RECORDS_BASE),
67            ),
68        ),
69        Node::if_then(
70            Expr::le(
71                Expr::add(Expr::var("printf_base"), Expr::u32(debug::RECORD_WORDS)),
72                Expr::u32(debug::BUFFER_WORDS),
73            ),
74            vec![
75                atomic_store_relaxed(
76                    "printf_fmt_prev",
77                    "debug_log",
78                    Expr::var("printf_base"),
79                    Expr::var("arg0"),
80                ),
81                atomic_store_relaxed(
82                    "printf_arg1_prev",
83                    "debug_log",
84                    Expr::add(Expr::var("printf_base"), Expr::u32(1)),
85                    Expr::var("arg1"),
86                ),
87                atomic_store_relaxed(
88                    "printf_arg2_prev",
89                    "debug_log",
90                    Expr::add(Expr::var("printf_base"), Expr::u32(2)),
91                    Expr::var("arg2"),
92                ),
93                atomic_store_relaxed(
94                    "printf_slot_prev",
95                    "debug_log",
96                    Expr::add(Expr::var("printf_base"), Expr::u32(3)),
97                    Expr::var("slot_base"),
98                ),
99            ],
100        ),
101    ]
102}
103
104// --- V6.4 new opcode bodies ---
105
106/// LOAD_U32: copy `control[arg0]` into `control[OBSERVABLE_BASE + arg1]`.
107pub(crate) fn load_u32_body() -> Vec<Node> {
108    vec![atomic_store_relaxed(
109        "load_u32_observable_prev",
110        "control",
111        Expr::add(Expr::u32(control::OBSERVABLE_BASE), Expr::var("arg1")),
112        atomic_load_relaxed("control", Expr::var("arg0")),
113    )]
114}
115
116/// COMPARE_SWAP: CAS on `control[arg0]`, expected=arg1, desired=arg2.
117/// Write the previous value (before CAS) to `control[OBSERVABLE_BASE + arg0]`
118/// so the host can detect success (prev == expected means swap happened).
119pub(crate) fn compare_swap_body() -> Vec<Node> {
120    vec![
121        Node::let_bind(
122            "cas_prev",
123            Expr::atomic_compare_exchange(
124                "control",
125                Expr::var("arg0"),
126                Expr::var("arg1"),
127                Expr::var("arg2"),
128            ),
129        ),
130        atomic_store_relaxed(
131            "cas_observable_prev",
132            "control",
133            Expr::add(Expr::u32(control::OBSERVABLE_BASE), Expr::var("arg0")),
134            Expr::var("cas_prev"),
135        ),
136    ]
137}
138
139/// MEMCPY: copy `control[arg0..arg0+arg2]` → `control[arg1..arg1+arg2]`.
140/// Sequential loop  -  fine for small copies within the control buffer.
141pub(crate) fn memcpy_body() -> Vec<Node> {
142    vec![Node::loop_for(
143        "copy_i",
144        Expr::u32(0),
145        Expr::var("arg2"),
146        vec![atomic_store_relaxed(
147            "memcpy_dst_prev",
148            "control",
149            Expr::add(Expr::var("arg1"), Expr::var("copy_i")),
150            atomic_load_relaxed("control", Expr::add(Expr::var("arg0"), Expr::var("copy_i"))),
151        )],
152    )]
153}
154
155/// BATCH_FENCE: atomically increment `control[EPOCH]` and write the
156/// user-tag (`arg1`) to `control[OBSERVABLE_BASE]`.
157pub(crate) fn batch_fence_body() -> Vec<Node> {
158    vec![
159        Node::let_bind(
160            "epoch_prev",
161            Expr::atomic_add("control", Expr::u32(control::EPOCH), Expr::u32(1)),
162        ),
163        atomic_store_relaxed(
164            "fence_observable_prev",
165            "control",
166            Expr::u32(control::OBSERVABLE_BASE),
167            Expr::var("arg1"),
168        ),
169    ]
170}
171
172/// LOAD_MISS: GPU-initiated DMA request to the IO queue.
173///
174/// Reads the consumer's `resource_id` from `arg0` and `prefetch_flag` from
175/// `arg1`, scans the IO queue for an empty slot, writes a READ request,
176/// and spins until the host/runtime marks it OK. vyre is opaque to the
177/// resource identifier  -  it's just a u32 the consumer uses to look up
178/// the source and destination of the read.
179pub(crate) fn load_miss_body() -> Vec<Node> {
180    let io_slot_count = super::io::IO_SLOT_COUNT;
181    let io_slot_words = super::io::IO_SLOT_WORDS;
182
183    vec![
184        Node::let_bind(
185            "resource_id",
186            Expr::load(
187                "ring_buffer",
188                Expr::add(
189                    Expr::var("slot_base"),
190                    Expr::u32(super::protocol::ARG0_WORD),
191                ),
192            ),
193        ),
194        Node::let_bind(
195            "prefetch_flag",
196            Expr::load(
197                "ring_buffer",
198                Expr::add(
199                    Expr::var("slot_base"),
200                    Expr::u32(super::protocol::ARG0_WORD + 1),
201                ),
202            ),
203        ),
204        // Scan for an empty IO slot.
205        Node::let_bind("found_io_slot", Expr::u32(io_slot_count)),
206        Node::loop_for(
207            "scan_i",
208            Expr::u32(0),
209            Expr::u32(io_slot_count),
210            vec![
211                Node::if_then(
212                    Expr::ne(Expr::var("found_io_slot"), Expr::u32(io_slot_count)),
213                    vec![], // already found, skip remaining scan iterations
214                ),
215                Node::if_then(
216                    Expr::eq(Expr::var("found_io_slot"), Expr::u32(io_slot_count)),
217                    vec![
218                        Node::let_bind(
219                            "scan_base",
220                            Expr::mul(Expr::var("scan_i"), Expr::u32(io_slot_words)),
221                        ),
222                        Node::let_bind(
223                            "scan_status",
224                            Expr::load(
225                                "io_queue",
226                                Expr::add(
227                                    Expr::var("scan_base"),
228                                    Expr::u32(super::io::io_word::STATUS),
229                                ),
230                            ),
231                        ),
232                        Node::if_then(
233                            Expr::eq(
234                                Expr::var("scan_status"),
235                                Expr::u32(super::protocol::slot::EMPTY),
236                            ),
237                            vec![Node::assign("found_io_slot", Expr::var("scan_i"))],
238                        ),
239                    ],
240                ),
241            ],
242        ),
243        // If a slot was found, write the DMA request and poll for completion.
244        Node::if_then(
245            Expr::ne(Expr::var("found_io_slot"), Expr::u32(io_slot_count)),
246            vec![
247                Node::let_bind(
248                    "io_base",
249                    Expr::mul(Expr::var("found_io_slot"), Expr::u32(io_slot_words)),
250                ),
251                Node::store(
252                    "io_queue",
253                    Expr::add(Expr::var("io_base"), Expr::u32(super::io::io_word::OP_TYPE)),
254                    Expr::u32(super::io::io_op::READ),
255                ),
256                Node::store(
257                    "io_queue",
258                    Expr::add(
259                        Expr::var("io_base"),
260                        Expr::u32(super::io::io_word::SRC_HANDLE),
261                    ),
262                    Expr::var("resource_id"),
263                ),
264                Node::store(
265                    "io_queue",
266                    Expr::add(
267                        Expr::var("io_base"),
268                        Expr::u32(super::io::io_word::DST_HANDLE),
269                    ),
270                    Expr::var("resource_id"),
271                ),
272                Node::store(
273                    "io_queue",
274                    Expr::add(
275                        Expr::var("io_base"),
276                        Expr::u32(super::io::io_word::OFFSET_LO),
277                    ),
278                    Expr::u32(0),
279                ),
280                Node::store(
281                    "io_queue",
282                    Expr::add(
283                        Expr::var("io_base"),
284                        Expr::u32(super::io::io_word::OFFSET_HI),
285                    ),
286                    Expr::u32(0),
287                ),
288                Node::store(
289                    "io_queue",
290                    Expr::add(
291                        Expr::var("io_base"),
292                        Expr::u32(super::io::io_word::BYTE_COUNT),
293                    ),
294                    Expr::u32(0),
295                ),
296                Node::store(
297                    "io_queue",
298                    Expr::add(Expr::var("io_base"), Expr::u32(super::io::io_word::TAG)),
299                    Expr::var("resource_id"),
300                ),
301                // Publish the request.
302                Node::store(
303                    "io_queue",
304                    Expr::add(Expr::var("io_base"), Expr::u32(super::io::io_word::STATUS)),
305                    Expr::u32(super::protocol::slot::PUBLISHED),
306                ),
307                // Poll until the host/runtime marks it OK.
308                Node::let_bind("poll_done", Expr::u32(0)),
309                Node::let_bind("poll_max_iters", Expr::u32(u32::MAX)),
310                Node::loop_for(
311                    "poll_i",
312                    Expr::u32(0),
313                    Expr::var("poll_max_iters"),
314                    vec![
315                        Node::if_then(
316                            Expr::eq(Expr::var("poll_done"), Expr::u32(1)),
317                            vec![], // skip once done
318                        ),
319                        Node::if_then(
320                            Expr::ne(Expr::var("poll_done"), Expr::u32(1)),
321                            vec![
322                                Node::let_bind(
323                                    "poll_status",
324                                    Expr::load(
325                                        "io_queue",
326                                        Expr::add(
327                                            Expr::var("io_base"),
328                                            Expr::u32(super::io::io_word::STATUS),
329                                        ),
330                                    ),
331                                ),
332                                Node::if_then(
333                                    Expr::eq(
334                                        Expr::var("poll_status"),
335                                        Expr::u32(super::io::io_status::OK),
336                                    ),
337                                    vec![
338                                        Node::store(
339                                            "io_queue",
340                                            Expr::add(
341                                                Expr::var("io_base"),
342                                                Expr::u32(super::io::io_word::STATUS),
343                                            ),
344                                            Expr::u32(super::protocol::slot::EMPTY),
345                                        ),
346                                        Node::assign("poll_done", Expr::u32(1)),
347                                    ],
348                                ),
349                            ],
350                        ),
351                    ],
352                ),
353            ],
354        ),
355    ]
356}
357
358fn packed_payload_byte(byte_offset: Expr) -> Expr {
359    let word_offset = Expr::div(byte_offset.clone(), Expr::u32(4));
360    let bit_shift = Expr::mul(Expr::rem(byte_offset, Expr::u32(4)), Expr::u32(8));
361    let word = Expr::load(
362        "ring_buffer",
363        Expr::add(
364            Expr::add(
365                Expr::var("slot_base"),
366                Expr::u32(super::protocol::ARG0_WORD),
367            ),
368            word_offset,
369        ),
370    );
371    Expr::bitand(Expr::shr(word, bit_shift), Expr::u32(0xFF))
372}
373
374fn dispatch_opcode_body(opcodes: &[OpcodeHandler]) -> Vec<Node> {
375    let mut nodes = vec![
376        Node::if_then(
377            Expr::lt(Expr::var("opcode"), Expr::u32(control::METRICS_SLOTS)),
378            vec![Node::let_bind(
379                "metric_prev",
380                Expr::atomic_add(
381                    "control",
382                    Expr::add(Expr::u32(control::METRICS_BASE), Expr::var("opcode")),
383                    Expr::u32(1),
384                ),
385            )],
386        ),
387        opcode_if(opcode::STORE_U32, store_u32_body()),
388        opcode_if(opcode::ATOMIC_ADD, atomic_add_body()),
389        opcode_if(opcode::LOAD_U32, load_u32_body()),
390        opcode_if(opcode::COMPARE_SWAP, compare_swap_body()),
391        opcode_if(opcode::MEMCPY, memcpy_body()),
392        opcode_if(opcode::BATCH_FENCE, batch_fence_body()),
393        opcode_if(opcode::PRINTF, printf_body()),
394        opcode_if(opcode::SHUTDOWN, shutdown_body()),
395    ];
396
397    for handler in opcodes {
398        nodes.push(opcode_if(handler.opcode, handler.body.clone()));
399    }
400
401    nodes
402}
403
404pub(crate) fn packed_slot_body(opcodes: &[OpcodeHandler]) -> Vec<Node> {
405    vec![
406        Node::let_bind("packed_raw_opcode_count", packed_payload_byte(Expr::u32(0))),
407        Node::let_bind(
408            "packed_opcode_count",
409            Expr::select(
410                Expr::gt(
411                    Expr::var("packed_raw_opcode_count"),
412                    Expr::u32(ARGS_PER_SLOT / 3),
413                ),
414                Expr::u32(ARGS_PER_SLOT / 3),
415                Expr::var("packed_raw_opcode_count"),
416            ),
417        ),
418        Node::let_bind(
419            "packed_metadata_bytes",
420            Expr::add(
421                Expr::u32(2),
422                Expr::mul(Expr::var("packed_opcode_count"), Expr::u32(2)),
423            ),
424        ),
425        Node::let_bind(
426            "packed_metadata_words",
427            Expr::div(
428                Expr::add(Expr::var("packed_metadata_bytes"), Expr::u32(3)),
429                Expr::u32(4),
430            ),
431        ),
432        Node::loop_for(
433            "packed_inner_index",
434            Expr::u32(0),
435            Expr::var("packed_opcode_count"),
436            vec![Node::block(vec![
437                Node::let_bind(
438                    "packed_pair_byte",
439                    Expr::add(
440                        Expr::u32(2),
441                        Expr::mul(Expr::var("packed_inner_index"), Expr::u32(2)),
442                    ),
443                ),
444                Node::let_bind(
445                    "packed_opcode",
446                    packed_payload_byte(Expr::var("packed_pair_byte")),
447                ),
448                Node::let_bind(
449                    "packed_arg_offset",
450                    packed_payload_byte(Expr::add(Expr::var("packed_pair_byte"), Expr::u32(1))),
451                ),
452                Node::let_bind(
453                    "packed_arg_base",
454                    Expr::add(
455                        Expr::var("packed_metadata_words"),
456                        Expr::var("packed_arg_offset"),
457                    ),
458                ),
459                Node::assign("opcode", Expr::var("packed_opcode")),
460                Node::assign(
461                    "arg0",
462                    Expr::load(
463                        "ring_buffer",
464                        Expr::add(
465                            Expr::add(
466                                Expr::var("slot_base"),
467                                Expr::u32(super::protocol::ARG0_WORD),
468                            ),
469                            Expr::var("packed_arg_base"),
470                        ),
471                    ),
472                ),
473                Node::assign(
474                    "arg1",
475                    Expr::load(
476                        "ring_buffer",
477                        Expr::add(
478                            Expr::add(
479                                Expr::var("slot_base"),
480                                Expr::u32(super::protocol::ARG0_WORD),
481                            ),
482                            Expr::add(Expr::var("packed_arg_base"), Expr::u32(1)),
483                        ),
484                    ),
485                ),
486                Node::assign(
487                    "arg2",
488                    Expr::load(
489                        "ring_buffer",
490                        Expr::add(
491                            Expr::add(
492                                Expr::var("slot_base"),
493                                Expr::u32(super::protocol::ARG0_WORD),
494                            ),
495                            Expr::add(Expr::var("packed_arg_base"), Expr::u32(2)),
496                        ),
497                    ),
498                ),
499                Node::if_then(
500                    Expr::le(
501                        Expr::add(Expr::var("packed_arg_base"), Expr::u32(3)),
502                        Expr::u32(ARGS_PER_SLOT),
503                    ),
504                    vec![Node::block(dispatch_opcode_body(opcodes))],
505                ),
506            ])],
507        ),
508    ]
509}
510
511/// Build the claimed-slot dispatch body (opcode If-tree + custom handlers).
512pub(crate) fn claimed_slot_bindings() -> Vec<Node> {
513    vec![
514        Node::let_bind(
515            "opcode",
516            Expr::load(
517                "ring_buffer",
518                Expr::add(
519                    Expr::var("slot_base"),
520                    Expr::u32(super::protocol::OPCODE_WORD),
521                ),
522            ),
523        ),
524        Node::let_bind(
525            "arg0",
526            Expr::load(
527                "ring_buffer",
528                Expr::add(
529                    Expr::var("slot_base"),
530                    Expr::u32(super::protocol::ARG0_WORD),
531                ),
532            ),
533        ),
534        Node::let_bind(
535            "arg1",
536            Expr::load(
537                "ring_buffer",
538                Expr::add(
539                    Expr::var("slot_base"),
540                    Expr::u32(super::protocol::ARG0_WORD + 1),
541                ),
542            ),
543        ),
544        Node::let_bind(
545            "arg2",
546            Expr::load(
547                "ring_buffer",
548                Expr::add(
549                    Expr::var("slot_base"),
550                    Expr::u32(super::protocol::ARG0_WORD + 2),
551                ),
552            ),
553        ),
554    ]
555}
556
557/// Build the claimed-slot dispatch body (opcode If-tree + custom handlers).
558pub(crate) fn claimed_slot_body(opcodes: &[OpcodeHandler]) -> Vec<Node> {
559    let mut nodes = claimed_slot_bindings();
560    nodes.push(Node::block(dispatch_opcode_body(opcodes)));
561    nodes.push(opcode_if(opcode::PACKED_SLOT, packed_slot_body(opcodes)));
562
563    // Tally progress so the host can observe done_count.
564    nodes.push(Node::let_bind(
565        "done_prev",
566        Expr::atomic_add("control", Expr::u32(control::DONE_COUNT), Expr::u32(1)),
567    ));
568
569    // Mark slot DONE.
570    nodes.push(Node::store(
571        "ring_buffer",
572        Expr::var("status_index"),
573        Expr::u32(super::protocol::slot::DONE),
574    ));
575
576    nodes
577}
578
579#[cfg(test)]
580mod tests;