use crate::device::Device;
#[cfg(not(target_os = "macos"))]
use crate::error::CudaError;
use crate::error::CudaResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DeviceOccupancyInfo {
pub sm_count: u32,
pub max_threads_per_sm: u32,
pub max_blocks_per_sm: u32,
pub max_registers_per_sm: u32,
pub max_shared_memory_per_sm: u32,
pub warp_size: u32,
}
impl DeviceOccupancyInfo {
fn max_warps_per_sm(&self) -> u32 {
if self.warp_size == 0 {
return 0;
}
self.max_threads_per_sm / self.warp_size
}
#[must_use]
pub fn for_compute_capability(sm_major: u32, sm_minor: u32) -> Self {
match (sm_major, sm_minor) {
(7, 5) => Self {
sm_count: 68,
max_threads_per_sm: 1024,
max_blocks_per_sm: 16,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 65536,
warp_size: 32,
},
(8, 0) => Self {
sm_count: 108,
max_threads_per_sm: 2048,
max_blocks_per_sm: 32,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 167936,
warp_size: 32,
},
(8, 6) => Self {
sm_count: 84,
max_threads_per_sm: 1536,
max_blocks_per_sm: 16,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 102400,
warp_size: 32,
},
(8, 9) => Self {
sm_count: 76,
max_threads_per_sm: 1536,
max_blocks_per_sm: 24,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 101376,
warp_size: 32,
},
(9, 0) => Self {
sm_count: 132,
max_threads_per_sm: 2048,
max_blocks_per_sm: 32,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 232448,
warp_size: 32,
},
(10, 0) => Self {
sm_count: 132,
max_threads_per_sm: 2048,
max_blocks_per_sm: 32,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 262144,
warp_size: 32,
},
(12, 0) => Self {
sm_count: 148,
max_threads_per_sm: 2048,
max_blocks_per_sm: 32,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 262144,
warp_size: 32,
},
_ => Self {
sm_count: 84,
max_threads_per_sm: 1536,
max_blocks_per_sm: 16,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 102400,
warp_size: 32,
},
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LimitingFactor {
Threads,
Registers,
SharedMemory,
Blocks,
None,
}
#[derive(Debug, Clone, Copy)]
pub struct OccupancyEstimate {
pub active_warps_per_sm: u32,
pub max_warps_per_sm: u32,
pub occupancy_ratio: f64,
pub limiting_factor: LimitingFactor,
}
#[derive(Debug, Clone)]
pub struct OccupancyCalculator {
info: DeviceOccupancyInfo,
}
impl OccupancyCalculator {
pub fn new(device_info: DeviceOccupancyInfo) -> Self {
Self { info: device_info }
}
pub fn device_info(&self) -> &DeviceOccupancyInfo {
&self.info
}
pub fn estimate_occupancy(
&self,
block_size: u32,
registers_per_thread: u32,
shared_memory: u32,
) -> OccupancyEstimate {
let max_warps = self.info.max_warps_per_sm();
if block_size == 0 || self.info.warp_size == 0 || max_warps == 0 {
return OccupancyEstimate {
active_warps_per_sm: 0,
max_warps_per_sm: max_warps,
occupancy_ratio: 0.0,
limiting_factor: LimitingFactor::None,
};
}
let warps_per_block = block_size.div_ceil(self.info.warp_size);
let blocks_by_block_limit = self.info.max_blocks_per_sm;
let blocks_by_threads = max_warps.checked_div(warps_per_block).unwrap_or(0);
let blocks_by_registers = if registers_per_thread == 0 || warps_per_block == 0 {
u32::MAX } else {
let regs_per_block = registers_per_thread * warps_per_block * self.info.warp_size;
self.info
.max_registers_per_sm
.checked_div(regs_per_block)
.unwrap_or(u32::MAX)
};
let blocks_by_smem = if shared_memory == 0 {
u32::MAX } else if self.info.max_shared_memory_per_sm == 0 {
0
} else {
self.info.max_shared_memory_per_sm / shared_memory
};
let active_blocks = blocks_by_block_limit
.min(blocks_by_threads)
.min(blocks_by_registers)
.min(blocks_by_smem);
let active_warps = active_blocks * warps_per_block;
let clamped_warps = active_warps.min(max_warps);
let ratio = if max_warps > 0 {
clamped_warps as f64 / max_warps as f64
} else {
0.0
};
let effective = active_blocks;
let limiting_factor = if effective == 0 {
if blocks_by_smem == 0 {
LimitingFactor::SharedMemory
} else if blocks_by_registers == 0 {
LimitingFactor::Registers
} else if blocks_by_threads == 0 {
LimitingFactor::Threads
} else {
LimitingFactor::Blocks
}
} else if effective == blocks_by_smem
&& blocks_by_smem
<= blocks_by_registers
.min(blocks_by_threads)
.min(blocks_by_block_limit)
{
LimitingFactor::SharedMemory
} else if effective == blocks_by_registers
&& blocks_by_registers <= blocks_by_threads.min(blocks_by_block_limit)
{
LimitingFactor::Registers
} else if effective == blocks_by_threads && blocks_by_threads <= blocks_by_block_limit {
LimitingFactor::Threads
} else if effective == blocks_by_block_limit {
LimitingFactor::Blocks
} else {
LimitingFactor::None
};
OccupancyEstimate {
active_warps_per_sm: clamped_warps,
max_warps_per_sm: max_warps,
occupancy_ratio: ratio,
limiting_factor,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct OccupancyPoint {
pub block_size: u32,
pub occupancy: f64,
pub active_warps: u32,
pub limiting_factor: LimitingFactor,
}
pub struct OccupancyGrid;
impl OccupancyGrid {
pub fn sweep(
calculator: &OccupancyCalculator,
registers_per_thread: u32,
shared_memory: u32,
) -> Vec<OccupancyPoint> {
let ws = calculator.info.warp_size;
if ws == 0 {
return Vec::new();
}
let max_threads = calculator.info.max_threads_per_sm;
let mut points = Vec::new();
let mut bs = ws;
while bs <= max_threads {
let est = calculator.estimate_occupancy(bs, registers_per_thread, shared_memory);
points.push(OccupancyPoint {
block_size: bs,
occupancy: est.occupancy_ratio,
active_warps: est.active_warps_per_sm,
limiting_factor: est.limiting_factor,
});
bs += ws;
}
points
}
pub fn best_block_size(points: &[OccupancyPoint]) -> u32 {
let mut best: Option<&OccupancyPoint> = Option::None;
for pt in points {
best = Some(match best {
Option::None => pt,
Some(prev) => {
if pt.occupancy > prev.occupancy
|| (pt.occupancy == prev.occupancy && pt.block_size < prev.block_size)
{
pt
} else {
prev
}
}
});
}
best.map_or(0, |p| p.block_size)
}
}
pub struct DynamicSmemOccupancy;
impl DynamicSmemOccupancy {
pub fn with_smem_function<F>(
calculator: &OccupancyCalculator,
smem_fn: F,
registers_per_thread: u32,
) -> Vec<OccupancyPoint>
where
F: Fn(u32) -> u32,
{
let ws = calculator.info.warp_size;
if ws == 0 {
return Vec::new();
}
let max_threads = calculator.info.max_threads_per_sm;
let mut points = Vec::new();
let mut bs = ws;
while bs <= max_threads {
let smem = smem_fn(bs);
let est = calculator.estimate_occupancy(bs, registers_per_thread, smem);
points.push(OccupancyPoint {
block_size: bs,
occupancy: est.occupancy_ratio,
active_warps: est.active_warps_per_sm,
limiting_factor: est.limiting_factor,
});
bs += ws;
}
points
}
pub fn linear_smem(bytes_per_thread: u32) -> impl Fn(u32) -> u32 {
move |block_size: u32| block_size * bytes_per_thread
}
pub fn tile_smem(tile_size: u32, element_size: u32) -> impl Fn(u32) -> u32 {
move |_block_size: u32| tile_size * tile_size * element_size
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ClusterConfig {
pub cluster_x: u32,
pub cluster_y: u32,
pub cluster_z: u32,
}
impl ClusterConfig {
pub fn total_blocks(&self) -> u32 {
self.cluster_x * self.cluster_y * self.cluster_z
}
}
#[derive(Debug, Clone, Copy)]
pub struct ClusterOccupancyEstimate {
pub blocks_per_cluster: u32,
pub clusters_per_sm: u32,
pub effective_occupancy: f64,
pub cluster_smem_total: u32,
}
pub struct ClusterOccupancy;
impl ClusterOccupancy {
pub fn estimate_cluster_occupancy(
calculator: &OccupancyCalculator,
block_size: u32,
cluster_size: u32,
registers_per_thread: u32,
shared_memory: u32,
) -> ClusterOccupancyEstimate {
if cluster_size == 0 {
return ClusterOccupancyEstimate {
blocks_per_cluster: 0,
clusters_per_sm: 0,
effective_occupancy: 0.0,
cluster_smem_total: 0,
};
}
let est = calculator.estimate_occupancy(block_size, registers_per_thread, shared_memory);
let max_warps = est.max_warps_per_sm;
let warps_per_block = if calculator.info.warp_size == 0 {
0
} else {
block_size.div_ceil(calculator.info.warp_size)
};
let blocks_per_sm = est
.active_warps_per_sm
.checked_div(warps_per_block)
.unwrap_or(0);
let clusters_per_sm = blocks_per_sm / cluster_size;
let active_blocks = clusters_per_sm * cluster_size;
let active_warps = active_blocks * warps_per_block;
let effective_occupancy = if max_warps > 0 {
(active_warps.min(max_warps)) as f64 / max_warps as f64
} else {
0.0
};
ClusterOccupancyEstimate {
blocks_per_cluster: cluster_size,
clusters_per_sm,
effective_occupancy,
cluster_smem_total: cluster_size * shared_memory,
}
}
}
impl Device {
pub fn occupancy_info(&self) -> CudaResult<DeviceOccupancyInfo> {
#[cfg(target_os = "macos")]
{
let _ = self; Ok(DeviceOccupancyInfo {
sm_count: 84,
max_threads_per_sm: 1536,
max_blocks_per_sm: 16,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 102400,
warp_size: 32,
})
}
#[cfg(not(target_os = "macos"))]
{
let sm_count = self
.multiprocessor_count()
.map(|v| v as u32)
.map_err(|_| CudaError::NotInitialized)?;
let max_threads_per_sm = self
.max_threads_per_multiprocessor()
.map(|v| v as u32)
.map_err(|_| CudaError::NotInitialized)?;
let max_blocks_per_sm = self
.max_blocks_per_multiprocessor()
.map(|v| v as u32)
.map_err(|_| CudaError::NotInitialized)?;
let max_registers_per_sm = self
.max_registers_per_multiprocessor()
.map(|v| v as u32)
.map_err(|_| CudaError::NotInitialized)?;
let max_shared_memory_per_sm = self
.max_shared_memory_per_multiprocessor()
.map(|v| v as u32)
.map_err(|_| CudaError::NotInitialized)?;
let warp_size = self
.warp_size()
.map(|v| v as u32)
.map_err(|_| CudaError::NotInitialized)?;
Ok(DeviceOccupancyInfo {
sm_count,
max_threads_per_sm,
max_blocks_per_sm,
max_registers_per_sm,
max_shared_memory_per_sm,
warp_size,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ampere_info() -> DeviceOccupancyInfo {
DeviceOccupancyInfo {
sm_count: 82,
max_threads_per_sm: 1536,
max_blocks_per_sm: 16,
max_registers_per_sm: 65536,
max_shared_memory_per_sm: 102400,
warp_size: 32,
}
}
#[test]
fn test_basic_occupancy_estimation() {
let calc = OccupancyCalculator::new(ampere_info());
let est = calc.estimate_occupancy(256, 32, 0);
assert_eq!(est.max_warps_per_sm, 48);
assert!(est.occupancy_ratio > 0.0);
assert!(est.active_warps_per_sm > 0);
}
#[test]
fn test_full_occupancy() {
let calc = OccupancyCalculator::new(ampere_info());
let est = calc.estimate_occupancy(32, 16, 0);
assert_eq!(est.active_warps_per_sm, 16);
}
#[test]
fn test_limiting_factor_threads() {
let calc = OccupancyCalculator::new(ampere_info());
let est = calc.estimate_occupancy(1024, 16, 0);
assert_eq!(est.limiting_factor, LimitingFactor::Threads);
}
#[test]
fn test_limiting_factor_registers() {
let calc = OccupancyCalculator::new(ampere_info());
let est = calc.estimate_occupancy(256, 128, 0);
assert_eq!(est.limiting_factor, LimitingFactor::Registers);
}
#[test]
fn test_limiting_factor_shared_memory() {
let calc = OccupancyCalculator::new(ampere_info());
let est = calc.estimate_occupancy(128, 16, 51200);
assert_eq!(est.limiting_factor, LimitingFactor::SharedMemory);
}
#[test]
fn test_limiting_factor_blocks() {
let info = DeviceOccupancyInfo {
max_blocks_per_sm: 4,
..ampere_info()
};
let calc = OccupancyCalculator::new(info);
let est = calc.estimate_occupancy(64, 16, 0);
assert_eq!(est.limiting_factor, LimitingFactor::Blocks);
}
#[test]
fn test_limiting_factor_none_zero_block() {
let calc = OccupancyCalculator::new(ampere_info());
let est = calc.estimate_occupancy(0, 32, 0);
assert_eq!(est.limiting_factor, LimitingFactor::None);
assert_eq!(est.active_warps_per_sm, 0);
assert_eq!(est.occupancy_ratio, 0.0);
}
#[test]
fn test_sweep_returns_points() {
let calc = OccupancyCalculator::new(ampere_info());
let points = OccupancyGrid::sweep(&calc, 32, 0);
assert_eq!(points.len(), 48);
assert_eq!(points[0].block_size, 32);
assert_eq!(points[47].block_size, 1536);
}
#[test]
fn test_best_block_size() {
let calc = OccupancyCalculator::new(ampere_info());
let points = OccupancyGrid::sweep(&calc, 32, 0);
let best = OccupancyGrid::best_block_size(&points);
assert!(best > 0);
assert_eq!(best % 32, 0);
}
#[test]
fn test_best_block_size_empty() {
assert_eq!(OccupancyGrid::best_block_size(&[]), 0);
}
#[test]
fn test_dynamic_smem_linear() {
let calc = OccupancyCalculator::new(ampere_info());
let smem_fn = DynamicSmemOccupancy::linear_smem(8); let points = DynamicSmemOccupancy::with_smem_function(&calc, smem_fn, 32);
assert!(!points.is_empty());
let first_occ = points[0].occupancy;
let last_occ = points[points.len() - 1].occupancy;
assert!((0.0..=1.0).contains(&first_occ));
assert!((0.0..=1.0).contains(&last_occ));
}
#[test]
fn test_dynamic_smem_tile() {
let calc = OccupancyCalculator::new(ampere_info());
let smem_fn = DynamicSmemOccupancy::tile_smem(16, 4); let points = DynamicSmemOccupancy::with_smem_function(&calc, smem_fn, 32);
assert!(!points.is_empty());
}
#[test]
fn test_cluster_occupancy_basic() {
let calc = OccupancyCalculator::new(ampere_info());
let result = ClusterOccupancy::estimate_cluster_occupancy(&calc, 128, 2, 32, 4096);
assert_eq!(result.blocks_per_cluster, 2);
assert!(result.effective_occupancy >= 0.0 && result.effective_occupancy <= 1.0);
assert_eq!(result.cluster_smem_total, 2 * 4096);
}
#[test]
fn test_cluster_occupancy_zero_cluster() {
let calc = OccupancyCalculator::new(ampere_info());
let result = ClusterOccupancy::estimate_cluster_occupancy(&calc, 128, 0, 32, 0);
assert_eq!(result.clusters_per_sm, 0);
assert_eq!(result.effective_occupancy, 0.0);
}
#[test]
fn test_cluster_config_total_blocks() {
let cfg = ClusterConfig {
cluster_x: 2,
cluster_y: 3,
cluster_z: 4,
};
assert_eq!(cfg.total_blocks(), 24);
}
#[test]
fn test_block_size_exceeds_max() {
let calc = OccupancyCalculator::new(ampere_info());
let est = calc.estimate_occupancy(2048, 32, 0);
assert_eq!(est.active_warps_per_sm, 0);
assert_eq!(est.occupancy_ratio, 0.0);
}
fn sm100_info() -> DeviceOccupancyInfo {
DeviceOccupancyInfo::for_compute_capability(10, 0)
}
fn sm120_info() -> DeviceOccupancyInfo {
DeviceOccupancyInfo::for_compute_capability(12, 0)
}
#[test]
fn test_sm100_device_info_attributes() {
let info = sm100_info();
assert_eq!(info.sm_count, 132, "Blackwell B100 has 132 SMs");
assert_eq!(info.max_threads_per_sm, 2048);
assert_eq!(info.max_blocks_per_sm, 32);
assert_eq!(info.max_shared_memory_per_sm, 262144, "256 KiB shared/SM");
assert_eq!(info.warp_size, 32);
}
#[test]
fn test_sm120_device_info_attributes() {
let info = sm120_info();
assert_eq!(info.sm_count, 148, "Blackwell B200 has 148 SMs");
assert_eq!(info.max_threads_per_sm, 2048);
assert_eq!(info.max_blocks_per_sm, 32);
assert_eq!(info.max_shared_memory_per_sm, 262144, "256 KiB shared/SM");
assert_eq!(info.warp_size, 32);
}
#[test]
fn test_sm100_occupancy_estimation() {
let calc = OccupancyCalculator::new(sm100_info());
let est = calc.estimate_occupancy(256, 0, 0);
assert!(
est.occupancy_ratio > 0.0,
"Blackwell B100 must report positive occupancy"
);
assert!(
est.active_warps_per_sm <= 64,
"Active warps must not exceed hardware limit"
);
}
#[test]
fn test_sm120_full_occupancy() {
let calc = OccupancyCalculator::new(sm120_info());
let est = calc.estimate_occupancy(64, 0, 0);
assert_eq!(est.occupancy_ratio, 1.0, "Should reach full occupancy");
assert_eq!(est.active_warps_per_sm, 64);
}
#[test]
fn test_sm100_large_shared_memory_limit() {
let calc = OccupancyCalculator::new(sm100_info());
let smem_per_block = 131_072u32;
let est = calc.estimate_occupancy(1024, 0, smem_per_block);
assert!(
matches!(est.limiting_factor, LimitingFactor::SharedMemory),
"Large smem must be the bottleneck"
);
}
#[test]
fn test_for_compute_capability_unknown_falls_back() {
let info = DeviceOccupancyInfo::for_compute_capability(99, 99);
let calc = OccupancyCalculator::new(info);
let est = calc.estimate_occupancy(256, 0, 0);
assert!(est.occupancy_ratio > 0.0);
}
#[test]
fn test_sm100_vs_sm90_shared_memory_capacity() {
let hopper = DeviceOccupancyInfo::for_compute_capability(9, 0);
let blackwell = sm100_info();
assert!(
blackwell.max_shared_memory_per_sm > hopper.max_shared_memory_per_sm,
"Blackwell B100 must have larger smem than Hopper H100"
);
}
#[test]
fn test_sm120_vs_sm100_sm_count() {
let b100 = sm100_info();
let b200 = sm120_info();
assert!(
b200.sm_count > b100.sm_count,
"Blackwell B200 must have more SMs than B100"
);
}
#[test]
fn test_for_compute_capability_all_known_arches() {
let arches = [(7, 5), (8, 0), (8, 6), (8, 9), (9, 0), (10, 0), (12, 0)];
for (major, minor) in arches {
let info = DeviceOccupancyInfo::for_compute_capability(major, minor);
assert_eq!(info.warp_size, 32, "sm_{major}{minor} warp_size must be 32");
assert!(info.sm_count > 0);
assert!(info.max_threads_per_sm > 0);
}
}
}