1use vyre_driver::backend::BackendError;
4
5mod cache;
6use super::planner::{MegakernelGridLimits, MegakernelGridRequest, MegakernelLaunchGeometry};
7use super::staging_reserve::try_reserve_vec_capacity;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum MegakernelQueuePressure {
12 Empty,
14 Light,
16 Balanced,
18 Saturated,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum MegakernelExecutionMode {
25 Interpreter,
27 Jit,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum MegakernelDispatchTopology {
34 Empty,
36 SparseFrontier,
39 HybridFrontier,
42 DenseFrontier,
45 FusedDense,
47 MemoryConstrained,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct MegakernelLaunchCacheStats {
55 pub entries: usize,
57 pub hits: u64,
59 pub misses: u64,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub struct MegakernelLaunchRequest {
66 pub queue_len: u32,
68 pub requested_worker_groups: u32,
70 pub max_workgroup_size_x: u32,
72 pub max_compute_workgroups_per_dimension: u32,
74 pub max_compute_invocations_per_workgroup: u32,
76 pub requested_hit_capacity: u32,
78 pub expected_hits_per_item: u32,
80 pub hot_opcode_count: u32,
82 pub hot_window_count: u32,
84 pub requeue_count: u64,
86 pub max_priority_age: u32,
88 pub graph_node_count: u32,
91 pub graph_edge_count: u32,
94 pub frontier_density_bps: u16,
96 pub memory_pressure_bps: u16,
98 pub resident_device_bytes: u64,
100 pub device_memory_budget_bytes: u64,
102}
103
104impl MegakernelLaunchRequest {
105 #[must_use]
107 pub const fn direct(
108 queue_len: u32,
109 requested_worker_groups: u32,
110 max_workgroup_size_x: u32,
111 ) -> Self {
112 Self {
113 queue_len,
114 requested_worker_groups,
115 max_workgroup_size_x,
116 max_compute_workgroups_per_dimension: requested_worker_groups,
117 max_compute_invocations_per_workgroup: max_workgroup_size_x,
118 requested_hit_capacity: 0,
119 expected_hits_per_item: 1,
120 hot_opcode_count: 0,
121 hot_window_count: 0,
122 requeue_count: 0,
123 max_priority_age: 0,
124 graph_node_count: 0,
125 graph_edge_count: 0,
126 frontier_density_bps: 0,
127 memory_pressure_bps: 0,
128 resident_device_bytes: 0,
129 device_memory_budget_bytes: 0,
130 }
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct MegakernelLaunchRecommendation {
137 pub geometry: MegakernelLaunchGeometry,
139 pub worker_groups: u32,
141 pub hit_capacity: u32,
143 pub pressure: MegakernelQueuePressure,
145 pub execution_mode: MegakernelExecutionMode,
147 pub topology: MegakernelDispatchTopology,
150 pub promote_hot_opcodes: bool,
152 pub promote_hot_windows: bool,
154 pub age_priority_work: bool,
156 pub estimated_peak_device_bytes: u64,
158 pub device_memory_budget_bytes: u64,
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
164pub struct PriorityRequeueAccounting {
165 pub requeue_count: u64,
167 pub aged_promotions: u64,
169 pub max_priority_age: u32,
171}
172
173impl PriorityRequeueAccounting {
174 pub fn record_requeue(&mut self, age_ticks: u32) {
176 self.requeue_count = self.requeue_count.checked_add(1).unwrap_or_else(|| {
177 panic!("megakernel priority requeue_count overflowed u64. Fix: drain scheduler telemetry before counters reach u64::MAX.")
178 });
179 self.max_priority_age = self.max_priority_age.max(age_ticks);
180 }
181
182 pub fn record_aged_promotion(&mut self, age_ticks: u32) {
184 self.aged_promotions = self.aged_promotions.checked_add(1).unwrap_or_else(|| {
185 panic!("megakernel aged_promotions overflowed u64. Fix: drain scheduler telemetry before counters reach u64::MAX.")
186 });
187 self.max_priority_age = self.max_priority_age.max(age_ticks);
188 }
189}
190
191#[must_use]
205pub fn diffuse_priority_across_siblings(
206 priority_stalks: &[f64],
207 restriction_diag: &[f64],
208 damping: f64,
209 iterations: u32,
210) -> Vec<f64> {
211 try_diffuse_priority_across_siblings(priority_stalks, restriction_diag, damping, iterations)
212 .unwrap_or_else(|source| {
213 panic!(
214 "megakernel priority diffusion allocation failed: {source}. Fix: shard the priority sibling set before diffusion."
215 )
216 })
217}
218
219pub fn try_diffuse_priority_across_siblings(
227 priority_stalks: &[f64],
228 restriction_diag: &[f64],
229 damping: f64,
230 iterations: u32,
231) -> Result<Vec<f64>, BackendError> {
232 let mut current = Vec::new();
233 let mut next = Vec::new();
234 try_diffuse_priority_across_siblings_into(
235 priority_stalks,
236 restriction_diag,
237 damping,
238 iterations,
239 &mut current,
240 &mut next,
241 )?;
242 Ok(current)
243}
244
245pub fn diffuse_priority_across_siblings_into(
247 priority_stalks: &[f64],
248 restriction_diag: &[f64],
249 damping: f64,
250 iterations: u32,
251 out: &mut Vec<f64>,
252 scratch: &mut Vec<f64>,
253) {
254 try_diffuse_priority_across_siblings_into(
255 priority_stalks,
256 restriction_diag,
257 damping,
258 iterations,
259 out,
260 scratch,
261 )
262 .unwrap_or_else(|source| {
263 panic!(
264 "megakernel priority diffusion allocation failed: {source}. Fix: shard the priority sibling set before diffusion."
265 )
266 });
267}
268
269pub fn try_diffuse_priority_across_siblings_into(
276 priority_stalks: &[f64],
277 restriction_diag: &[f64],
278 damping: f64,
279 iterations: u32,
280 out: &mut Vec<f64>,
281 scratch: &mut Vec<f64>,
282) -> Result<(), BackendError> {
283 out.clear();
284 reserve_target_capacity(out, priority_stalks.len(), "priority diffusion output")?;
285 out.extend_from_slice(priority_stalks);
286 scratch.clear();
287 if priority_stalks.len() != restriction_diag.len() {
288 return Ok(());
289 }
290 for _ in 0..iterations {
291 diffuse_step_into(out, restriction_diag, damping, scratch)?;
292 std::mem::swap(out, scratch);
293 }
294 Ok(())
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
299pub struct MegakernelLaunchPolicy {
300 pub sizing: super::planner::MegakernelSizingPolicy,
302 pub min_hit_capacity: u32,
304 pub hit_capacity_multiplier: u32,
306 pub saturated_waves: u32,
308 pub hot_opcode_threshold: u32,
310 pub hot_window_threshold: u32,
312 pub jit_queue_len_threshold: u32,
314 pub priority_age_threshold: u32,
316 pub sparse_frontier_threshold_bps: u16,
318 pub dense_frontier_threshold_bps: u16,
320 pub memory_pressure_threshold_bps: u16,
322 pub fusion_edge_threshold: u32,
324 pub scratch_bytes_per_hit: u32,
326}
327
328impl Default for MegakernelLaunchPolicy {
329 fn default() -> Self {
330 Self::standard()
331 }
332}
333
334const FRONTIER_TOPOLOGY_HYSTERESIS_BPS: u16 = 250;
335const MEMORY_TOPOLOGY_HYSTERESIS_BPS: u16 = 250;
336
337impl MegakernelLaunchPolicy {
338 #[must_use]
340 pub const fn standard() -> Self {
341 Self {
342 sizing: super::planner::MegakernelSizingPolicy::standard(),
343 min_hit_capacity: 1024,
344 hit_capacity_multiplier: 2,
345 saturated_waves: 4,
346 hot_opcode_threshold: 8,
347 hot_window_threshold: 4,
348 jit_queue_len_threshold: 4096,
349 priority_age_threshold: 32,
350 sparse_frontier_threshold_bps: 500,
351 dense_frontier_threshold_bps: 4_000,
352 memory_pressure_threshold_bps: 8_500,
353 fusion_edge_threshold: 65_536,
354 scratch_bytes_per_hit: 16,
355 }
356 }
357
358 #[must_use]
360 pub fn launch_cache_stats() -> MegakernelLaunchCacheStats {
361 cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| cache.borrow().stats())
362 }
363
364 pub fn reset_launch_cache_for_thread() {
366 cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| cache.borrow_mut().clear());
367 }
368
369 pub fn recommend(
376 &self,
377 request: MegakernelLaunchRequest,
378 ) -> Result<MegakernelLaunchRecommendation, BackendError> {
379 self.recommend_inner(request, None)
380 }
381
382 pub fn recommend_with_previous_topology(
396 &self,
397 request: MegakernelLaunchRequest,
398 previous_topology: MegakernelDispatchTopology,
399 ) -> Result<MegakernelLaunchRecommendation, BackendError> {
400 self.recommend_inner(request, Some(previous_topology))
401 }
402
403 fn recommend_inner(
404 &self,
405 request: MegakernelLaunchRequest,
406 previous_topology: Option<MegakernelDispatchTopology>,
407 ) -> Result<MegakernelLaunchRecommendation, BackendError> {
408 let cache_key = cache::LaunchRecommendationCacheKey {
409 policy: *self,
410 request,
411 };
412 if previous_topology.is_none() {
413 if let Some(cached) =
414 cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| cache.borrow_mut().get(&cache_key))
415 {
416 return Ok(cached);
417 }
418 }
419
420 let effective_request = self.infer_missing_scale_signals(request)?;
421 let promote_hot_opcodes = effective_request.hot_opcode_count >= self.hot_opcode_threshold;
422 let promote_hot_windows = effective_request.hot_window_count >= self.hot_window_threshold;
423 let raw_topology =
424 self.dispatch_topology_for(effective_request, promote_hot_opcodes, promote_hot_windows);
425 let topology = self.stabilize_topology(
426 raw_topology,
427 effective_request,
428 previous_topology,
429 promote_hot_opcodes,
430 promote_hot_windows,
431 );
432 let scheduled_request = self.apply_topology_worker_policy(effective_request, topology)?;
433 let grid = self.sizing.calculate_optimal_grid(
434 MegakernelGridRequest::new(
435 scheduled_request.queue_len,
436 scheduled_request.requested_worker_groups,
437 ),
438 MegakernelGridLimits::new(
439 scheduled_request.max_workgroup_size_x,
440 scheduled_request.max_compute_workgroups_per_dimension,
441 scheduled_request.max_compute_invocations_per_workgroup,
442 ),
443 )?;
444 let geometry = grid.geometry;
445 let worker_groups = grid.worker_groups;
446 let lanes = u64::from(geometry.dispatch_grid[0])
447 .checked_mul(u64::from(geometry.workgroup_size_x))
448 .ok_or_else(|| {
449 BackendError::new(
450 "megakernel launch lane count overflowed u64. Fix: reduce dispatch grid or workgroup size.",
451 )
452 })?;
453 let pressure = classify_pressure(
454 effective_request.queue_len,
455 lanes,
456 effective_request.requeue_count,
457 self,
458 )?;
459 let hit_capacity = self.hit_capacity_for(effective_request)?;
460 let estimated_peak_device_bytes =
461 self.estimated_peak_device_bytes(effective_request, hit_capacity)?;
462 if effective_request.device_memory_budget_bytes != 0
463 && estimated_peak_device_bytes > effective_request.device_memory_budget_bytes
464 {
465 return Err(BackendError::DeviceOutOfMemory {
466 requested: estimated_peak_device_bytes,
467 available: effective_request.device_memory_budget_bytes,
468 });
469 }
470 let execution_mode = if effective_request.queue_len >= self.jit_queue_len_threshold
471 || promote_hot_opcodes
472 || promote_hot_windows
473 || topology == MegakernelDispatchTopology::FusedDense
474 {
475 MegakernelExecutionMode::Jit
476 } else {
477 MegakernelExecutionMode::Interpreter
478 };
479 let age_priority_work = effective_request.requeue_count > 0
480 || effective_request.max_priority_age >= self.priority_age_threshold;
481
482 let recommendation = MegakernelLaunchRecommendation {
483 geometry,
484 worker_groups,
485 hit_capacity,
486 pressure,
487 execution_mode,
488 topology,
489 promote_hot_opcodes,
490 promote_hot_windows,
491 age_priority_work,
492 estimated_peak_device_bytes,
493 device_memory_budget_bytes: effective_request.device_memory_budget_bytes,
494 };
495 if previous_topology.is_none() {
496 cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| {
497 cache.borrow_mut().insert(cache_key, recommendation);
498 });
499 }
500 Ok(recommendation)
501 }
502
503 fn hit_capacity_for(&self, request: MegakernelLaunchRequest) -> Result<u32, BackendError> {
504 if request.requested_hit_capacity != 0 {
505 return Ok(request.requested_hit_capacity);
506 }
507 let expected_hits = request.expected_hits_per_item.max(1);
508 let multiplier = if request.memory_pressure_bps >= self.memory_pressure_threshold_bps {
509 1
510 } else {
511 self.hit_capacity_multiplier
512 };
513 let derived = request
514 .queue_len
515 .checked_mul(expected_hits)
516 .and_then(|value| value.checked_mul(multiplier))
517 .ok_or_else(|| {
518 BackendError::new(
519 "megakernel sparse-hit capacity overflowed u32. Fix: lower queue length, expected_hits_per_item, or hit_capacity_multiplier.",
520 )
521 })?;
522 Ok(derived.max(self.min_hit_capacity))
523 }
524
525 fn estimated_peak_device_bytes(
526 &self,
527 request: MegakernelLaunchRequest,
528 hit_capacity: u32,
529 ) -> Result<u64, BackendError> {
530 let scratch_bytes = u64::from(hit_capacity)
531 .checked_mul(u64::from(self.scratch_bytes_per_hit))
532 .ok_or_else(|| {
533 BackendError::new(
534 "megakernel scratch byte estimate overflowed u64. Fix: lower hit capacity or scratch_bytes_per_hit.",
535 )
536 })?;
537 request
538 .resident_device_bytes
539 .checked_add(scratch_bytes)
540 .ok_or_else(|| {
541 BackendError::new(
542 "megakernel peak resident byte estimate overflowed u64. Fix: reduce resident buffers or scratch capacity.",
543 )
544 })
545 }
546
547 fn infer_missing_scale_signals(
548 &self,
549 mut request: MegakernelLaunchRequest,
550 ) -> Result<MegakernelLaunchRequest, BackendError> {
551 if request.frontier_density_bps == 0
552 && request.queue_len != 0
553 && request.graph_node_count != 0
554 {
555 let active_nodes = u64::from(request.queue_len.min(request.graph_node_count));
556 let density = active_nodes
557 .checked_mul(10_000)
558 .ok_or_else(|| {
559 BackendError::new(
560 "megakernel frontier-density numerator overflowed u64. Fix: shard the resident graph before launch.",
561 )
562 })?
563 .checked_div(u64::from(request.graph_node_count))
564 .unwrap_or(0)
565 .clamp(1, 10_000);
566 request.frontier_density_bps = u16::try_from(density).map_err(|error| {
567 BackendError::new(format!(
568 "megakernel frontier density cannot fit u16: {error}. Fix: clamp density before ABI encoding."
569 ))
570 })?;
571 }
572 if request.memory_pressure_bps == 0
573 && request.device_memory_budget_bytes != 0
574 && request.resident_device_bytes != 0
575 {
576 let pressure = (u128::from(request.resident_device_bytes)
577 .checked_mul(10_000)
578 .ok_or_else(|| {
579 BackendError::new(
580 "megakernel memory-pressure numerator overflowed u128. Fix: reduce resident device bytes before launch.",
581 )
582 })?
583 / u128::from(request.device_memory_budget_bytes))
584 .min(10_000);
585 request.memory_pressure_bps = u16::try_from(pressure).map_err(|error| {
586 BackendError::new(format!(
587 "megakernel memory pressure cannot fit u16: {error}. Fix: clamp pressure before ABI encoding."
588 ))
589 })?;
590 }
591 Ok(request)
592 }
593
594 fn apply_topology_worker_policy(
595 &self,
596 mut request: MegakernelLaunchRequest,
597 topology: MegakernelDispatchTopology,
598 ) -> Result<MegakernelLaunchRequest, BackendError> {
599 if topology == MegakernelDispatchTopology::MemoryConstrained
600 && request.memory_pressure_bps != 0
601 && request.requested_worker_groups > 1
602 {
603 let pressure_span = u32::from(
604 10_000_u16
605 .checked_sub(self.memory_pressure_threshold_bps)
606 .ok_or_else(|| {
607 BackendError::new(
608 "megakernel memory-pressure threshold exceeds 10000 bps. Fix: configure threshold in basis points.",
609 )
610 })?,
611 )
612 .max(1);
613 let over_threshold = u32::from(
614 match request
615 .memory_pressure_bps
616 .checked_sub(self.memory_pressure_threshold_bps)
617 {
618 Some(value) => value,
619 None => 0,
620 },
621 )
622 .min(pressure_span);
623 let shed_bps = 2_500_u32
624 .checked_add(
625 over_threshold
626 .checked_mul(2_500)
627 .ok_or_else(|| {
628 BackendError::new(
629 "megakernel memory-pressure worker shed overflowed u32. Fix: lower pressure telemetry before launch.",
630 )
631 })?
632 / pressure_span,
633 )
634 .ok_or_else(|| {
635 BackendError::new(
636 "megakernel memory-pressure worker shed overflowed u32. Fix: lower pressure telemetry before launch.",
637 )
638 })?;
639 let keep_bps = 10_000_u32.checked_sub(shed_bps).ok_or_else(|| {
640 BackendError::new(
641 "megakernel memory-pressure worker keep ratio underflowed. Fix: keep shed_bps within 0..=10000.",
642 )
643 })?;
644 let scaled = u64::from(request.requested_worker_groups)
645 .checked_mul(u64::from(keep_bps))
646 .ok_or_else(|| {
647 BackendError::new(
648 "megakernel memory-constrained worker count overflowed u64. Fix: reduce requested worker groups.",
649 )
650 })?
651 / 10_000;
652 request.requested_worker_groups = u32::try_from(scaled)
653 .map_err(|error| {
654 BackendError::new(format!(
655 "megakernel memory-constrained worker count cannot fit u32: {error}. Fix: reduce requested worker groups."
656 ))
657 })?
658 .max(1);
659 }
660 if topology == MegakernelDispatchTopology::SparseFrontier
661 && request.graph_node_count != 0
662 && request.frontier_density_bps != 0
663 && request.requested_worker_groups > 1
664 {
665 let sparse_span = u32::from(self.sparse_frontier_threshold_bps).max(1);
666 let density = u32::from(request.frontier_density_bps).clamp(1, sparse_span);
667 let scaled = u64::from(request.requested_worker_groups)
668 .checked_mul(u64::from(density))
669 .ok_or_else(|| {
670 BackendError::new(
671 "megakernel sparse-frontier worker count overflowed u64. Fix: reduce requested worker groups.",
672 )
673 })?
674 / u64::from(sparse_span);
675 let warp_floor = request.requested_worker_groups.min(32);
676 request.requested_worker_groups = u32::try_from(scaled)
677 .map_err(|error| {
678 BackendError::new(format!(
679 "megakernel sparse-frontier worker count cannot fit u32: {error}. Fix: reduce requested worker groups."
680 ))
681 })?
682 .max(warp_floor)
683 .min(request.requested_worker_groups);
684 }
685 Ok(request)
686 }
687
688 fn dispatch_topology_for(
689 &self,
690 request: MegakernelLaunchRequest,
691 promote_hot_opcodes: bool,
692 promote_hot_windows: bool,
693 ) -> MegakernelDispatchTopology {
694 if request.queue_len == 0 {
695 return MegakernelDispatchTopology::Empty;
696 }
697 if request.memory_pressure_bps >= self.memory_pressure_threshold_bps {
698 return MegakernelDispatchTopology::MemoryConstrained;
699 }
700 if request.frontier_density_bps <= self.sparse_frontier_threshold_bps {
701 return MegakernelDispatchTopology::SparseFrontier;
702 }
703 let dense = request.frontier_density_bps >= self.dense_frontier_threshold_bps;
704 let graph_is_large =
705 request.graph_node_count > 0 && request.graph_edge_count >= self.fusion_edge_threshold;
706 if dense && graph_is_large && (promote_hot_opcodes || promote_hot_windows) {
707 return MegakernelDispatchTopology::FusedDense;
708 }
709 if dense {
710 return MegakernelDispatchTopology::DenseFrontier;
711 }
712 MegakernelDispatchTopology::HybridFrontier
713 }
714
715 fn stabilize_topology(
716 &self,
717 raw_topology: MegakernelDispatchTopology,
718 request: MegakernelLaunchRequest,
719 previous_topology: Option<MegakernelDispatchTopology>,
720 promote_hot_opcodes: bool,
721 promote_hot_windows: bool,
722 ) -> MegakernelDispatchTopology {
723 if raw_topology == MegakernelDispatchTopology::Empty {
724 return raw_topology;
725 }
726 if raw_topology == MegakernelDispatchTopology::MemoryConstrained {
727 return raw_topology;
728 }
729 let Some(previous_topology) = previous_topology else {
730 return raw_topology;
731 };
732 if previous_topology == MegakernelDispatchTopology::MemoryConstrained
733 && request.memory_pressure_bps
734 >= hysteresis_sub(
735 self.memory_pressure_threshold_bps,
736 MEMORY_TOPOLOGY_HYSTERESIS_BPS,
737 )
738 {
739 return MegakernelDispatchTopology::MemoryConstrained;
740 }
741
742 match previous_topology {
743 MegakernelDispatchTopology::SparseFrontier
744 if raw_topology != MegakernelDispatchTopology::SparseFrontier
745 && request.frontier_density_bps
746 <= hysteresis_add(
747 self.sparse_frontier_threshold_bps,
748 FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
749 ) =>
750 {
751 MegakernelDispatchTopology::SparseFrontier
752 }
753 MegakernelDispatchTopology::HybridFrontier
754 if raw_topology == MegakernelDispatchTopology::SparseFrontier
755 && request.frontier_density_bps
756 >= hysteresis_sub(
757 self.sparse_frontier_threshold_bps,
758 FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
759 ) =>
760 {
761 MegakernelDispatchTopology::HybridFrontier
762 }
763 MegakernelDispatchTopology::HybridFrontier
764 if matches!(
765 raw_topology,
766 MegakernelDispatchTopology::DenseFrontier
767 | MegakernelDispatchTopology::FusedDense
768 ) && request.frontier_density_bps
769 <= hysteresis_add(
770 self.dense_frontier_threshold_bps,
771 FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
772 ) =>
773 {
774 MegakernelDispatchTopology::HybridFrontier
775 }
776 MegakernelDispatchTopology::DenseFrontier
777 if raw_topology == MegakernelDispatchTopology::HybridFrontier
778 && request.frontier_density_bps
779 >= hysteresis_sub(
780 self.dense_frontier_threshold_bps,
781 FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
782 ) =>
783 {
784 MegakernelDispatchTopology::DenseFrontier
785 }
786 MegakernelDispatchTopology::FusedDense
787 if raw_topology == MegakernelDispatchTopology::HybridFrontier
788 && request.frontier_density_bps
789 >= hysteresis_sub(
790 self.dense_frontier_threshold_bps,
791 FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
792 )
793 && request.graph_edge_count >= self.fusion_edge_threshold
794 && (promote_hot_opcodes || promote_hot_windows) =>
795 {
796 MegakernelDispatchTopology::FusedDense
797 }
798 _ => raw_topology,
799 }
800 }
801
802 #[must_use]
813 pub fn autotune_hit_capacity_multiplier(
814 &self,
815 candidate_multipliers: &[u32],
816 costs: &[f64],
817 ) -> u32 {
818 if candidate_multipliers.is_empty() || costs.is_empty() {
819 return self.hit_capacity_multiplier;
820 }
821 let n = candidate_multipliers.len().min(costs.len());
822 let chosen = best_cost_index(&costs[..n]);
823 candidate_multipliers
824 .get(chosen)
825 .copied()
826 .unwrap_or(self.hit_capacity_multiplier)
827 }
828
829 #[must_use]
835 pub fn autotune_workgroup_size(
836 &self,
837 candidate_sizes: &[u32],
838 costs: &[f64],
839 current_size: u32,
840 ) -> u32 {
841 if candidate_sizes.is_empty() || costs.is_empty() {
842 return current_size;
843 }
844 let n = candidate_sizes.len().min(costs.len());
845 let chosen = best_cost_index(&costs[..n]);
846 candidate_sizes.get(chosen).copied().unwrap_or(current_size)
847 }
848
849 #[must_use]
866 pub fn natural_gradient_autotune_step(
867 m_inv_sqrt: &[f64],
868 grad: &[f64],
869 n: u32,
870 learning_rate: f64,
871 ) -> Vec<f64> {
872 Self::try_natural_gradient_autotune_step(m_inv_sqrt, grad, n, learning_rate)
873 .unwrap_or_else(|source| {
874 panic!(
875 "megakernel natural-gradient autotune allocation failed: {source}. Fix: shard the autotune surface."
876 )
877 })
878 }
879
880 pub fn try_natural_gradient_autotune_step(
887 m_inv_sqrt: &[f64],
888 grad: &[f64],
889 n: u32,
890 learning_rate: f64,
891 ) -> Result<Vec<f64>, BackendError> {
892 let mut out = Vec::new();
893 Self::try_natural_gradient_autotune_step_into(
894 m_inv_sqrt,
895 grad,
896 n,
897 learning_rate,
898 &mut out,
899 )?;
900 Ok(out)
901 }
902
903 pub fn natural_gradient_autotune_step_into(
905 m_inv_sqrt: &[f64],
906 grad: &[f64],
907 n: u32,
908 learning_rate: f64,
909 out: &mut Vec<f64>,
910 ) {
911 Self::try_natural_gradient_autotune_step_into(m_inv_sqrt, grad, n, learning_rate, out)
912 .unwrap_or_else(|source| {
913 panic!(
914 "megakernel natural-gradient autotune allocation failed: {source}. Fix: shard the autotune surface."
915 )
916 });
917 }
918
919 pub fn try_natural_gradient_autotune_step_into(
927 m_inv_sqrt: &[f64],
928 grad: &[f64],
929 n: u32,
930 learning_rate: f64,
931 out: &mut Vec<f64>,
932 ) -> Result<(), BackendError> {
933 let n = u32_to_usize_or_panic(n, "natural-gradient dimension");
934 out.clear();
935 let Some(required_matrix_len) = n.checked_mul(n) else {
936 return Ok(());
937 };
938 if m_inv_sqrt.len() < required_matrix_len || grad.len() < n {
939 return Ok(());
940 }
941 reserve_target_capacity(out, n, "natural-gradient output")?;
942 out.resize(n, 0.0);
943 for row in 0..n {
944 let mut acc = 0.0;
945 for col in 0..n {
946 acc += m_inv_sqrt[row * n + col] * grad[col];
947 }
948 out[row] = -learning_rate * acc;
949 }
950 Ok(())
951 }
952}
953
954
955fn diffuse_step_into(
956 stalks: &[f64],
957 restriction_diag: &[f64],
958 damping: f64,
959 out: &mut Vec<f64>,
960) -> Result<(), BackendError> {
961 out.clear();
962 reserve_target_capacity(out, stalks.len(), "priority diffusion scratch")?;
963 out.resize(stalks.len(), 0.0);
964 for ((slot, &stalk), &restriction) in out
965 .iter_mut()
966 .zip(stalks.iter())
967 .zip(restriction_diag.iter())
968 {
969 *slot = stalk - damping * restriction * stalk;
970 }
971 Ok(())
972}
973
974fn reserve_target_capacity<T>(
975 out: &mut Vec<T>,
976 target_capacity: usize,
977 label: &'static str,
978) -> Result<(), BackendError> {
979 try_reserve_vec_capacity(out, target_capacity).map_err(|source| {
980 BackendError::new(format!(
981 "megakernel {label} reservation failed for {target_capacity} element(s): {source}. Fix: shard the policy input before launch-policy math."
982 ))
983 })
984}
985
986fn best_cost_index(costs: &[f64]) -> usize {
987 debug_assert!(!costs.is_empty());
988 let mut best = 0;
989 let mut best_cost = costs[0];
990 for (index, &cost) in costs.iter().enumerate().skip(1) {
991 if cost.total_cmp(&best_cost).is_lt() {
992 best = index;
993 best_cost = cost;
994 }
995 }
996 best
997}
998
999fn u32_to_usize_or_panic(value: u32, label: &'static str) -> usize {
1000 match usize::try_from(value) {
1001 Ok(value) => value,
1002 Err(error) => {
1003 panic!("{label} cannot fit usize: {error}. Fix: shard the autotune surface.")
1004 }
1005 }
1006}
1007
1008fn hysteresis_add(value: u16, hysteresis: u16) -> u16 {
1009 value.checked_add(hysteresis).unwrap_or_else(|| {
1010 panic!(
1011 "megakernel topology hysteresis upper bound overflowed u16. Fix: lower topology threshold or hysteresis."
1012 )
1013 })
1014}
1015
1016fn hysteresis_sub(value: u16, hysteresis: u16) -> u16 {
1017 value.checked_sub(hysteresis).unwrap_or_else(|| {
1018 panic!(
1019 "megakernel topology hysteresis lower bound underflowed u16. Fix: lower hysteresis or raise topology threshold."
1020 )
1021 })
1022}
1023
1024fn classify_pressure(
1025 queue_len: u32,
1026 lanes: u64,
1027 requeue_count: u64,
1028 policy: &MegakernelLaunchPolicy,
1029) -> Result<MegakernelQueuePressure, BackendError> {
1030 if queue_len == 0 {
1031 return Ok(MegakernelQueuePressure::Empty);
1032 }
1033 let lanes = lanes.max(1);
1034 let queue_len = u64::from(queue_len);
1035 let saturated_lanes = lanes
1036 .checked_mul(u64::from(policy.saturated_waves))
1037 .ok_or_else(|| {
1038 BackendError::new(
1039 "megakernel pressure wave threshold overflowed u64. Fix: reduce worker lanes or saturated_waves.",
1040 )
1041 })?;
1042 if requeue_count > 0 || queue_len >= saturated_lanes {
1043 Ok(MegakernelQueuePressure::Saturated)
1044 } else if queue_len >= lanes {
1045 Ok(MegakernelQueuePressure::Balanced)
1046 } else {
1047 Ok(MegakernelQueuePressure::Light)
1048 }
1049}
1050
1051#[cfg(test)]
1052mod tests;
1053