1use crate::accounting::{
10 checked_add_u64_count as checked_add, checked_mul_u64_count as checked_mul,
11};
12use crate::megakernel_barrier::{
13 plan_megakernel_barriers_with_scratch, MegakernelBarrierGroup, MegakernelBarrierPlan,
14 MegakernelBarrierPlanError, MegakernelBarrierScratch, MegakernelWaveDependency,
15};
16use crate::reservation_policy::{
17 reserve_typed_vec_to_capacity as reserve_vec_to_capacity, ReservationPolicy,
18};
19
20const MEGAKERNEL_FRONTIER_RESERVATION: ReservationPolicy = ReservationPolicy::new(
21 "megakernel frontier memory planner",
22 "shard the frontier wave group or split the fused phase",
23);
24
25#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27pub struct MegakernelFrontierWave {
28 pub frontier_bytes: u64,
30 pub scratch_bytes: u64,
32 pub output_bytes: u64,
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
38pub struct MegakernelFrontierMemoryPlan {
39 pub barriers: MegakernelBarrierPlan,
41 pub peak_frontier_bytes: u64,
43 pub peak_scratch_bytes: u64,
45 pub peak_output_bytes: u64,
47 pub amortized_readback_bytes: u64,
50 pub max_group_width: usize,
52}
53
54#[derive(Clone, Debug, Eq, PartialEq)]
56pub enum MegakernelFrontierMemoryPlanError {
57 Barrier(MegakernelBarrierPlanError),
59 ByteCountOverflow {
61 field: &'static str,
63 },
64 GroupOverBudget {
66 required_bytes: u64,
68 budget_bytes: u64,
70 field: &'static str,
72 },
73 StorageReserveFailed {
75 field: &'static str,
77 requested: usize,
79 message: String,
81 },
82}
83
84impl crate::accounting::ArithmeticOverflow for MegakernelFrontierMemoryPlanError {
85 fn arithmetic_overflow(field: &'static str) -> Self {
86 Self::ByteCountOverflow { field }
87 }
88}
89
90impl std::fmt::Display for MegakernelFrontierMemoryPlanError {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 match self {
93 Self::Barrier(error) => error.fmt(f),
94 Self::ByteCountOverflow { field } => write!(
95 f,
96 "megakernel frontier memory planner overflowed while accumulating {field}. Fix: shard the frontier wave group or split the fused phase."
97 ),
98 Self::GroupOverBudget {
99 required_bytes,
100 budget_bytes,
101 field,
102 } => write!(
103 f,
104 "megakernel frontier memory planner requires {required_bytes} bytes for {field} but budget allows {budget_bytes}. Fix: shard the graph/frontier waves or raise the explicit megakernel budget."
105 ),
106 Self::StorageReserveFailed {
107 field,
108 requested,
109 message,
110 } => write!(
111 f,
112 "megakernel frontier memory planner could not reserve {requested} {field} entries: {message}. Fix: shard the frontier waves before planning."
113 ),
114 }
115 }
116}
117
118impl std::error::Error for MegakernelFrontierMemoryPlanError {}
119
120impl From<MegakernelBarrierPlanError> for MegakernelFrontierMemoryPlanError {
121 fn from(error: MegakernelBarrierPlanError) -> Self {
122 Self::Barrier(error)
123 }
124}
125
126pub fn plan_megakernel_frontier_memory_with_scratch(
134 waves: &[MegakernelFrontierWave],
135 dependencies: &[MegakernelWaveDependency],
136 resident_graph_bytes: u64,
137 budget_bytes: u64,
138 readback_bytes: u64,
139 scratch: &mut MegakernelBarrierScratch,
140) -> Result<MegakernelFrontierMemoryPlan, MegakernelFrontierMemoryPlanError> {
141 let barriers = plan_megakernel_barriers_with_scratch(waves.len(), dependencies, scratch)?;
142 let group_budget_bytes = budget_bytes.checked_sub(resident_graph_bytes).ok_or(
143 MegakernelFrontierMemoryPlanError::GroupOverBudget {
144 required_bytes: resident_graph_bytes,
145 budget_bytes,
146 field: "resident graph bytes",
147 },
148 )?;
149 let barriers = split_barrier_groups_to_memory_budget(barriers, waves, group_budget_bytes)?;
150 let mut peak_frontier_bytes = 0u64;
151 let mut peak_scratch_bytes = 0u64;
152 let mut peak_output_bytes = 0u64;
153 let mut max_group_width = 0usize;
154 for group in &barriers.groups {
155 let mut group_frontier_bytes = 0u64;
156 let mut group_scratch_bytes = 0u64;
157 let mut group_output_bytes = 0u64;
158 max_group_width = max_group_width.max(group.waves.len());
159 for &wave_index in &group.waves {
160 let wave = waves[wave_index];
161 group_frontier_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
162 group_frontier_bytes,
163 wave.frontier_bytes,
164 "frontier wave bytes",
165 )?;
166 group_scratch_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
167 group_scratch_bytes,
168 wave.scratch_bytes,
169 "scratch wave bytes",
170 )?;
171 group_output_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
172 group_output_bytes,
173 wave.output_bytes,
174 "output wave bytes",
175 )?;
176 }
177 peak_frontier_bytes = peak_frontier_bytes.max(group_frontier_bytes);
178 peak_scratch_bytes = peak_scratch_bytes.max(group_scratch_bytes);
179 peak_output_bytes = peak_output_bytes.max(group_output_bytes);
180 }
181
182 Ok(MegakernelFrontierMemoryPlan {
183 barriers,
184 peak_frontier_bytes,
185 peak_scratch_bytes,
186 peak_output_bytes,
187 amortized_readback_bytes: readback_bytes.max(peak_output_bytes),
188 max_group_width,
189 })
190}
191
192fn split_barrier_groups_to_memory_budget(
193 barriers: MegakernelBarrierPlan,
194 waves: &[MegakernelFrontierWave],
195 group_budget_bytes: u64,
196) -> Result<MegakernelBarrierPlan, MegakernelFrontierMemoryPlanError> {
197 let mut groups = Vec::new();
198 reserve_vec::<MegakernelBarrierGroup>(
199 &mut groups,
200 barriers.groups.len(),
201 "split barrier groups",
202 )?;
203 for group in barriers.groups {
204 split_one_barrier_group_to_memory_budget(group, waves, group_budget_bytes, &mut groups)?;
205 }
206 Ok(MegakernelBarrierPlan {
207 global_barriers: if groups.is_empty() {
208 0
209 } else {
210 groups.len() - 1
211 },
212 groups,
213 })
214}
215
216fn split_one_barrier_group_to_memory_budget(
217 group: MegakernelBarrierGroup,
218 waves: &[MegakernelFrontierWave],
219 group_budget_bytes: u64,
220 groups: &mut Vec<MegakernelBarrierGroup>,
221) -> Result<(), MegakernelFrontierMemoryPlanError> {
222 let mut current = Vec::new();
223 reserve_vec::<usize>(
224 &mut current,
225 group.waves.len().min(8),
226 "current split barrier group",
227 )?;
228 let mut current_bytes = 0u64;
229 for wave_index in group.waves {
230 let wave_bytes = megakernel_frontier_fused_wave_budget_bytes(waves[wave_index])?;
231 let combined = checked_add::<MegakernelFrontierMemoryPlanError>(
232 current_bytes,
233 wave_bytes,
234 "barrier group fused wave budget bytes",
235 )?;
236 if current.is_empty() && wave_bytes > group_budget_bytes {
237 return Err(MegakernelFrontierMemoryPlanError::GroupOverBudget {
238 required_bytes: wave_bytes,
239 budget_bytes: group_budget_bytes,
240 field: "single fused frontier wave bytes",
241 });
242 }
243 if !current.is_empty() && combined > group_budget_bytes {
244 groups.push(MegakernelBarrierGroup {
245 waves: std::mem::take(&mut current),
246 });
247 current_bytes = 0;
248 }
249 current.push(wave_index);
250 current_bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
251 current_bytes,
252 wave_bytes,
253 "barrier group fused wave budget bytes",
254 )?;
255 }
256 if !current.is_empty() {
257 groups.push(MegakernelBarrierGroup { waves: current });
258 }
259 Ok(())
260}
261
262pub fn megakernel_frontier_fused_wave_budget_bytes(
265 wave: MegakernelFrontierWave,
266) -> Result<u64, MegakernelFrontierMemoryPlanError> {
267 let fused_scratch_bytes = checked_mul::<MegakernelFrontierMemoryPlanError>(
268 wave.scratch_bytes,
269 4,
270 "fused wave scratch bytes",
271 )?;
272 let bytes = checked_add::<MegakernelFrontierMemoryPlanError>(
273 wave.frontier_bytes,
274 fused_scratch_bytes,
275 "fused wave bytes",
276 )?;
277 checked_add::<MegakernelFrontierMemoryPlanError>(bytes, wave.output_bytes, "fused wave bytes")
278}
279
280fn reserve_vec<T>(
281 vec: &mut Vec<T>,
282 target_capacity: usize,
283 item: &'static str,
284) -> Result<(), MegakernelFrontierMemoryPlanError> {
285 reserve_vec_to_capacity(
286 MEGAKERNEL_FRONTIER_RESERVATION,
287 vec,
288 target_capacity,
289 item,
290 storage_reserve_failed,
291 )
292}
293
294fn storage_reserve_failed(
295 field: &'static str,
296 requested: usize,
297 message: String,
298) -> MegakernelFrontierMemoryPlanError {
299 MegakernelFrontierMemoryPlanError::StorageReserveFailed {
300 field,
301 requested,
302 message,
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::{
309 megakernel_frontier_fused_wave_budget_bytes, plan_megakernel_frontier_memory_with_scratch,
310 MegakernelFrontierMemoryPlanError, MegakernelFrontierWave,
311 };
312 use crate::megakernel_barrier::{MegakernelBarrierScratch, MegakernelWaveDependency};
313
314 #[test]
315 fn frontier_memory_plan_uses_peak_barrier_group_memory() {
316 let mut scratch = MegakernelBarrierScratch::default();
317 let plan = plan_megakernel_frontier_memory_with_scratch(
318 &[
319 MegakernelFrontierWave {
320 frontier_bytes: 1_024,
321 scratch_bytes: 512,
322 output_bytes: 256,
323 },
324 MegakernelFrontierWave {
325 frontier_bytes: 2_048,
326 scratch_bytes: 1_024,
327 output_bytes: 512,
328 },
329 MegakernelFrontierWave {
330 frontier_bytes: 4_096,
331 scratch_bytes: 2_048,
332 output_bytes: 1_024,
333 },
334 MegakernelFrontierWave {
335 frontier_bytes: 8_192,
336 scratch_bytes: 4_096,
337 output_bytes: 2_048,
338 },
339 ],
340 &[
341 MegakernelWaveDependency {
342 before: 0,
343 after: 1,
344 },
345 MegakernelWaveDependency {
346 before: 0,
347 after: 2,
348 },
349 MegakernelWaveDependency {
350 before: 1,
351 after: 3,
352 },
353 MegakernelWaveDependency {
354 before: 2,
355 after: 3,
356 },
357 ],
358 16_000,
359 128 * 1024,
360 1 << 20,
361 &mut scratch,
362 )
363 .expect("Fix: frontier-typed megakernel memory plan should fit the budget.");
364
365 assert_eq!(plan.barriers.global_barriers, 2);
366 assert_eq!(plan.barriers.groups[1].waves, vec![1, 2]);
367 assert_eq!(plan.peak_frontier_bytes, 8_192);
368 assert_eq!(plan.peak_scratch_bytes, 4_096);
369 assert_eq!(plan.peak_output_bytes, 2_048);
370 assert_eq!(plan.amortized_readback_bytes, 1 << 20);
371 assert_eq!(plan.max_group_width, 2);
372 }
373
374 #[test]
375 fn frontier_memory_uses_static_group_output_to_amortize_readback() {
376 let mut scratch = MegakernelBarrierScratch::default();
377 let plan = plan_megakernel_frontier_memory_with_scratch(
378 &[
379 MegakernelFrontierWave {
380 frontier_bytes: 1_024,
381 scratch_bytes: 512,
382 output_bytes: 3_072,
383 },
384 MegakernelFrontierWave {
385 frontier_bytes: 1_024,
386 scratch_bytes: 512,
387 output_bytes: 3_072,
388 },
389 ],
390 &[],
391 16_000,
392 128 * 1024,
393 0,
394 &mut scratch,
395 )
396 .expect("Fix: static output-amortized frontier memory plan should fit the budget.");
397
398 assert_eq!(plan.peak_output_bytes, 6_144);
399 assert_eq!(plan.amortized_readback_bytes, 6_144);
400 }
401
402 #[test]
403 fn frontier_memory_splits_independent_layers_to_fit_fused_budget() {
404 let mut scratch = MegakernelBarrierScratch::default();
405 let waves = [
406 MegakernelFrontierWave {
407 frontier_bytes: 10,
408 scratch_bytes: 10,
409 output_bytes: 10,
410 },
411 MegakernelFrontierWave {
412 frontier_bytes: 10,
413 scratch_bytes: 10,
414 output_bytes: 10,
415 },
416 MegakernelFrontierWave {
417 frontier_bytes: 10,
418 scratch_bytes: 10,
419 output_bytes: 10,
420 },
421 ];
422 let plan =
423 plan_megakernel_frontier_memory_with_scratch(&waves, &[], 0, 100, 4_096, &mut scratch)
424 .expect("Fix: independent frontier waves should split into budget-fit chunks.");
425
426 assert_eq!(plan.barriers.groups.len(), 3);
427 assert_eq!(plan.barriers.global_barriers, 2);
428 assert_eq!(plan.max_group_width, 1);
429 assert_eq!(plan.peak_frontier_bytes, 10);
430 assert_eq!(plan.peak_scratch_bytes, 10);
431 assert_eq!(plan.peak_output_bytes, 10);
432 }
433
434 #[test]
435 fn frontier_memory_rejects_graph_and_single_wave_over_budget() {
436 let mut scratch = MegakernelBarrierScratch::default();
437 let graph_error = plan_megakernel_frontier_memory_with_scratch(
438 &[MegakernelFrontierWave {
439 frontier_bytes: 1,
440 scratch_bytes: 1,
441 output_bytes: 1,
442 }],
443 &[],
444 1_600,
445 1_000,
446 0,
447 &mut scratch,
448 )
449 .expect_err("resident graph bytes above budget must fail before split planning");
450 assert_eq!(
451 graph_error,
452 MegakernelFrontierMemoryPlanError::GroupOverBudget {
453 required_bytes: 1_600,
454 budget_bytes: 1_000,
455 field: "resident graph bytes",
456 }
457 );
458
459 let wave_error = plan_megakernel_frontier_memory_with_scratch(
460 &[MegakernelFrontierWave {
461 frontier_bytes: 100,
462 scratch_bytes: 100,
463 output_bytes: 100,
464 }],
465 &[],
466 0,
467 500,
468 0,
469 &mut scratch,
470 )
471 .expect_err("single fused wave above group budget must fail before topology planning");
472 assert_eq!(
473 wave_error,
474 MegakernelFrontierMemoryPlanError::GroupOverBudget {
475 required_bytes: 600,
476 budget_bytes: 500,
477 field: "single fused frontier wave bytes",
478 }
479 );
480 }
481
482 #[test]
483 fn frontier_fused_wave_budget_uses_topology_scratch_multiplier() {
484 assert_eq!(
485 megakernel_frontier_fused_wave_budget_bytes(MegakernelFrontierWave {
486 frontier_bytes: 16,
487 scratch_bytes: 16,
488 output_bytes: 16,
489 })
490 .expect("Fix: fused frontier wave budget should fit"),
491 96
492 );
493 }
494
495 #[test]
496 fn frontier_memory_fails_loudly_on_wave_byte_overflow() {
497 let mut scratch = MegakernelBarrierScratch::default();
498 let error = plan_megakernel_frontier_memory_with_scratch(
499 &[
500 MegakernelFrontierWave {
501 frontier_bytes: u64::MAX,
502 scratch_bytes: 1,
503 output_bytes: 1,
504 },
505 MegakernelFrontierWave {
506 frontier_bytes: 1,
507 scratch_bytes: 1,
508 output_bytes: 1,
509 },
510 ],
511 &[],
512 2,
513 u64::MAX,
514 0,
515 &mut scratch,
516 )
517 .expect_err("Fix: overflowed frontier wave bytes must fail before launch planning.");
518
519 assert_eq!(
520 error,
521 MegakernelFrontierMemoryPlanError::ByteCountOverflow {
522 field: "fused wave bytes"
523 }
524 );
525 }
526
527 #[test]
528 fn generated_frontier_memory_profiles_preserve_peak_and_budget_for_1024_shapes() {
529 let mut scratch = MegakernelBarrierScratch::default();
530 for width in 1u64..=32 {
531 for depth in 1u64..=32 {
532 let mut waves = Vec::new();
533 let mut dependencies = Vec::new();
534 for layer in 0..depth {
535 for slot in 0..width {
536 waves.push(MegakernelFrontierWave {
537 frontier_bytes: width,
538 scratch_bytes: slot + 1,
539 output_bytes: layer + 1,
540 });
541 if layer + 1 < depth {
542 dependencies.push(MegakernelWaveDependency {
543 before: (layer * width + slot) as usize,
544 after: ((layer + 1) * width + slot) as usize,
545 });
546 }
547 }
548 }
549
550 let plan = plan_megakernel_frontier_memory_with_scratch(
551 &waves,
552 &dependencies,
553 256,
554 u64::MAX / 2,
555 7,
556 &mut scratch,
557 )
558 .expect("Fix: generated frontier memory DAG should plan under large budget.");
559
560 assert_eq!(plan.barriers.groups.len(), depth as usize);
561 assert_eq!(plan.max_group_width, width as usize);
562 assert_eq!(plan.peak_frontier_bytes, width * width);
563 assert_eq!(plan.peak_scratch_bytes, width * (width + 1) / 2);
564 assert_eq!(plan.peak_output_bytes, width * depth);
565 assert_eq!(plan.amortized_readback_bytes, 7.max(width * depth));
566 }
567 }
568 }
569}