1use crate::device::Device;
33#[cfg(not(target_os = "macos"))]
34use crate::error::CudaError;
35use crate::error::CudaResult;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct DeviceOccupancyInfo {
44 pub sm_count: u32,
46 pub max_threads_per_sm: u32,
48 pub max_blocks_per_sm: u32,
50 pub max_registers_per_sm: u32,
52 pub max_shared_memory_per_sm: u32,
54 pub warp_size: u32,
56}
57
58impl DeviceOccupancyInfo {
59 fn max_warps_per_sm(&self) -> u32 {
61 if self.warp_size == 0 {
62 return 0;
63 }
64 self.max_threads_per_sm / self.warp_size
65 }
66
67 #[must_use]
85 pub fn for_compute_capability(sm_major: u32, sm_minor: u32) -> Self {
86 match (sm_major, sm_minor) {
87 (7, 5) => Self {
89 sm_count: 68,
90 max_threads_per_sm: 1024,
91 max_blocks_per_sm: 16,
92 max_registers_per_sm: 65536,
93 max_shared_memory_per_sm: 65536,
94 warp_size: 32,
95 },
96 (8, 0) => Self {
98 sm_count: 108,
99 max_threads_per_sm: 2048,
100 max_blocks_per_sm: 32,
101 max_registers_per_sm: 65536,
102 max_shared_memory_per_sm: 167936,
103 warp_size: 32,
104 },
105 (8, 6) => Self {
107 sm_count: 84,
108 max_threads_per_sm: 1536,
109 max_blocks_per_sm: 16,
110 max_registers_per_sm: 65536,
111 max_shared_memory_per_sm: 102400,
112 warp_size: 32,
113 },
114 (8, 9) => Self {
116 sm_count: 76,
117 max_threads_per_sm: 1536,
118 max_blocks_per_sm: 24,
119 max_registers_per_sm: 65536,
120 max_shared_memory_per_sm: 101376,
121 warp_size: 32,
122 },
123 (9, 0) => Self {
125 sm_count: 132,
126 max_threads_per_sm: 2048,
127 max_blocks_per_sm: 32,
128 max_registers_per_sm: 65536,
129 max_shared_memory_per_sm: 232448,
130 warp_size: 32,
131 },
132 (10, 0) => Self {
134 sm_count: 132,
135 max_threads_per_sm: 2048,
136 max_blocks_per_sm: 32,
137 max_registers_per_sm: 65536,
138 max_shared_memory_per_sm: 262144,
139 warp_size: 32,
140 },
141 (12, 0) => Self {
143 sm_count: 148,
144 max_threads_per_sm: 2048,
145 max_blocks_per_sm: 32,
146 max_registers_per_sm: 65536,
147 max_shared_memory_per_sm: 262144,
148 warp_size: 32,
149 },
150 _ => Self {
152 sm_count: 84,
153 max_threads_per_sm: 1536,
154 max_blocks_per_sm: 16,
155 max_registers_per_sm: 65536,
156 max_shared_memory_per_sm: 102400,
157 warp_size: 32,
158 },
159 }
160 }
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum LimitingFactor {
170 Threads,
172 Registers,
174 SharedMemory,
176 Blocks,
178 None,
180}
181
182#[derive(Debug, Clone, Copy)]
188pub struct OccupancyEstimate {
189 pub active_warps_per_sm: u32,
191 pub max_warps_per_sm: u32,
193 pub occupancy_ratio: f64,
195 pub limiting_factor: LimitingFactor,
197}
198
199#[derive(Debug, Clone)]
208pub struct OccupancyCalculator {
209 info: DeviceOccupancyInfo,
210}
211
212impl OccupancyCalculator {
213 pub fn new(device_info: DeviceOccupancyInfo) -> Self {
215 Self { info: device_info }
216 }
217
218 pub fn device_info(&self) -> &DeviceOccupancyInfo {
220 &self.info
221 }
222
223 pub fn estimate_occupancy(
231 &self,
232 block_size: u32,
233 registers_per_thread: u32,
234 shared_memory: u32,
235 ) -> OccupancyEstimate {
236 let max_warps = self.info.max_warps_per_sm();
237
238 if block_size == 0 || self.info.warp_size == 0 || max_warps == 0 {
240 return OccupancyEstimate {
241 active_warps_per_sm: 0,
242 max_warps_per_sm: max_warps,
243 occupancy_ratio: 0.0,
244 limiting_factor: LimitingFactor::None,
245 };
246 }
247
248 let warps_per_block = block_size.div_ceil(self.info.warp_size);
249
250 let blocks_by_block_limit = self.info.max_blocks_per_sm;
252
253 let blocks_by_threads = max_warps.checked_div(warps_per_block).unwrap_or(0);
255
256 let blocks_by_registers = if registers_per_thread == 0 || warps_per_block == 0 {
258 u32::MAX } else {
260 let regs_per_block = registers_per_thread * warps_per_block * self.info.warp_size;
261 self.info
262 .max_registers_per_sm
263 .checked_div(regs_per_block)
264 .unwrap_or(u32::MAX)
265 };
266
267 let blocks_by_smem = if shared_memory == 0 {
269 u32::MAX } else if self.info.max_shared_memory_per_sm == 0 {
271 0
272 } else {
273 self.info.max_shared_memory_per_sm / shared_memory
274 };
275
276 let active_blocks = blocks_by_block_limit
278 .min(blocks_by_threads)
279 .min(blocks_by_registers)
280 .min(blocks_by_smem);
281
282 let active_warps = active_blocks * warps_per_block;
283 let clamped_warps = active_warps.min(max_warps);
284 let ratio = if max_warps > 0 {
285 clamped_warps as f64 / max_warps as f64
286 } else {
287 0.0
288 };
289
290 let effective = active_blocks;
292 let limiting_factor = if effective == 0 {
293 if blocks_by_smem == 0 {
294 LimitingFactor::SharedMemory
295 } else if blocks_by_registers == 0 {
296 LimitingFactor::Registers
297 } else if blocks_by_threads == 0 {
298 LimitingFactor::Threads
299 } else {
300 LimitingFactor::Blocks
301 }
302 } else if effective == blocks_by_smem
303 && blocks_by_smem
304 <= blocks_by_registers
305 .min(blocks_by_threads)
306 .min(blocks_by_block_limit)
307 {
308 LimitingFactor::SharedMemory
309 } else if effective == blocks_by_registers
310 && blocks_by_registers <= blocks_by_threads.min(blocks_by_block_limit)
311 {
312 LimitingFactor::Registers
313 } else if effective == blocks_by_threads && blocks_by_threads <= blocks_by_block_limit {
314 LimitingFactor::Threads
315 } else if effective == blocks_by_block_limit {
316 LimitingFactor::Blocks
317 } else {
318 LimitingFactor::None
319 };
320
321 OccupancyEstimate {
322 active_warps_per_sm: clamped_warps,
323 max_warps_per_sm: max_warps,
324 occupancy_ratio: ratio,
325 limiting_factor,
326 }
327 }
328}
329
330#[derive(Debug, Clone, Copy)]
336pub struct OccupancyPoint {
337 pub block_size: u32,
339 pub occupancy: f64,
341 pub active_warps: u32,
343 pub limiting_factor: LimitingFactor,
345}
346
347pub struct OccupancyGrid;
349
350impl OccupancyGrid {
351 pub fn sweep(
354 calculator: &OccupancyCalculator,
355 registers_per_thread: u32,
356 shared_memory: u32,
357 ) -> Vec<OccupancyPoint> {
358 let ws = calculator.info.warp_size;
359 if ws == 0 {
360 return Vec::new();
361 }
362 let max_threads = calculator.info.max_threads_per_sm;
363 let mut points = Vec::new();
364 let mut bs = ws;
365 while bs <= max_threads {
366 let est = calculator.estimate_occupancy(bs, registers_per_thread, shared_memory);
367 points.push(OccupancyPoint {
368 block_size: bs,
369 occupancy: est.occupancy_ratio,
370 active_warps: est.active_warps_per_sm,
371 limiting_factor: est.limiting_factor,
372 });
373 bs += ws;
374 }
375 points
376 }
377
378 pub fn best_block_size(points: &[OccupancyPoint]) -> u32 {
383 let mut best: Option<&OccupancyPoint> = Option::None;
384 for pt in points {
385 best = Some(match best {
386 Option::None => pt,
387 Some(prev) => {
388 if pt.occupancy > prev.occupancy
389 || (pt.occupancy == prev.occupancy && pt.block_size < prev.block_size)
390 {
391 pt
392 } else {
393 prev
394 }
395 }
396 });
397 }
398 best.map_or(0, |p| p.block_size)
399 }
400}
401
402pub struct DynamicSmemOccupancy;
408
409impl DynamicSmemOccupancy {
410 pub fn with_smem_function<F>(
412 calculator: &OccupancyCalculator,
413 smem_fn: F,
414 registers_per_thread: u32,
415 ) -> Vec<OccupancyPoint>
416 where
417 F: Fn(u32) -> u32,
418 {
419 let ws = calculator.info.warp_size;
420 if ws == 0 {
421 return Vec::new();
422 }
423 let max_threads = calculator.info.max_threads_per_sm;
424 let mut points = Vec::new();
425 let mut bs = ws;
426 while bs <= max_threads {
427 let smem = smem_fn(bs);
428 let est = calculator.estimate_occupancy(bs, registers_per_thread, smem);
429 points.push(OccupancyPoint {
430 block_size: bs,
431 occupancy: est.occupancy_ratio,
432 active_warps: est.active_warps_per_sm,
433 limiting_factor: est.limiting_factor,
434 });
435 bs += ws;
436 }
437 points
438 }
439
440 pub fn linear_smem(bytes_per_thread: u32) -> impl Fn(u32) -> u32 {
444 move |block_size: u32| block_size * bytes_per_thread
445 }
446
447 pub fn tile_smem(tile_size: u32, element_size: u32) -> impl Fn(u32) -> u32 {
451 move |_block_size: u32| tile_size * tile_size * element_size
452 }
453}
454
455#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
461pub struct ClusterConfig {
462 pub cluster_x: u32,
464 pub cluster_y: u32,
466 pub cluster_z: u32,
468}
469
470impl ClusterConfig {
471 pub fn total_blocks(&self) -> u32 {
473 self.cluster_x * self.cluster_y * self.cluster_z
474 }
475}
476
477#[derive(Debug, Clone, Copy)]
479pub struct ClusterOccupancyEstimate {
480 pub blocks_per_cluster: u32,
482 pub clusters_per_sm: u32,
484 pub effective_occupancy: f64,
486 pub cluster_smem_total: u32,
488}
489
490pub struct ClusterOccupancy;
492
493impl ClusterOccupancy {
494 pub fn estimate_cluster_occupancy(
508 calculator: &OccupancyCalculator,
509 block_size: u32,
510 cluster_size: u32,
511 registers_per_thread: u32,
512 shared_memory: u32,
513 ) -> ClusterOccupancyEstimate {
514 if cluster_size == 0 {
515 return ClusterOccupancyEstimate {
516 blocks_per_cluster: 0,
517 clusters_per_sm: 0,
518 effective_occupancy: 0.0,
519 cluster_smem_total: 0,
520 };
521 }
522
523 let est = calculator.estimate_occupancy(block_size, registers_per_thread, shared_memory);
525
526 let max_warps = est.max_warps_per_sm;
527 let warps_per_block = if calculator.info.warp_size == 0 {
528 0
529 } else {
530 block_size.div_ceil(calculator.info.warp_size)
531 };
532
533 let blocks_per_sm = est
535 .active_warps_per_sm
536 .checked_div(warps_per_block)
537 .unwrap_or(0);
538
539 let clusters_per_sm = blocks_per_sm / cluster_size;
541 let active_blocks = clusters_per_sm * cluster_size;
542 let active_warps = active_blocks * warps_per_block;
543
544 let effective_occupancy = if max_warps > 0 {
545 (active_warps.min(max_warps)) as f64 / max_warps as f64
546 } else {
547 0.0
548 };
549
550 ClusterOccupancyEstimate {
551 blocks_per_cluster: cluster_size,
552 clusters_per_sm,
553 effective_occupancy,
554 cluster_smem_total: cluster_size * shared_memory,
555 }
556 }
557}
558
559impl Device {
564 pub fn occupancy_info(&self) -> CudaResult<DeviceOccupancyInfo> {
575 #[cfg(target_os = "macos")]
577 {
578 let _ = self; Ok(DeviceOccupancyInfo {
580 sm_count: 84,
581 max_threads_per_sm: 1536,
582 max_blocks_per_sm: 16,
583 max_registers_per_sm: 65536,
584 max_shared_memory_per_sm: 102400,
585 warp_size: 32,
586 })
587 }
588
589 #[cfg(not(target_os = "macos"))]
590 {
591 let sm_count = self
592 .multiprocessor_count()
593 .map(|v| v as u32)
594 .map_err(|_| CudaError::NotInitialized)?;
595 let max_threads_per_sm = self
596 .max_threads_per_multiprocessor()
597 .map(|v| v as u32)
598 .map_err(|_| CudaError::NotInitialized)?;
599 let max_blocks_per_sm = self
600 .max_blocks_per_multiprocessor()
601 .map(|v| v as u32)
602 .map_err(|_| CudaError::NotInitialized)?;
603 let max_registers_per_sm = self
604 .max_registers_per_multiprocessor()
605 .map(|v| v as u32)
606 .map_err(|_| CudaError::NotInitialized)?;
607 let max_shared_memory_per_sm = self
608 .max_shared_memory_per_multiprocessor()
609 .map(|v| v as u32)
610 .map_err(|_| CudaError::NotInitialized)?;
611 let warp_size = self
612 .warp_size()
613 .map(|v| v as u32)
614 .map_err(|_| CudaError::NotInitialized)?;
615
616 Ok(DeviceOccupancyInfo {
617 sm_count,
618 max_threads_per_sm,
619 max_blocks_per_sm,
620 max_registers_per_sm,
621 max_shared_memory_per_sm,
622 warp_size,
623 })
624 }
625 }
626}
627
628#[cfg(test)]
633mod tests {
634 use super::*;
635
636 fn ampere_info() -> DeviceOccupancyInfo {
638 DeviceOccupancyInfo {
639 sm_count: 82,
640 max_threads_per_sm: 1536,
641 max_blocks_per_sm: 16,
642 max_registers_per_sm: 65536,
643 max_shared_memory_per_sm: 102400,
644 warp_size: 32,
645 }
646 }
647
648 #[test]
651 fn test_basic_occupancy_estimation() {
652 let calc = OccupancyCalculator::new(ampere_info());
653 let est = calc.estimate_occupancy(256, 32, 0);
654 assert_eq!(est.max_warps_per_sm, 48);
659 assert!(est.occupancy_ratio > 0.0);
660 assert!(est.active_warps_per_sm > 0);
661 }
662
663 #[test]
664 fn test_full_occupancy() {
665 let calc = OccupancyCalculator::new(ampere_info());
666 let est = calc.estimate_occupancy(32, 16, 0);
668 assert_eq!(est.active_warps_per_sm, 16);
671 }
672
673 #[test]
676 fn test_limiting_factor_threads() {
677 let calc = OccupancyCalculator::new(ampere_info());
678 let est = calc.estimate_occupancy(1024, 16, 0);
681 assert_eq!(est.limiting_factor, LimitingFactor::Threads);
682 }
683
684 #[test]
685 fn test_limiting_factor_registers() {
686 let calc = OccupancyCalculator::new(ampere_info());
687 let est = calc.estimate_occupancy(256, 128, 0);
691 assert_eq!(est.limiting_factor, LimitingFactor::Registers);
692 }
693
694 #[test]
695 fn test_limiting_factor_shared_memory() {
696 let calc = OccupancyCalculator::new(ampere_info());
697 let est = calc.estimate_occupancy(128, 16, 51200);
701 assert_eq!(est.limiting_factor, LimitingFactor::SharedMemory);
702 }
703
704 #[test]
705 fn test_limiting_factor_blocks() {
706 let info = DeviceOccupancyInfo {
707 max_blocks_per_sm: 4,
708 ..ampere_info()
709 };
710 let calc = OccupancyCalculator::new(info);
711 let est = calc.estimate_occupancy(64, 16, 0);
714 assert_eq!(est.limiting_factor, LimitingFactor::Blocks);
715 }
716
717 #[test]
718 fn test_limiting_factor_none_zero_block() {
719 let calc = OccupancyCalculator::new(ampere_info());
720 let est = calc.estimate_occupancy(0, 32, 0);
721 assert_eq!(est.limiting_factor, LimitingFactor::None);
722 assert_eq!(est.active_warps_per_sm, 0);
723 assert_eq!(est.occupancy_ratio, 0.0);
724 }
725
726 #[test]
729 fn test_sweep_returns_points() {
730 let calc = OccupancyCalculator::new(ampere_info());
731 let points = OccupancyGrid::sweep(&calc, 32, 0);
732 assert_eq!(points.len(), 48);
734 assert_eq!(points[0].block_size, 32);
735 assert_eq!(points[47].block_size, 1536);
736 }
737
738 #[test]
739 fn test_best_block_size() {
740 let calc = OccupancyCalculator::new(ampere_info());
741 let points = OccupancyGrid::sweep(&calc, 32, 0);
742 let best = OccupancyGrid::best_block_size(&points);
743 assert!(best > 0);
745 assert_eq!(best % 32, 0);
746 }
747
748 #[test]
749 fn test_best_block_size_empty() {
750 assert_eq!(OccupancyGrid::best_block_size(&[]), 0);
751 }
752
753 #[test]
756 fn test_dynamic_smem_linear() {
757 let calc = OccupancyCalculator::new(ampere_info());
758 let smem_fn = DynamicSmemOccupancy::linear_smem(8); let points = DynamicSmemOccupancy::with_smem_function(&calc, smem_fn, 32);
760 assert!(!points.is_empty());
761 let first_occ = points[0].occupancy;
765 let last_occ = points[points.len() - 1].occupancy;
766 assert!((0.0..=1.0).contains(&first_occ));
768 assert!((0.0..=1.0).contains(&last_occ));
769 }
770
771 #[test]
772 fn test_dynamic_smem_tile() {
773 let calc = OccupancyCalculator::new(ampere_info());
774 let smem_fn = DynamicSmemOccupancy::tile_smem(16, 4); let points = DynamicSmemOccupancy::with_smem_function(&calc, smem_fn, 32);
776 assert!(!points.is_empty());
778 }
779
780 #[test]
783 fn test_cluster_occupancy_basic() {
784 let calc = OccupancyCalculator::new(ampere_info());
785 let result = ClusterOccupancy::estimate_cluster_occupancy(&calc, 128, 2, 32, 4096);
786 assert_eq!(result.blocks_per_cluster, 2);
787 assert!(result.effective_occupancy >= 0.0 && result.effective_occupancy <= 1.0);
788 assert_eq!(result.cluster_smem_total, 2 * 4096);
789 }
790
791 #[test]
792 fn test_cluster_occupancy_zero_cluster() {
793 let calc = OccupancyCalculator::new(ampere_info());
794 let result = ClusterOccupancy::estimate_cluster_occupancy(&calc, 128, 0, 32, 0);
795 assert_eq!(result.clusters_per_sm, 0);
796 assert_eq!(result.effective_occupancy, 0.0);
797 }
798
799 #[test]
802 fn test_cluster_config_total_blocks() {
803 let cfg = ClusterConfig {
804 cluster_x: 2,
805 cluster_y: 3,
806 cluster_z: 4,
807 };
808 assert_eq!(cfg.total_blocks(), 24);
809 }
810
811 #[test]
814 fn test_block_size_exceeds_max() {
815 let calc = OccupancyCalculator::new(ampere_info());
816 let est = calc.estimate_occupancy(2048, 32, 0);
819 assert_eq!(est.active_warps_per_sm, 0);
821 assert_eq!(est.occupancy_ratio, 0.0);
822 }
823
824 fn sm100_info() -> DeviceOccupancyInfo {
827 DeviceOccupancyInfo::for_compute_capability(10, 0)
828 }
829
830 fn sm120_info() -> DeviceOccupancyInfo {
831 DeviceOccupancyInfo::for_compute_capability(12, 0)
832 }
833
834 #[test]
835 fn test_sm100_device_info_attributes() {
836 let info = sm100_info();
837 assert_eq!(info.sm_count, 132, "Blackwell B100 has 132 SMs");
838 assert_eq!(info.max_threads_per_sm, 2048);
839 assert_eq!(info.max_blocks_per_sm, 32);
840 assert_eq!(info.max_shared_memory_per_sm, 262144, "256 KiB shared/SM");
841 assert_eq!(info.warp_size, 32);
842 }
843
844 #[test]
845 fn test_sm120_device_info_attributes() {
846 let info = sm120_info();
847 assert_eq!(info.sm_count, 148, "Blackwell B200 has 148 SMs");
848 assert_eq!(info.max_threads_per_sm, 2048);
849 assert_eq!(info.max_blocks_per_sm, 32);
850 assert_eq!(info.max_shared_memory_per_sm, 262144, "256 KiB shared/SM");
851 assert_eq!(info.warp_size, 32);
852 }
853
854 #[test]
855 fn test_sm100_occupancy_estimation() {
856 let calc = OccupancyCalculator::new(sm100_info());
857 let est = calc.estimate_occupancy(256, 0, 0);
861 assert!(
862 est.occupancy_ratio > 0.0,
863 "Blackwell B100 must report positive occupancy"
864 );
865 assert!(
866 est.active_warps_per_sm <= 64,
867 "Active warps must not exceed hardware limit"
868 );
869 }
870
871 #[test]
872 fn test_sm120_full_occupancy() {
873 let calc = OccupancyCalculator::new(sm120_info());
874 let est = calc.estimate_occupancy(64, 0, 0);
877 assert_eq!(est.occupancy_ratio, 1.0, "Should reach full occupancy");
878 assert_eq!(est.active_warps_per_sm, 64);
879 }
880
881 #[test]
882 fn test_sm100_large_shared_memory_limit() {
883 let calc = OccupancyCalculator::new(sm100_info());
884 let smem_per_block = 131_072u32;
886 let est = calc.estimate_occupancy(1024, 0, smem_per_block);
887 assert!(
889 matches!(est.limiting_factor, LimitingFactor::SharedMemory),
890 "Large smem must be the bottleneck"
891 );
892 }
893
894 #[test]
895 fn test_for_compute_capability_unknown_falls_back() {
896 let info = DeviceOccupancyInfo::for_compute_capability(99, 99);
898 let calc = OccupancyCalculator::new(info);
899 let est = calc.estimate_occupancy(256, 0, 0);
900 assert!(est.occupancy_ratio > 0.0);
901 }
902
903 #[test]
904 fn test_sm100_vs_sm90_shared_memory_capacity() {
905 let hopper = DeviceOccupancyInfo::for_compute_capability(9, 0);
906 let blackwell = sm100_info();
907 assert!(
909 blackwell.max_shared_memory_per_sm > hopper.max_shared_memory_per_sm,
910 "Blackwell B100 must have larger smem than Hopper H100"
911 );
912 }
913
914 #[test]
915 fn test_sm120_vs_sm100_sm_count() {
916 let b100 = sm100_info();
917 let b200 = sm120_info();
918 assert!(
920 b200.sm_count > b100.sm_count,
921 "Blackwell B200 must have more SMs than B100"
922 );
923 }
924
925 #[test]
926 fn test_for_compute_capability_all_known_arches() {
927 let arches = [(7, 5), (8, 0), (8, 6), (8, 9), (9, 0), (10, 0), (12, 0)];
930 for (major, minor) in arches {
931 let info = DeviceOccupancyInfo::for_compute_capability(major, minor);
932 assert_eq!(info.warp_size, 32, "sm_{major}{minor} warp_size must be 32");
933 assert!(info.sm_count > 0);
934 assert!(info.max_threads_per_sm > 0);
935 }
936 }
937}