Skip to main content

vyre_driver/
megakernel_frontier.rs

1//! Backend-neutral frontier memory planning for dependency-aware megakernels.
2//!
3//! Backends can choose different execution topologies, but the memory envelope
4//! of dependency-layered frontier waves is a backend-neutral contract. This
5//! module plans that envelope once, including dependency barriers, fused-group
6//! splitting under an explicit byte budget, peak byte accounting, and readback
7//! pressure amortization.
8
9use crate::accounting::{
10    checked_add_u64_count as checked_add, checked_mul_u64_count as checked_mul,
11};
12use crate::megakernel_barrier::{
13    plan_megakernel_barriers_with_scratch, MegakernelBarrierGroup, MegakernelBarrierPlan,
14    MegakernelBarrierPlanError, MegakernelBarrierScratch, MegakernelWaveDependency,
15};
16use crate::reservation_policy::{
17    reserve_typed_vec_to_capacity as reserve_vec_to_capacity, ReservationPolicy,
18};
19
20const MEGAKERNEL_FRONTIER_RESERVATION: ReservationPolicy = ReservationPolicy::new(
21    "megakernel frontier memory planner",
22    "shard the frontier wave group or split the fused phase",
23);
24
25/// Frontier-typed megakernel wave memory envelope.
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27pub struct MegakernelFrontierWave {
28    /// Resident frontier bytes touched by this wave.
29    pub frontier_bytes: u64,
30    /// Temporary scratch bytes required by this wave before topology scaling.
31    pub scratch_bytes: u64,
32    /// Output bytes produced by this wave.
33    pub output_bytes: u64,
34}
35
36/// Dependency-aware megakernel frontier memory plan.
37#[derive(Clone, Debug, Eq, PartialEq)]
38pub struct MegakernelFrontierMemoryPlan {
39    /// Minimum global-barrier grouping after memory-budget splitting.
40    pub barriers: MegakernelBarrierPlan,
41    /// Peak frontier bytes across any fused barrier-free group.
42    pub peak_frontier_bytes: u64,
43    /// Peak scratch bytes across any fused barrier-free group.
44    pub peak_scratch_bytes: u64,
45    /// Peak output bytes across any fused barrier-free group.
46    pub peak_output_bytes: u64,
47    /// Readback pressure after combining runtime telemetry with static
48    /// fused-wave output volume.
49    pub amortized_readback_bytes: u64,
50    /// Widest barrier-free group in wave count.
51    pub max_group_width: usize,
52}
53
54/// Frontier memory planning failure.
55#[derive(Clone, Debug, Eq, PartialEq)]
56pub enum MegakernelFrontierMemoryPlanError {
57    /// Dependency graph cannot be barrier-planned.
58    Barrier(MegakernelBarrierPlanError),
59    /// Peak wave bytes overflowed while grouping a barrier-free phase.
60    ByteCountOverflow {
61        /// Field being accumulated.
62        field: &'static str,
63    },
64    /// Static graph or fused frontier bytes exceed the caller-approved budget.
65    GroupOverBudget {
66        /// Required bytes before topology selection.
67        required_bytes: u64,
68        /// Caller-provided budget.
69        budget_bytes: u64,
70        /// Budget region being checked.
71        field: &'static str,
72    },
73    /// Frontier planning result storage could not be reserved.
74    StorageReserveFailed {
75        /// Field being reserved.
76        field: &'static str,
77        /// Number of elements requested.
78        requested: usize,
79        /// Allocator error text.
80        message: String,
81    },
82}
83
84impl crate::accounting::ArithmeticOverflow for MegakernelFrontierMemoryPlanError {
85    fn arithmetic_overflow(field: &'static str) -> Self {
86        Self::ByteCountOverflow { field }
87    }
88}
89
90impl std::fmt::Display for MegakernelFrontierMemoryPlanError {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        match self {
93            Self::Barrier(error) => error.fmt(f),
94            Self::ByteCountOverflow { field } => write!(
95                f,
96                "megakernel frontier memory planner overflowed while accumulating {field}. Fix: shard the frontier wave group or split the fused phase."
97            ),
98            Self::GroupOverBudget {
99                required_bytes,
100                budget_bytes,
101                field,
102            } => write!(
103                f,
104                "megakernel frontier memory planner requires {required_bytes} bytes for {field} but budget allows {budget_bytes}. Fix: shard the graph/frontier waves or raise the explicit megakernel budget."
105            ),
106            Self::StorageReserveFailed {
107                field,
108                requested,
109                message,
110            } => write!(
111                f,
112                "megakernel frontier memory planner could not reserve {requested} {field} entries: {message}. Fix: shard the frontier waves before planning."
113            ),
114        }
115    }
116}
117
118impl std::error::Error for MegakernelFrontierMemoryPlanError {}
119
120impl From<MegakernelBarrierPlanError> for MegakernelFrontierMemoryPlanError {
121    fn from(error: MegakernelBarrierPlanError) -> Self {
122        Self::Barrier(error)
123    }
124}
125
126/// Plan dependency-aware frontier memory using caller-owned barrier scratch.
127///
128/// # Errors
129///
130/// Returns [`MegakernelFrontierMemoryPlanError`] when dependencies are invalid,
131/// counters overflow, or the requested graph/frontier envelope cannot fit the
132/// explicit budget.
133pub fn plan_megakernel_frontier_memory_with_scratch(
134    waves: &[MegakernelFrontierWave],
135    dependencies: &[MegakernelWaveDependency],
136    resident_graph_bytes: u64,
137    budget_bytes: u64,
138    readback_bytes: u64,
139    scratch: &mut MegakernelBarrierScratch,
140) -> Result<MegakernelFrontierMemoryPlan, MegakernelFrontierMemoryPlanError> {
141    let barriers = plan_megakernel_barriers_with_scratch(waves.len(), dependencies, scratch)?;
142    let group_budget_bytes = budget_bytes.checked_sub(resident_graph_bytes).ok_or(
143        MegakernelFrontierMemoryPlanError::GroupOverBudget {
144            required_bytes: resident_graph_bytes,
145            budget_bytes,
146            field: "resident graph bytes",
147        },
148    )?;
149    let barriers = split_barrier_groups_to_memory_budget(barriers, waves, group_budget_bytes)?;
150    let mut peak_frontier_bytes = 0u64;
151    let mut peak_scratch_bytes = 0u64;
152    let mut peak_output_bytes = 0u64;
153    let mut max_group_width = 0usize;
154    for group in &barriers.groups {
155        let mut group_frontier_bytes = 0u64;
156        let mut group_scratch_bytes = 0u64;
157        let mut group_output_bytes = 0u64;
158        max_group_width = max_group_width.max(group.waves.len());
159        for &wave_index in &group.waves {
160            let wave = waves[wave_index];
161            group_frontier_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
162                group_frontier_bytes,
163                wave.frontier_bytes,
164                "frontier wave bytes",
165            )?;
166            group_scratch_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
167                group_scratch_bytes,
168                wave.scratch_bytes,
169                "scratch wave bytes",
170            )?;
171            group_output_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
172                group_output_bytes,
173                wave.output_bytes,
174                "output wave bytes",
175            )?;
176        }
177        peak_frontier_bytes = peak_frontier_bytes.max(group_frontier_bytes);
178        peak_scratch_bytes = peak_scratch_bytes.max(group_scratch_bytes);
179        peak_output_bytes = peak_output_bytes.max(group_output_bytes);
180    }
181
182    Ok(MegakernelFrontierMemoryPlan {
183        barriers,
184        peak_frontier_bytes,
185        peak_scratch_bytes,
186        peak_output_bytes,
187        amortized_readback_bytes: readback_bytes.max(peak_output_bytes),
188        max_group_width,
189    })
190}
191
192fn split_barrier_groups_to_memory_budget(
193    barriers: MegakernelBarrierPlan,
194    waves: &[MegakernelFrontierWave],
195    group_budget_bytes: u64,
196) -> Result<MegakernelBarrierPlan, MegakernelFrontierMemoryPlanError> {
197    let mut groups = Vec::new();
198    reserve_vec::<MegakernelBarrierGroup>(
199        &mut groups,
200        barriers.groups.len(),
201        "split barrier groups",
202    )?;
203    for group in barriers.groups {
204        split_one_barrier_group_to_memory_budget(group, waves, group_budget_bytes, &mut groups)?;
205    }
206    Ok(MegakernelBarrierPlan {
207        global_barriers: if groups.is_empty() {
208            0
209        } else {
210            groups.len() - 1
211        },
212        groups,
213    })
214}
215
216fn split_one_barrier_group_to_memory_budget(
217    group: MegakernelBarrierGroup,
218    waves: &[MegakernelFrontierWave],
219    group_budget_bytes: u64,
220    groups: &mut Vec<MegakernelBarrierGroup>,
221) -> Result<(), MegakernelFrontierMemoryPlanError> {
222    let mut current = Vec::new();
223    reserve_vec::<usize>(
224        &mut current,
225        group.waves.len().min(8),
226        "current split barrier group",
227    )?;
228    let mut current_bytes = 0u64;
229    for wave_index in group.waves {
230        let wave_bytes = megakernel_frontier_fused_wave_budget_bytes(waves[wave_index])?;
231        let combined = checked_add::<MegakernelFrontierMemoryPlanError>(
232            current_bytes,
233            wave_bytes,
234            "barrier group fused wave budget bytes",
235        )?;
236        if current.is_empty() && wave_bytes > group_budget_bytes {
237            return Err(MegakernelFrontierMemoryPlanError::GroupOverBudget {
238                required_bytes: wave_bytes,
239                budget_bytes: group_budget_bytes,
240                field: "single fused frontier wave bytes",
241            });
242        }
243        if !current.is_empty() && combined > group_budget_bytes {
244            groups.push(MegakernelBarrierGroup {
245                waves: std::mem::take(&mut current),
246            });
247            current_bytes = 0;
248        }
249        current.push(wave_index);
250        current_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
251            current_bytes,
252            wave_bytes,
253            "barrier group fused wave budget bytes",
254        )?;
255    }
256    if !current.is_empty() {
257        groups.push(MegakernelBarrierGroup { waves: current });
258    }
259    Ok(())
260}
261
262/// Compute the byte budget used to decide whether one frontier wave can fit in
263/// a fused barrier-free resident group.
264pub fn megakernel_frontier_fused_wave_budget_bytes(
265    wave: MegakernelFrontierWave,
266) -> Result<u64, MegakernelFrontierMemoryPlanError> {
267    let fused_scratch_bytes = checked_mul::<MegakernelFrontierMemoryPlanError>(
268        wave.scratch_bytes,
269        4,
270        "fused wave scratch bytes",
271    )?;
272    let bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
273        wave.frontier_bytes,
274        fused_scratch_bytes,
275        "fused wave bytes",
276    )?;
277    checked_add::<MegakernelFrontierMemoryPlanError>(bytes, wave.output_bytes, "fused wave bytes")
278}
279
280fn reserve_vec<T>(
281    vec: &mut Vec<T>,
282    target_capacity: usize,
283    item: &'static str,
284) -> Result<(), MegakernelFrontierMemoryPlanError> {
285    reserve_vec_to_capacity(
286        MEGAKERNEL_FRONTIER_RESERVATION,
287        vec,
288        target_capacity,
289        item,
290        storage_reserve_failed,
291    )
292}
293
294fn storage_reserve_failed(
295    field: &'static str,
296    requested: usize,
297    message: String,
298) -> MegakernelFrontierMemoryPlanError {
299    MegakernelFrontierMemoryPlanError::StorageReserveFailed {
300        field,
301        requested,
302        message,
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::{
309        megakernel_frontier_fused_wave_budget_bytes, plan_megakernel_frontier_memory_with_scratch,
310        MegakernelFrontierMemoryPlanError, MegakernelFrontierWave,
311    };
312    use crate::megakernel_barrier::{MegakernelBarrierScratch, MegakernelWaveDependency};
313
314    #[test]
315    fn frontier_memory_plan_uses_peak_barrier_group_memory() {
316        let mut scratch = MegakernelBarrierScratch::default();
317        let plan = plan_megakernel_frontier_memory_with_scratch(
318            &[
319                MegakernelFrontierWave {
320                    frontier_bytes: 1_024,
321                    scratch_bytes: 512,
322                    output_bytes: 256,
323                },
324                MegakernelFrontierWave {
325                    frontier_bytes: 2_048,
326                    scratch_bytes: 1_024,
327                    output_bytes: 512,
328                },
329                MegakernelFrontierWave {
330                    frontier_bytes: 4_096,
331                    scratch_bytes: 2_048,
332                    output_bytes: 1_024,
333                },
334                MegakernelFrontierWave {
335                    frontier_bytes: 8_192,
336                    scratch_bytes: 4_096,
337                    output_bytes: 2_048,
338                },
339            ],
340            &[
341                MegakernelWaveDependency {
342                    before: 0,
343                    after: 1,
344                },
345                MegakernelWaveDependency {
346                    before: 0,
347                    after: 2,
348                },
349                MegakernelWaveDependency {
350                    before: 1,
351                    after: 3,
352                },
353                MegakernelWaveDependency {
354                    before: 2,
355                    after: 3,
356                },
357            ],
358            16_000,
359            128 * 1024,
360            1 << 20,
361            &mut scratch,
362        )
363        .expect("Fix: frontier-typed megakernel memory plan should fit the budget.");
364
365        assert_eq!(plan.barriers.global_barriers, 2);
366        assert_eq!(plan.barriers.groups[1].waves, vec![1, 2]);
367        assert_eq!(plan.peak_frontier_bytes, 8_192);
368        assert_eq!(plan.peak_scratch_bytes, 4_096);
369        assert_eq!(plan.peak_output_bytes, 2_048);
370        assert_eq!(plan.amortized_readback_bytes, 1 << 20);
371        assert_eq!(plan.max_group_width, 2);
372    }
373
374    #[test]
375    fn frontier_memory_uses_static_group_output_to_amortize_readback() {
376        let mut scratch = MegakernelBarrierScratch::default();
377        let plan = plan_megakernel_frontier_memory_with_scratch(
378            &[
379                MegakernelFrontierWave {
380                    frontier_bytes: 1_024,
381                    scratch_bytes: 512,
382                    output_bytes: 3_072,
383                },
384                MegakernelFrontierWave {
385                    frontier_bytes: 1_024,
386                    scratch_bytes: 512,
387                    output_bytes: 3_072,
388                },
389            ],
390            &[],
391            16_000,
392            128 * 1024,
393            0,
394            &mut scratch,
395        )
396        .expect("Fix: static output-amortized frontier memory plan should fit the budget.");
397
398        assert_eq!(plan.peak_output_bytes, 6_144);
399        assert_eq!(plan.amortized_readback_bytes, 6_144);
400    }
401
402    #[test]
403    fn frontier_memory_splits_independent_layers_to_fit_fused_budget() {
404        let mut scratch = MegakernelBarrierScratch::default();
405        let waves = [
406            MegakernelFrontierWave {
407                frontier_bytes: 10,
408                scratch_bytes: 10,
409                output_bytes: 10,
410            },
411            MegakernelFrontierWave {
412                frontier_bytes: 10,
413                scratch_bytes: 10,
414                output_bytes: 10,
415            },
416            MegakernelFrontierWave {
417                frontier_bytes: 10,
418                scratch_bytes: 10,
419                output_bytes: 10,
420            },
421        ];
422        let plan =
423            plan_megakernel_frontier_memory_with_scratch(&waves, &[], 0, 100, 4_096, &mut scratch)
424                .expect("Fix: independent frontier waves should split into budget-fit chunks.");
425
426        assert_eq!(plan.barriers.groups.len(), 3);
427        assert_eq!(plan.barriers.global_barriers, 2);
428        assert_eq!(plan.max_group_width, 1);
429        assert_eq!(plan.peak_frontier_bytes, 10);
430        assert_eq!(plan.peak_scratch_bytes, 10);
431        assert_eq!(plan.peak_output_bytes, 10);
432    }
433
434    #[test]
435    fn frontier_memory_rejects_graph_and_single_wave_over_budget() {
436        let mut scratch = MegakernelBarrierScratch::default();
437        let graph_error = plan_megakernel_frontier_memory_with_scratch(
438            &[MegakernelFrontierWave {
439                frontier_bytes: 1,
440                scratch_bytes: 1,
441                output_bytes: 1,
442            }],
443            &[],
444            1_600,
445            1_000,
446            0,
447            &mut scratch,
448        )
449        .expect_err("resident graph bytes above budget must fail before split planning");
450        assert_eq!(
451            graph_error,
452            MegakernelFrontierMemoryPlanError::GroupOverBudget {
453                required_bytes: 1_600,
454                budget_bytes: 1_000,
455                field: "resident graph bytes",
456            }
457        );
458
459        let wave_error = plan_megakernel_frontier_memory_with_scratch(
460            &[MegakernelFrontierWave {
461                frontier_bytes: 100,
462                scratch_bytes: 100,
463                output_bytes: 100,
464            }],
465            &[],
466            0,
467            500,
468            0,
469            &mut scratch,
470        )
471        .expect_err("single fused wave above group budget must fail before topology planning");
472        assert_eq!(
473            wave_error,
474            MegakernelFrontierMemoryPlanError::GroupOverBudget {
475                required_bytes: 600,
476                budget_bytes: 500,
477                field: "single fused frontier wave bytes",
478            }
479        );
480    }
481
482    #[test]
483    fn frontier_fused_wave_budget_uses_topology_scratch_multiplier() {
484        assert_eq!(
485            megakernel_frontier_fused_wave_budget_bytes(MegakernelFrontierWave {
486                frontier_bytes: 16,
487                scratch_bytes: 16,
488                output_bytes: 16,
489            })
490            .expect("Fix: fused frontier wave budget should fit"),
491            96
492        );
493    }
494
495    #[test]
496    fn frontier_memory_fails_loudly_on_wave_byte_overflow() {
497        let mut scratch = MegakernelBarrierScratch::default();
498        let error = plan_megakernel_frontier_memory_with_scratch(
499            &[
500                MegakernelFrontierWave {
501                    frontier_bytes: u64::MAX,
502                    scratch_bytes: 1,
503                    output_bytes: 1,
504                },
505                MegakernelFrontierWave {
506                    frontier_bytes: 1,
507                    scratch_bytes: 1,
508                    output_bytes: 1,
509                },
510            ],
511            &[],
512            2,
513            u64::MAX,
514            0,
515            &mut scratch,
516        )
517        .expect_err("Fix: overflowed frontier wave bytes must fail before launch planning.");
518
519        assert_eq!(
520            error,
521            MegakernelFrontierMemoryPlanError::ByteCountOverflow {
522                field: "fused wave bytes"
523            }
524        );
525    }
526
527    #[test]
528    fn generated_frontier_memory_profiles_preserve_peak_and_budget_for_1024_shapes() {
529        let mut scratch = MegakernelBarrierScratch::default();
530        for width in 1u64..=32 {
531            for depth in 1u64..=32 {
532                let mut waves = Vec::new();
533                let mut dependencies = Vec::new();
534                for layer in 0..depth {
535                    for slot in 0..width {
536                        waves.push(MegakernelFrontierWave {
537                            frontier_bytes: width,
538                            scratch_bytes: slot + 1,
539                            output_bytes: layer + 1,
540                        });
541                        if layer + 1 < depth {
542                            dependencies.push(MegakernelWaveDependency {
543                                before: (layer * width + slot) as usize,
544                                after: ((layer + 1) * width + slot) as usize,
545                            });
546                        }
547                    }
548                }
549
550                let plan = plan_megakernel_frontier_memory_with_scratch(
551                    &waves,
552                    &dependencies,
553                    256,
554                    u64::MAX / 2,
555                    7,
556                    &mut scratch,
557                )
558                .expect("Fix: generated frontier memory DAG should plan under large budget.");
559
560                assert_eq!(plan.barriers.groups.len(), depth as usize);
561                assert_eq!(plan.max_group_width, width as usize);
562                assert_eq!(plan.peak_frontier_bytes, width * width);
563                assert_eq!(plan.peak_scratch_bytes, width * (width + 1) / 2);
564                assert_eq!(plan.peak_output_bytes, width * depth);
565                assert_eq!(plan.amortized_readback_bytes, 7.max(width * depth));
566            }
567        }
568    }
569}