use core::ffi::c_void;
use baracuda_cuda_sys::{driver, CUstream};
use baracuda_types::KernelArg;
use crate::error::{check, Result};
use crate::module::Function;
use crate::stream::Stream;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Dim3 {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl From<u32> for Dim3 {
fn from(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
}
impl From<(u32, u32)> for Dim3 {
fn from((x, y): (u32, u32)) -> Self {
Self { x, y, z: 1 }
}
}
impl From<(u32, u32, u32)> for Dim3 {
fn from((x, y, z): (u32, u32, u32)) -> Self {
Self { x, y, z }
}
}
impl Function {
#[inline]
pub fn launch(&self) -> LaunchBuilder<'_> {
LaunchBuilder {
function: self,
grid: Dim3 { x: 1, y: 1, z: 1 },
block: Dim3 { x: 1, y: 1, z: 1 },
shared_mem_bytes: 0,
stream: None,
args: Vec::new(),
}
}
}
#[must_use = "the launch builder does nothing until `launch()` is called"]
pub struct LaunchBuilder<'f> {
function: &'f Function,
grid: Dim3,
block: Dim3,
shared_mem_bytes: u32,
stream: Option<&'f Stream>,
args: Vec<*mut c_void>,
}
impl core::fmt::Debug for LaunchBuilder<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("LaunchBuilder")
.field("grid", &self.grid)
.field("block", &self.block)
.field("shared_mem_bytes", &self.shared_mem_bytes)
.field("arg_count", &self.args.len())
.finish_non_exhaustive()
}
}
impl<'f> LaunchBuilder<'f> {
#[inline]
pub fn grid(mut self, grid: impl Into<Dim3>) -> Self {
self.grid = grid.into();
self
}
#[inline]
pub fn block(mut self, block: impl Into<Dim3>) -> Self {
self.block = block.into();
self
}
#[inline]
pub fn shared_mem_bytes(mut self, bytes: u32) -> Self {
self.shared_mem_bytes = bytes;
self
}
#[inline]
pub fn stream(mut self, stream: &'f Stream) -> Self {
self.stream = Some(stream);
self
}
#[inline]
pub fn arg<K: KernelArg>(mut self, arg: K) -> Self {
self.args.push(arg.as_kernel_arg_ptr());
self
}
pub unsafe fn launch(mut self) -> Result<()> { unsafe {
let d = driver()?;
let cu = d.cu_launch_kernel()?;
let stream_handle: CUstream = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
let args_ptr = if self.args.is_empty() {
core::ptr::null_mut()
} else {
self.args.as_mut_ptr()
};
check(cu(
self.function.as_raw(),
self.grid.x,
self.grid.y,
self.grid.z,
self.block.x,
self.block.y,
self.block.z,
self.shared_mem_bytes,
stream_handle,
args_ptr,
core::ptr::null_mut(), ))
}}
pub unsafe fn launch_ex(
mut self,
attributes: &mut [baracuda_cuda_sys::types::CUlaunchAttribute],
) -> Result<()> { unsafe {
let d = driver()?;
let cu = d.cu_launch_kernel_ex()?;
let stream_handle: CUstream = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
let args_ptr = if self.args.is_empty() {
core::ptr::null_mut()
} else {
self.args.as_mut_ptr()
};
let config = baracuda_cuda_sys::types::CUlaunchConfig {
grid_dim_x: self.grid.x,
grid_dim_y: self.grid.y,
grid_dim_z: self.grid.z,
block_dim_x: self.block.x,
block_dim_y: self.block.y,
block_dim_z: self.block.z,
shared_mem_bytes: self.shared_mem_bytes,
stream: stream_handle,
attrs: if attributes.is_empty() {
core::ptr::null_mut()
} else {
attributes.as_mut_ptr()
},
num_attrs: attributes.len() as core::ffi::c_uint,
};
check(cu(
&config,
self.function.as_raw(),
args_ptr,
core::ptr::null_mut(),
))
}}
}