use crate::device_context::with_default_device_policy;
use crate::device_future::DeviceFuture;
use crate::device_operation::{DeviceOp, ExecutionContext};
use crate::error::DeviceError;
use anyhow::{Context, Result};
use cuda_core::sys::CUdeviceptr;
use cuda_core::{launch_kernel, CudaFunction, CudaStream, DType, LaunchConfig};
use std::ffi::c_void;
use std::fmt::Debug;
use std::future::IntoFuture;
use std::sync::Arc;
use std::vec::Vec;
#[derive(Debug)]
pub struct AsyncKernelLaunch {
pub func: Arc<CudaFunction>,
pub args: Vec<*mut c_void>,
cfg: Option<LaunchConfig>,
}
unsafe impl Send for AsyncKernelLaunch {}
impl Drop for AsyncKernelLaunch {
fn drop(&mut self) {
let _ = self
.args
.iter()
.map(|arg| {
unsafe { Box::from_raw(*arg) }
})
.collect::<Vec<_>>();
}
}
impl AsyncKernelLaunch {
pub fn new(func: Arc<CudaFunction>) -> AsyncKernelLaunch {
AsyncKernelLaunch {
func,
args: Vec::new(),
cfg: None,
}
}
#[inline(always)]
pub fn push_arg<T: KernelArgument>(&mut self, arg: T) -> &mut Self {
arg.push_arg(self);
self
}
#[inline(always)]
pub fn push_arg_arc<T: ArcKernelArgument>(&mut self, arg: &Arc<T>) -> &mut Self {
arg.push_arg_arc(self);
self
}
pub unsafe fn push_device_ptr(&mut self, ptr: CUdeviceptr) -> &mut Self {
self.push_arg_raw(Box::new(ptr))
}
unsafe fn push_arg_raw<T>(&mut self, arg: Box<T>) -> &mut Self {
let r = Box::into_raw(arg);
self.args.push(r as *mut _);
self
}
pub fn set_launch_config(&mut self, cfg: LaunchConfig) -> &mut Self {
self.cfg = Some(cfg);
self
}
unsafe fn launch(mut self, stream: &Arc<CudaStream>) -> Result<(), DeviceError> {
let cfg = self.cfg.ok_or(DeviceError::Launch(
"Await called before launching the kernel.".to_string(),
))?;
launch_kernel(
self.func.cu_function(),
cfg.grid_dim,
cfg.block_dim,
cfg.shared_mem_bytes,
stream.cu_stream(),
&mut self.args,
)
.with_context(|| {
format!(
r#"
Failed to launch kernel.
args: {:#?}
cfg: {:#?}"#,
self.args, cfg
)
})?;
Ok(())
}
}
pub trait ArcKernelArgument {
fn push_arg_arc(self: &Arc<Self>, launcher: &mut AsyncKernelLaunch);
}
pub trait KernelArgument {
fn push_arg(self, launcher: &mut AsyncKernelLaunch);
}
impl<T: DType> KernelArgument for T {
fn push_arg(self, launcher: &mut AsyncKernelLaunch) {
unsafe {
launcher.push_arg_raw(Box::new(self));
}
}
}
impl DeviceOp for AsyncKernelLaunch {
type Output = ();
unsafe fn execute(
self,
ctx: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
self.launch(ctx.get_cuda_stream())
}
}
impl IntoFuture for AsyncKernelLaunch {
type Output = Result<(), DeviceError>;
type IntoFuture = DeviceFuture<(), AsyncKernelLaunch>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}