use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::driver::{CudaFunction, CudaModule, LaunchConfig, PushKernelArg};
use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions, Ptx};
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::envelope;
use crate::stream::StreamAllocator;
const LIB: &str = "nvrtc";
#[derive(Debug, Clone, Default)]
pub struct NvrtcOpts {
pub ftz: Option<bool>,
pub maxrregcount: Option<usize>,
pub name: Option<String>,
pub use_fast_math: Option<bool>,
}
impl NvrtcOpts {
fn into_cudarc(self) -> CompileOptions {
CompileOptions {
ftz: self.ftz,
maxrregcount: self.maxrregcount,
name: self.name,
use_fast_math: self.use_fast_math,
..Default::default()
}
}
}
#[derive(Clone)]
pub struct KernelHandle {
func: Arc<CudaFunction>,
generation: u64,
#[allow(dead_code)]
src_hash: u64,
pub name: String,
}
impl KernelHandle {
pub fn generation(&self) -> u64 {
self.generation
}
}
pub enum KernelArg {
DevSliceF32(GpuRef<f32>),
DevSliceF64(GpuRef<f64>),
DevSliceI32(GpuRef<i32>),
DevSliceU32(GpuRef<u32>),
DevSliceU8(GpuRef<u8>),
ScalarF32(f32),
ScalarF64(f64),
ScalarI32(i32),
ScalarU32(u32),
ScalarU64(u64),
Usize(usize),
}
pub enum NvrtcMsg {
Compile {
src: String,
kernel_name: String,
opts: NvrtcOpts,
reply: oneshot::Sender<Result<KernelHandle, GpuError>>,
},
Launch {
kernel: KernelHandle,
args: Vec<KernelArg>,
cfg: LaunchConfig,
reply: oneshot::Sender<Result<(), GpuError>>,
},
}
pub struct NvrtcActor {
inner: NvrtcInner,
}
struct SendModule(Arc<CudaModule>);
unsafe impl Send for SendModule {}
unsafe impl Sync for SendModule {}
impl Clone for SendModule {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
enum NvrtcInner {
Real {
ctx: Arc<cudarc::driver::CudaContext>,
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
modules: Mutex<HashMap<u64, SendModule>>,
},
Mock,
}
impl NvrtcActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
_allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
ctx: Arc<cudarc::driver::CudaContext>,
) -> Props<Self> {
Props::create(move || NvrtcActor {
inner: NvrtcInner::Real {
ctx: ctx.clone(),
stream: stream.clone(),
completion: completion.clone(),
state: state.clone(),
modules: Mutex::new(HashMap::new()),
},
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| NvrtcActor {
inner: NvrtcInner::Mock,
})
}
}
#[async_trait]
impl Actor for NvrtcActor {
type Msg = NvrtcMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: NvrtcMsg) {
match &self.inner {
NvrtcInner::Mock => match msg {
NvrtcMsg::Compile { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"NvrtcActor in mock mode".into(),
)));
}
NvrtcMsg::Launch { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"NvrtcActor in mock mode".into(),
)));
}
},
NvrtcInner::Real {
ctx,
stream,
completion,
state,
modules,
} => match msg {
NvrtcMsg::Compile {
src,
kernel_name,
opts,
reply,
} => {
let _ = reply.send(handle_compile(ctx, state, modules, src, kernel_name, opts));
}
NvrtcMsg::Launch {
kernel,
args,
cfg,
reply,
} => {
handle_launch(stream, completion, state, kernel, args, cfg, reply);
}
},
}
}
}
fn hash_src(src: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
src.hash(&mut h);
h.finish()
}
fn handle_compile(
ctx: &Arc<cudarc::driver::CudaContext>,
state: &Arc<DeviceState>,
modules: &Mutex<HashMap<u64, SendModule>>,
src: String,
kernel_name: String,
opts: NvrtcOpts,
) -> Result<KernelHandle, GpuError> {
let src_hash = hash_src(&src);
let module = {
let mut g = modules.lock();
if let Some(m) = g.get(&src_hash) {
m.clone()
} else {
let ptx: Ptx = compile_ptx_with_opts(&src, opts.into_cudarc()).map_err(|e| {
GpuError::LibraryError {
lib: LIB,
msg: format!("compile_ptx: {e}"),
}
})?;
let module = ctx.load_module(ptx).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("load_module: {e}"),
})?;
let sm = SendModule(module);
g.insert(src_hash, sm.clone());
sm
}
};
let func = module
.0
.load_function(&kernel_name)
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("load_function {kernel_name}: {e}"),
})?;
Ok(KernelHandle {
func: Arc::new(func),
generation: state.generation(),
src_hash,
name: kernel_name,
})
}
fn handle_launch(
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
state: &Arc<DeviceState>,
kernel: KernelHandle,
args: Vec<KernelArg>,
cfg: LaunchConfig,
reply: oneshot::Sender<Result<(), GpuError>>,
) {
if kernel.generation != state.generation() {
let _ = reply.send(Err(GpuError::GpuRefStale(
"nvrtc kernel from prior context generation",
)));
return;
}
let mut gpu_owners: Vec<Box<dyn std::any::Any + Send>> = Vec::new();
for arg in &args {
match arg {
KernelArg::DevSliceF32(g) => match g.access() {
Ok(s) => gpu_owners.push(Box::new(s.clone())),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
},
KernelArg::DevSliceF64(g) => match g.access() {
Ok(s) => gpu_owners.push(Box::new(s.clone())),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
},
KernelArg::DevSliceI32(g) => match g.access() {
Ok(s) => gpu_owners.push(Box::new(s.clone())),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
},
KernelArg::DevSliceU32(g) => match g.access() {
Ok(s) => gpu_owners.push(Box::new(s.clone())),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
},
KernelArg::DevSliceU8(g) => match g.access() {
Ok(s) => gpu_owners.push(Box::new(s.clone())),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
},
_ => {}
}
}
let func = kernel.func.clone();
let stream_clone = stream.clone();
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let mut builder = stream_clone.launch_builder(&func);
for arg in args.iter() {
match arg {
KernelArg::DevSliceF32(g) => {
let s = g.access().expect("re-validated above");
builder.arg(&**s);
}
KernelArg::DevSliceF64(g) => {
let s = g.access().expect("re-validated above");
builder.arg(&**s);
}
KernelArg::DevSliceI32(g) => {
let s = g.access().expect("re-validated above");
builder.arg(&**s);
}
KernelArg::DevSliceU32(g) => {
let s = g.access().expect("re-validated above");
builder.arg(&**s);
}
KernelArg::DevSliceU8(g) => {
let s = g.access().expect("re-validated above");
builder.arg(&**s);
}
KernelArg::ScalarF32(v) => {
builder.arg(v);
}
KernelArg::ScalarF64(v) => {
builder.arg(v);
}
KernelArg::ScalarI32(v) => {
builder.arg(v);
}
KernelArg::ScalarU32(v) => {
builder.arg(v);
}
KernelArg::ScalarU64(v) => {
builder.arg(v);
}
KernelArg::Usize(v) => {
builder.arg(v);
}
}
}
let res = unsafe { builder.launch(cfg) };
match res {
Ok(_) => Ok((gpu_owners, func, args)),
Err(e) => Err(GpuError::LibraryError {
lib: LIB,
msg: format!("launch: {e}"),
}),
}
});
}