Skip to main content

oxicuda_ptx/
profile_guided.rs

1//! Profile-guided code generation for PTX kernels.
2//!
3//! This module uses profiling data (from autotuning or `nsight` runs) to make
4//! informed decisions about PTX instruction selection, loop unrolling, memory
5//! access strategies, and tile sizing.  The [`ProfileGuidedOptimizer`] ingests
6//! a [`ProfileData`] snapshot and emits a set of [`CodeGenDecision`]s that
7//! downstream kernel builders can apply.
8
9use std::fmt;
10
11use crate::arch::SmVersion;
12
13// ---------------------------------------------------------------------------
14// Profile data types
15// ---------------------------------------------------------------------------
16
17/// Collected profiling information for a single kernel invocation.
18///
19/// This is the primary input to the profile-guided optimizer. Typically
20/// constructed from autotune results or external profiler output.
21#[derive(Debug, Clone)]
22pub struct ProfileData {
23    /// Name of the profiled kernel.
24    pub kernel_name: String,
25    /// Target SM architecture the profile was gathered on.
26    pub sm_version: SmVersion,
27    /// Aggregate performance metrics.
28    pub metrics: ProfileMetrics,
29    /// Hot instruction indices with stall information.
30    pub hotspots: Vec<HotSpot>,
31    /// Per-branch taken/not-taken statistics.
32    pub branch_stats: Vec<BranchProfile>,
33    /// Memory access coalescing and caching statistics.
34    pub memory_access_pattern: MemoryAccessProfile,
35}
36
37/// Aggregate GPU performance metrics (all ratios are 0.0–1.0 unless noted).
38#[derive(Debug, Clone, Copy)]
39pub struct ProfileMetrics {
40    /// Fraction of the theoretical maximum occupancy achieved.
41    pub achieved_occupancy: f64,
42    /// Fraction of peak compute throughput utilised.
43    pub compute_throughput: f64,
44    /// Fraction of peak memory bandwidth utilised.
45    pub memory_throughput: f64,
46    /// L2 cache hit rate.
47    pub l2_hit_rate: f64,
48    /// Shared memory transaction efficiency.
49    pub shared_memory_efficiency: f64,
50    /// Fraction of warps with all lanes active.
51    pub warp_execution_efficiency: f64,
52    /// Instructions retired per clock cycle.
53    pub ipc: f64,
54}
55
56/// A single hot instruction with cycle count and stall classification.
57#[derive(Debug, Clone)]
58pub struct HotSpot {
59    /// Index into the instruction stream.
60    pub instruction_index: usize,
61    /// Total cycles spent at this instruction.
62    pub cycle_count: u64,
63    /// Dominant stall category at this instruction.
64    pub stall_reason: StallReason,
65}
66
67/// Reason a warp stalled at a particular instruction.
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum StallReason {
70    /// No significant stall.
71    None,
72    /// Waiting for a memory operation to complete.
73    MemoryDependency,
74    /// Waiting for a prior arithmetic result.
75    ExecutionDependency,
76    /// Blocked on a synchronisation barrier.
77    SyncBarrier,
78    /// Instruction cache miss.
79    InstructionFetch,
80    /// Any other stall category.
81    Other(String),
82}
83
84impl fmt::Display for StallReason {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        match self {
87            Self::None => f.write_str("none"),
88            Self::MemoryDependency => f.write_str("memory_dependency"),
89            Self::ExecutionDependency => f.write_str("execution_dependency"),
90            Self::SyncBarrier => f.write_str("sync_barrier"),
91            Self::InstructionFetch => f.write_str("instruction_fetch"),
92            Self::Other(s) => write!(f, "other({s})"),
93        }
94    }
95}
96
97/// Taken/not-taken statistics for a single branch site.
98#[derive(Debug, Clone, Copy)]
99pub struct BranchProfile {
100    /// Index of the branch instruction.
101    pub branch_index: usize,
102    /// Number of times the branch was taken.
103    pub taken_count: u64,
104    /// Number of times the branch was *not* taken.
105    pub not_taken_count: u64,
106}
107
108impl BranchProfile {
109    /// Returns the fraction of executions where the branch was taken.
110    ///
111    /// Returns 0.0 if neither path was ever executed.
112    #[must_use]
113    pub fn taken_ratio(&self) -> f64 {
114        let total = self.taken_count + self.not_taken_count;
115        if total == 0 {
116            return 0.0;
117        }
118        #[allow(clippy::cast_precision_loss)]
119        let ratio = self.taken_count as f64 / total as f64;
120        ratio
121    }
122
123    /// Returns `true` if the branch is biased beyond `threshold` in
124    /// either direction (taken ratio > threshold or < 1 − threshold).
125    #[must_use]
126    pub fn is_biased(&self, threshold: f64) -> bool {
127        let ratio = self.taken_ratio();
128        ratio > threshold || ratio < (1.0 - threshold)
129    }
130}
131
132impl fmt::Display for BranchProfile {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        write!(
135            f,
136            "branch[{}]: taken={} not_taken={} ratio={:.2}%",
137            self.branch_index,
138            self.taken_count,
139            self.not_taken_count,
140            self.taken_ratio() * 100.0,
141        )
142    }
143}
144
145/// Memory access pattern statistics.
146#[derive(Debug, Clone, Copy)]
147pub struct MemoryAccessProfile {
148    /// Fraction of global loads that are coalesced (0.0–1.0).
149    pub coalesced_ratio: f64,
150    /// Fraction of shared memory accesses with bank conflicts (0.0–1.0).
151    pub bank_conflict_rate: f64,
152    /// Average fraction of each cache line actually consumed (0.0–1.0).
153    pub cache_line_utilization: f64,
154}
155
156impl fmt::Display for MemoryAccessProfile {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        write!(
159            f,
160            "coalesced={:.1}% bank_conflicts={:.1}% cache_util={:.1}%",
161            self.coalesced_ratio * 100.0,
162            self.bank_conflict_rate * 100.0,
163            self.cache_line_utilization * 100.0,
164        )
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Bottleneck classification
170// ---------------------------------------------------------------------------
171
172/// High-level classification of a kernel's performance bottleneck.
173#[derive(Debug, Clone, Copy, PartialEq, Eq)]
174pub enum Bottleneck {
175    /// The kernel is limited by arithmetic throughput.
176    ComputeBound,
177    /// The kernel is limited by memory bandwidth.
178    MemoryBound,
179    /// The kernel is limited by instruction or data latency (pipeline bubbles).
180    LatencyBound,
181    /// No single bottleneck dominates — the kernel is reasonably balanced.
182    Balanced,
183}
184
185impl fmt::Display for Bottleneck {
186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187        match self {
188            Self::ComputeBound => f.write_str("compute-bound"),
189            Self::MemoryBound => f.write_str("memory-bound"),
190            Self::LatencyBound => f.write_str("latency-bound"),
191            Self::Balanced => f.write_str("balanced"),
192        }
193    }
194}
195
196// ---------------------------------------------------------------------------
197// Code generation decisions
198// ---------------------------------------------------------------------------
199
200/// A concrete optimisation decision derived from profiling data.
201#[derive(Debug, Clone, PartialEq, Eq)]
202pub enum CodeGenDecision {
203    /// Unroll a hot loop by the given factor.
204    UnrollLoop {
205        /// Number of iterations to unroll.
206        factor: u32,
207    },
208    /// Convert a heavily biased branch to a predicated instruction.
209    PredicateBranch,
210    /// Insert prefetch instructions at the given distance (in iterations).
211    PrefetchMemory {
212        /// Prefetch lookahead distance.
213        distance: u32,
214    },
215    /// Increase occupancy by targeting the given number of blocks per SM.
216    IncreaseOccupancy {
217        /// Desired concurrent blocks per SM.
218        target_blocks: u32,
219    },
220    /// Use larger tile dimensions for a compute-bound GEMM.
221    UseLargerTiles {
222        /// Tile size in the M dimension.
223        tile_m: u32,
224        /// Tile size in the N dimension.
225        tile_n: u32,
226    },
227    /// Promote global memory loads to shared memory.
228    SwitchToSharedMemory,
229    /// Enable split-K parallelism.
230    EnableSplitK {
231        /// Number of K-dimension slices.
232        k_slices: u32,
233    },
234}
235
236impl fmt::Display for CodeGenDecision {
237    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238        match self {
239            Self::UnrollLoop { factor } => write!(f, "unroll loop x{factor}"),
240            Self::PredicateBranch => f.write_str("convert branch to predicated"),
241            Self::PrefetchMemory { distance } => {
242                write!(f, "insert prefetch (distance={distance})")
243            }
244            Self::IncreaseOccupancy { target_blocks } => {
245                write!(f, "increase occupancy to {target_blocks} blocks/SM")
246            }
247            Self::UseLargerTiles { tile_m, tile_n } => {
248                write!(f, "use larger tiles ({tile_m}x{tile_n})")
249            }
250            Self::SwitchToSharedMemory => f.write_str("switch to shared memory"),
251            Self::EnableSplitK { k_slices } => {
252                write!(f, "enable split-K ({k_slices} slices)")
253            }
254        }
255    }
256}
257
258// ---------------------------------------------------------------------------
259// Tile configuration
260// ---------------------------------------------------------------------------
261
262/// Suggested tile configuration produced by the optimizer.
263#[derive(Debug, Clone, Copy, PartialEq, Eq)]
264pub struct TileConfig {
265    /// Tile size in the M dimension.
266    pub tile_m: u32,
267    /// Tile size in the N dimension.
268    pub tile_n: u32,
269    /// Tile size in the K dimension.
270    pub tile_k: u32,
271}
272
273impl fmt::Display for TileConfig {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        write!(f, "{}x{}x{}", self.tile_m, self.tile_n, self.tile_k)
276    }
277}
278
279// ---------------------------------------------------------------------------
280// KernelProfile — mutable configuration that decisions are applied to
281// ---------------------------------------------------------------------------
282
283/// Mutable kernel configuration that the profile-guided optimizer adjusts.
284///
285/// Downstream builders read these fields after optimisation to generate the
286/// final PTX code.
287#[derive(Debug, Clone)]
288pub struct KernelProfile {
289    /// Tile size in the M dimension.
290    pub tile_m: u32,
291    /// Tile size in the N dimension.
292    pub tile_n: u32,
293    /// Tile size in the K dimension.
294    pub tile_k: u32,
295    /// Loop unroll factor.
296    pub unroll_factor: u32,
297    /// Whether shared memory staging is enabled.
298    pub use_shared_memory: bool,
299    /// Target register count per thread (0 = no constraint).
300    pub register_target: u32,
301    /// Number of split-K slices (1 = disabled).
302    pub split_k: u32,
303}
304
305impl KernelProfile {
306    /// Creates a new `KernelProfile` with sensible defaults.
307    #[must_use]
308    pub const fn new() -> Self {
309        Self {
310            tile_m: 64,
311            tile_n: 64,
312            tile_k: 8,
313            unroll_factor: 1,
314            use_shared_memory: false,
315            register_target: 0,
316            split_k: 1,
317        }
318    }
319}
320
321impl Default for KernelProfile {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327impl fmt::Display for KernelProfile {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        write!(
330            f,
331            "tile={}x{}x{} unroll={} smem={} regs={} split_k={}",
332            self.tile_m,
333            self.tile_n,
334            self.tile_k,
335            self.unroll_factor,
336            if self.use_shared_memory { "on" } else { "off" },
337            self.register_target,
338            self.split_k,
339        )
340    }
341}
342
343// ---------------------------------------------------------------------------
344// Thresholds (internal constants)
345// ---------------------------------------------------------------------------
346
347/// Ratio above which a kernel is considered compute-bound.
348const COMPUTE_BOUND_THRESHOLD: f64 = 0.7;
349/// Ratio above which a kernel is considered memory-bound.
350const MEMORY_BOUND_THRESHOLD: f64 = 0.7;
351/// IPC below which a kernel is considered latency-bound.
352const LATENCY_BOUND_IPC_THRESHOLD: f64 = 1.0;
353/// Default branch bias threshold (90 %).
354const DEFAULT_BRANCH_BIAS_THRESHOLD: f64 = 0.9;
355/// Occupancy below which we recommend increasing it.
356const LOW_OCCUPANCY_THRESHOLD: f64 = 0.5;
357/// Coalescing ratio below which we recommend shared memory staging.
358const POOR_COALESCING_THRESHOLD: f64 = 0.5;
359/// Memory throughput above which prefetch is beneficial.
360const PREFETCH_MEMORY_THROUGHPUT_THRESHOLD: f64 = 0.5;
361
362// ---------------------------------------------------------------------------
363// ProfileGuidedOptimizer
364// ---------------------------------------------------------------------------
365
366/// Analyses profiling data and produces [`CodeGenDecision`]s.
367#[derive(Debug, Clone)]
368pub struct ProfileGuidedOptimizer {
369    profile: ProfileData,
370}
371
372impl ProfileGuidedOptimizer {
373    /// Create a new optimizer from the given profile data.
374    #[must_use]
375    pub const fn new(profile: ProfileData) -> Self {
376        Self { profile }
377    }
378
379    /// Classify the kernel's dominant bottleneck.
380    #[must_use]
381    pub fn classify_bottleneck(&self) -> Bottleneck {
382        let m = &self.profile.metrics;
383
384        let compute_heavy = m.compute_throughput >= COMPUTE_BOUND_THRESHOLD;
385        let memory_heavy = m.memory_throughput >= MEMORY_BOUND_THRESHOLD;
386
387        match (compute_heavy, memory_heavy) {
388            (true, false) => Bottleneck::ComputeBound,
389            (false, true) => Bottleneck::MemoryBound,
390            (true, true) => Bottleneck::Balanced,
391            (false, false) => {
392                // Neither unit is saturated — check IPC for latency bound.
393                if m.ipc < LATENCY_BOUND_IPC_THRESHOLD
394                    && m.achieved_occupancy < LOW_OCCUPANCY_THRESHOLD
395                {
396                    Bottleneck::LatencyBound
397                } else {
398                    Bottleneck::Balanced
399                }
400            }
401        }
402    }
403
404    /// Produce a list of optimisation decisions based on the profile.
405    ///
406    /// The returned decisions are ordered from most impactful to least.
407    #[must_use]
408    pub fn analyze(&self) -> Vec<CodeGenDecision> {
409        let mut decisions = Vec::new();
410        let bottleneck = self.classify_bottleneck();
411
412        // --- Unroll hot loops ---
413        let unroll = self.suggest_unroll_factor();
414        if unroll > 1 {
415            decisions.push(CodeGenDecision::UnrollLoop { factor: unroll });
416        }
417
418        // --- Branch predication ---
419        for bp in &self.profile.branch_stats {
420            if bp.is_biased(DEFAULT_BRANCH_BIAS_THRESHOLD) {
421                decisions.push(CodeGenDecision::PredicateBranch);
422                break; // one decision covers all biased branches
423            }
424        }
425
426        // --- Memory-bound specific ---
427        if bottleneck == Bottleneck::MemoryBound || bottleneck == Bottleneck::Balanced {
428            let mem = &self.profile.memory_access_pattern;
429            if mem.coalesced_ratio < POOR_COALESCING_THRESHOLD {
430                decisions.push(CodeGenDecision::SwitchToSharedMemory);
431            }
432            if self.profile.metrics.memory_throughput > PREFETCH_MEMORY_THROUGHPUT_THRESHOLD {
433                let distance = self.suggest_prefetch_distance();
434                decisions.push(CodeGenDecision::PrefetchMemory { distance });
435            }
436        }
437
438        // --- Occupancy ---
439        if self.profile.metrics.achieved_occupancy < LOW_OCCUPANCY_THRESHOLD {
440            let target = self.suggest_target_blocks();
441            decisions.push(CodeGenDecision::IncreaseOccupancy {
442                target_blocks: target,
443            });
444        }
445
446        // --- Compute-bound tile sizing ---
447        if bottleneck == Bottleneck::ComputeBound {
448            decisions.push(CodeGenDecision::UseLargerTiles {
449                tile_m: 128,
450                tile_n: 128,
451            });
452        }
453
454        // --- Split-K for tall-skinny K ---
455        if bottleneck == Bottleneck::LatencyBound {
456            decisions.push(CodeGenDecision::EnableSplitK { k_slices: 4 });
457        }
458
459        decisions
460    }
461
462    /// Suggest a tile configuration for a GEMM of the given dimensions.
463    #[must_use]
464    pub fn suggest_tile_config(&self, m: u32, n: u32, k: u32) -> TileConfig {
465        let bottleneck = self.classify_bottleneck();
466        let caps = self.profile.sm_version.capabilities();
467
468        // Base tile sizes depend on bottleneck classification.
469        let (base_m, base_n) = match bottleneck {
470            Bottleneck::ComputeBound => {
471                if caps.has_wgmma {
472                    (256, 128) // Hopper+ can sustain larger tiles
473                } else if caps.has_ampere_mma {
474                    (128, 128)
475                } else {
476                    (128, 64)
477                }
478            }
479            Bottleneck::MemoryBound => (64, 64),
480            Bottleneck::LatencyBound => (64, 32),
481            Bottleneck::Balanced => (128, 64),
482        };
483
484        // Clamp to problem dimensions.
485        let tile_m = base_m.min(m);
486        let tile_n = base_n.min(n);
487
488        // K tile: for memory-bound kernels use deeper K tiles for reuse.
489        let tile_k = match bottleneck {
490            Bottleneck::MemoryBound => 32.min(k),
491            Bottleneck::ComputeBound => 16.min(k),
492            _ => 8.min(k),
493        };
494
495        TileConfig {
496            tile_m,
497            tile_n,
498            tile_k,
499        }
500    }
501
502    /// Suggest an unroll factor based on hotspot and IPC data.
503    #[must_use]
504    pub fn suggest_unroll_factor(&self) -> u32 {
505        let m = &self.profile.metrics;
506
507        // Count memory-dependency stalls — unrolling helps hide latency.
508        let mem_stalls = self
509            .profile
510            .hotspots
511            .iter()
512            .filter(|h| h.stall_reason == StallReason::MemoryDependency)
513            .count();
514
515        if mem_stalls >= 3 {
516            return 8;
517        }
518
519        if m.ipc < 1.0 {
520            return 4;
521        }
522
523        if m.ipc < 2.0 {
524            return 2;
525        }
526
527        1
528    }
529
530    // --- private helpers ---
531
532    /// Suggest prefetch distance based on memory throughput and L2 hit rate.
533    fn suggest_prefetch_distance(&self) -> u32 {
534        let m = &self.profile.metrics;
535        if m.l2_hit_rate < 0.3 {
536            4 // deep prefetch for poor caching
537        } else if m.l2_hit_rate < 0.6 {
538            2
539        } else {
540            1
541        }
542    }
543
544    /// Suggest target concurrent blocks per SM.
545    fn suggest_target_blocks(&self) -> u32 {
546        let max_threads = self.profile.sm_version.max_threads_per_sm();
547        // Aim for 75 % of max threads at 128 threads/block.
548        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
549        let target_threads = (f64::from(max_threads) * 0.75) as u32;
550        let blocks = target_threads / 128;
551        blocks.max(2)
552    }
553}
554
555// ---------------------------------------------------------------------------
556// apply_profile_decisions
557// ---------------------------------------------------------------------------
558
559/// Apply a set of [`CodeGenDecision`]s to a mutable [`KernelProfile`].
560///
561/// Returns a human-readable log of every change that was made.
562pub fn apply_profile_decisions(
563    decisions: &[CodeGenDecision],
564    config: &mut KernelProfile,
565) -> Vec<String> {
566    let mut log = Vec::with_capacity(decisions.len());
567
568    for decision in decisions {
569        match decision {
570            CodeGenDecision::UnrollLoop { factor } => {
571                let prev = config.unroll_factor;
572                config.unroll_factor = *factor;
573                log.push(format!("unroll factor: {prev} -> {factor}"));
574            }
575            CodeGenDecision::PredicateBranch => {
576                log.push("enabled branch predication".to_string());
577            }
578            CodeGenDecision::PrefetchMemory { distance } => {
579                log.push(format!("enabled prefetch with distance {distance}"));
580            }
581            CodeGenDecision::IncreaseOccupancy { target_blocks } => {
582                // Reduce register pressure to fit more blocks.
583                let new_target = 255 / target_blocks;
584                let prev = config.register_target;
585                config.register_target = new_target;
586                log.push(format!(
587                    "register target: {prev} -> {new_target} (for {target_blocks} blocks/SM)"
588                ));
589            }
590            CodeGenDecision::UseLargerTiles { tile_m, tile_n } => {
591                let prev_m = config.tile_m;
592                let prev_n = config.tile_n;
593                config.tile_m = *tile_m;
594                config.tile_n = *tile_n;
595                log.push(format!("tile size: {prev_m}x{prev_n} -> {tile_m}x{tile_n}"));
596            }
597            CodeGenDecision::SwitchToSharedMemory => {
598                config.use_shared_memory = true;
599                log.push("enabled shared memory staging".to_string());
600            }
601            CodeGenDecision::EnableSplitK { k_slices } => {
602                let prev = config.split_k;
603                config.split_k = *k_slices;
604                log.push(format!("split-K: {prev} -> {k_slices} slices"));
605            }
606        }
607    }
608
609    log
610}
611
612// ---------------------------------------------------------------------------
613// Tests
614// ---------------------------------------------------------------------------
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    /// Helper to build a `ProfileData` with the given metrics.
621    fn make_profile(metrics: ProfileMetrics) -> ProfileData {
622        ProfileData {
623            kernel_name: "test_kernel".to_string(),
624            sm_version: SmVersion::Sm80,
625            metrics,
626            hotspots: Vec::new(),
627            branch_stats: Vec::new(),
628            memory_access_pattern: MemoryAccessProfile {
629                coalesced_ratio: 0.9,
630                bank_conflict_rate: 0.05,
631                cache_line_utilization: 0.85,
632            },
633        }
634    }
635
636    fn balanced_metrics() -> ProfileMetrics {
637        ProfileMetrics {
638            achieved_occupancy: 0.75,
639            compute_throughput: 0.5,
640            memory_throughput: 0.5,
641            l2_hit_rate: 0.6,
642            shared_memory_efficiency: 0.9,
643            warp_execution_efficiency: 0.95,
644            ipc: 2.5,
645        }
646    }
647
648    fn compute_bound_metrics() -> ProfileMetrics {
649        ProfileMetrics {
650            achieved_occupancy: 0.8,
651            compute_throughput: 0.85,
652            memory_throughput: 0.3,
653            l2_hit_rate: 0.7,
654            shared_memory_efficiency: 0.9,
655            warp_execution_efficiency: 0.95,
656            ipc: 3.0,
657        }
658    }
659
660    fn memory_bound_metrics() -> ProfileMetrics {
661        ProfileMetrics {
662            achieved_occupancy: 0.7,
663            compute_throughput: 0.2,
664            memory_throughput: 0.85,
665            l2_hit_rate: 0.4,
666            shared_memory_efficiency: 0.6,
667            warp_execution_efficiency: 0.9,
668            ipc: 1.5,
669        }
670    }
671
672    fn latency_bound_metrics() -> ProfileMetrics {
673        ProfileMetrics {
674            achieved_occupancy: 0.3,
675            compute_throughput: 0.15,
676            memory_throughput: 0.2,
677            l2_hit_rate: 0.5,
678            shared_memory_efficiency: 0.7,
679            warp_execution_efficiency: 0.8,
680            ipc: 0.5,
681        }
682    }
683
684    // --- Bottleneck classification tests ---
685
686    #[test]
687    fn classify_compute_bound() {
688        let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
689        assert_eq!(opt.classify_bottleneck(), Bottleneck::ComputeBound);
690    }
691
692    #[test]
693    fn classify_memory_bound() {
694        let opt = ProfileGuidedOptimizer::new(make_profile(memory_bound_metrics()));
695        assert_eq!(opt.classify_bottleneck(), Bottleneck::MemoryBound);
696    }
697
698    #[test]
699    fn classify_latency_bound() {
700        let opt = ProfileGuidedOptimizer::new(make_profile(latency_bound_metrics()));
701        assert_eq!(opt.classify_bottleneck(), Bottleneck::LatencyBound);
702    }
703
704    #[test]
705    fn classify_balanced() {
706        let opt = ProfileGuidedOptimizer::new(make_profile(balanced_metrics()));
707        assert_eq!(opt.classify_bottleneck(), Bottleneck::Balanced);
708    }
709
710    #[test]
711    fn classify_both_saturated_is_balanced() {
712        let mut m = balanced_metrics();
713        m.compute_throughput = 0.8;
714        m.memory_throughput = 0.8;
715        let opt = ProfileGuidedOptimizer::new(make_profile(m));
716        assert_eq!(opt.classify_bottleneck(), Bottleneck::Balanced);
717    }
718
719    // --- Decision generation tests ---
720
721    #[test]
722    fn compute_bound_suggests_larger_tiles() {
723        let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
724        let decisions = opt.analyze();
725        assert!(
726            decisions
727                .iter()
728                .any(|d| matches!(d, CodeGenDecision::UseLargerTiles { .. })),
729            "expected UseLargerTiles in {decisions:?}"
730        );
731    }
732
733    #[test]
734    fn memory_bound_with_poor_coalescing_suggests_shared_mem() {
735        let mut profile = make_profile(memory_bound_metrics());
736        profile.memory_access_pattern.coalesced_ratio = 0.3;
737        let opt = ProfileGuidedOptimizer::new(profile);
738        let decisions = opt.analyze();
739        assert!(
740            decisions
741                .iter()
742                .any(|d| matches!(d, CodeGenDecision::SwitchToSharedMemory)),
743            "expected SwitchToSharedMemory in {decisions:?}"
744        );
745    }
746
747    #[test]
748    fn latency_bound_suggests_split_k() {
749        let opt = ProfileGuidedOptimizer::new(make_profile(latency_bound_metrics()));
750        let decisions = opt.analyze();
751        assert!(
752            decisions
753                .iter()
754                .any(|d| matches!(d, CodeGenDecision::EnableSplitK { .. })),
755            "expected EnableSplitK in {decisions:?}"
756        );
757    }
758
759    #[test]
760    fn low_occupancy_suggests_increase() {
761        let mut m = balanced_metrics();
762        m.achieved_occupancy = 0.3;
763        let opt = ProfileGuidedOptimizer::new(make_profile(m));
764        let decisions = opt.analyze();
765        assert!(
766            decisions
767                .iter()
768                .any(|d| matches!(d, CodeGenDecision::IncreaseOccupancy { .. })),
769            "expected IncreaseOccupancy in {decisions:?}"
770        );
771    }
772
773    // --- Branch bias tests ---
774
775    #[test]
776    fn branch_profile_taken_ratio() {
777        let bp = BranchProfile {
778            branch_index: 0,
779            taken_count: 900,
780            not_taken_count: 100,
781        };
782        let ratio = bp.taken_ratio();
783        assert!((ratio - 0.9).abs() < 1e-9);
784    }
785
786    #[test]
787    fn branch_profile_zero_executions() {
788        let bp = BranchProfile {
789            branch_index: 0,
790            taken_count: 0,
791            not_taken_count: 0,
792        };
793        assert!((bp.taken_ratio() - 0.0).abs() < 1e-9);
794    }
795
796    #[test]
797    fn branch_bias_detection() {
798        let bp = BranchProfile {
799            branch_index: 0,
800            taken_count: 950,
801            not_taken_count: 50,
802        };
803        assert!(bp.is_biased(0.9));
804        assert!(!bp.is_biased(0.96));
805    }
806
807    #[test]
808    fn biased_branch_triggers_predication() {
809        let mut profile = make_profile(balanced_metrics());
810        profile.branch_stats.push(BranchProfile {
811            branch_index: 0,
812            taken_count: 980,
813            not_taken_count: 20,
814        });
815        let opt = ProfileGuidedOptimizer::new(profile);
816        let decisions = opt.analyze();
817        assert!(
818            decisions
819                .iter()
820                .any(|d| matches!(d, CodeGenDecision::PredicateBranch)),
821            "expected PredicateBranch in {decisions:?}"
822        );
823    }
824
825    // --- Unroll factor tests ---
826
827    #[test]
828    fn unroll_factor_high_mem_stalls() {
829        let mut profile = make_profile(balanced_metrics());
830        for i in 0..4 {
831            profile.hotspots.push(HotSpot {
832                instruction_index: i,
833                cycle_count: 500,
834                stall_reason: StallReason::MemoryDependency,
835            });
836        }
837        let opt = ProfileGuidedOptimizer::new(profile);
838        assert_eq!(opt.suggest_unroll_factor(), 8);
839    }
840
841    #[test]
842    fn unroll_factor_low_ipc() {
843        let mut m = balanced_metrics();
844        m.ipc = 0.8;
845        let opt = ProfileGuidedOptimizer::new(make_profile(m));
846        assert_eq!(opt.suggest_unroll_factor(), 4);
847    }
848
849    #[test]
850    fn unroll_factor_moderate_ipc() {
851        let mut m = balanced_metrics();
852        m.ipc = 1.5;
853        let opt = ProfileGuidedOptimizer::new(make_profile(m));
854        assert_eq!(opt.suggest_unroll_factor(), 2);
855    }
856
857    #[test]
858    fn unroll_factor_high_ipc_no_unroll() {
859        let m = balanced_metrics(); // ipc = 2.5
860        let opt = ProfileGuidedOptimizer::new(make_profile(m));
861        assert_eq!(opt.suggest_unroll_factor(), 1);
862    }
863
864    // --- Tile suggestion tests ---
865
866    #[test]
867    fn tile_config_compute_bound_ampere() {
868        let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
869        let tc = opt.suggest_tile_config(512, 512, 256);
870        assert_eq!(tc.tile_m, 128);
871        assert_eq!(tc.tile_n, 128);
872        assert_eq!(tc.tile_k, 16);
873    }
874
875    #[test]
876    fn tile_config_compute_bound_hopper() {
877        let mut profile = make_profile(compute_bound_metrics());
878        profile.sm_version = SmVersion::Sm90;
879        let opt = ProfileGuidedOptimizer::new(profile);
880        let tc = opt.suggest_tile_config(512, 512, 256);
881        assert_eq!(tc.tile_m, 256);
882        assert_eq!(tc.tile_n, 128);
883    }
884
885    #[test]
886    fn tile_config_clamps_to_problem_size() {
887        let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
888        let tc = opt.suggest_tile_config(32, 16, 4);
889        assert_eq!(tc.tile_m, 32);
890        assert_eq!(tc.tile_n, 16);
891        assert_eq!(tc.tile_k, 4);
892    }
893
894    #[test]
895    fn tile_config_memory_bound_uses_deep_k() {
896        let opt = ProfileGuidedOptimizer::new(make_profile(memory_bound_metrics()));
897        let tc = opt.suggest_tile_config(512, 512, 256);
898        assert_eq!(tc.tile_k, 32);
899    }
900
901    // --- apply_profile_decisions tests ---
902
903    #[test]
904    fn apply_decisions_updates_config() {
905        let decisions = vec![
906            CodeGenDecision::UnrollLoop { factor: 4 },
907            CodeGenDecision::SwitchToSharedMemory,
908            CodeGenDecision::EnableSplitK { k_slices: 8 },
909            CodeGenDecision::UseLargerTiles {
910                tile_m: 128,
911                tile_n: 256,
912            },
913        ];
914        let mut config = KernelProfile::new();
915        let log = apply_profile_decisions(&decisions, &mut config);
916
917        assert_eq!(config.unroll_factor, 4);
918        assert!(config.use_shared_memory);
919        assert_eq!(config.split_k, 8);
920        assert_eq!(config.tile_m, 128);
921        assert_eq!(config.tile_n, 256);
922        assert_eq!(log.len(), 4);
923    }
924
925    #[test]
926    fn apply_increase_occupancy_sets_register_target() {
927        let decisions = vec![CodeGenDecision::IncreaseOccupancy { target_blocks: 4 }];
928        let mut config = KernelProfile::new();
929        let log = apply_profile_decisions(&decisions, &mut config);
930        // 255 / 4 = 63
931        assert_eq!(config.register_target, 63);
932        assert_eq!(log.len(), 1);
933    }
934
935    // --- Display trait tests ---
936
937    #[test]
938    fn display_bottleneck() {
939        assert_eq!(format!("{}", Bottleneck::ComputeBound), "compute-bound");
940        assert_eq!(format!("{}", Bottleneck::MemoryBound), "memory-bound");
941        assert_eq!(format!("{}", Bottleneck::LatencyBound), "latency-bound");
942        assert_eq!(format!("{}", Bottleneck::Balanced), "balanced");
943    }
944
945    #[test]
946    fn display_stall_reason() {
947        assert_eq!(format!("{}", StallReason::None), "none");
948        assert_eq!(
949            format!("{}", StallReason::MemoryDependency),
950            "memory_dependency"
951        );
952        assert_eq!(
953            format!("{}", StallReason::Other("pipe_busy".to_string())),
954            "other(pipe_busy)"
955        );
956    }
957
958    #[test]
959    fn display_code_gen_decision() {
960        let d = CodeGenDecision::UnrollLoop { factor: 4 };
961        assert_eq!(format!("{d}"), "unroll loop x4");
962        let d = CodeGenDecision::EnableSplitK { k_slices: 8 };
963        assert_eq!(format!("{d}"), "enable split-K (8 slices)");
964    }
965
966    #[test]
967    fn display_kernel_profile() {
968        let kp = KernelProfile::new();
969        let s = format!("{kp}");
970        assert!(s.contains("tile=64x64x8"));
971        assert!(s.contains("smem=off"));
972    }
973
974    #[test]
975    fn display_tile_config() {
976        let tc = TileConfig {
977            tile_m: 128,
978            tile_n: 64,
979            tile_k: 16,
980        };
981        assert_eq!(format!("{tc}"), "128x64x16");
982    }
983
984    #[test]
985    fn display_memory_access_profile() {
986        let m = MemoryAccessProfile {
987            coalesced_ratio: 0.95,
988            bank_conflict_rate: 0.02,
989            cache_line_utilization: 0.88,
990        };
991        let s = format!("{m}");
992        assert!(s.contains("coalesced=95.0%"));
993    }
994
995    #[test]
996    fn display_branch_profile() {
997        let bp = BranchProfile {
998            branch_index: 3,
999            taken_count: 750,
1000            not_taken_count: 250,
1001        };
1002        let s = format!("{bp}");
1003        assert!(s.contains("branch[3]"));
1004        assert!(s.contains("75.00%"));
1005    }
1006
1007    // --- End-to-end: profile -> decisions -> applied config ---
1008
1009    #[test]
1010    fn end_to_end_compute_bound_pipeline() {
1011        let profile = make_profile(compute_bound_metrics());
1012        let opt = ProfileGuidedOptimizer::new(profile);
1013        assert_eq!(opt.classify_bottleneck(), Bottleneck::ComputeBound);
1014
1015        let decisions = opt.analyze();
1016        let mut config = KernelProfile::new();
1017        let log = apply_profile_decisions(&decisions, &mut config);
1018
1019        // Compute-bound should have enlarged tiles.
1020        assert!(config.tile_m >= 128);
1021        assert!(!log.is_empty());
1022    }
1023}