use std::ffi::c_int;
use crate::error::{CudaError, CudaResult};
use crate::ffi::CUfunction_attribute;
use crate::loader::try_driver;
use crate::module::Function;
fn get_attribute(func: &Function, attrib: CUfunction_attribute) -> CudaResult<i32> {
let api = try_driver()?;
let f = api.cu_func_get_attribute.ok_or(CudaError::NotSupported)?;
let mut value: c_int = 0;
crate::cuda_call!(f(&mut value, attrib as i32, func.raw()))?;
Ok(value)
}
fn set_attribute(func: &Function, attrib: CUfunction_attribute, value: i32) -> CudaResult<()> {
let api = try_driver()?;
let f = api.cu_func_set_attribute.ok_or(CudaError::NotSupported)?;
crate::cuda_call!(f(func.raw(), attrib as i32, value))
}
impl Function {
pub fn num_registers(&self) -> CudaResult<i32> {
get_attribute(self, CUfunction_attribute::NumRegs)
}
pub fn shared_memory_bytes(&self) -> CudaResult<i32> {
get_attribute(self, CUfunction_attribute::SharedSizeBytes)
}
pub fn max_threads_per_block_attr(&self) -> CudaResult<i32> {
get_attribute(self, CUfunction_attribute::MaxThreadsPerBlock)
}
pub fn local_memory_bytes(&self) -> CudaResult<i32> {
get_attribute(self, CUfunction_attribute::LocalSizeBytes)
}
pub fn ptx_version(&self) -> CudaResult<i32> {
get_attribute(self, CUfunction_attribute::PtxVersion)
}
pub fn binary_version(&self) -> CudaResult<i32> {
get_attribute(self, CUfunction_attribute::BinaryVersion)
}
pub fn max_dynamic_shared_memory(&self) -> CudaResult<i32> {
get_attribute(self, CUfunction_attribute::MaxDynamicSharedSizeBytes)
}
pub fn set_max_dynamic_shared_memory(&self, bytes: i32) -> CudaResult<()> {
set_attribute(self, CUfunction_attribute::MaxDynamicSharedSizeBytes, bytes)
}
pub fn set_preferred_shared_memory_carveout(&self, percent: i32) -> CudaResult<()> {
set_attribute(
self,
CUfunction_attribute::PreferredSharedMemoryCarveout,
percent,
)
}
}
#[cfg(test)]
mod tests {
#[test]
fn function_attribute_enum_values() {
use crate::ffi::CUfunction_attribute;
assert_eq!(CUfunction_attribute::MaxThreadsPerBlock as i32, 0);
assert_eq!(CUfunction_attribute::NumRegs as i32, 4);
assert_eq!(CUfunction_attribute::PtxVersion as i32, 5);
assert_eq!(CUfunction_attribute::MaxDynamicSharedSizeBytes as i32, 8);
}
}