Skip to main content

vyre_driver/
megakernel_barrier.rs

1//! Backend-neutral megakernel barrier planning for dependency-typed waves.
2//!
3//! The planner is pure and deterministic: it converts a wave dependency DAG
4//! into the minimum number of global-synchronization layers implied by those
5//! dependencies. Waves inside one layer are independent and can be fused into
6//! one cooperative megakernel phase without inserting a host-side barrier.
7
8use crate::accounting::{checked_add_usize_count, ArithmeticOverflow};
9use crate::reservation_policy::{
10    reserve_typed_vec_to_capacity as reserve_vec_to_capacity, ReservationPolicy,
11};
12
13const MEGAKERNEL_BARRIER_RESERVATION: ReservationPolicy = ReservationPolicy::new(
14    "megakernel barrier planner",
15    "shard the dependency graph before barrier planning",
16);
17
18/// Directed dependency between two megakernel dataflow waves.
19#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub struct MegakernelWaveDependency {
21    /// Wave that must complete first.
22    pub before: usize,
23    /// Wave that can run after `before`.
24    pub after: usize,
25}
26
27/// One barrier-free group of independent megakernel waves.
28#[derive(Clone, Debug, Eq, PartialEq)]
29pub struct MegakernelBarrierGroup {
30    /// Wave indices that can run before the next global synchronization point.
31    pub waves: Vec<usize>,
32}
33
34/// Barrier plan for megakernel execution.
35#[derive(Clone, Debug, Eq, PartialEq)]
36pub struct MegakernelBarrierPlan {
37    /// Ordered barrier-free wave groups.
38    pub groups: Vec<MegakernelBarrierGroup>,
39    /// Number of global synchronization points required between groups.
40    pub global_barriers: usize,
41}
42
43/// Caller-owned scratch for repeated megakernel barrier planning.
44///
45/// This keeps CSR adjacency, indegree, and ready-layer buffers reusable across
46/// frontier-planning calls. Returned barrier groups still own their wave lists;
47/// the scratch removes the temporary O(waves + dependencies) planning
48/// allocations from steady-state callers.
49#[derive(Debug, Default)]
50pub struct MegakernelBarrierScratch {
51    outgoing_counts: Vec<usize>,
52    indegree: Vec<usize>,
53    outgoing_offsets: Vec<usize>,
54    outgoing_targets: Vec<usize>,
55    ready: Vec<usize>,
56    next_ready: Vec<usize>,
57}
58
59impl MegakernelBarrierScratch {
60    /// Allocate reusable scratch for a known megakernel dependency shape,
61    /// returning a typed planner error when the shape cannot be represented.
62    ///
63    /// # Errors
64    ///
65    /// Returns [`MegakernelBarrierPlanError`] when the scratch capacity cannot
66    /// be represented or reserved.
67    pub fn try_with_capacity(
68        wave_count: usize,
69        dependency_count: usize,
70    ) -> Result<Self, MegakernelBarrierPlanError> {
71        let mut scratch = Self::default();
72        scratch.try_reserve_shape(wave_count, dependency_count)?;
73        Ok(scratch)
74    }
75
76    fn try_reserve_shape(
77        &mut self,
78        wave_count: usize,
79        dependency_count: usize,
80    ) -> Result<(), MegakernelBarrierPlanError> {
81        let offset_capacity =
82            wave_count
83                .checked_add(1)
84                .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
85                    field: "barrier scratch wave offsets",
86                })?;
87        reserve_vec(&mut self.outgoing_counts, wave_count, "outgoing counts")?;
88        reserve_vec(&mut self.indegree, wave_count, "indegree")?;
89        reserve_vec(
90            &mut self.outgoing_offsets,
91            offset_capacity,
92            "outgoing offsets",
93        )?;
94        reserve_vec(
95            &mut self.outgoing_targets,
96            dependency_count,
97            "outgoing targets",
98        )?;
99        reserve_vec(&mut self.ready, wave_count, "ready wave layer")?;
100        reserve_vec(&mut self.next_ready, wave_count, "next ready wave layer")?;
101        Ok(())
102    }
103
104    /// Retained wave-index capacity across CSR planning buffers.
105    #[must_use]
106    pub fn wave_capacity(&self) -> usize {
107        let offset_wave_capacity = if self.outgoing_offsets.capacity() == 0 {
108            0
109        } else {
110            self.outgoing_offsets.capacity() - 1
111        };
112        self.outgoing_counts
113            .capacity()
114            .min(self.indegree.capacity())
115            .min(offset_wave_capacity)
116    }
117
118    /// Retained dependency-edge capacity for CSR adjacency targets.
119    #[must_use]
120    pub fn dependency_capacity(&self) -> usize {
121        self.outgoing_targets.capacity()
122    }
123}
124
125/// Barrier planning failure.
126#[derive(Clone, Debug, Eq, PartialEq)]
127pub enum MegakernelBarrierPlanError {
128    /// A dependency references a wave outside `0..wave_count`.
129    InvalidWave {
130        /// Declared number of waves.
131        wave_count: usize,
132        /// Invalid `before` endpoint.
133        before: usize,
134        /// Invalid `after` endpoint.
135        after: usize,
136    },
137    /// A wave was declared to depend on itself.
138    SelfDependency {
139        /// Self-dependent wave index.
140        wave: usize,
141    },
142    /// The dependency graph contains a cycle and cannot be scheduled.
143    Cycle {
144        /// Number of waves that could not be scheduled.
145        unscheduled_waves: usize,
146    },
147    /// Dependency CSR arithmetic overflowed.
148    ByteCountOverflow {
149        /// Field being computed.
150        field: &'static str,
151    },
152    /// Planner scratch/result storage could not be reserved.
153    StorageReserveFailed {
154        /// Field being reserved.
155        field: &'static str,
156        /// Number of elements requested.
157        requested: usize,
158        /// Allocator error text.
159        message: String,
160    },
161}
162
163impl ArithmeticOverflow for MegakernelBarrierPlanError {
164    fn arithmetic_overflow(field: &'static str) -> Self {
165        Self::ByteCountOverflow { field }
166    }
167}
168
169impl std::fmt::Display for MegakernelBarrierPlanError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        match self {
172            Self::InvalidWave {
173                wave_count,
174                before,
175                after,
176            } => write!(
177                f,
178                "megakernel dependency references invalid wave before={before} after={after} for wave_count={wave_count}. Fix: emit dependencies only over normalized wave indices."
179            ),
180            Self::SelfDependency { wave } => write!(
181                f,
182                "megakernel wave {wave} depends on itself. Fix: remove the self-edge or split the wave into distinct producer/consumer phases."
183            ),
184            Self::Cycle { unscheduled_waves } => write!(
185                f,
186                "megakernel wave dependency graph contains a cycle with {unscheduled_waves} unscheduled waves. Fix: break the cyclic dataflow edge or insert an explicit iterative fixed-point kernel."
187            ),
188            Self::ByteCountOverflow { field } => write!(
189                f,
190                "megakernel barrier planner overflowed while computing {field}. Fix: shard the dependency graph before barrier planning."
191            ),
192            Self::StorageReserveFailed {
193                field,
194                requested,
195                message,
196            } => write!(
197                f,
198                "megakernel barrier planner could not reserve {requested} {field} entries: {message}. Fix: shard the dependency graph before barrier planning."
199            ),
200        }
201    }
202}
203
204impl std::error::Error for MegakernelBarrierPlanError {}
205
206/// Plan minimum global barriers for a megakernel wave dependency DAG.
207///
208/// The returned groups are Kahn topological layers. That is the minimum number
209/// of dependency-implied execution rounds for a DAG when every ready wave may
210/// execute in the same cooperative phase.
211///
212/// # Errors
213///
214/// Returns [`MegakernelBarrierPlanError`] when dependencies are invalid,
215/// cyclic, overflow counters, or cannot reserve planner storage.
216pub fn plan_megakernel_barriers(
217    wave_count: usize,
218    dependencies: &[MegakernelWaveDependency],
219) -> Result<MegakernelBarrierPlan, MegakernelBarrierPlanError> {
220    let mut scratch = MegakernelBarrierScratch::try_with_capacity(wave_count, dependencies.len())?;
221    plan_megakernel_barriers_with_scratch(wave_count, dependencies, &mut scratch)
222}
223
224/// Plan minimum global barriers using caller-owned temporary storage.
225///
226/// # Errors
227///
228/// Returns [`MegakernelBarrierPlanError`] when dependencies are invalid,
229/// cyclic, overflow counters, or cannot reserve planner storage.
230pub fn plan_megakernel_barriers_with_scratch(
231    wave_count: usize,
232    dependencies: &[MegakernelWaveDependency],
233    scratch: &mut MegakernelBarrierScratch,
234) -> Result<MegakernelBarrierPlan, MegakernelBarrierPlanError> {
235    scratch.try_reserve_shape(wave_count, dependencies.len())?;
236    if wave_count == 0 {
237        if !dependencies.is_empty() {
238            return Err(MegakernelBarrierPlanError::InvalidWave {
239                wave_count,
240                before: dependencies[0].before,
241                after: dependencies[0].after,
242            });
243        }
244        return Ok(MegakernelBarrierPlan {
245            global_barriers: 0,
246            groups: Vec::new(),
247        });
248    }
249    if dependencies.is_empty() {
250        let mut waves = Vec::new();
251        reserve_vec(&mut waves, wave_count, "independent wave group")?;
252        for wave in 0..wave_count {
253            waves.push(wave);
254        }
255        let mut groups = Vec::new();
256        reserve_vec(&mut groups, 1, "barrier groups")?;
257        groups.push(MegakernelBarrierGroup { waves });
258        return Ok(MegakernelBarrierPlan {
259            global_barriers: 0,
260            groups,
261        });
262    }
263
264    fill_barrier_vec_zeroed(&mut scratch.outgoing_counts, wave_count, "outgoing counts")?;
265    fill_barrier_vec_zeroed(&mut scratch.indegree, wave_count, "indegree")?;
266    for dependency in dependencies {
267        if dependency.before >= wave_count || dependency.after >= wave_count {
268            return Err(MegakernelBarrierPlanError::InvalidWave {
269                wave_count,
270                before: dependency.before,
271                after: dependency.after,
272            });
273        }
274        if dependency.before == dependency.after {
275            return Err(MegakernelBarrierPlanError::SelfDependency {
276                wave: dependency.before,
277            });
278        }
279        scratch.outgoing_counts[dependency.before] = scratch.outgoing_counts[dependency.before]
280            .checked_add(1)
281            .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
282                field: "outgoing dependency count",
283            })?;
284        scratch.indegree[dependency.after] = scratch.indegree[dependency.after]
285            .checked_add(1)
286            .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
287                field: "incoming dependency count",
288            })?;
289    }
290
291    scratch.outgoing_offsets.clear();
292    scratch.outgoing_offsets.push(0usize);
293    for count in &scratch.outgoing_counts {
294        let next = scratch
295            .outgoing_offsets
296            .last()
297            .copied()
298            .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
299                field: "outgoing offset seed",
300            })?
301            .checked_add(*count)
302            .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
303                field: "outgoing dependency offsets",
304            })?;
305        scratch.outgoing_offsets.push(next);
306    }
307    fill_barrier_vec_zeroed(
308        &mut scratch.outgoing_targets,
309        dependencies.len(),
310        "outgoing targets",
311    )?;
312    scratch
313        .outgoing_counts
314        .copy_from_slice(&scratch.outgoing_offsets[..wave_count]);
315    for dependency in dependencies {
316        let offset = scratch.outgoing_counts[dependency.before];
317        scratch.outgoing_targets[offset] = dependency.after;
318        scratch.outgoing_counts[dependency.before] =
319            offset
320                .checked_add(1)
321                .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
322                    field: "outgoing target cursor",
323                })?;
324    }
325
326    scratch.ready.clear();
327    for (wave, degree) in scratch.indegree.iter().copied().enumerate() {
328        if degree == 0 {
329            scratch.ready.push(wave);
330        }
331    }
332
333    let mut scheduled = 0usize;
334    let mut groups = Vec::new();
335    reserve_vec(
336        &mut groups,
337        group_capacity_hint(wave_count, dependencies.len())?,
338        "barrier groups",
339    )?;
340    scratch.next_ready.clear();
341    while !scratch.ready.is_empty() {
342        scratch.next_ready.clear();
343        for &wave in &scratch.ready {
344            for &next in &scratch.outgoing_targets
345                [scratch.outgoing_offsets[wave]..scratch.outgoing_offsets[wave + 1]]
346            {
347                scratch.indegree[next] -= 1;
348                if scratch.indegree[next] == 0 {
349                    scratch.next_ready.push(next);
350                }
351            }
352        }
353        scheduled += scratch.ready.len();
354        groups.push(MegakernelBarrierGroup {
355            waves: std::mem::take(&mut scratch.ready),
356        });
357        std::mem::swap(&mut scratch.ready, &mut scratch.next_ready);
358    }
359
360    if scheduled != wave_count {
361        return Err(MegakernelBarrierPlanError::Cycle {
362            unscheduled_waves: wave_count - scheduled,
363        });
364    }
365
366    Ok(MegakernelBarrierPlan {
367        global_barriers: if groups.is_empty() {
368            0
369        } else {
370            groups.len() - 1
371        },
372        groups,
373    })
374}
375
376fn group_capacity_hint(
377    wave_count: usize,
378    dependency_count: usize,
379) -> Result<usize, MegakernelBarrierPlanError> {
380    if wave_count == 0 {
381        Ok(0)
382    } else {
383        let dependency_layer_cap = checked_add_usize_count::<MegakernelBarrierPlanError>(
384            dependency_count,
385            1,
386            "barrier group capacity hint",
387        )?;
388        Ok(wave_count.min(dependency_layer_cap))
389    }
390}
391
392fn fill_barrier_vec_zeroed(
393    vec: &mut Vec<usize>,
394    len: usize,
395    field: &'static str,
396) -> Result<(), MegakernelBarrierPlanError> {
397    vec.clear();
398    reserve_vec(vec, len, field)?;
399    vec.extend((0..len).map(|_| 0));
400    Ok(())
401}
402
403fn reserve_vec<T>(
404    vec: &mut Vec<T>,
405    target_capacity: usize,
406    item: &'static str,
407) -> Result<(), MegakernelBarrierPlanError> {
408    reserve_vec_to_capacity(
409        MEGAKERNEL_BARRIER_RESERVATION,
410        vec,
411        target_capacity,
412        item,
413        storage_reserve_failed,
414    )
415}
416
417fn storage_reserve_failed(
418    field: &'static str,
419    requested: usize,
420    message: String,
421) -> MegakernelBarrierPlanError {
422    MegakernelBarrierPlanError::StorageReserveFailed {
423        field,
424        requested,
425        message,
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::{
432        plan_megakernel_barriers, plan_megakernel_barriers_with_scratch,
433        MegakernelBarrierPlanError, MegakernelBarrierScratch, MegakernelWaveDependency,
434    };
435
436    #[test]
437    fn independent_waves_share_one_barrier_free_group() {
438        let plan = plan_megakernel_barriers(4, &[])
439            .expect("Fix: independent megakernel waves should not need barriers.");
440
441        assert_eq!(plan.global_barriers, 0);
442        assert_eq!(plan.groups.len(), 1);
443        assert_eq!(plan.groups[0].waves, vec![0, 1, 2, 3]);
444    }
445
446    #[test]
447    fn dependency_chain_requires_one_barrier_between_each_wave() {
448        let plan = plan_megakernel_barriers(
449            4,
450            &[
451                MegakernelWaveDependency {
452                    before: 0,
453                    after: 1,
454                },
455                MegakernelWaveDependency {
456                    before: 1,
457                    after: 2,
458                },
459                MegakernelWaveDependency {
460                    before: 2,
461                    after: 3,
462                },
463            ],
464        )
465        .expect("Fix: acyclic megakernel wave chain should be schedulable.");
466
467        assert_eq!(plan.global_barriers, 3);
468        assert_eq!(plan.groups[0].waves, vec![0]);
469        assert_eq!(plan.groups[1].waves, vec![1]);
470        assert_eq!(plan.groups[2].waves, vec![2]);
471        assert_eq!(plan.groups[3].waves, vec![3]);
472    }
473
474    #[test]
475    fn diamond_dependencies_fuse_middle_waves() {
476        let plan = plan_megakernel_barriers(
477            4,
478            &[
479                MegakernelWaveDependency {
480                    before: 0,
481                    after: 1,
482                },
483                MegakernelWaveDependency {
484                    before: 0,
485                    after: 2,
486                },
487                MegakernelWaveDependency {
488                    before: 1,
489                    after: 3,
490                },
491                MegakernelWaveDependency {
492                    before: 2,
493                    after: 3,
494                },
495            ],
496        )
497        .expect("Fix: diamond megakernel dependencies should preserve middle-wave fusion.");
498
499        assert_eq!(plan.global_barriers, 2);
500        assert_eq!(plan.groups[0].waves, vec![0]);
501        assert_eq!(plan.groups[1].waves, vec![1, 2]);
502        assert_eq!(plan.groups[2].waves, vec![3]);
503    }
504
505    #[test]
506    fn invalid_self_and_cyclic_dependencies_fail_loudly() {
507        let invalid = plan_megakernel_barriers(
508            2,
509            &[MegakernelWaveDependency {
510                before: 0,
511                after: 2,
512            }],
513        )
514        .expect_err("Fix: invalid megakernel wave index must fail before planning.");
515        assert!(matches!(
516            invalid,
517            MegakernelBarrierPlanError::InvalidWave { .. }
518        ));
519
520        let self_edge = plan_megakernel_barriers(
521            2,
522            &[MegakernelWaveDependency {
523                before: 1,
524                after: 1,
525            }],
526        )
527        .expect_err("Fix: self-dependent megakernel waves must fail before planning.");
528        assert_eq!(
529            self_edge,
530            MegakernelBarrierPlanError::SelfDependency { wave: 1 }
531        );
532
533        let cycle = plan_megakernel_barriers(
534            2,
535            &[
536                MegakernelWaveDependency {
537                    before: 0,
538                    after: 1,
539                },
540                MegakernelWaveDependency {
541                    before: 1,
542                    after: 0,
543                },
544            ],
545        )
546        .expect_err("Fix: cyclic megakernel dependencies require explicit fixed-point kernels.");
547        assert_eq!(
548            cycle,
549            MegakernelBarrierPlanError::Cycle {
550                unscheduled_waves: 2
551            }
552        );
553    }
554
555    #[test]
556    fn barrier_planner_uses_csr_adjacency_for_wide_wave_graphs() {
557        let dependencies = (1..1_025)
558            .map(|after| MegakernelWaveDependency { before: 0, after })
559            .collect::<Vec<_>>();
560        let plan = plan_megakernel_barriers(1_025, &dependencies)
561            .expect("Fix: wide megakernel dependency fanout must schedule without per-wave adjacency allocation.");
562
563        assert_eq!(plan.global_barriers, 1);
564        assert_eq!(plan.groups[0].waves, vec![0]);
565        assert_eq!(plan.groups[1].waves.len(), 1_024);
566
567        let src = include_str!("megakernel_barrier.rs");
568        assert!(
569            !src.contains(concat!("vec![", "Vec::new(); wave_count]")),
570            "Fix: megakernel barrier planner must use contiguous CSR adjacency instead of allocating one Vec per wave."
571        );
572        assert!(
573            !src.contains(concat!("outgoing_offsets[..wave_count]", ".to_vec()")),
574            "Fix: megakernel barrier planner must reuse the counts buffer as the CSR write cursor instead of allocating an O(wave_count) cursor Vec."
575        );
576        assert!(
577            !src.contains(concat!("Vec", "Deque")),
578            "Fix: megakernel barrier planner should use contiguous current/next ready vectors, not deque queue mechanics, for wide wave layers."
579        );
580        assert!(
581            !src.contains(concat!("saturating", "_add")),
582            "Fix: megakernel barrier dependency accounting is bounded by the validated graph shape and must not hide invariant violations with saturating arithmetic."
583        );
584        assert!(
585            src.contains("field: \"outgoing dependency count\"")
586                && src.contains("field: \"incoming dependency count\"")
587                && src.contains("field: \"outgoing dependency offsets\"")
588                && src.contains("field: \"outgoing target cursor\""),
589            "Fix: megakernel barrier CSR construction must use checked arithmetic for dependency counters, offsets, and cursors."
590        );
591        assert!(
592            src.contains("reserve_typed_vec_to_capacity as reserve_vec_to_capacity")
593                && src.contains("fn fill_barrier_vec_zeroed(")
594                && src.contains("StorageReserveFailed"),
595            "Fix: megakernel barrier staging must reserve through shared fallible driver staging instead of panicking under scale pressure."
596        );
597        assert!(
598            !src.contains(concat!("Vec::with_capacity", "(wave_count)"))
599                && !src.contains(concat!(".reserve", "(wave_count)"))
600                && !src.contains(concat!("scratch.outgoing_counts", ".resize"))
601                && !src.contains(concat!("scratch.indegree", ".resize"))
602                && !src.contains(concat!("scratch.outgoing_targets", ".resize")),
603            "Fix: megakernel barrier planner must not use infallible capacity growth in release topology planning."
604        );
605        assert!(
606            !src.contains(concat!(
607                "scratch.outgoing_counts[dependency.before]",
608                " += 1"
609            ))
610                && !src.contains(concat!("scratch.indegree[dependency.after]", " += 1"))
611                && !src.contains(concat!(
612                    "let next = scratch.outgoing_offsets.last().copied().unwrap_or(0)",
613                    " + *count"
614                )),
615            "Fix: megakernel barrier planning must not use unchecked usize arithmetic for CSR construction."
616        );
617    }
618
619    #[test]
620    fn barrier_planner_reuses_caller_owned_csr_scratch_across_shapes() {
621        let mut scratch = MegakernelBarrierScratch::try_with_capacity(1_025, 1_024)
622            .expect("Fix: wide reusable megakernel barrier scratch should fit");
623        let wide_dependencies = (1..1_025)
624            .map(|after| MegakernelWaveDependency { before: 0, after })
625            .collect::<Vec<_>>();
626        let wide = plan_megakernel_barriers_with_scratch(1_025, &wide_dependencies, &mut scratch)
627            .expect("Fix: wide megakernel dependency fanout should plan with reusable scratch");
628        let wave_capacity = scratch.wave_capacity();
629        let dependency_capacity = scratch.dependency_capacity();
630
631        assert_eq!(wide.groups[1].waves.len(), 1_024);
632
633        let narrow = plan_megakernel_barriers_with_scratch(
634            4,
635            &[
636                MegakernelWaveDependency {
637                    before: 0,
638                    after: 1,
639                },
640                MegakernelWaveDependency {
641                    before: 1,
642                    after: 2,
643                },
644                MegakernelWaveDependency {
645                    before: 2,
646                    after: 3,
647                },
648            ],
649            &mut scratch,
650        )
651        .expect("Fix: narrow megakernel dependency chain should reuse larger scratch");
652
653        assert_eq!(narrow.global_barriers, 3);
654        assert!(scratch.wave_capacity() >= wave_capacity);
655        assert!(scratch.dependency_capacity() >= dependency_capacity);
656    }
657
658    #[test]
659    fn generated_layered_dags_preserve_exact_barrier_depth_for_2048_shapes() {
660        let mut scratch = MegakernelBarrierScratch::default();
661        for width in 1usize..=64 {
662            for depth in 1usize..=32 {
663                let wave_count = width * depth;
664                let mut dependencies = Vec::new();
665                for layer in 0..depth.saturating_sub(1) {
666                    let base = layer * width;
667                    let next = base + width;
668                    for slot in 0..width {
669                        dependencies.push(MegakernelWaveDependency {
670                            before: base + slot,
671                            after: next + slot,
672                        });
673                    }
674                }
675
676                let plan =
677                    plan_megakernel_barriers_with_scratch(wave_count, &dependencies, &mut scratch)
678                        .expect("Fix: generated layered megakernel DAG should be schedulable");
679
680                assert_eq!(plan.groups.len(), depth);
681                assert_eq!(plan.global_barriers, depth - 1);
682                for group in &plan.groups {
683                    assert_eq!(group.waves.len(), width);
684                }
685            }
686        }
687    }
688}