use std::sync::Arc;
#[cfg(feature = "gpu")]
use cudarc::driver::CudaContext;
use crate::error::{DbxError, DbxResult};
#[derive(Debug, Clone)]
pub struct OccupancyParams {
pub registers_per_thread: usize,
pub shared_mem_per_block: usize,
pub threads_per_block: usize,
}
#[cfg(feature = "gpu")]
pub struct OccupancyCalculator {
device: Arc<CudaContext>,
max_threads_per_block: usize,
max_shared_mem_per_block: usize,
max_registers_per_block: usize,
}
#[cfg(feature = "gpu")]
impl OccupancyCalculator {
pub fn new(device: Arc<CudaContext>) -> DbxResult<Self> {
let max_threads_per_block = 1024; let max_shared_mem_per_block = 48 * 1024; let max_registers_per_block = 65536;
Ok(Self {
device,
max_threads_per_block,
max_shared_mem_per_block,
max_registers_per_block,
})
}
pub fn calculate_optimal_block_size(&self, params: &OccupancyParams) -> DbxResult<usize> {
let mut block_size = self.max_threads_per_block;
if params.shared_mem_per_block > 0 {
let max_blocks_by_shmem = self.max_shared_mem_per_block / params.shared_mem_per_block;
let max_threads_by_shmem = max_blocks_by_shmem * params.threads_per_block;
block_size = block_size.min(max_threads_by_shmem);
}
if params.registers_per_thread > 0 {
let max_threads_by_regs = self.max_registers_per_block / params.registers_per_thread;
block_size = block_size.min(max_threads_by_regs);
}
block_size = (block_size / 32) * 32;
if block_size < 32 {
return Err(DbxError::Gpu(
"Insufficient resources for kernel execution".to_string(),
));
}
Ok(block_size)
}
pub fn calculate_occupancy(&self, params: &OccupancyParams) -> DbxResult<f64> {
let optimal_block_size = self.calculate_optimal_block_size(params)?;
let occupancy = optimal_block_size as f64 / self.max_threads_per_block as f64;
Ok(occupancy * 100.0)
}
}
#[cfg(not(feature = "gpu"))]
pub struct OccupancyCalculator;
#[cfg(not(feature = "gpu"))]
pub struct OccupancyParams {
pub registers_per_thread: usize,
pub shared_mem_per_block: usize,
pub threads_per_block: usize,
}
#[cfg(not(feature = "gpu"))]
impl OccupancyCalculator {
pub fn new(_device: ()) -> DbxResult<Self> {
Err(DbxError::NotImplemented(
"GPU acceleration is not enabled".to_string(),
))
}
}