1use std::fmt;
10
11use crate::arch::SmVersion;
12
13#[derive(Debug, Clone)]
22pub struct ProfileData {
23 pub kernel_name: String,
25 pub sm_version: SmVersion,
27 pub metrics: ProfileMetrics,
29 pub hotspots: Vec<HotSpot>,
31 pub branch_stats: Vec<BranchProfile>,
33 pub memory_access_pattern: MemoryAccessProfile,
35}
36
37#[derive(Debug, Clone, Copy)]
39pub struct ProfileMetrics {
40 pub achieved_occupancy: f64,
42 pub compute_throughput: f64,
44 pub memory_throughput: f64,
46 pub l2_hit_rate: f64,
48 pub shared_memory_efficiency: f64,
50 pub warp_execution_efficiency: f64,
52 pub ipc: f64,
54}
55
56#[derive(Debug, Clone)]
58pub struct HotSpot {
59 pub instruction_index: usize,
61 pub cycle_count: u64,
63 pub stall_reason: StallReason,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum StallReason {
70 None,
72 MemoryDependency,
74 ExecutionDependency,
76 SyncBarrier,
78 InstructionFetch,
80 Other(String),
82}
83
84impl fmt::Display for StallReason {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 match self {
87 Self::None => f.write_str("none"),
88 Self::MemoryDependency => f.write_str("memory_dependency"),
89 Self::ExecutionDependency => f.write_str("execution_dependency"),
90 Self::SyncBarrier => f.write_str("sync_barrier"),
91 Self::InstructionFetch => f.write_str("instruction_fetch"),
92 Self::Other(s) => write!(f, "other({s})"),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Copy)]
99pub struct BranchProfile {
100 pub branch_index: usize,
102 pub taken_count: u64,
104 pub not_taken_count: u64,
106}
107
108impl BranchProfile {
109 #[must_use]
113 pub fn taken_ratio(&self) -> f64 {
114 let total = self.taken_count + self.not_taken_count;
115 if total == 0 {
116 return 0.0;
117 }
118 #[allow(clippy::cast_precision_loss)]
119 let ratio = self.taken_count as f64 / total as f64;
120 ratio
121 }
122
123 #[must_use]
126 pub fn is_biased(&self, threshold: f64) -> bool {
127 let ratio = self.taken_ratio();
128 ratio > threshold || ratio < (1.0 - threshold)
129 }
130}
131
132impl fmt::Display for BranchProfile {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 write!(
135 f,
136 "branch[{}]: taken={} not_taken={} ratio={:.2}%",
137 self.branch_index,
138 self.taken_count,
139 self.not_taken_count,
140 self.taken_ratio() * 100.0,
141 )
142 }
143}
144
145#[derive(Debug, Clone, Copy)]
147pub struct MemoryAccessProfile {
148 pub coalesced_ratio: f64,
150 pub bank_conflict_rate: f64,
152 pub cache_line_utilization: f64,
154}
155
156impl fmt::Display for MemoryAccessProfile {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 write!(
159 f,
160 "coalesced={:.1}% bank_conflicts={:.1}% cache_util={:.1}%",
161 self.coalesced_ratio * 100.0,
162 self.bank_conflict_rate * 100.0,
163 self.cache_line_utilization * 100.0,
164 )
165 }
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
174pub enum Bottleneck {
175 ComputeBound,
177 MemoryBound,
179 LatencyBound,
181 Balanced,
183}
184
185impl fmt::Display for Bottleneck {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 match self {
188 Self::ComputeBound => f.write_str("compute-bound"),
189 Self::MemoryBound => f.write_str("memory-bound"),
190 Self::LatencyBound => f.write_str("latency-bound"),
191 Self::Balanced => f.write_str("balanced"),
192 }
193 }
194}
195
196#[derive(Debug, Clone, PartialEq, Eq)]
202pub enum CodeGenDecision {
203 UnrollLoop {
205 factor: u32,
207 },
208 PredicateBranch,
210 PrefetchMemory {
212 distance: u32,
214 },
215 IncreaseOccupancy {
217 target_blocks: u32,
219 },
220 UseLargerTiles {
222 tile_m: u32,
224 tile_n: u32,
226 },
227 SwitchToSharedMemory,
229 EnableSplitK {
231 k_slices: u32,
233 },
234}
235
236impl fmt::Display for CodeGenDecision {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 match self {
239 Self::UnrollLoop { factor } => write!(f, "unroll loop x{factor}"),
240 Self::PredicateBranch => f.write_str("convert branch to predicated"),
241 Self::PrefetchMemory { distance } => {
242 write!(f, "insert prefetch (distance={distance})")
243 }
244 Self::IncreaseOccupancy { target_blocks } => {
245 write!(f, "increase occupancy to {target_blocks} blocks/SM")
246 }
247 Self::UseLargerTiles { tile_m, tile_n } => {
248 write!(f, "use larger tiles ({tile_m}x{tile_n})")
249 }
250 Self::SwitchToSharedMemory => f.write_str("switch to shared memory"),
251 Self::EnableSplitK { k_slices } => {
252 write!(f, "enable split-K ({k_slices} slices)")
253 }
254 }
255 }
256}
257
258#[derive(Debug, Clone, Copy, PartialEq, Eq)]
264pub struct TileConfig {
265 pub tile_m: u32,
267 pub tile_n: u32,
269 pub tile_k: u32,
271}
272
273impl fmt::Display for TileConfig {
274 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275 write!(f, "{}x{}x{}", self.tile_m, self.tile_n, self.tile_k)
276 }
277}
278
279#[derive(Debug, Clone)]
288pub struct KernelProfile {
289 pub tile_m: u32,
291 pub tile_n: u32,
293 pub tile_k: u32,
295 pub unroll_factor: u32,
297 pub use_shared_memory: bool,
299 pub register_target: u32,
301 pub split_k: u32,
303}
304
305impl KernelProfile {
306 #[must_use]
308 pub const fn new() -> Self {
309 Self {
310 tile_m: 64,
311 tile_n: 64,
312 tile_k: 8,
313 unroll_factor: 1,
314 use_shared_memory: false,
315 register_target: 0,
316 split_k: 1,
317 }
318 }
319}
320
321impl Default for KernelProfile {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327impl fmt::Display for KernelProfile {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 write!(
330 f,
331 "tile={}x{}x{} unroll={} smem={} regs={} split_k={}",
332 self.tile_m,
333 self.tile_n,
334 self.tile_k,
335 self.unroll_factor,
336 if self.use_shared_memory { "on" } else { "off" },
337 self.register_target,
338 self.split_k,
339 )
340 }
341}
342
343const COMPUTE_BOUND_THRESHOLD: f64 = 0.7;
349const MEMORY_BOUND_THRESHOLD: f64 = 0.7;
351const LATENCY_BOUND_IPC_THRESHOLD: f64 = 1.0;
353const DEFAULT_BRANCH_BIAS_THRESHOLD: f64 = 0.9;
355const LOW_OCCUPANCY_THRESHOLD: f64 = 0.5;
357const POOR_COALESCING_THRESHOLD: f64 = 0.5;
359const PREFETCH_MEMORY_THROUGHPUT_THRESHOLD: f64 = 0.5;
361
362#[derive(Debug, Clone)]
368pub struct ProfileGuidedOptimizer {
369 profile: ProfileData,
370}
371
372impl ProfileGuidedOptimizer {
373 #[must_use]
375 pub const fn new(profile: ProfileData) -> Self {
376 Self { profile }
377 }
378
379 #[must_use]
381 pub fn classify_bottleneck(&self) -> Bottleneck {
382 let m = &self.profile.metrics;
383
384 let compute_heavy = m.compute_throughput >= COMPUTE_BOUND_THRESHOLD;
385 let memory_heavy = m.memory_throughput >= MEMORY_BOUND_THRESHOLD;
386
387 match (compute_heavy, memory_heavy) {
388 (true, false) => Bottleneck::ComputeBound,
389 (false, true) => Bottleneck::MemoryBound,
390 (true, true) => Bottleneck::Balanced,
391 (false, false) => {
392 if m.ipc < LATENCY_BOUND_IPC_THRESHOLD
394 && m.achieved_occupancy < LOW_OCCUPANCY_THRESHOLD
395 {
396 Bottleneck::LatencyBound
397 } else {
398 Bottleneck::Balanced
399 }
400 }
401 }
402 }
403
404 #[must_use]
408 pub fn analyze(&self) -> Vec<CodeGenDecision> {
409 let mut decisions = Vec::new();
410 let bottleneck = self.classify_bottleneck();
411
412 let unroll = self.suggest_unroll_factor();
414 if unroll > 1 {
415 decisions.push(CodeGenDecision::UnrollLoop { factor: unroll });
416 }
417
418 for bp in &self.profile.branch_stats {
420 if bp.is_biased(DEFAULT_BRANCH_BIAS_THRESHOLD) {
421 decisions.push(CodeGenDecision::PredicateBranch);
422 break; }
424 }
425
426 if bottleneck == Bottleneck::MemoryBound || bottleneck == Bottleneck::Balanced {
428 let mem = &self.profile.memory_access_pattern;
429 if mem.coalesced_ratio < POOR_COALESCING_THRESHOLD {
430 decisions.push(CodeGenDecision::SwitchToSharedMemory);
431 }
432 if self.profile.metrics.memory_throughput > PREFETCH_MEMORY_THROUGHPUT_THRESHOLD {
433 let distance = self.suggest_prefetch_distance();
434 decisions.push(CodeGenDecision::PrefetchMemory { distance });
435 }
436 }
437
438 if self.profile.metrics.achieved_occupancy < LOW_OCCUPANCY_THRESHOLD {
440 let target = self.suggest_target_blocks();
441 decisions.push(CodeGenDecision::IncreaseOccupancy {
442 target_blocks: target,
443 });
444 }
445
446 if bottleneck == Bottleneck::ComputeBound {
448 decisions.push(CodeGenDecision::UseLargerTiles {
449 tile_m: 128,
450 tile_n: 128,
451 });
452 }
453
454 if bottleneck == Bottleneck::LatencyBound {
456 decisions.push(CodeGenDecision::EnableSplitK { k_slices: 4 });
457 }
458
459 decisions
460 }
461
462 #[must_use]
464 pub fn suggest_tile_config(&self, m: u32, n: u32, k: u32) -> TileConfig {
465 let bottleneck = self.classify_bottleneck();
466 let caps = self.profile.sm_version.capabilities();
467
468 let (base_m, base_n) = match bottleneck {
470 Bottleneck::ComputeBound => {
471 if caps.has_wgmma {
472 (256, 128) } else if caps.has_ampere_mma {
474 (128, 128)
475 } else {
476 (128, 64)
477 }
478 }
479 Bottleneck::MemoryBound => (64, 64),
480 Bottleneck::LatencyBound => (64, 32),
481 Bottleneck::Balanced => (128, 64),
482 };
483
484 let tile_m = base_m.min(m);
486 let tile_n = base_n.min(n);
487
488 let tile_k = match bottleneck {
490 Bottleneck::MemoryBound => 32.min(k),
491 Bottleneck::ComputeBound => 16.min(k),
492 _ => 8.min(k),
493 };
494
495 TileConfig {
496 tile_m,
497 tile_n,
498 tile_k,
499 }
500 }
501
502 #[must_use]
504 pub fn suggest_unroll_factor(&self) -> u32 {
505 let m = &self.profile.metrics;
506
507 let mem_stalls = self
509 .profile
510 .hotspots
511 .iter()
512 .filter(|h| h.stall_reason == StallReason::MemoryDependency)
513 .count();
514
515 if mem_stalls >= 3 {
516 return 8;
517 }
518
519 if m.ipc < 1.0 {
520 return 4;
521 }
522
523 if m.ipc < 2.0 {
524 return 2;
525 }
526
527 1
528 }
529
530 fn suggest_prefetch_distance(&self) -> u32 {
534 let m = &self.profile.metrics;
535 if m.l2_hit_rate < 0.3 {
536 4 } else if m.l2_hit_rate < 0.6 {
538 2
539 } else {
540 1
541 }
542 }
543
544 fn suggest_target_blocks(&self) -> u32 {
546 let max_threads = self.profile.sm_version.max_threads_per_sm();
547 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
549 let target_threads = (f64::from(max_threads) * 0.75) as u32;
550 let blocks = target_threads / 128;
551 blocks.max(2)
552 }
553}
554
555pub fn apply_profile_decisions(
563 decisions: &[CodeGenDecision],
564 config: &mut KernelProfile,
565) -> Vec<String> {
566 let mut log = Vec::with_capacity(decisions.len());
567
568 for decision in decisions {
569 match decision {
570 CodeGenDecision::UnrollLoop { factor } => {
571 let prev = config.unroll_factor;
572 config.unroll_factor = *factor;
573 log.push(format!("unroll factor: {prev} -> {factor}"));
574 }
575 CodeGenDecision::PredicateBranch => {
576 log.push("enabled branch predication".to_string());
577 }
578 CodeGenDecision::PrefetchMemory { distance } => {
579 log.push(format!("enabled prefetch with distance {distance}"));
580 }
581 CodeGenDecision::IncreaseOccupancy { target_blocks } => {
582 let new_target = 255 / target_blocks;
584 let prev = config.register_target;
585 config.register_target = new_target;
586 log.push(format!(
587 "register target: {prev} -> {new_target} (for {target_blocks} blocks/SM)"
588 ));
589 }
590 CodeGenDecision::UseLargerTiles { tile_m, tile_n } => {
591 let prev_m = config.tile_m;
592 let prev_n = config.tile_n;
593 config.tile_m = *tile_m;
594 config.tile_n = *tile_n;
595 log.push(format!("tile size: {prev_m}x{prev_n} -> {tile_m}x{tile_n}"));
596 }
597 CodeGenDecision::SwitchToSharedMemory => {
598 config.use_shared_memory = true;
599 log.push("enabled shared memory staging".to_string());
600 }
601 CodeGenDecision::EnableSplitK { k_slices } => {
602 let prev = config.split_k;
603 config.split_k = *k_slices;
604 log.push(format!("split-K: {prev} -> {k_slices} slices"));
605 }
606 }
607 }
608
609 log
610}
611
612#[cfg(test)]
617mod tests {
618 use super::*;
619
620 fn make_profile(metrics: ProfileMetrics) -> ProfileData {
622 ProfileData {
623 kernel_name: "test_kernel".to_string(),
624 sm_version: SmVersion::Sm80,
625 metrics,
626 hotspots: Vec::new(),
627 branch_stats: Vec::new(),
628 memory_access_pattern: MemoryAccessProfile {
629 coalesced_ratio: 0.9,
630 bank_conflict_rate: 0.05,
631 cache_line_utilization: 0.85,
632 },
633 }
634 }
635
636 fn balanced_metrics() -> ProfileMetrics {
637 ProfileMetrics {
638 achieved_occupancy: 0.75,
639 compute_throughput: 0.5,
640 memory_throughput: 0.5,
641 l2_hit_rate: 0.6,
642 shared_memory_efficiency: 0.9,
643 warp_execution_efficiency: 0.95,
644 ipc: 2.5,
645 }
646 }
647
648 fn compute_bound_metrics() -> ProfileMetrics {
649 ProfileMetrics {
650 achieved_occupancy: 0.8,
651 compute_throughput: 0.85,
652 memory_throughput: 0.3,
653 l2_hit_rate: 0.7,
654 shared_memory_efficiency: 0.9,
655 warp_execution_efficiency: 0.95,
656 ipc: 3.0,
657 }
658 }
659
660 fn memory_bound_metrics() -> ProfileMetrics {
661 ProfileMetrics {
662 achieved_occupancy: 0.7,
663 compute_throughput: 0.2,
664 memory_throughput: 0.85,
665 l2_hit_rate: 0.4,
666 shared_memory_efficiency: 0.6,
667 warp_execution_efficiency: 0.9,
668 ipc: 1.5,
669 }
670 }
671
672 fn latency_bound_metrics() -> ProfileMetrics {
673 ProfileMetrics {
674 achieved_occupancy: 0.3,
675 compute_throughput: 0.15,
676 memory_throughput: 0.2,
677 l2_hit_rate: 0.5,
678 shared_memory_efficiency: 0.7,
679 warp_execution_efficiency: 0.8,
680 ipc: 0.5,
681 }
682 }
683
684 #[test]
687 fn classify_compute_bound() {
688 let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
689 assert_eq!(opt.classify_bottleneck(), Bottleneck::ComputeBound);
690 }
691
692 #[test]
693 fn classify_memory_bound() {
694 let opt = ProfileGuidedOptimizer::new(make_profile(memory_bound_metrics()));
695 assert_eq!(opt.classify_bottleneck(), Bottleneck::MemoryBound);
696 }
697
698 #[test]
699 fn classify_latency_bound() {
700 let opt = ProfileGuidedOptimizer::new(make_profile(latency_bound_metrics()));
701 assert_eq!(opt.classify_bottleneck(), Bottleneck::LatencyBound);
702 }
703
704 #[test]
705 fn classify_balanced() {
706 let opt = ProfileGuidedOptimizer::new(make_profile(balanced_metrics()));
707 assert_eq!(opt.classify_bottleneck(), Bottleneck::Balanced);
708 }
709
710 #[test]
711 fn classify_both_saturated_is_balanced() {
712 let mut m = balanced_metrics();
713 m.compute_throughput = 0.8;
714 m.memory_throughput = 0.8;
715 let opt = ProfileGuidedOptimizer::new(make_profile(m));
716 assert_eq!(opt.classify_bottleneck(), Bottleneck::Balanced);
717 }
718
719 #[test]
722 fn compute_bound_suggests_larger_tiles() {
723 let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
724 let decisions = opt.analyze();
725 assert!(
726 decisions
727 .iter()
728 .any(|d| matches!(d, CodeGenDecision::UseLargerTiles { .. })),
729 "expected UseLargerTiles in {decisions:?}"
730 );
731 }
732
733 #[test]
734 fn memory_bound_with_poor_coalescing_suggests_shared_mem() {
735 let mut profile = make_profile(memory_bound_metrics());
736 profile.memory_access_pattern.coalesced_ratio = 0.3;
737 let opt = ProfileGuidedOptimizer::new(profile);
738 let decisions = opt.analyze();
739 assert!(
740 decisions
741 .iter()
742 .any(|d| matches!(d, CodeGenDecision::SwitchToSharedMemory)),
743 "expected SwitchToSharedMemory in {decisions:?}"
744 );
745 }
746
747 #[test]
748 fn latency_bound_suggests_split_k() {
749 let opt = ProfileGuidedOptimizer::new(make_profile(latency_bound_metrics()));
750 let decisions = opt.analyze();
751 assert!(
752 decisions
753 .iter()
754 .any(|d| matches!(d, CodeGenDecision::EnableSplitK { .. })),
755 "expected EnableSplitK in {decisions:?}"
756 );
757 }
758
759 #[test]
760 fn low_occupancy_suggests_increase() {
761 let mut m = balanced_metrics();
762 m.achieved_occupancy = 0.3;
763 let opt = ProfileGuidedOptimizer::new(make_profile(m));
764 let decisions = opt.analyze();
765 assert!(
766 decisions
767 .iter()
768 .any(|d| matches!(d, CodeGenDecision::IncreaseOccupancy { .. })),
769 "expected IncreaseOccupancy in {decisions:?}"
770 );
771 }
772
773 #[test]
776 fn branch_profile_taken_ratio() {
777 let bp = BranchProfile {
778 branch_index: 0,
779 taken_count: 900,
780 not_taken_count: 100,
781 };
782 let ratio = bp.taken_ratio();
783 assert!((ratio - 0.9).abs() < 1e-9);
784 }
785
786 #[test]
787 fn branch_profile_zero_executions() {
788 let bp = BranchProfile {
789 branch_index: 0,
790 taken_count: 0,
791 not_taken_count: 0,
792 };
793 assert!((bp.taken_ratio() - 0.0).abs() < 1e-9);
794 }
795
796 #[test]
797 fn branch_bias_detection() {
798 let bp = BranchProfile {
799 branch_index: 0,
800 taken_count: 950,
801 not_taken_count: 50,
802 };
803 assert!(bp.is_biased(0.9));
804 assert!(!bp.is_biased(0.96));
805 }
806
807 #[test]
808 fn biased_branch_triggers_predication() {
809 let mut profile = make_profile(balanced_metrics());
810 profile.branch_stats.push(BranchProfile {
811 branch_index: 0,
812 taken_count: 980,
813 not_taken_count: 20,
814 });
815 let opt = ProfileGuidedOptimizer::new(profile);
816 let decisions = opt.analyze();
817 assert!(
818 decisions
819 .iter()
820 .any(|d| matches!(d, CodeGenDecision::PredicateBranch)),
821 "expected PredicateBranch in {decisions:?}"
822 );
823 }
824
825 #[test]
828 fn unroll_factor_high_mem_stalls() {
829 let mut profile = make_profile(balanced_metrics());
830 for i in 0..4 {
831 profile.hotspots.push(HotSpot {
832 instruction_index: i,
833 cycle_count: 500,
834 stall_reason: StallReason::MemoryDependency,
835 });
836 }
837 let opt = ProfileGuidedOptimizer::new(profile);
838 assert_eq!(opt.suggest_unroll_factor(), 8);
839 }
840
841 #[test]
842 fn unroll_factor_low_ipc() {
843 let mut m = balanced_metrics();
844 m.ipc = 0.8;
845 let opt = ProfileGuidedOptimizer::new(make_profile(m));
846 assert_eq!(opt.suggest_unroll_factor(), 4);
847 }
848
849 #[test]
850 fn unroll_factor_moderate_ipc() {
851 let mut m = balanced_metrics();
852 m.ipc = 1.5;
853 let opt = ProfileGuidedOptimizer::new(make_profile(m));
854 assert_eq!(opt.suggest_unroll_factor(), 2);
855 }
856
857 #[test]
858 fn unroll_factor_high_ipc_no_unroll() {
859 let m = balanced_metrics(); let opt = ProfileGuidedOptimizer::new(make_profile(m));
861 assert_eq!(opt.suggest_unroll_factor(), 1);
862 }
863
864 #[test]
867 fn tile_config_compute_bound_ampere() {
868 let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
869 let tc = opt.suggest_tile_config(512, 512, 256);
870 assert_eq!(tc.tile_m, 128);
871 assert_eq!(tc.tile_n, 128);
872 assert_eq!(tc.tile_k, 16);
873 }
874
875 #[test]
876 fn tile_config_compute_bound_hopper() {
877 let mut profile = make_profile(compute_bound_metrics());
878 profile.sm_version = SmVersion::Sm90;
879 let opt = ProfileGuidedOptimizer::new(profile);
880 let tc = opt.suggest_tile_config(512, 512, 256);
881 assert_eq!(tc.tile_m, 256);
882 assert_eq!(tc.tile_n, 128);
883 }
884
885 #[test]
886 fn tile_config_clamps_to_problem_size() {
887 let opt = ProfileGuidedOptimizer::new(make_profile(compute_bound_metrics()));
888 let tc = opt.suggest_tile_config(32, 16, 4);
889 assert_eq!(tc.tile_m, 32);
890 assert_eq!(tc.tile_n, 16);
891 assert_eq!(tc.tile_k, 4);
892 }
893
894 #[test]
895 fn tile_config_memory_bound_uses_deep_k() {
896 let opt = ProfileGuidedOptimizer::new(make_profile(memory_bound_metrics()));
897 let tc = opt.suggest_tile_config(512, 512, 256);
898 assert_eq!(tc.tile_k, 32);
899 }
900
901 #[test]
904 fn apply_decisions_updates_config() {
905 let decisions = vec![
906 CodeGenDecision::UnrollLoop { factor: 4 },
907 CodeGenDecision::SwitchToSharedMemory,
908 CodeGenDecision::EnableSplitK { k_slices: 8 },
909 CodeGenDecision::UseLargerTiles {
910 tile_m: 128,
911 tile_n: 256,
912 },
913 ];
914 let mut config = KernelProfile::new();
915 let log = apply_profile_decisions(&decisions, &mut config);
916
917 assert_eq!(config.unroll_factor, 4);
918 assert!(config.use_shared_memory);
919 assert_eq!(config.split_k, 8);
920 assert_eq!(config.tile_m, 128);
921 assert_eq!(config.tile_n, 256);
922 assert_eq!(log.len(), 4);
923 }
924
925 #[test]
926 fn apply_increase_occupancy_sets_register_target() {
927 let decisions = vec![CodeGenDecision::IncreaseOccupancy { target_blocks: 4 }];
928 let mut config = KernelProfile::new();
929 let log = apply_profile_decisions(&decisions, &mut config);
930 assert_eq!(config.register_target, 63);
932 assert_eq!(log.len(), 1);
933 }
934
935 #[test]
938 fn display_bottleneck() {
939 assert_eq!(format!("{}", Bottleneck::ComputeBound), "compute-bound");
940 assert_eq!(format!("{}", Bottleneck::MemoryBound), "memory-bound");
941 assert_eq!(format!("{}", Bottleneck::LatencyBound), "latency-bound");
942 assert_eq!(format!("{}", Bottleneck::Balanced), "balanced");
943 }
944
945 #[test]
946 fn display_stall_reason() {
947 assert_eq!(format!("{}", StallReason::None), "none");
948 assert_eq!(
949 format!("{}", StallReason::MemoryDependency),
950 "memory_dependency"
951 );
952 assert_eq!(
953 format!("{}", StallReason::Other("pipe_busy".to_string())),
954 "other(pipe_busy)"
955 );
956 }
957
958 #[test]
959 fn display_code_gen_decision() {
960 let d = CodeGenDecision::UnrollLoop { factor: 4 };
961 assert_eq!(format!("{d}"), "unroll loop x4");
962 let d = CodeGenDecision::EnableSplitK { k_slices: 8 };
963 assert_eq!(format!("{d}"), "enable split-K (8 slices)");
964 }
965
966 #[test]
967 fn display_kernel_profile() {
968 let kp = KernelProfile::new();
969 let s = format!("{kp}");
970 assert!(s.contains("tile=64x64x8"));
971 assert!(s.contains("smem=off"));
972 }
973
974 #[test]
975 fn display_tile_config() {
976 let tc = TileConfig {
977 tile_m: 128,
978 tile_n: 64,
979 tile_k: 16,
980 };
981 assert_eq!(format!("{tc}"), "128x64x16");
982 }
983
984 #[test]
985 fn display_memory_access_profile() {
986 let m = MemoryAccessProfile {
987 coalesced_ratio: 0.95,
988 bank_conflict_rate: 0.02,
989 cache_line_utilization: 0.88,
990 };
991 let s = format!("{m}");
992 assert!(s.contains("coalesced=95.0%"));
993 }
994
995 #[test]
996 fn display_branch_profile() {
997 let bp = BranchProfile {
998 branch_index: 3,
999 taken_count: 750,
1000 not_taken_count: 250,
1001 };
1002 let s = format!("{bp}");
1003 assert!(s.contains("branch[3]"));
1004 assert!(s.contains("75.00%"));
1005 }
1006
1007 #[test]
1010 fn end_to_end_compute_bound_pipeline() {
1011 let profile = make_profile(compute_bound_metrics());
1012 let opt = ProfileGuidedOptimizer::new(profile);
1013 assert_eq!(opt.classify_bottleneck(), Bottleneck::ComputeBound);
1014
1015 let decisions = opt.analyze();
1016 let mut config = KernelProfile::new();
1017 let log = apply_profile_decisions(&decisions, &mut config);
1018
1019 assert!(config.tile_m >= 128);
1021 assert!(!log.is_empty());
1022 }
1023}