Skip to main content

vyre_runtime/megakernel/
policy.rs

1//! Resident megakernel launch policy and queue-pressure decisions.
2
3use vyre_driver::backend::BackendError;
4
5mod cache;
6use super::planner::{MegakernelGridLimits, MegakernelGridRequest, MegakernelLaunchGeometry};
7use super::staging_reserve::try_reserve_vec_capacity;
8
9/// Host-side pressure classification for one megakernel launch.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum MegakernelQueuePressure {
12    /// No logical slots are queued.
13    Empty,
14    /// The queue is below the available worker lanes.
15    Light,
16    /// The queue is large enough to keep the submitted workers occupied.
17    Balanced,
18    /// The queue is several waves deep or already showing requeue pressure.
19    Saturated,
20}
21
22/// Interpreter/JIT route selected by the launch policy.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum MegakernelExecutionMode {
25    /// Use the generic opcode interpreter.
26    Interpreter,
27    /// Use a fused payload processor for hot windows or opcodes.
28    Jit,
29}
30
31/// Scale-aware execution topology selected for one megakernel launch.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum MegakernelDispatchTopology {
34    /// Nothing is queued.
35    Empty,
36    /// Low frontier density; prefer sparse frontier expansion and avoid
37    /// block-wide dense scans.
38    SparseFrontier,
39    /// Mid-density frontier; combine sparse frontier queues with dense block
40    /// tiles instead of forcing either extreme.
41    HybridFrontier,
42    /// High frontier density; prefer dense block propagation with coalesced
43    /// scans.
44    DenseFrontier,
45    /// High-density graph with enough hot structure to justify fused waves.
46    FusedDense,
47    /// Memory pressure is high enough that bounded occupancy is more important
48    /// than maximizing active waves.
49    MemoryConstrained,
50}
51
52/// Thread-local launch recommendation cache telemetry.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct MegakernelLaunchCacheStats {
55    /// Live cache entries retained in the current thread.
56    pub entries: usize,
57    /// Cache hits served without recomputing launch geometry.
58    pub hits: u64,
59    /// Cache misses that required policy recomputation.
60    pub misses: u64,
61}
62
63/// Inputs for one launch-policy recommendation.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub struct MegakernelLaunchRequest {
66    /// Logical ring slots or work items queued for this launch.
67    pub queue_len: u32,
68    /// Caller-requested worker workgroup ceiling. Zero means derive from occupancy.
69    pub requested_worker_groups: u32,
70    /// Adapter maximum workgroup size in the x dimension.
71    pub max_workgroup_size_x: u32,
72    /// Adapter maximum compute workgroups per dimension.
73    pub max_compute_workgroups_per_dimension: u32,
74    /// Adapter maximum invocations per compute workgroup.
75    pub max_compute_invocations_per_workgroup: u32,
76    /// Caller-requested sparse-hit capacity. Zero means derive from queue shape.
77    pub requested_hit_capacity: u32,
78    /// Expected sparse hits per queued item when deriving hit capacity.
79    pub expected_hits_per_item: u32,
80    /// Count of opcodes observed hot enough for promotion.
81    pub hot_opcode_count: u32,
82    /// Count of ticketed route windows observed hot enough for promotion.
83    pub hot_window_count: u32,
84    /// Slots requeued by priority scheduling since the last recommendation.
85    pub requeue_count: u64,
86    /// Maximum priority age observed since the last recommendation.
87    pub max_priority_age: u32,
88    /// Nodes in the resident dependency graph. Zero means the caller has no
89    /// graph-shape telemetry for this launch.
90    pub graph_node_count: u32,
91    /// Edges in the resident dependency graph. Zero means the caller has no
92    /// graph-shape telemetry for this launch.
93    pub graph_edge_count: u32,
94    /// Active frontier density in basis points relative to graph nodes.
95    pub frontier_density_bps: u16,
96    /// Device-memory pressure in basis points relative to the active budget.
97    pub memory_pressure_bps: u16,
98    /// Device-resident bytes already required by this dispatch family.
99    pub resident_device_bytes: u64,
100    /// Hard device-memory budget for this launch. Zero means unbounded.
101    pub device_memory_budget_bytes: u64,
102}
103
104impl MegakernelLaunchRequest {
105    /// Construct a direct-dispatch request with conservative defaults.
106    #[must_use]
107    pub const fn direct(
108        queue_len: u32,
109        requested_worker_groups: u32,
110        max_workgroup_size_x: u32,
111    ) -> Self {
112        Self {
113            queue_len,
114            requested_worker_groups,
115            max_workgroup_size_x,
116            max_compute_workgroups_per_dimension: requested_worker_groups,
117            max_compute_invocations_per_workgroup: max_workgroup_size_x,
118            requested_hit_capacity: 0,
119            expected_hits_per_item: 1,
120            hot_opcode_count: 0,
121            hot_window_count: 0,
122            requeue_count: 0,
123            max_priority_age: 0,
124            graph_node_count: 0,
125            graph_edge_count: 0,
126            frontier_density_bps: 0,
127            memory_pressure_bps: 0,
128            resident_device_bytes: 0,
129            device_memory_budget_bytes: 0,
130        }
131    }
132}
133
134/// Policy output consumed by runtime dispatchers and batch builders.
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct MegakernelLaunchRecommendation {
137    /// Padded launch geometry for the ring protocol.
138    pub geometry: MegakernelLaunchGeometry,
139    /// Worker workgroups selected for the dispatch.
140    pub worker_groups: u32,
141    /// Sparse-hit capacity selected for the dispatch.
142    pub hit_capacity: u32,
143    /// Queue pressure classification.
144    pub pressure: MegakernelQueuePressure,
145    /// Interpreter or JIT route selected from telemetry.
146    pub execution_mode: MegakernelExecutionMode,
147    /// Scale-aware dispatch topology selected from graph shape, frontier
148    /// density, and memory pressure.
149    pub topology: MegakernelDispatchTopology,
150    /// True when hot opcode counters justify fused opcode promotion.
151    pub promote_hot_opcodes: bool,
152    /// True when ticketed route windows justify fused window promotion.
153    pub promote_hot_windows: bool,
154    /// True when aged/requeued priority work should be lifted on the next publish.
155    pub age_priority_work: bool,
156    /// Estimated peak device bytes needed by the resident launch plan.
157    pub estimated_peak_device_bytes: u64,
158    /// Hard device-memory budget applied to this recommendation. Zero means unbounded.
159    pub device_memory_budget_bytes: u64,
160}
161
162/// Requeue and aging counters produced by priority-aware schedulers.
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
164pub struct PriorityRequeueAccounting {
165    /// Number of slots requeued due to contention or quota pressure.
166    pub requeue_count: u64,
167    /// Number of slots promoted because their priority age crossed policy.
168    pub aged_promotions: u64,
169    /// Largest age observed for any queued priority slot.
170    pub max_priority_age: u32,
171}
172
173impl PriorityRequeueAccounting {
174    /// Record one requeue event.
175    pub fn record_requeue(&mut self, age_ticks: u32) {
176        self.requeue_count = self.requeue_count.checked_add(1).unwrap_or_else(|| {
177            panic!("megakernel priority requeue_count overflowed u64. Fix: drain scheduler telemetry before counters reach u64::MAX.")
178        });
179        self.max_priority_age = self.max_priority_age.max(age_ticks);
180    }
181
182    /// Record one priority-aging promotion.
183    pub fn record_aged_promotion(&mut self, age_ticks: u32) {
184        self.aged_promotions = self.aged_promotions.checked_add(1).unwrap_or_else(|| {
185            panic!("megakernel aged_promotions overflowed u64. Fix: drain scheduler telemetry before counters reach u64::MAX.")
186        });
187        self.max_priority_age = self.max_priority_age.max(age_ticks);
188    }
189}
190
191/// Diffuse priority signals across a set of priority-class siblings
192/// via sheaf diffusion (P-RUNTIME-3). Higher-priority siblings pull
193/// neighbors toward higher priority; lower-priority siblings drag
194/// down. After a few diffusion steps, each item's priority reflects
195/// both its own age and its neighborhood pressure  -  letting requeue
196/// decisions be group-aware without hand-rolling a propagation pass.
197///
198/// `priority_stalks` is the per-item priority value (caller's choice
199/// of scale; higher = more urgent). `restriction_diag` is the
200/// per-item transmission coefficient (1.0 = freely shares priority,
201/// 0.0 = isolated). `damping` controls the diffusion rate in [0, 1].
202///
203/// Returns the post-diffusion priority vector, same shape as input.
204#[must_use]
205pub fn diffuse_priority_across_siblings(
206    priority_stalks: &[f64],
207    restriction_diag: &[f64],
208    damping: f64,
209    iterations: u32,
210) -> Vec<f64> {
211    try_diffuse_priority_across_siblings(priority_stalks, restriction_diag, damping, iterations)
212        .unwrap_or_else(|source| {
213            panic!(
214                "megakernel priority diffusion allocation failed: {source}. Fix: shard the priority sibling set before diffusion."
215            )
216        })
217}
218
219/// Diffuse priority signals across priority-class siblings with fallible
220/// output staging.
221///
222/// # Errors
223///
224/// Returns [`BackendError`] when host staging cannot be reserved for the
225/// priority vector.
226pub fn try_diffuse_priority_across_siblings(
227    priority_stalks: &[f64],
228    restriction_diag: &[f64],
229    damping: f64,
230    iterations: u32,
231) -> Result<Vec<f64>, BackendError> {
232    let mut current = Vec::new();
233    let mut next = Vec::new();
234    try_diffuse_priority_across_siblings_into(
235        priority_stalks,
236        restriction_diag,
237        damping,
238        iterations,
239        &mut current,
240        &mut next,
241    )?;
242    Ok(current)
243}
244
245/// Diffuse priority signals into caller-owned storage.
246pub fn diffuse_priority_across_siblings_into(
247    priority_stalks: &[f64],
248    restriction_diag: &[f64],
249    damping: f64,
250    iterations: u32,
251    out: &mut Vec<f64>,
252    scratch: &mut Vec<f64>,
253) {
254    try_diffuse_priority_across_siblings_into(
255        priority_stalks,
256        restriction_diag,
257        damping,
258        iterations,
259        out,
260        scratch,
261    )
262    .unwrap_or_else(|source| {
263        panic!(
264            "megakernel priority diffusion allocation failed: {source}. Fix: shard the priority sibling set before diffusion."
265        )
266    });
267}
268
269/// Diffuse priority signals into caller-owned storage with fallible staging.
270///
271/// # Errors
272///
273/// Returns [`BackendError`] when host staging cannot be reserved for the
274/// priority vector.
275pub fn try_diffuse_priority_across_siblings_into(
276    priority_stalks: &[f64],
277    restriction_diag: &[f64],
278    damping: f64,
279    iterations: u32,
280    out: &mut Vec<f64>,
281    scratch: &mut Vec<f64>,
282) -> Result<(), BackendError> {
283    out.clear();
284    reserve_target_capacity(out, priority_stalks.len(), "priority diffusion output")?;
285    out.extend_from_slice(priority_stalks);
286    scratch.clear();
287    if priority_stalks.len() != restriction_diag.len() {
288        return Ok(());
289    }
290    for _ in 0..iterations {
291        diffuse_step_into(out, restriction_diag, damping, scratch)?;
292        std::mem::swap(out, scratch);
293    }
294    Ok(())
295}
296
297/// Single policy surface for megakernel launch sizing and telemetry-driven routing.
298#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
299pub struct MegakernelLaunchPolicy {
300    /// Sizing policy for worker counts and grid geometry.
301    pub sizing: super::planner::MegakernelSizingPolicy,
302    /// Minimum capacity for sparse-hit results.
303    pub min_hit_capacity: u32,
304    /// Multiplier for expected hits to determine capacity.
305    pub hit_capacity_multiplier: u32,
306    /// Number of waves that define a saturated queue.
307    pub saturated_waves: u32,
308    /// Threshold for promoting hot opcodes to JIT.
309    pub hot_opcode_threshold: u32,
310    /// Threshold for promoting hot windows to JIT.
311    pub hot_window_threshold: u32,
312    /// Queue length threshold to prefer JIT over interpreter.
313    pub jit_queue_len_threshold: u32,
314    /// Priority age threshold to trigger aging promotions.
315    pub priority_age_threshold: u32,
316    /// Frontier density at or below this value uses sparse expansion.
317    pub sparse_frontier_threshold_bps: u16,
318    /// Frontier density at or above this value uses dense propagation.
319    pub dense_frontier_threshold_bps: u16,
320    /// Memory pressure at or above this value uses the memory-constrained path.
321    pub memory_pressure_threshold_bps: u16,
322    /// Minimum graph edge count before dense hot work is eligible for fusion.
323    pub fusion_edge_threshold: u32,
324    /// Conservative resident scratch bytes needed per sparse-hit entry.
325    pub scratch_bytes_per_hit: u32,
326}
327
328impl Default for MegakernelLaunchPolicy {
329    fn default() -> Self {
330        Self::standard()
331    }
332}
333
334const FRONTIER_TOPOLOGY_HYSTERESIS_BPS: u16 = 250;
335const MEMORY_TOPOLOGY_HYSTERESIS_BPS: u16 = 250;
336
337impl MegakernelLaunchPolicy {
338    /// Standard launch policy used by VYRE megakernel dispatchers.
339    #[must_use]
340    pub const fn standard() -> Self {
341        Self {
342            sizing: super::planner::MegakernelSizingPolicy::standard(),
343            min_hit_capacity: 1024,
344            hit_capacity_multiplier: 2,
345            saturated_waves: 4,
346            hot_opcode_threshold: 8,
347            hot_window_threshold: 4,
348            jit_queue_len_threshold: 4096,
349            priority_age_threshold: 32,
350            sparse_frontier_threshold_bps: 500,
351            dense_frontier_threshold_bps: 4_000,
352            memory_pressure_threshold_bps: 8_500,
353            fusion_edge_threshold: 65_536,
354            scratch_bytes_per_hit: 16,
355        }
356    }
357
358    /// Return launch recommendation cache telemetry for the current thread.
359    #[must_use]
360    pub fn launch_cache_stats() -> MegakernelLaunchCacheStats {
361        cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| cache.borrow().stats())
362    }
363
364    /// Clear launch recommendation cache entries and counters for this thread.
365    pub fn reset_launch_cache_for_thread() {
366        cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| cache.borrow_mut().clear());
367    }
368
369    /// Recommend geometry, hit capacity, and interpreter/JIT route.
370    ///
371    /// # Errors
372    ///
373    /// Returns [`BackendError`] when required adapter limits are zero or derived
374    /// launch values cannot fit the u32 ring protocol.
375    pub fn recommend(
376        &self,
377        request: MegakernelLaunchRequest,
378    ) -> Result<MegakernelLaunchRecommendation, BackendError> {
379        self.recommend_inner(request, None)
380    }
381
382    /// Recommend a launch while preserving the previous topology inside a
383    /// narrow hysteresis band.
384    ///
385    /// CUDA resident graphs and long-running dataflow streams should use this
386    /// entry point when they can track the last successful topology. It prevents
387    /// borderline frontier-density or memory-pressure telemetry from repeatedly
388    /// switching kernel variants, invalidating launch plans, and disturbing
389    /// cache locality at scale.
390    ///
391    /// # Errors
392    ///
393    /// Returns [`BackendError`] when required adapter limits are zero or derived
394    /// launch values cannot fit the u32 ring protocol.
395    pub fn recommend_with_previous_topology(
396        &self,
397        request: MegakernelLaunchRequest,
398        previous_topology: MegakernelDispatchTopology,
399    ) -> Result<MegakernelLaunchRecommendation, BackendError> {
400        self.recommend_inner(request, Some(previous_topology))
401    }
402
403    fn recommend_inner(
404        &self,
405        request: MegakernelLaunchRequest,
406        previous_topology: Option<MegakernelDispatchTopology>,
407    ) -> Result<MegakernelLaunchRecommendation, BackendError> {
408        let cache_key = cache::LaunchRecommendationCacheKey {
409            policy: *self,
410            request,
411        };
412        if previous_topology.is_none() {
413            if let Some(cached) =
414                cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| cache.borrow_mut().get(&cache_key))
415            {
416                return Ok(cached);
417            }
418        }
419
420        let effective_request = self.infer_missing_scale_signals(request)?;
421        let promote_hot_opcodes = effective_request.hot_opcode_count >= self.hot_opcode_threshold;
422        let promote_hot_windows = effective_request.hot_window_count >= self.hot_window_threshold;
423        let raw_topology =
424            self.dispatch_topology_for(effective_request, promote_hot_opcodes, promote_hot_windows);
425        let topology = self.stabilize_topology(
426            raw_topology,
427            effective_request,
428            previous_topology,
429            promote_hot_opcodes,
430            promote_hot_windows,
431        );
432        let scheduled_request = self.apply_topology_worker_policy(effective_request, topology)?;
433        let grid = self.sizing.calculate_optimal_grid(
434            MegakernelGridRequest::new(
435                scheduled_request.queue_len,
436                scheduled_request.requested_worker_groups,
437            ),
438            MegakernelGridLimits::new(
439                scheduled_request.max_workgroup_size_x,
440                scheduled_request.max_compute_workgroups_per_dimension,
441                scheduled_request.max_compute_invocations_per_workgroup,
442            ),
443        )?;
444        let geometry = grid.geometry;
445        let worker_groups = grid.worker_groups;
446        let lanes = u64::from(geometry.dispatch_grid[0])
447            .checked_mul(u64::from(geometry.workgroup_size_x))
448            .ok_or_else(|| {
449                BackendError::new(
450                    "megakernel launch lane count overflowed u64. Fix: reduce dispatch grid or workgroup size.",
451                )
452            })?;
453        let pressure = classify_pressure(
454            effective_request.queue_len,
455            lanes,
456            effective_request.requeue_count,
457            self,
458        )?;
459        let hit_capacity = self.hit_capacity_for(effective_request)?;
460        let estimated_peak_device_bytes =
461            self.estimated_peak_device_bytes(effective_request, hit_capacity)?;
462        if effective_request.device_memory_budget_bytes != 0
463            && estimated_peak_device_bytes > effective_request.device_memory_budget_bytes
464        {
465            return Err(BackendError::DeviceOutOfMemory {
466                requested: estimated_peak_device_bytes,
467                available: effective_request.device_memory_budget_bytes,
468            });
469        }
470        let execution_mode = if effective_request.queue_len >= self.jit_queue_len_threshold
471            || promote_hot_opcodes
472            || promote_hot_windows
473            || topology == MegakernelDispatchTopology::FusedDense
474        {
475            MegakernelExecutionMode::Jit
476        } else {
477            MegakernelExecutionMode::Interpreter
478        };
479        let age_priority_work = effective_request.requeue_count > 0
480            || effective_request.max_priority_age >= self.priority_age_threshold;
481
482        let recommendation = MegakernelLaunchRecommendation {
483            geometry,
484            worker_groups,
485            hit_capacity,
486            pressure,
487            execution_mode,
488            topology,
489            promote_hot_opcodes,
490            promote_hot_windows,
491            age_priority_work,
492            estimated_peak_device_bytes,
493            device_memory_budget_bytes: effective_request.device_memory_budget_bytes,
494        };
495        if previous_topology.is_none() {
496            cache::LAUNCH_RECOMMENDATION_CACHE.with(|cache| {
497                cache.borrow_mut().insert(cache_key, recommendation);
498            });
499        }
500        Ok(recommendation)
501    }
502
503    fn hit_capacity_for(&self, request: MegakernelLaunchRequest) -> Result<u32, BackendError> {
504        if request.requested_hit_capacity != 0 {
505            return Ok(request.requested_hit_capacity);
506        }
507        let expected_hits = request.expected_hits_per_item.max(1);
508        let multiplier = if request.memory_pressure_bps >= self.memory_pressure_threshold_bps {
509            1
510        } else {
511            self.hit_capacity_multiplier
512        };
513        let derived = request
514            .queue_len
515            .checked_mul(expected_hits)
516            .and_then(|value| value.checked_mul(multiplier))
517            .ok_or_else(|| {
518                BackendError::new(
519                    "megakernel sparse-hit capacity overflowed u32. Fix: lower queue length, expected_hits_per_item, or hit_capacity_multiplier.",
520                )
521            })?;
522        Ok(derived.max(self.min_hit_capacity))
523    }
524
525    fn estimated_peak_device_bytes(
526        &self,
527        request: MegakernelLaunchRequest,
528        hit_capacity: u32,
529    ) -> Result<u64, BackendError> {
530        let scratch_bytes = u64::from(hit_capacity)
531            .checked_mul(u64::from(self.scratch_bytes_per_hit))
532            .ok_or_else(|| {
533                BackendError::new(
534                    "megakernel scratch byte estimate overflowed u64. Fix: lower hit capacity or scratch_bytes_per_hit.",
535                )
536            })?;
537        request
538            .resident_device_bytes
539            .checked_add(scratch_bytes)
540            .ok_or_else(|| {
541                BackendError::new(
542                    "megakernel peak resident byte estimate overflowed u64. Fix: reduce resident buffers or scratch capacity.",
543                )
544            })
545    }
546
547    fn infer_missing_scale_signals(
548        &self,
549        mut request: MegakernelLaunchRequest,
550    ) -> Result<MegakernelLaunchRequest, BackendError> {
551        if request.frontier_density_bps == 0
552            && request.queue_len != 0
553            && request.graph_node_count != 0
554        {
555            let active_nodes = u64::from(request.queue_len.min(request.graph_node_count));
556            let density = active_nodes
557                .checked_mul(10_000)
558                .ok_or_else(|| {
559                    BackendError::new(
560                        "megakernel frontier-density numerator overflowed u64. Fix: shard the resident graph before launch.",
561                    )
562                })?
563                .checked_div(u64::from(request.graph_node_count))
564                .unwrap_or(0)
565                .clamp(1, 10_000);
566            request.frontier_density_bps = u16::try_from(density).map_err(|error| {
567                BackendError::new(format!(
568                    "megakernel frontier density cannot fit u16: {error}. Fix: clamp density before ABI encoding."
569                ))
570            })?;
571        }
572        if request.memory_pressure_bps == 0
573            && request.device_memory_budget_bytes != 0
574            && request.resident_device_bytes != 0
575        {
576            let pressure = (u128::from(request.resident_device_bytes)
577                .checked_mul(10_000)
578                .ok_or_else(|| {
579                    BackendError::new(
580                        "megakernel memory-pressure numerator overflowed u128. Fix: reduce resident device bytes before launch.",
581                    )
582                })?
583                / u128::from(request.device_memory_budget_bytes))
584            .min(10_000);
585            request.memory_pressure_bps = u16::try_from(pressure).map_err(|error| {
586                BackendError::new(format!(
587                    "megakernel memory pressure cannot fit u16: {error}. Fix: clamp pressure before ABI encoding."
588                ))
589            })?;
590        }
591        Ok(request)
592    }
593
594    fn apply_topology_worker_policy(
595        &self,
596        mut request: MegakernelLaunchRequest,
597        topology: MegakernelDispatchTopology,
598    ) -> Result<MegakernelLaunchRequest, BackendError> {
599        if topology == MegakernelDispatchTopology::MemoryConstrained
600            && request.memory_pressure_bps != 0
601            && request.requested_worker_groups > 1
602        {
603            let pressure_span = u32::from(
604                10_000_u16
605                    .checked_sub(self.memory_pressure_threshold_bps)
606                    .ok_or_else(|| {
607                        BackendError::new(
608                            "megakernel memory-pressure threshold exceeds 10000 bps. Fix: configure threshold in basis points.",
609                        )
610                    })?,
611            )
612            .max(1);
613            let over_threshold = u32::from(
614                match request
615                    .memory_pressure_bps
616                    .checked_sub(self.memory_pressure_threshold_bps)
617                {
618                    Some(value) => value,
619                    None => 0,
620                },
621            )
622            .min(pressure_span);
623            let shed_bps = 2_500_u32
624                .checked_add(
625                    over_threshold
626                        .checked_mul(2_500)
627                        .ok_or_else(|| {
628                            BackendError::new(
629                                "megakernel memory-pressure worker shed overflowed u32. Fix: lower pressure telemetry before launch.",
630                            )
631                        })?
632                        / pressure_span,
633                )
634                .ok_or_else(|| {
635                    BackendError::new(
636                        "megakernel memory-pressure worker shed overflowed u32. Fix: lower pressure telemetry before launch.",
637                    )
638                })?;
639            let keep_bps = 10_000_u32.checked_sub(shed_bps).ok_or_else(|| {
640                BackendError::new(
641                    "megakernel memory-pressure worker keep ratio underflowed. Fix: keep shed_bps within 0..=10000.",
642                )
643            })?;
644            let scaled = u64::from(request.requested_worker_groups)
645                .checked_mul(u64::from(keep_bps))
646                .ok_or_else(|| {
647                    BackendError::new(
648                        "megakernel memory-constrained worker count overflowed u64. Fix: reduce requested worker groups.",
649                    )
650                })?
651                / 10_000;
652            request.requested_worker_groups = u32::try_from(scaled)
653                .map_err(|error| {
654                    BackendError::new(format!(
655                        "megakernel memory-constrained worker count cannot fit u32: {error}. Fix: reduce requested worker groups."
656                    ))
657                })?
658                .max(1);
659        }
660        if topology == MegakernelDispatchTopology::SparseFrontier
661            && request.graph_node_count != 0
662            && request.frontier_density_bps != 0
663            && request.requested_worker_groups > 1
664        {
665            let sparse_span = u32::from(self.sparse_frontier_threshold_bps).max(1);
666            let density = u32::from(request.frontier_density_bps).clamp(1, sparse_span);
667            let scaled = u64::from(request.requested_worker_groups)
668                .checked_mul(u64::from(density))
669                .ok_or_else(|| {
670                    BackendError::new(
671                        "megakernel sparse-frontier worker count overflowed u64. Fix: reduce requested worker groups.",
672                    )
673                })?
674                / u64::from(sparse_span);
675            let warp_floor = request.requested_worker_groups.min(32);
676            request.requested_worker_groups = u32::try_from(scaled)
677                .map_err(|error| {
678                    BackendError::new(format!(
679                        "megakernel sparse-frontier worker count cannot fit u32: {error}. Fix: reduce requested worker groups."
680                    ))
681                })?
682                .max(warp_floor)
683                .min(request.requested_worker_groups);
684        }
685        Ok(request)
686    }
687
688    fn dispatch_topology_for(
689        &self,
690        request: MegakernelLaunchRequest,
691        promote_hot_opcodes: bool,
692        promote_hot_windows: bool,
693    ) -> MegakernelDispatchTopology {
694        if request.queue_len == 0 {
695            return MegakernelDispatchTopology::Empty;
696        }
697        if request.memory_pressure_bps >= self.memory_pressure_threshold_bps {
698            return MegakernelDispatchTopology::MemoryConstrained;
699        }
700        if request.frontier_density_bps <= self.sparse_frontier_threshold_bps {
701            return MegakernelDispatchTopology::SparseFrontier;
702        }
703        let dense = request.frontier_density_bps >= self.dense_frontier_threshold_bps;
704        let graph_is_large =
705            request.graph_node_count > 0 && request.graph_edge_count >= self.fusion_edge_threshold;
706        if dense && graph_is_large && (promote_hot_opcodes || promote_hot_windows) {
707            return MegakernelDispatchTopology::FusedDense;
708        }
709        if dense {
710            return MegakernelDispatchTopology::DenseFrontier;
711        }
712        MegakernelDispatchTopology::HybridFrontier
713    }
714
715    fn stabilize_topology(
716        &self,
717        raw_topology: MegakernelDispatchTopology,
718        request: MegakernelLaunchRequest,
719        previous_topology: Option<MegakernelDispatchTopology>,
720        promote_hot_opcodes: bool,
721        promote_hot_windows: bool,
722    ) -> MegakernelDispatchTopology {
723        if raw_topology == MegakernelDispatchTopology::Empty {
724            return raw_topology;
725        }
726        if raw_topology == MegakernelDispatchTopology::MemoryConstrained {
727            return raw_topology;
728        }
729        let Some(previous_topology) = previous_topology else {
730            return raw_topology;
731        };
732        if previous_topology == MegakernelDispatchTopology::MemoryConstrained
733            && request.memory_pressure_bps
734                >= hysteresis_sub(
735                    self.memory_pressure_threshold_bps,
736                    MEMORY_TOPOLOGY_HYSTERESIS_BPS,
737                )
738        {
739            return MegakernelDispatchTopology::MemoryConstrained;
740        }
741
742        match previous_topology {
743            MegakernelDispatchTopology::SparseFrontier
744                if raw_topology != MegakernelDispatchTopology::SparseFrontier
745                    && request.frontier_density_bps
746                        <= hysteresis_add(
747                            self.sparse_frontier_threshold_bps,
748                            FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
749                        ) =>
750            {
751                MegakernelDispatchTopology::SparseFrontier
752            }
753            MegakernelDispatchTopology::HybridFrontier
754                if raw_topology == MegakernelDispatchTopology::SparseFrontier
755                    && request.frontier_density_bps
756                        >= hysteresis_sub(
757                            self.sparse_frontier_threshold_bps,
758                            FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
759                        ) =>
760            {
761                MegakernelDispatchTopology::HybridFrontier
762            }
763            MegakernelDispatchTopology::HybridFrontier
764                if matches!(
765                    raw_topology,
766                    MegakernelDispatchTopology::DenseFrontier
767                        | MegakernelDispatchTopology::FusedDense
768                ) && request.frontier_density_bps
769                    <= hysteresis_add(
770                        self.dense_frontier_threshold_bps,
771                        FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
772                    ) =>
773            {
774                MegakernelDispatchTopology::HybridFrontier
775            }
776            MegakernelDispatchTopology::DenseFrontier
777                if raw_topology == MegakernelDispatchTopology::HybridFrontier
778                    && request.frontier_density_bps
779                        >= hysteresis_sub(
780                            self.dense_frontier_threshold_bps,
781                            FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
782                        ) =>
783            {
784                MegakernelDispatchTopology::DenseFrontier
785            }
786            MegakernelDispatchTopology::FusedDense
787                if raw_topology == MegakernelDispatchTopology::HybridFrontier
788                    && request.frontier_density_bps
789                        >= hysteresis_sub(
790                            self.dense_frontier_threshold_bps,
791                            FRONTIER_TOPOLOGY_HYSTERESIS_BPS,
792                        )
793                    && request.graph_edge_count >= self.fusion_edge_threshold
794                    && (promote_hot_opcodes || promote_hot_windows) =>
795            {
796                MegakernelDispatchTopology::FusedDense
797            }
798            _ => raw_topology,
799        }
800    }
801
802    /// Select the best `hit_capacity_multiplier` from a candidate set.
803    ///
804    /// `candidate_multipliers` are the multipliers to try; `costs[i]`
805    /// is the observed dispatch latency (or any minimization metric)
806    /// when `candidate_multipliers[i]` was used. Lower cost wins; the
807    /// minimum observed cost selects the multiplier.
808    ///
809    /// Returns the chosen multiplier. If `candidate_multipliers` is
810    /// empty, returns the policy's existing `hit_capacity_multiplier`.
811    ///
812    #[must_use]
813    pub fn autotune_hit_capacity_multiplier(
814        &self,
815        candidate_multipliers: &[u32],
816        costs: &[f64],
817    ) -> u32 {
818        if candidate_multipliers.is_empty() || costs.is_empty() {
819            return self.hit_capacity_multiplier;
820        }
821        let n = candidate_multipliers.len().min(costs.len());
822        let chosen = best_cost_index(&costs[..n]);
823        candidate_multipliers
824            .get(chosen)
825            .copied()
826            .unwrap_or(self.hit_capacity_multiplier)
827    }
828
829    /// Select the best workgroup-size from a candidate set.
830    ///
831    /// `candidate_sizes[i]` is paired
832    /// with `costs[i]` (lower is better). Returns the chosen size or
833    /// the policy's `sizing.default_workgroup_size_x()` fallback.
834    #[must_use]
835    pub fn autotune_workgroup_size(
836        &self,
837        candidate_sizes: &[u32],
838        costs: &[f64],
839        current_size: u32,
840    ) -> u32 {
841        if candidate_sizes.is_empty() || costs.is_empty() {
842            return current_size;
843        }
844        let n = candidate_sizes.len().min(costs.len());
845        let chosen = best_cost_index(&costs[..n]);
846        candidate_sizes.get(chosen).copied().unwrap_or(current_size)
847    }
848
849    /// Compute the next-step parameter delta for a continuous autotune
850    /// knob using a Fisher-preconditioned natural-gradient step.
851    ///
852    /// `m_inv_sqrt`: inverse-square-root of the Fisher block (n×n
853    /// row-major). Passing an identity matrix reduces the natural
854    /// gradient to plain gradient descent.
855    ///
856    /// `grad`: plain gradient ∂latency/∂param (length n).
857    ///
858    /// Returns the parameter delta `-lr · M_inv_sqrt · grad`.
859    ///
860    /// P-DRIVER-8: every continuous autotune knob (workgroup size,
861    /// hit-capacity, fixpoint iteration count, …) should follow the
862    /// natural-gradient direction by default  -  Fisher-preconditioned
863    /// descent converges 5-10× faster than plain gradient on the
864    /// elongated-valley latency surfaces typical of GPU autotuning.
865    #[must_use]
866    pub fn natural_gradient_autotune_step(
867        m_inv_sqrt: &[f64],
868        grad: &[f64],
869        n: u32,
870        learning_rate: f64,
871    ) -> Vec<f64> {
872        Self::try_natural_gradient_autotune_step(m_inv_sqrt, grad, n, learning_rate)
873            .unwrap_or_else(|source| {
874                panic!(
875                    "megakernel natural-gradient autotune allocation failed: {source}. Fix: shard the autotune surface."
876                )
877            })
878    }
879
880    /// Compute the next-step parameter delta with fallible output staging.
881    ///
882    /// # Errors
883    ///
884    /// Returns [`BackendError`] when host staging cannot be reserved for the
885    /// natural-gradient vector.
886    pub fn try_natural_gradient_autotune_step(
887        m_inv_sqrt: &[f64],
888        grad: &[f64],
889        n: u32,
890        learning_rate: f64,
891    ) -> Result<Vec<f64>, BackendError> {
892        let mut out = Vec::new();
893        Self::try_natural_gradient_autotune_step_into(
894            m_inv_sqrt,
895            grad,
896            n,
897            learning_rate,
898            &mut out,
899        )?;
900        Ok(out)
901    }
902
903    /// Compute the natural-gradient autotune step into caller-owned storage.
904    pub fn natural_gradient_autotune_step_into(
905        m_inv_sqrt: &[f64],
906        grad: &[f64],
907        n: u32,
908        learning_rate: f64,
909        out: &mut Vec<f64>,
910    ) {
911        Self::try_natural_gradient_autotune_step_into(m_inv_sqrt, grad, n, learning_rate, out)
912            .unwrap_or_else(|source| {
913                panic!(
914                    "megakernel natural-gradient autotune allocation failed: {source}. Fix: shard the autotune surface."
915                )
916            });
917    }
918
919    /// Compute the natural-gradient autotune step into caller-owned storage
920    /// with fallible host staging.
921    ///
922    /// # Errors
923    ///
924    /// Returns [`BackendError`] when host staging cannot be reserved for the
925    /// natural-gradient vector.
926    pub fn try_natural_gradient_autotune_step_into(
927        m_inv_sqrt: &[f64],
928        grad: &[f64],
929        n: u32,
930        learning_rate: f64,
931        out: &mut Vec<f64>,
932    ) -> Result<(), BackendError> {
933        let n = u32_to_usize_or_panic(n, "natural-gradient dimension");
934        out.clear();
935        let Some(required_matrix_len) = n.checked_mul(n) else {
936            return Ok(());
937        };
938        if m_inv_sqrt.len() < required_matrix_len || grad.len() < n {
939            return Ok(());
940        }
941        reserve_target_capacity(out, n, "natural-gradient output")?;
942        out.resize(n, 0.0);
943        for row in 0..n {
944            let mut acc = 0.0;
945            for col in 0..n {
946                acc += m_inv_sqrt[row * n + col] * grad[col];
947            }
948            out[row] = -learning_rate * acc;
949        }
950        Ok(())
951    }
952}
953
954
955fn diffuse_step_into(
956    stalks: &[f64],
957    restriction_diag: &[f64],
958    damping: f64,
959    out: &mut Vec<f64>,
960) -> Result<(), BackendError> {
961    out.clear();
962    reserve_target_capacity(out, stalks.len(), "priority diffusion scratch")?;
963    out.resize(stalks.len(), 0.0);
964    for ((slot, &stalk), &restriction) in out
965        .iter_mut()
966        .zip(stalks.iter())
967        .zip(restriction_diag.iter())
968    {
969        *slot = stalk - damping * restriction * stalk;
970    }
971    Ok(())
972}
973
974fn reserve_target_capacity<T>(
975    out: &mut Vec<T>,
976    target_capacity: usize,
977    label: &'static str,
978) -> Result<(), BackendError> {
979    try_reserve_vec_capacity(out, target_capacity).map_err(|source| {
980        BackendError::new(format!(
981            "megakernel {label} reservation failed for {target_capacity} element(s): {source}. Fix: shard the policy input before launch-policy math."
982        ))
983    })
984}
985
986fn best_cost_index(costs: &[f64]) -> usize {
987    debug_assert!(!costs.is_empty());
988    let mut best = 0;
989    let mut best_cost = costs[0];
990    for (index, &cost) in costs.iter().enumerate().skip(1) {
991        if cost.total_cmp(&best_cost).is_lt() {
992            best = index;
993            best_cost = cost;
994        }
995    }
996    best
997}
998
999fn u32_to_usize_or_panic(value: u32, label: &'static str) -> usize {
1000    match usize::try_from(value) {
1001        Ok(value) => value,
1002        Err(error) => {
1003            panic!("{label} cannot fit usize: {error}. Fix: shard the autotune surface.")
1004        }
1005    }
1006}
1007
1008fn hysteresis_add(value: u16, hysteresis: u16) -> u16 {
1009    value.checked_add(hysteresis).unwrap_or_else(|| {
1010        panic!(
1011            "megakernel topology hysteresis upper bound overflowed u16. Fix: lower topology threshold or hysteresis."
1012        )
1013    })
1014}
1015
1016fn hysteresis_sub(value: u16, hysteresis: u16) -> u16 {
1017    value.checked_sub(hysteresis).unwrap_or_else(|| {
1018        panic!(
1019            "megakernel topology hysteresis lower bound underflowed u16. Fix: lower hysteresis or raise topology threshold."
1020        )
1021    })
1022}
1023
1024fn classify_pressure(
1025    queue_len: u32,
1026    lanes: u64,
1027    requeue_count: u64,
1028    policy: &MegakernelLaunchPolicy,
1029) -> Result<MegakernelQueuePressure, BackendError> {
1030    if queue_len == 0 {
1031        return Ok(MegakernelQueuePressure::Empty);
1032    }
1033    let lanes = lanes.max(1);
1034    let queue_len = u64::from(queue_len);
1035    let saturated_lanes = lanes
1036        .checked_mul(u64::from(policy.saturated_waves))
1037        .ok_or_else(|| {
1038            BackendError::new(
1039                "megakernel pressure wave threshold overflowed u64. Fix: reduce worker lanes or saturated_waves.",
1040            )
1041        })?;
1042    if requeue_count > 0 || queue_len >= saturated_lanes {
1043        Ok(MegakernelQueuePressure::Saturated)
1044    } else if queue_len >= lanes {
1045        Ok(MegakernelQueuePressure::Balanced)
1046    } else {
1047        Ok(MegakernelQueuePressure::Light)
1048    }
1049}
1050
1051#[cfg(test)]
1052mod tests;
1053