use anyhow::{Context, Result, anyhow, bail};
use cudarc::driver::{CudaContext, CudaStream, LaunchArgs, LaunchConfig};
use cudarc::nvrtc::{Ptx, compile_ptx};
use std::sync::Arc;
use cudarc::driver::PushKernelArg;
use cudarc::driver::{DevicePtr, DevicePtrMut};
pub struct RawKernel {
module: Arc<cudarc::driver::CudaModule>,
func: cudarc::driver::CudaFunction,
}
impl RawKernel {
pub fn from_ptx(ptx: &Ptx, name: &str, ctx: &Arc<CudaContext>) -> anyhow::Result<Self> {
let module = ctx.load_module(ptx.clone())?;
let func = module.load_function(name)?;
Ok(RawKernel { module, func })
}
pub fn from_source(src: &str, name: &str, ctx: &Arc<CudaContext>) -> anyhow::Result<Self> {
let ptx = compile_ptx(src)?;
Self::from_ptx(&ptx, name, ctx)
}
pub fn launch_builder<'a>(&'a self, stream: &'a CudaStream) -> LaunchArgs<'a> {
stream.launch_builder(&self.func)
}
}
#[macro_export]
macro_rules! launch {
($kernel:expr, $stream:expr, $cfg:expr, $($arg:expr),*) => {{
let mut builder = $stream.launch_builder(&$kernel.func);
$(builder.arg($arg);)*
unsafe { builder.launch($cfg)?; }
Ok(())
}};
}