Skip to main content

vyre_runtime/megakernel/
execution.rs

1//! Compiled persistent-megakernel handle and dispatch path.
2
3mod persistent_handles;
4mod readback_dispatch;
5mod types;
6
7use super::builder::{build_program_jit_slots, build_program_sharded_slots_shared};
8use super::handlers::OpcodeHandler;
9use super::io;
10use super::planner::MegakernelLaunchGeometry;
11use super::protocol;
12use super::protocol_api::{validate_control_bytes, validate_debug_log_bytes};
13use super::recovery::{
14    backend_error_indicates_device_loss, recover_compiled_pipeline, MegakernelRecoveryDecision,
15    MegakernelRecoveryPolicy,
16};
17use super::staging_reserve::reserve_vec_capacity;
18use crate::PipelineError;
19use arc_swap::ArcSwap;
20use std::sync::Arc;
21use std::time::Instant;
22use vyre_driver::backend::{
23    CompiledPipeline, DispatchConfig, OutputBuffers, Resource, VyreBackend,
24};
25use vyre_foundation::ir::Program;
26
27pub use types::{
28    MegakernelBatchDispatchOutput, MegakernelDispatchOutput, MegakernelDispatchStats,
29    MegakernelResidentBatchScratch, MegakernelResidentHandles,
30};
31
32/// Orchestrated persistent-megakernel handle.
33///
34/// Construct with [`Megakernel::bootstrap`] (default 256 lanes x 1
35/// workgroup) or [`Megakernel::bootstrap_sharded`] for multi-tenant fan-in.
36/// Feed bytecode with [`Megakernel::dispatch`].
37pub struct Megakernel {
38    backend: Arc<dyn VyreBackend>,
39    pipeline: ArcSwap<PipelineSlot>,
40    pipeline_id: String,
41    program: Arc<Program>,
42    has_grid_sync: bool,
43    empty_io_queue_bytes: Arc<[u8]>,
44    slot_count: u32,
45    workgroup_size_x: u32,
46    recovery_policy: MegakernelRecoveryPolicy,
47}
48
49struct PipelineSlot {
50    inner: Arc<dyn CompiledPipeline>,
51}
52
53impl Megakernel {
54    /// Default bootstrap: 256 lanes x 1 workgroup, no custom opcodes.
55    ///
56    /// # Errors
57    ///
58    /// Returns [`PipelineError::Backend`] if the backend rejects the program.
59    pub fn bootstrap(backend: Arc<dyn VyreBackend>) -> Result<Self, PipelineError> {
60        Self::bootstrap_sharded(backend, 256, 256, Vec::new())
61    }
62
63    /// Bootstrap with custom opcodes but default sharding.
64    ///
65    /// # Errors
66    ///
67    /// See [`Megakernel::bootstrap`].
68    pub fn bootstrap_with_opcodes(
69        backend: Arc<dyn VyreBackend>,
70        opcodes: Vec<OpcodeHandler>,
71    ) -> Result<Self, PipelineError> {
72        Self::bootstrap_sharded(backend, 256, 256, opcodes)
73    }
74
75    /// Compute worker groups for a megakernel slot geometry without compiling.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`PipelineError::QueueFull`] when the geometry cannot map slots
80    /// to whole workgroups.
81    pub fn worker_groups_for_geometry(
82        slot_count: u32,
83        workgroup_size_x: u32,
84    ) -> Result<u32, PipelineError> {
85        validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
86        Ok(slot_count / workgroup_size_x)
87    }
88
89    /// Full bootstrap with sharding and custom opcodes.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`PipelineError::QueueFull`] when geometry is invalid or
94    /// [`PipelineError::Backend`] from the underlying compile.
95    pub fn bootstrap_sharded(
96        backend: Arc<dyn VyreBackend>,
97        slot_count: u32,
98        workgroup_size_x: u32,
99        opcodes: Vec<OpcodeHandler>,
100    ) -> Result<Self, PipelineError> {
101        validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
102        let program = build_program_sharded_slots_shared(workgroup_size_x, slot_count, &opcodes);
103        Self::compile_bootstrap_shared(backend, slot_count, workgroup_size_x, program)
104    }
105
106    /// JIT compiler bootstrap for fused payload processors.
107    ///
108    /// # Errors
109    ///
110    /// See [`Megakernel::bootstrap_sharded`].
111    pub fn bootstrap_jit(
112        backend: Arc<dyn VyreBackend>,
113        slot_count: u32,
114        workgroup_size_x: u32,
115        payload_processor: &[vyre_foundation::ir::Node],
116    ) -> Result<Self, PipelineError> {
117        validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
118        let program = build_program_jit_slots(workgroup_size_x, slot_count, payload_processor);
119        Self::compile_bootstrap(backend, slot_count, workgroup_size_x, program)
120    }
121
122    fn compile_bootstrap(
123        backend: Arc<dyn VyreBackend>,
124        slot_count: u32,
125        workgroup_size_x: u32,
126        program: Program,
127    ) -> Result<Self, PipelineError> {
128        Self::compile_bootstrap_shared(backend, slot_count, workgroup_size_x, Arc::new(program))
129    }
130
131    fn compile_bootstrap_shared(
132        backend: Arc<dyn VyreBackend>,
133        slot_count: u32,
134        workgroup_size_x: u32,
135        program: Arc<Program>,
136    ) -> Result<Self, PipelineError> {
137        validate_bootstrap_geometry(slot_count, workgroup_size_x)?;
138        let config = DispatchConfig::default();
139        let pipeline = vyre_driver::pipeline::compile_shared(
140            Arc::clone(&backend),
141            Arc::clone(&program),
142            &config,
143        )?;
144        let pipeline_id = pipeline.id().to_string();
145        let has_grid_sync = vyre_driver::grid_sync::contains_grid_sync(&program);
146        let empty_io_queue_bytes =
147            Arc::<[u8]>::from(io::try_encode_empty_io_queue(io::IO_SLOT_COUNT)?.into_boxed_slice());
148        Ok(Self {
149            backend,
150            pipeline: ArcSwap::from(Arc::new(PipelineSlot { inner: pipeline })),
151            pipeline_id,
152            program,
153            has_grid_sync,
154            empty_io_queue_bytes,
155            slot_count,
156            workgroup_size_x,
157            recovery_policy: MegakernelRecoveryPolicy::default(),
158        })
159    }
160
161    /// Dispatch a full storage buffer set with an empty IO queue.
162    ///
163    /// # Errors
164    ///
165    /// Returns [`PipelineError`] when protocol buffers are malformed, dispatch
166    /// fails, or device-loss recovery cannot rebuild the compiled pipeline.
167    pub fn dispatch(
168        &self,
169        control_bytes: Vec<u8>,
170        ring_bytes: Vec<u8>,
171        debug_log_bytes: Vec<u8>,
172    ) -> Result<Vec<Vec<u8>>, PipelineError> {
173        self.dispatch_borrowed(&control_bytes, &ring_bytes, &debug_log_bytes)
174    }
175
176    /// Dispatch a borrowed storage buffer set with an empty IO queue.
177    ///
178    /// # Errors
179    ///
180    /// Returns [`PipelineError`] when protocol buffers are malformed, dispatch
181    /// fails, or device-loss recovery cannot rebuild the compiled pipeline.
182    pub fn dispatch_borrowed(
183        &self,
184        control_bytes: &[u8],
185        ring_bytes: &[u8],
186        debug_log_bytes: &[u8],
187    ) -> Result<Vec<Vec<u8>>, PipelineError> {
188        Ok(self
189            .dispatch_borrowed_observed(control_bytes, ring_bytes, debug_log_bytes)?
190            .buffers)
191    }
192
193    /// Dispatch a full storage buffer set and return runtime instrumentation.
194    ///
195    /// # Errors
196    ///
197    /// See [`Megakernel::dispatch`].
198    pub fn dispatch_observed(
199        &self,
200        control_bytes: Vec<u8>,
201        ring_bytes: Vec<u8>,
202        debug_log_bytes: Vec<u8>,
203    ) -> Result<MegakernelDispatchOutput, PipelineError> {
204        self.dispatch_with_io_queue_borrowed_observed(
205            &control_bytes,
206            &ring_bytes,
207            &debug_log_bytes,
208            &self.empty_io_queue_bytes,
209        )
210    }
211
212    /// Dispatch borrowed buffers with an empty IO queue and return runtime
213    /// instrumentation.
214    ///
215    /// # Errors
216    ///
217    /// See [`Megakernel::dispatch_borrowed`].
218    pub fn dispatch_borrowed_observed(
219        &self,
220        control_bytes: &[u8],
221        ring_bytes: &[u8],
222        debug_log_bytes: &[u8],
223    ) -> Result<MegakernelDispatchOutput, PipelineError> {
224        self.dispatch_with_io_queue_borrowed_observed(
225            control_bytes,
226            ring_bytes,
227            debug_log_bytes,
228            &self.empty_io_queue_bytes,
229        )
230    }
231
232    /// Dispatch a full storage buffer set with a caller-supplied `io_queue`.
233    ///
234    /// # Errors
235    ///
236    /// Returns [`PipelineError`] when any protocol buffer is malformed, backend
237    /// dispatch fails, or device-loss recovery cannot rebuild the pipeline.
238    pub fn dispatch_with_io_queue(
239        &self,
240        control_bytes: Vec<u8>,
241        ring_bytes: Vec<u8>,
242        debug_log_bytes: Vec<u8>,
243        io_queue_bytes: Vec<u8>,
244    ) -> Result<Vec<Vec<u8>>, PipelineError> {
245        self.dispatch_with_io_queue_borrowed(
246            &control_bytes,
247            &ring_bytes,
248            &debug_log_bytes,
249            &io_queue_bytes,
250        )
251    }
252
253    /// Dispatch borrowed buffers with a caller-supplied `io_queue`.
254    ///
255    /// # Errors
256    ///
257    /// See [`Megakernel::dispatch_with_io_queue`].
258    pub fn dispatch_with_io_queue_borrowed(
259        &self,
260        control_bytes: &[u8],
261        ring_bytes: &[u8],
262        debug_log_bytes: &[u8],
263        io_queue_bytes: &[u8],
264    ) -> Result<Vec<Vec<u8>>, PipelineError> {
265        Ok(self
266            .dispatch_with_io_queue_borrowed_observed(
267                control_bytes,
268                ring_bytes,
269                debug_log_bytes,
270                io_queue_bytes,
271            )?
272            .buffers)
273    }
274
275    /// Dispatch with a caller-supplied `io_queue` and return instrumentation.
276    ///
277    /// # Errors
278    ///
279    /// See [`Megakernel::dispatch_with_io_queue`].
280    pub fn dispatch_with_io_queue_observed(
281        &self,
282        control_bytes: Vec<u8>,
283        ring_bytes: Vec<u8>,
284        debug_log_bytes: Vec<u8>,
285        io_queue_bytes: Vec<u8>,
286    ) -> Result<MegakernelDispatchOutput, PipelineError> {
287        self.dispatch_with_io_queue_borrowed_observed(
288            &control_bytes,
289            &ring_bytes,
290            &debug_log_bytes,
291            &io_queue_bytes,
292        )
293    }
294
295    /// Dispatch borrowed buffers with a caller-supplied `io_queue` and return
296    /// instrumentation.
297    ///
298    /// # Errors
299    ///
300    /// See [`Megakernel::dispatch_with_io_queue`].
301    pub fn dispatch_with_io_queue_borrowed_observed(
302        &self,
303        control_bytes: &[u8],
304        ring_bytes: &[u8],
305        debug_log_bytes: &[u8],
306        io_queue_bytes: &[u8],
307    ) -> Result<MegakernelDispatchOutput, PipelineError> {
308        let mut buffers = Vec::new();
309        reserve_output_shell(
310            &mut buffers,
311            MegakernelResidentHandles::ABI_RESOURCE_COUNT,
312            "borrowed megakernel output shell",
313        )?;
314        let stats = self.dispatch_with_io_queue_borrowed_into(
315            control_bytes,
316            ring_bytes,
317            debug_log_bytes,
318            io_queue_bytes,
319            &mut buffers,
320        )?;
321        Ok(MegakernelDispatchOutput { buffers, stats })
322    }
323
324    /// Dispatch borrowed buffers with a caller-supplied IO queue, writing
325    /// backend outputs into caller-owned storage.
326    ///
327    /// # Errors
328    ///
329    /// See [`Megakernel::dispatch_with_io_queue_borrowed`].
330    pub fn dispatch_with_io_queue_borrowed_into(
331        &self,
332        control_bytes: &[u8],
333        ring_bytes: &[u8],
334        debug_log_bytes: &[u8],
335        io_queue_bytes: &[u8],
336        outputs: &mut OutputBuffers,
337    ) -> Result<MegakernelDispatchStats, PipelineError> {
338        validate_control_bytes(control_bytes)?;
339        validate_debug_log_bytes(debug_log_bytes)?;
340        io::validate_io_queue_bytes(io_queue_bytes)?;
341        self.validate_ring_bytes(ring_bytes)?;
342
343        let input_bytes = total_len([control_bytes, ring_bytes, debug_log_bytes, io_queue_bytes])?;
344        let inputs = [control_bytes, ring_bytes, debug_log_bytes, io_queue_bytes];
345        let config = self.launch_geometry().dispatch_config(None);
346        let started = Instant::now();
347        let mut recovered = false;
348        match self.dispatch_once_into(&inputs, &config, outputs) {
349            Ok(()) => {}
350            Err(error) if self.recovery_policy.allows_retry(&error) => {
351                self.recover_after_device_loss()?;
352                recovered = true;
353                self.dispatch_once_into(&inputs, &config, outputs)?
354            }
355            Err(error) => return Err(error.into()),
356        }
357        let latency_ns = nanos_u64(started.elapsed().as_nanos())?;
358        let output_bytes = output_bytes(outputs)?;
359        let readback_bytes = output_bytes;
360        let bytes_moved = checked_add_u64(input_bytes, readback_bytes, "megakernel bytes moved")?;
361        let device_allocation_bytes = checked_add_u64(
362            input_bytes,
363            output_bytes,
364            "megakernel host-visible device allocation bytes",
365        )?;
366        let output_buffers = count_u32(outputs.len(), "megakernel output buffer count")?;
367        let device_allocation_events =
368            checked_add_u32(4, output_buffers, "megakernel allocation event count")?;
369        Ok(MegakernelDispatchStats {
370            input_bytes,
371            output_bytes,
372            readback_bytes,
373            bytes_moved,
374            device_allocation_bytes,
375            device_allocation_events,
376            latency_ns,
377            output_buffers,
378            resident_resource_rows: 0,
379            resident_resource_handles: 0,
380            kernel_launches: if recovered { 2 } else { 1 },
381            sync_points: 1,
382            recovered_after_device_loss: recovered,
383        })
384    }
385
386    /// Rebuild the compiled pipeline after device-loss symptoms.
387    ///
388    /// This does not mask the failure: if recompilation fails, the structured
389    /// backend error is returned with the original remediation text intact.
390    ///
391    /// # Errors
392    ///
393    /// Returns [`PipelineError::Backend`] when the backend cannot recompile.
394    pub fn recover_after_device_loss(&self) -> Result<MegakernelRecoveryDecision, PipelineError> {
395        let config = self.launch_geometry().dispatch_config(None);
396        let rebuilt = recover_compiled_pipeline(&self.backend, Arc::clone(&self.program), &config)?;
397        self.pipeline
398            .store(Arc::new(PipelineSlot { inner: rebuilt }));
399        Ok(MegakernelRecoveryDecision::RecompiledPipeline)
400    }
401
402    /// Pipeline id from the backend.
403    #[must_use]
404    pub fn pipeline_id(&self) -> &str {
405        &self.pipeline_id
406    }
407
408    /// Slot count this kernel was sharded for.
409    #[must_use]
410    pub const fn slot_count(&self) -> u32 {
411        self.slot_count
412    }
413
414    /// Workgroup size this kernel was compiled for.
415    #[must_use]
416    pub const fn workgroup_size_x(&self) -> u32 {
417        self.workgroup_size_x
418    }
419
420    /// Workgroup count needed to cover every ring slot.
421    #[must_use]
422    pub fn worker_groups(&self) -> u32 {
423        self.slot_count / self.workgroup_size_x
424    }
425
426    pub(super) fn validate_ring_bytes(&self, ring_bytes: &[u8]) -> Result<(), PipelineError> {
427        let expected_ring_bytes = protocol::ring_byte_len(self.slot_count).ok_or_else(|| {
428            PipelineError::Backend(
429                "megakernel ring byte length overflowed usize. Fix: split the ring into smaller dispatch shards."
430                    .to_string(),
431            )
432        })?;
433        if ring_bytes.len() != expected_ring_bytes {
434            return Err(PipelineError::Backend(format!(
435                "megakernel ring buffer has {} bytes, expected {expected_ring_bytes} for {} slots. Fix: build ring bytes with Megakernel::encode_empty_ring(slot_count) for this handle.",
436                ring_bytes.len(),
437                self.slot_count
438            )));
439        }
440        Ok(())
441    }
442
443    pub(super) fn launch_geometry(&self) -> MegakernelLaunchGeometry {
444        MegakernelLaunchGeometry {
445            workgroup_size_x: self.workgroup_size_x,
446            slot_count: self.slot_count,
447            dispatch_grid: [self.slot_count / self.workgroup_size_x, 1, 1],
448        }
449    }
450
451    fn dispatch_once_into(
452        &self,
453        inputs: &[&[u8]; 4],
454        config: &DispatchConfig,
455        outputs: &mut OutputBuffers,
456    ) -> Result<(), vyre_driver::BackendError> {
457        if self.has_grid_sync && !self.backend.supports_grid_sync() {
458            return vyre_driver::grid_sync::dispatch_with_grid_sync_split_into(
459                self.backend.as_ref(),
460                &self.program,
461                inputs,
462                config,
463                outputs,
464            );
465        }
466        let pipeline = self.pipeline.load();
467        pipeline
468            .inner
469            .dispatch_borrowed_into(inputs, config, outputs)
470    }
471
472    fn dispatch_persistent_handles_once_into(
473        &self,
474        inputs: &[Resource; 4],
475        config: &DispatchConfig,
476        outputs: &mut OutputBuffers,
477    ) -> Result<(), vyre_driver::BackendError> {
478        let pipeline = self.pipeline.load();
479        pipeline
480            .inner
481            .dispatch_persistent_handles_into(inputs, config, outputs)
482    }
483
484    fn dispatch_persistent_handle_rows_once_into(
485        &self,
486        rows: &[[Resource; 4]],
487        config: &DispatchConfig,
488        outputs: &mut Vec<OutputBuffers>,
489    ) -> Result<(), vyre_driver::BackendError> {
490        let pipeline = self.pipeline.load();
491        pipeline
492            .inner
493            .dispatch_persistent_handle_rows_into(rows, config, outputs)
494    }
495}
496
497impl MegakernelRecoveryPolicy {
498    fn allows_retry(self, error: &vyre_driver::BackendError) -> bool {
499        self.retry_device_loss_once && backend_error_indicates_device_loss(error)
500    }
501}
502
503fn validate_bootstrap_geometry(
504    slot_count: u32,
505    workgroup_size_x: u32,
506) -> Result<(), PipelineError> {
507    if slot_count == 0 || workgroup_size_x == 0 || slot_count % workgroup_size_x != 0 {
508        return Err(PipelineError::QueueFull {
509            queue: "submission",
510            fix: "slot_count must be a non-zero multiple of workgroup_size_x",
511        });
512    }
513    Ok(())
514}
515
516pub(super) fn total_len<const N: usize>(buffers: [&[u8]; N]) -> Result<u64, PipelineError> {
517    let mut total = 0u64;
518    for buffer in buffers {
519        total = checked_add_u64(
520            total,
521            usize_to_u64(buffer.len(), "megakernel input buffer length")?,
522            "megakernel input byte total",
523        )?;
524    }
525    Ok(total)
526}
527
528pub(super) fn output_bytes(outputs: &[Vec<u8>]) -> Result<u64, PipelineError> {
529    let mut total = 0u64;
530    for output in outputs {
531        total = checked_add_u64(
532            total,
533            usize_to_u64(output.len(), "megakernel output buffer length")?,
534            "megakernel output byte total",
535        )?;
536    }
537    Ok(total)
538}
539
540pub(super) fn nested_output_bytes(outputs: &[Vec<Vec<u8>>]) -> Result<u64, PipelineError> {
541    let mut total = 0u64;
542    for row in outputs {
543        total = checked_add_u64(
544            total,
545            output_bytes(row)?,
546            "megakernel nested output byte total",
547        )?;
548    }
549    Ok(total)
550}
551
552pub(super) fn output_count_u32(outputs: &[Vec<u8>]) -> Result<u32, PipelineError> {
553    count_u32(outputs.len(), "megakernel output buffer count")
554}
555
556pub(super) fn nested_output_count_u32(outputs: &[Vec<Vec<u8>>]) -> Result<u32, PipelineError> {
557    let mut total = 0usize;
558    for row in outputs {
559        total = total.checked_add(row.len()).ok_or_else(|| {
560            PipelineError::Backend(
561                "megakernel nested output buffer count overflowed usize. Fix: split resident rows before dispatch.".to_string(),
562            )
563        })?;
564    }
565    count_u32(total, "megakernel nested output buffer count")
566}
567
568pub(super) fn resident_row_count_u32(rows: usize) -> Result<u32, PipelineError> {
569    count_u32(rows, "megakernel resident resource row count")
570}
571
572pub(super) fn resident_handle_count_u32(rows: usize) -> Result<u32, PipelineError> {
573    let handles = rows
574        .checked_mul(MegakernelResidentHandles::ABI_RESOURCE_COUNT)
575        .ok_or_else(|| {
576            PipelineError::Backend(
577                "megakernel resident resource handle count overflowed usize. Fix: split resident rows before dispatch.".to_string(),
578            )
579        })?;
580    count_u32(handles, "megakernel resident resource handle count")
581}
582
583pub(super) fn reserve_output_shell<T>(
584    out: &mut Vec<T>,
585    capacity: usize,
586    label: &'static str,
587) -> Result<(), PipelineError> {
588    reserve_vec_capacity(out, capacity, label)
589}
590
591pub(super) fn nanos_u64(nanos: u128) -> Result<u64, PipelineError> {
592    u64::try_from(nanos).map_err(|source| {
593        PipelineError::Backend(format!(
594            "megakernel latency cannot fit u64 nanoseconds: {source}. Fix: timeout or split the dispatch before telemetry overflows."
595        ))
596    })
597}
598
599fn usize_to_u64(value: usize, label: &str) -> Result<u64, PipelineError> {
600    u64::try_from(value).map_err(|source| {
601        PipelineError::Backend(format!(
602            "{label} cannot fit u64: {source}. Fix: split the megakernel dispatch before telemetry/accounting."
603        ))
604    })
605}
606
607fn count_u32(value: usize, label: &str) -> Result<u32, PipelineError> {
608    u32::try_from(value).map_err(|source| {
609        PipelineError::Backend(format!(
610            "{label} cannot fit u32: {source}. Fix: split the megakernel dispatch before telemetry/accounting."
611        ))
612    })
613}
614
615fn checked_add_u64(left: u64, right: u64, label: &str) -> Result<u64, PipelineError> {
616    left.checked_add(right).ok_or_else(|| {
617        PipelineError::Backend(format!(
618            "{label} overflowed u64. Fix: split the megakernel dispatch before telemetry/accounting."
619        ))
620    })
621}
622
623fn checked_add_u32(left: u32, right: u32, label: &str) -> Result<u32, PipelineError> {
624    left.checked_add(right).ok_or_else(|| {
625        PipelineError::Backend(format!(
626            "{label} overflowed u32. Fix: split the megakernel dispatch before telemetry/accounting."
627        ))
628    })
629}
630
631#[cfg(test)]
632mod tests;