use std::{
borrow::Cow,
ffi::CString,
fmt::{self, Display, Formatter},
marker::PhantomData,
mem::{ManuallyDrop, MaybeUninit, align_of, size_of},
ptr,
sync::Arc,
};
use singe_cuda_sys::driver;
use crate::{
context::Context,
dim::Dim3,
error::{Error, Result},
graph::{ExecutableGraph, Graph, GraphNode},
kernel::{self, ModuleKernelHandle},
memory::{DeviceMemory, ManagedMemory},
stream::{GraphRecordable, Stream, StreamCaptureScope},
try_ffi,
types::{DeviceFunction, FunctionAttribute, SharedMemoryCarveout},
utility::{to_u32, to_u64},
view::{DeviceRepr, DeviceSlice, DeviceSliceMut, DeviceView, DeviceViewMut},
};
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct OccupancyFlags: u32 {
const DEFAULT = driver::CUoccupancy_flags::CU_OCCUPANCY_DEFAULT as _;
const DISABLE_CACHING_OVERRIDE = driver::CUoccupancy_flags::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE as _;
}
}
impl Display for OccupancyFlags {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if self.is_empty() {
return Ok(());
}
let mut first = true;
let write_sep = |f: &mut Formatter<'_>, first: &mut bool, name: &str| -> fmt::Result {
if *first {
*first = false;
} else {
f.write_str(" | ")?;
}
f.write_str(name)
};
if self.contains(Self::DEFAULT) {
write_sep(f, &mut first, "CU_OCCUPANCY_DEFAULT")?;
}
if self.contains(Self::DISABLE_CACHING_OVERRIDE) {
write_sep(f, &mut first, "CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE")?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FunctionAttributes {
pub shared_size_bytes: usize,
pub const_size_bytes: usize,
pub local_size_bytes: usize,
pub max_threads_per_block: i32,
pub num_regs: i32,
pub ptx_version: i32,
pub binary_version: i32,
pub cache_mode_ca: bool,
pub max_dynamic_shared_size_bytes: i32,
pub preferred_shared_memory_carveout: i32,
pub cluster_dim_must_be_set: bool,
pub required_cluster_width: i32,
pub required_cluster_height: i32,
pub required_cluster_depth: i32,
pub cluster_scheduling_policy_preference: i32,
pub non_portable_cluster_size_allowed: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OccupancyMaxPotentialBlockSize {
pub min_grid_size: i32,
pub block_size: i32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ClusterLaunchConfig {
grid_dim: Dim3,
block_dim: Dim3,
shared_memory_bytes: usize,
}
#[derive(Debug)]
pub struct Module {
handle: driver::CUmodule,
ctx: Arc<Context>,
owns_handle: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct Global<'a> {
ptr: *mut (),
size: usize,
_module: &'a Module,
}
#[derive(Debug, Clone, Copy)]
pub struct TextureReference<'a> {
handle: driver::CUtexref,
_module: &'a Module,
}
#[derive(Debug, Clone, Copy)]
pub struct SurfaceReference<'a> {
handle: driver::CUsurfref,
_module: &'a Module,
}
#[derive(Debug, Clone)]
pub struct ModuleImage<'a> {
data: Cow<'a, [u8]>,
}
#[derive(Debug)]
pub struct KernelFunction<'a> {
handle: DeviceFunction,
module: &'a Module,
}
#[derive(Debug)]
pub struct KernelLaunchOperation<'kernel, 'config, P> {
function: &'kernel KernelFunction<'kernel>,
config: &'config LaunchConfig,
params: P,
}
#[derive(Debug, Clone)]
pub struct LaunchConfig {
grid_dim: Dim3,
block_dim: Dim3,
shared_memory_bytes: usize,
}
pub struct KernelParameters<'a> {
arguments: Vec<KernelParameter<'a>>,
}
const INLINE_KERNEL_ARGUMENTS: usize = 16;
const INLINE_KERNEL_ARGUMENT_BYTES: usize = 16;
mod private {
pub trait Sealed {}
}
pub trait PushKernelArg {
fn push_to<'a>(self, params: &mut KernelParameters<'a>);
}
pub trait KernelLaunchArgs<'a>: private::Sealed {
#[doc(hidden)]
fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R;
}
#[doc(hidden)]
pub struct EncodedKernelArgs<'a> {
pointers: &'a mut [*mut ()],
}
trait KernelTupleArgument<'a> {
fn into_kernel_argument_ptr(self) -> *mut ();
}
enum KernelParameter<'a> {
Borrowed {
ptr: *mut (),
_marker: PhantomData<&'a ()>,
},
Owned(OwnedKernelArgument),
}
impl fmt::Debug for KernelParameter<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Borrowed { ptr, .. } => f.debug_tuple("Borrowed").field(ptr).finish(),
Self::Owned(value) => f.debug_tuple("Owned").field(value).finish(),
}
}
}
enum OwnedKernelArgument {
Inline(InlineKernelArgument),
Boxed(Box<dyn KernelArgumentStorage>),
}
impl fmt::Debug for OwnedKernelArgument {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Inline(value) => f.debug_tuple("Inline").field(value).finish(),
Self::Boxed(_) => f.debug_tuple("Boxed").finish_non_exhaustive(),
}
}
}
trait KernelArgumentStorage {
fn as_mut_ptr(&mut self) -> *mut ();
}
impl<T> KernelArgumentStorage for T {
fn as_mut_ptr(&mut self) -> *mut () {
ptr::from_mut(self).cast()
}
}
#[derive(Clone, Copy)]
#[repr(C, align(16))]
struct InlineKernelArgument {
bytes: [MaybeUninit<u8>; INLINE_KERNEL_ARGUMENT_BYTES],
}
impl fmt::Debug for InlineKernelArgument {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("InlineKernelArgument")
.finish_non_exhaustive()
}
}
impl fmt::Debug for KernelParameters<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("KernelParameters")
.field("arguments", &self.arguments.len())
.finish()
}
}
impl Module {
pub unsafe fn from_raw(handle: driver::CUmodule, ctx: Arc<Context>) -> Result<Self> {
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Self {
handle,
ctx,
owns_handle: true,
})
}
pub const unsafe fn from_borrowed_raw(handle: driver::CUmodule, ctx: Arc<Context>) -> Self {
Self {
handle,
ctx,
owns_handle: false,
}
}
pub fn function(&self, name: &str) -> Result<KernelFunction<'_>> {
unsafe {
let c_name = CString::new(name)?;
let mut function_handle = ptr::null_mut();
try_ffi!(driver::cuModuleGetFunction(
&raw mut function_handle,
self.handle,
c_name.as_ptr(),
))?;
if function_handle.is_null() {
return Err(Error::NullHandle);
}
let function = DeviceFunction::from_raw(function_handle);
Ok(KernelFunction::from_raw(function, self))
}
}
pub fn function_count(&self) -> Result<usize> {
unsafe {
let mut count = 0;
try_ffi!(driver::cuModuleGetFunctionCount(
&raw mut count,
self.handle
))?;
Ok(count as usize)
}
}
pub const fn as_raw(&self) -> driver::CUmodule {
self.handle
}
pub fn into_raw(self) -> driver::CUmodule {
let module = ManuallyDrop::new(self);
module.handle
}
pub fn global(&self, name: &str) -> Result<Global<'_>> {
let c_name = CString::new(name)?;
let mut ptr = 0;
let mut size = 0;
self.ctx.bind()?;
unsafe {
try_ffi!(driver::cuModuleGetGlobal_v2(
&raw mut ptr,
&raw mut size,
self.handle,
c_name.as_ptr(),
))?;
}
Ok(Global {
ptr: ptr as _,
size: size as _,
_module: self,
})
}
}
impl Drop for Module {
fn drop(&mut self) {
if !self.owns_handle {
return;
}
if let Err(err) = self.ctx.bind() {
#[cfg(debug_assertions)]
eprintln!("failed to bind context before unloading module: {err}");
return;
}
unsafe {
if let Err(err) = try_ffi!(driver::cuModuleUnload(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to unload cuda module: {err}");
}
}
}
}
unsafe impl Send for Module {}
unsafe impl Sync for Module {}
impl<'a> ModuleImage<'a> {
pub const fn new(data: &'a [u8]) -> Self {
Self {
data: Cow::Borrowed(data),
}
}
pub fn from_vec(data: Vec<u8>) -> Self {
Self {
data: Cow::Owned(data),
}
}
pub fn from_string(data: String) -> Self {
Self::from_vec(data.into_bytes())
}
pub fn as_ptr(&self) -> *const () {
self.data.as_ptr().cast()
}
pub fn as_bytes(&self) -> &[u8] {
self.data.as_ref()
}
}
impl Global<'_> {
pub const fn as_ptr(&self) -> *mut () {
self.ptr
}
pub const fn byte_len(&self) -> usize {
self.size
}
}
impl TextureReference<'_> {
pub const fn as_raw(&self) -> driver::CUtexref {
self.handle
}
}
impl SurfaceReference<'_> {
pub const fn as_raw(&self) -> driver::CUsurfref {
self.handle
}
}
impl KernelFunction<'_> {
pub const unsafe fn from_raw(handle: DeviceFunction, module: &Module) -> KernelFunction<'_> {
KernelFunction { handle, module }
}
pub const unsafe fn launch_operation<'kernel, 'config, P>(
&'kernel self,
config: &'config LaunchConfig,
params: P,
) -> KernelLaunchOperation<'kernel, 'config, P> {
KernelLaunchOperation {
function: self,
config,
params,
}
}
fn check_graph_context(&self, graph: &Graph) -> Result<()> {
if matches!(graph.context(), Some(ctx) if ctx != self.module.ctx.as_ref()) {
return Err(Error::GraphContextMismatch);
}
Ok(())
}
fn check_executable_graph_context(&self, executable: &ExecutableGraph) -> Result<()> {
if matches!(executable.context(), Some(ctx) if ctx != self.module.ctx.as_ref()) {
return Err(Error::GraphContextMismatch);
}
Ok(())
}
pub fn launch<'a, P>(&self, config: &LaunchConfig, params: P) -> Result<()>
where
P: KernelLaunchArgs<'a>,
{
self.module.ctx.bind()?;
params.with_encoded_arguments(|mut arguments| unsafe {
try_ffi!(driver::cuLaunchKernel(
self.handle.as_raw(),
config.grid_dim().x,
config.grid_dim().y,
config.grid_dim().z,
config.block_dim().x,
config.block_dim().y,
config.block_dim().z,
config.shared_memory_bytes_u32(),
ptr::null_mut(),
arguments.as_mut_ptr().cast(),
ptr::null_mut(),
))?;
Ok(())
})
}
pub fn launch_on<'a, P>(&self, config: &LaunchConfig, params: P, stream: &Stream) -> Result<()>
where
P: KernelLaunchArgs<'a>,
{
if stream.context() != self.module.ctx.as_ref() {
return Err(driver::CUresult::CUDA_ERROR_INVALID_CONTEXT.into());
}
self.module.ctx.bind()?;
params.with_encoded_arguments(|mut arguments| unsafe {
try_ffi!(driver::cuLaunchKernel(
self.handle.as_raw(),
config.grid_dim().x,
config.grid_dim().y,
config.grid_dim().z,
config.block_dim().x,
config.block_dim().y,
config.block_dim().z,
config.shared_memory_bytes_u32(),
stream.as_raw(),
arguments.as_mut_ptr().cast(),
ptr::null_mut(),
))?;
Ok(())
})
}
pub unsafe fn add_to_graph<'a, P>(
&self,
graph: &mut Graph,
dependencies: &[GraphNode],
config: &LaunchConfig,
params: P,
) -> Result<GraphNode>
where
P: KernelLaunchArgs<'a>,
{
self.check_graph_context(graph)?;
unsafe { graph.add_kernel_node(dependencies, self.handle, config, params) }
}
pub unsafe fn set_graph_node_params<'a, P>(
&self,
executable: &mut ExecutableGraph,
node: GraphNode,
config: &LaunchConfig,
params: P,
) -> Result<()>
where
P: KernelLaunchArgs<'a>,
{
self.check_executable_graph_context(executable)?;
unsafe { executable.set_kernel_node_params(node, self.handle, config, params) }
}
pub const fn module(&self) -> &Module {
self.module
}
pub fn name(&self) -> Result<String> {
kernel::name::<ModuleKernelHandle>(self.module.ctx.as_ref(), self.handle.as_raw())
}
pub fn attribute(&self, attribute: FunctionAttribute) -> Result<i32> {
kernel::attribute::<ModuleKernelHandle>(
self.module.ctx.as_ref(),
self.handle.as_raw(),
attribute,
)
}
pub fn set_attribute(&self, attribute: FunctionAttribute, value: i32) -> Result<()> {
kernel::set_attribute::<ModuleKernelHandle>(
self.module.ctx.as_ref(),
self.handle.as_raw(),
attribute,
value,
)
}
pub fn set_max_dynamic_shared_memory_bytes(&self, bytes: i32) -> Result<()> {
self.set_attribute(FunctionAttribute::MaxDynamicSharedSizeBytes, bytes)
}
pub fn set_preferred_shared_memory_carveout(
&self,
carveout: SharedMemoryCarveout,
) -> Result<()> {
self.set_attribute(
FunctionAttribute::PreferredSharedMemoryCarveout,
i32::from(carveout),
)
}
pub fn attributes(&self) -> Result<FunctionAttributes> {
Ok(FunctionAttributes {
shared_size_bytes: self.attribute(FunctionAttribute::SharedSizeBytes)? as usize,
const_size_bytes: self.attribute(FunctionAttribute::ConstSizeBytes)? as usize,
local_size_bytes: self.attribute(FunctionAttribute::LocalSizeBytes)? as usize,
max_threads_per_block: self.attribute(FunctionAttribute::MaxThreadsPerBlock)?,
num_regs: self.attribute(FunctionAttribute::NumRegs)?,
ptx_version: self.attribute(FunctionAttribute::PtxVersion)?,
binary_version: self.attribute(FunctionAttribute::BinaryVersion)?,
cache_mode_ca: self.attribute(FunctionAttribute::CacheModeCa)? != 0,
max_dynamic_shared_size_bytes: self
.attribute(FunctionAttribute::MaxDynamicSharedSizeBytes)?,
preferred_shared_memory_carveout: self
.attribute(FunctionAttribute::PreferredSharedMemoryCarveout)?,
cluster_dim_must_be_set: self.attribute(FunctionAttribute::ClusterSizeMustBeSet)? != 0,
required_cluster_width: self.attribute(FunctionAttribute::RequiredClusterWidth)?,
required_cluster_height: self.attribute(FunctionAttribute::RequiredClusterHeight)?,
required_cluster_depth: self.attribute(FunctionAttribute::RequiredClusterDepth)?,
cluster_scheduling_policy_preference: self
.attribute(FunctionAttribute::ClusterSchedulingPolicyPreference)?,
non_portable_cluster_size_allowed: self
.attribute(FunctionAttribute::NonPortableClusterSizeAllowed)?
!= 0,
})
}
pub fn occupancy_max_active_blocks_per_multiprocessor(
&self,
block_size: i32,
dynamic_shared_memory_bytes: usize,
) -> Result<i32> {
self.occupancy_max_active_blocks_per_multiprocessor_with_flags(
block_size,
dynamic_shared_memory_bytes,
OccupancyFlags::DEFAULT,
)
}
pub fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
&self,
block_size: i32,
dynamic_shared_memory_bytes: usize,
flags: OccupancyFlags,
) -> Result<i32> {
self.module.ctx.bind()?;
let dynamic_shared_memory_bytes =
validate_dynamic_shared_memory_bytes(dynamic_shared_memory_bytes)?;
let mut blocks = 0;
unsafe {
try_ffi!(
driver::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
&raw mut blocks,
self.handle.as_raw(),
block_size,
dynamic_shared_memory_bytes,
flags.bits(),
)
)?;
}
Ok(blocks)
}
pub fn occupancy_available_dynamic_shared_memory_per_block(
&self,
num_blocks: i32,
block_size: i32,
) -> Result<usize> {
self.module.ctx.bind()?;
let mut bytes = 0;
unsafe {
try_ffi!(driver::cuOccupancyAvailableDynamicSMemPerBlock(
&raw mut bytes,
self.handle.as_raw(),
num_blocks,
block_size,
))?;
}
Ok(bytes as usize)
}
pub fn occupancy_max_potential_block_size(
&self,
dynamic_shared_memory_bytes: usize,
block_size_limit: i32,
) -> Result<OccupancyMaxPotentialBlockSize> {
self.occupancy_max_potential_block_size_with_flags(
dynamic_shared_memory_bytes,
block_size_limit,
OccupancyFlags::DEFAULT,
)
}
pub fn occupancy_max_potential_block_size_with_flags(
&self,
dynamic_shared_memory_bytes: usize,
block_size_limit: i32,
flags: OccupancyFlags,
) -> Result<OccupancyMaxPotentialBlockSize> {
self.module.ctx.bind()?;
let dynamic_shared_memory_bytes =
validate_dynamic_shared_memory_bytes(dynamic_shared_memory_bytes)?;
let mut min_grid_size = 0;
let mut block_size = 0;
unsafe {
try_ffi!(driver::cuOccupancyMaxPotentialBlockSizeWithFlags(
&raw mut min_grid_size,
&raw mut block_size,
self.handle.as_raw(),
None,
dynamic_shared_memory_bytes,
block_size_limit,
flags.bits(),
))?;
}
Ok(OccupancyMaxPotentialBlockSize {
min_grid_size,
block_size,
})
}
pub fn occupancy_max_potential_cluster_size(&self, config: ClusterLaunchConfig) -> Result<i32> {
self.module.ctx.bind()?;
let mut cluster_size = 0;
let config = driver::CUlaunchConfig {
gridDimX: config.grid_dim().x,
gridDimY: config.grid_dim().y,
gridDimZ: config.grid_dim().z,
blockDimX: config.block_dim().x,
blockDimY: config.block_dim().y,
blockDimZ: config.block_dim().z,
sharedMemBytes: config.shared_memory_bytes_u32(),
hStream: ptr::null_mut(),
attrs: ptr::null_mut(),
numAttrs: 0,
};
unsafe {
try_ffi!(driver::cuOccupancyMaxPotentialClusterSize(
&raw mut cluster_size,
self.handle.as_raw(),
&raw const config,
))?;
}
Ok(cluster_size)
}
pub fn occupancy_max_active_clusters(&self, config: ClusterLaunchConfig) -> Result<i32> {
self.module.ctx.bind()?;
let mut clusters = 0;
let config = driver::CUlaunchConfig {
gridDimX: config.grid_dim().x,
gridDimY: config.grid_dim().y,
gridDimZ: config.grid_dim().z,
blockDimX: config.block_dim().x,
blockDimY: config.block_dim().y,
blockDimZ: config.block_dim().z,
sharedMemBytes: config.shared_memory_bytes_u32(),
hStream: ptr::null_mut(),
attrs: ptr::null_mut(),
numAttrs: 0,
};
unsafe {
try_ffi!(driver::cuOccupancyMaxActiveClusters(
&raw mut clusters,
self.handle.as_raw(),
&raw const config,
))?;
}
Ok(clusters)
}
pub const fn as_raw(&self) -> DeviceFunction {
self.handle
}
}
unsafe impl<'a, P> GraphRecordable for KernelLaunchOperation<'_, '_, P>
where
P: KernelLaunchArgs<'a>,
{
type Output = ();
fn record(self, scope: &StreamCaptureScope<'_>) -> Result<Self::Output> {
self.function
.launch_on(self.config, self.params, scope.stream())
}
}
impl LaunchConfig {
pub fn new(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Result<Self> {
validate_dim3(grid_dim, "grid_dim")?;
validate_dim3(block_dim, "block_dim")?;
validate_shared_memory_bytes(shared_memory_bytes)?;
Ok(Self::from_validated(
grid_dim,
block_dim,
shared_memory_bytes,
))
}
const fn from_validated(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Self {
Self {
grid_dim,
block_dim,
shared_memory_bytes,
}
}
pub const fn grid_dim(&self) -> Dim3 {
self.grid_dim
}
pub const fn block_dim(&self) -> Dim3 {
self.block_dim
}
pub const fn shared_memory_bytes(&self) -> usize {
self.shared_memory_bytes
}
pub(crate) const fn shared_memory_bytes_u32(&self) -> u32 {
self.shared_memory_bytes as u32
}
pub fn with_shared_memory_bytes(mut self, shared_memory_bytes: usize) -> Result<Self> {
validate_shared_memory_bytes(shared_memory_bytes)?;
self.shared_memory_bytes = shared_memory_bytes;
Ok(self)
}
pub fn try_for_1d_grid(element_count: usize, block_size: usize) -> Result<Self> {
validate_block_dimension(block_size, "block_size")?;
let grid_size = element_count.div_ceil(block_size);
validate_grid_dimension(grid_size, "grid_size")?;
Ok(Self::from_validated(
Dim3::new(to_u32(grid_size, "grid_size")?, 1, 1),
Dim3::new(to_u32(block_size, "block_size")?, 1, 1),
0,
))
}
pub fn for_1d_grid(element_count: usize, block_size: usize) -> Self {
Self::try_for_1d_grid(element_count, block_size)
.expect("invalid 1d cuda launch configuration")
}
pub fn try_for_num_elems(element_count: usize, block_size: usize) -> Result<Self> {
Self::try_for_1d_grid(element_count, block_size)
}
pub fn for_num_elems(element_count: usize, block_size: usize) -> Self {
Self::try_for_num_elems(element_count, block_size)
.expect("invalid cuda launch configuration")
}
pub fn try_for_2d_grid(
width: usize,
height: usize,
block_width: usize,
block_height: usize,
) -> Result<Self> {
validate_block_dimension(block_width, "block_width")?;
validate_block_dimension(block_height, "block_height")?;
let grid_x = width.div_ceil(block_width);
let grid_y = height.div_ceil(block_height);
validate_grid_dimension(grid_x, "grid_x")?;
validate_grid_dimension(grid_y, "grid_y")?;
Ok(Self::from_validated(
Dim3::new(to_u32(grid_x, "grid_x")?, to_u32(grid_y, "grid_y")?, 1),
Dim3::new(
to_u32(block_width, "block_width")?,
to_u32(block_height, "block_height")?,
1,
),
0,
))
}
pub fn for_2d_grid(
width: usize,
height: usize,
block_width: usize,
block_height: usize,
) -> Self {
Self::try_for_2d_grid(width, height, block_width, block_height)
.expect("invalid 2d cuda launch configuration")
}
pub fn try_for_3d_grid(
width: usize,
height: usize,
depth: usize,
block_width: usize,
block_height: usize,
block_depth: usize,
) -> Result<Self> {
validate_block_dimension(block_width, "block_width")?;
validate_block_dimension(block_height, "block_height")?;
validate_block_dimension(block_depth, "block_depth")?;
let grid_x = width.div_ceil(block_width);
let grid_y = height.div_ceil(block_height);
let grid_z = depth.div_ceil(block_depth);
validate_grid_dimension(grid_x, "grid_x")?;
validate_grid_dimension(grid_y, "grid_y")?;
validate_grid_dimension(grid_z, "grid_z")?;
Ok(Self::from_validated(
Dim3::new(
to_u32(grid_x, "grid_x")?,
to_u32(grid_y, "grid_y")?,
to_u32(grid_z, "grid_z")?,
),
Dim3::new(
to_u32(block_width, "block_width")?,
to_u32(block_height, "block_height")?,
to_u32(block_depth, "block_depth")?,
),
0,
))
}
pub fn for_3d_grid(
width: usize,
height: usize,
depth: usize,
block_width: usize,
block_height: usize,
block_depth: usize,
) -> Self {
Self::try_for_3d_grid(width, height, depth, block_width, block_height, block_depth)
.expect("invalid 3d cuda launch configuration")
}
}
impl ClusterLaunchConfig {
pub fn new(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Result<Self> {
validate_dim3(grid_dim, "grid_dim")?;
validate_dim3(block_dim, "block_dim")?;
validate_shared_memory_bytes(shared_memory_bytes)?;
Ok(Self {
grid_dim,
block_dim,
shared_memory_bytes,
})
}
pub const fn grid_dim(&self) -> Dim3 {
self.grid_dim
}
pub const fn block_dim(&self) -> Dim3 {
self.block_dim
}
pub const fn shared_memory_bytes(&self) -> usize {
self.shared_memory_bytes
}
pub(crate) const fn shared_memory_bytes_u32(&self) -> u32 {
self.shared_memory_bytes as u32
}
pub fn with_shared_memory_bytes(mut self, shared_memory_bytes: usize) -> Result<Self> {
validate_shared_memory_bytes(shared_memory_bytes)?;
self.shared_memory_bytes = shared_memory_bytes;
Ok(self)
}
}
fn validate_dim3(value: Dim3, name: &str) -> Result<()> {
validate_grid_dimension(value.x as usize, &format!("{name}.x"))?;
validate_grid_dimension(value.y as usize, &format!("{name}.y"))?;
validate_grid_dimension(value.z as usize, &format!("{name}.z"))?;
Ok(())
}
fn validate_grid_dimension(value: usize, name: &str) -> Result<()> {
if value == 0 {
return Err(Error::ZeroValue {
name: name.to_owned(),
});
}
Ok(())
}
fn validate_block_dimension(value: usize, name: &str) -> Result<()> {
if value == 0 {
return Err(Error::ZeroValue {
name: name.to_owned(),
});
}
Ok(())
}
fn validate_shared_memory_bytes(value: usize) -> Result<u32> {
to_u32(value, "shared_memory_bytes")
}
fn validate_dynamic_shared_memory_bytes(value: usize) -> Result<u64> {
to_u64(value, "dynamic_shared_memory_bytes")
}
impl<'a> KernelParameters<'a> {
pub const fn new() -> Self {
Self {
arguments: Vec::new(),
}
}
pub fn arg<T: 'a>(&mut self, value: &'a T) -> &mut Self {
self.arguments.push(KernelParameter::Borrowed {
ptr: ptr::from_ref(value).cast_mut().cast::<()>(),
_marker: PhantomData,
});
self
}
pub fn arg_mut<T: 'a>(&mut self, value: &'a mut T) -> &mut Self {
self.arguments.push(KernelParameter::Borrowed {
ptr: ptr::from_mut(value).cast::<()>(),
_marker: PhantomData,
});
self
}
pub fn owned_arg<T: Copy + 'static>(&mut self, value: T) -> &mut Self {
let value = OwnedKernelArgument::from_value(value);
self.arguments.push(KernelParameter::Owned(value));
self
}
pub fn push<A: PushKernelArg>(&mut self, arg: A) -> &mut Self {
arg.push_to(self);
self
}
pub fn device_slice<T: DeviceRepr, S: DeviceSlice<T> + ?Sized>(
&mut self,
slice: &S,
) -> &mut Self {
self.owned_arg(slice.as_device_ptr())
}
pub fn device_slice_mut<T: DeviceRepr, S: DeviceSliceMut<T> + ?Sized>(
&mut self,
slice: &mut S,
) -> &mut Self {
self.owned_arg(slice.as_device_mut_ptr())
}
fn raw_pointers(&mut self) -> RawKernelPointers {
RawKernelPointers::from_parameters(self.arguments.as_mut_slice())
}
}
impl<'a> KernelParameter<'a> {
fn as_mut_ptr(&mut self) -> *mut () {
match self {
Self::Borrowed { ptr, .. } => *ptr,
Self::Owned(value) => value.as_mut_ptr(),
}
}
}
impl OwnedKernelArgument {
fn from_value<T: Copy + 'static>(value: T) -> Self {
if size_of::<T>() <= INLINE_KERNEL_ARGUMENT_BYTES
&& align_of::<T>() <= align_of::<InlineKernelArgument>()
{
Self::Inline(InlineKernelArgument::from_value(value))
} else {
Self::Boxed(Box::new(value))
}
}
fn as_mut_ptr(&mut self) -> *mut () {
match self {
Self::Inline(value) => value.as_mut_ptr(),
Self::Boxed(value) => value.as_mut().as_mut_ptr(),
}
}
}
impl InlineKernelArgument {
fn from_value<T: Copy>(value: T) -> Self {
let mut storage = Self {
bytes: [MaybeUninit::uninit(); INLINE_KERNEL_ARGUMENT_BYTES],
};
unsafe {
ptr::write(storage.as_mut_ptr().cast::<T>(), value);
}
storage
}
fn as_mut_ptr(&mut self) -> *mut () {
self.bytes.as_mut_ptr().cast()
}
}
enum RawKernelPointers {
Inline {
pointers: [*mut (); INLINE_KERNEL_ARGUMENTS],
len: usize,
},
Heap(Vec<*mut ()>),
}
impl RawKernelPointers {
fn from_parameters(parameters: &mut [KernelParameter<'_>]) -> Self {
if parameters.len() <= INLINE_KERNEL_ARGUMENTS {
let mut pointers = [ptr::null_mut(); INLINE_KERNEL_ARGUMENTS];
for (dst, parameter) in pointers.iter_mut().zip(&mut *parameters) {
*dst = parameter.as_mut_ptr();
}
Self::Inline {
pointers,
len: parameters.len(),
}
} else {
Self::Heap(
parameters
.iter_mut()
.map(KernelParameter::as_mut_ptr)
.collect(),
)
}
}
fn as_mut_slice(&mut self) -> &mut [*mut ()] {
match self {
Self::Inline { pointers, len } => &mut pointers[..*len],
Self::Heap(pointers) => pointers.as_mut_slice(),
}
}
}
impl EncodedKernelArgs<'_> {
pub(crate) fn as_mut_ptr(&mut self) -> *mut *mut () {
self.pointers.as_mut_ptr()
}
}
impl<'a> KernelLaunchArgs<'a> for KernelParameters<'a> {
fn with_encoded_arguments<R>(mut self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
let mut pointers = self.raw_pointers();
f(EncodedKernelArgs {
pointers: pointers.as_mut_slice(),
})
}
}
impl private::Sealed for KernelParameters<'_> {}
impl<'a> KernelLaunchArgs<'a> for &mut KernelParameters<'a> {
fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
let mut pointers = self.raw_pointers();
f(EncodedKernelArgs {
pointers: pointers.as_mut_slice(),
})
}
}
impl private::Sealed for &mut KernelParameters<'_> {}
impl<'a> KernelLaunchArgs<'a> for () {
fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
let mut pointers: [*mut (); 0] = [];
f(EncodedKernelArgs {
pointers: &mut pointers,
})
}
}
impl private::Sealed for () {}
macro_rules! impl_kernel_arguments_for_tuple {
($($arg:ident),+ $(,)?) => {
impl<'a, $($arg),+> private::Sealed for ($($arg,)+)
where
$($arg: KernelTupleArgument<'a>,)+
{
}
impl<'a, $($arg),+> KernelLaunchArgs<'a> for ($($arg,)+)
where
$($arg: KernelTupleArgument<'a>,)+
{
fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
#[allow(non_snake_case)]
let ($($arg,)+) = self;
let mut pointers = [
$($arg.into_kernel_argument_ptr(),)+
];
f(EncodedKernelArgs {
pointers: &mut pointers,
})
}
}
};
}
impl<'a, T: 'a> KernelTupleArgument<'a> for &'a T {
fn into_kernel_argument_ptr(self) -> *mut () {
ptr::from_ref(self).cast_mut().cast()
}
}
impl<'a, T: 'a> KernelTupleArgument<'a> for &'a mut T {
fn into_kernel_argument_ptr(self) -> *mut () {
ptr::from_mut(self).cast()
}
}
impl_kernel_arguments_for_tuple!(A);
impl_kernel_arguments_for_tuple!(A, B);
impl_kernel_arguments_for_tuple!(A, B, C);
impl_kernel_arguments_for_tuple!(A, B, C, D);
impl_kernel_arguments_for_tuple!(A, B, C, D, E);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O);
impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);
macro_rules! impl_push_scalar {
($($ty:ty),* $(,)?) => {
$(
impl PushKernelArg for $ty {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.owned_arg(self);
}
}
)*
};
}
impl_push_scalar!(
u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64,
);
impl<T: DeviceRepr> PushKernelArg for &DeviceMemory<T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.device_slice(self);
}
}
impl<T: DeviceRepr> PushKernelArg for &mut DeviceMemory<T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.device_slice_mut(self);
}
}
impl<T: DeviceRepr> PushKernelArg for &ManagedMemory<T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.device_slice(self);
}
}
impl<T: DeviceRepr> PushKernelArg for &mut ManagedMemory<T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.device_slice_mut(self);
}
}
impl<T: DeviceRepr> PushKernelArg for DeviceView<'_, T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.owned_arg(self.as_ptr());
}
}
impl<T: DeviceRepr> PushKernelArg for &DeviceView<'_, T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.owned_arg(self.as_device_ptr());
}
}
impl<T: DeviceRepr> PushKernelArg for &DeviceViewMut<'_, T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.owned_arg(self.as_device_ptr());
}
}
impl<T: DeviceRepr> PushKernelArg for &mut DeviceViewMut<'_, T> {
fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
params.owned_arg(self.as_device_mut_ptr());
}
}
impl Default for KernelParameters<'_> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Copy)]
#[repr(C)]
struct LargeArgument {
words: [u64; 3],
}
#[test]
fn boxed_owned_kernel_argument_points_to_inner_value() {
let mut argument = OwnedKernelArgument::from_value(LargeArgument { words: [1, 2, 3] });
assert!(matches!(argument, OwnedKernelArgument::Boxed(_)));
let expected = match &mut argument {
OwnedKernelArgument::Boxed(value) => value.as_mut().as_mut_ptr(),
OwnedKernelArgument::Inline(_) => unreachable!(),
};
assert_eq!(argument.as_mut_ptr(), expected);
}
#[test]
fn launch_config_rejects_zero_grid_dimensions() {
let error = LaunchConfig::try_for_1d_grid(0, 128).unwrap_err();
assert!(matches!(error, Error::ZeroValue { name } if name == "grid_size"));
let error = LaunchConfig::new(Dim3::new(0, 1, 1), Dim3::new(128, 1, 1), 0).unwrap_err();
assert!(matches!(error, Error::ZeroValue { name } if name == "grid_dim.x"));
}
#[test]
fn launch_config_rejects_invalid_shared_memory_size() {
let error = LaunchConfig::try_for_1d_grid(1, 128)
.unwrap()
.with_shared_memory_bytes(u32::MAX as usize + 1)
.unwrap_err();
assert!(matches!(error, Error::OutOfRange { name } if name == "shared_memory_bytes"));
}
#[test]
fn launch_config_exposes_checked_shared_memory_u32() {
let config = LaunchConfig::try_for_1d_grid(1, 128)
.unwrap()
.with_shared_memory_bytes(u32::MAX as usize)
.unwrap();
assert_eq!(config.shared_memory_bytes(), u32::MAX as usize);
assert_eq!(config.shared_memory_bytes_u32(), u32::MAX);
}
#[test]
fn occupancy_dynamic_shared_memory_uses_checked_driver_width() {
assert_eq!(validate_dynamic_shared_memory_bytes(0).unwrap(), 0);
assert_eq!(
validate_dynamic_shared_memory_bytes(usize::MAX).unwrap(),
usize::MAX as u64
);
}
#[test]
fn cluster_launch_config_uses_checked_construction() {
let config = ClusterLaunchConfig::new(Dim3::new(1, 1, 1), Dim3::new(32, 1, 1), 0).unwrap();
assert_eq!(config.grid_dim(), Dim3::new(1, 1, 1));
assert_eq!(config.block_dim(), Dim3::new(32, 1, 1));
assert_eq!(config.shared_memory_bytes(), 0);
}
}