use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::driver::{CudaFunction, CudaModule, LaunchConfig, PushKernelArg};
use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions, Ptx};
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::dispatch::{DevSliceArg, ScalarArg};
use crate::kernel::envelope;
use crate::nvrtc_cache::{hash_options, hash_source, CachedKernel, NvrtcCache, NvrtcCacheKey};
use crate::stream::StreamAllocator;
const LIB: &str = "nvrtc";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SmArch {
Sm80,
Sm86,
Sm89,
Sm90,
Sm90a,
Sm100,
Sm120,
}
impl SmArch {
pub fn nvrtc_flag(self) -> &'static str {
match self {
SmArch::Sm80 => "compute_80",
SmArch::Sm86 => "compute_86",
SmArch::Sm89 => "compute_89",
SmArch::Sm90 => "compute_90",
SmArch::Sm90a => "compute_90a",
SmArch::Sm100 => "compute_100",
SmArch::Sm120 => "compute_120",
}
}
pub fn compute_capability(self) -> u32 {
match self {
SmArch::Sm80 => 80,
SmArch::Sm86 => 86,
SmArch::Sm89 => 89,
SmArch::Sm90 | SmArch::Sm90a => 90,
SmArch::Sm100 => 100,
SmArch::Sm120 => 120,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CppStd {
Cpp14,
Cpp17,
Cpp20,
}
impl CppStd {
pub fn nvrtc_flag(self) -> &'static str {
match self {
CppStd::Cpp14 => "--std=c++14",
CppStd::Cpp17 => "--std=c++17",
CppStd::Cpp20 => "--std=c++20",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct NvrtcOpts {
pub ftz: Option<bool>,
pub maxrregcount: Option<usize>,
pub name: Option<String>,
pub use_fast_math: Option<bool>,
pub lto: bool,
pub cpp_std: Option<CppStd>,
pub arch: Option<SmArch>,
pub name_expressions: Vec<String>,
pub extra_options: Vec<String>,
pub include_paths: Vec<String>,
}
impl NvrtcOpts {
pub fn for_arch(arch: SmArch) -> Self {
Self {
arch: Some(arch),
..Default::default()
}
}
pub fn with_lto(mut self) -> Self {
self.lto = true;
self
}
pub fn with_cpp_std(mut self, std: CppStd) -> Self {
self.cpp_std = Some(std);
self
}
pub fn with_name_expression(mut self, expr: impl Into<String>) -> Self {
self.name_expressions.push(expr.into());
self
}
pub fn with_extra_option(mut self, opt: impl Into<String>) -> Self {
self.extra_options.push(opt.into());
self
}
pub fn with_include_path(mut self, path: impl Into<String>) -> Self {
self.include_paths.push(path.into());
self
}
pub fn build_flags(&self) -> Vec<String> {
let mut flags = Vec::new();
if let Some(v) = self.ftz {
flags.push(format!("--ftz={v}"));
}
if let Some(true) = self.use_fast_math {
flags.push("--use_fast_math".into());
}
if let Some(c) = self.maxrregcount {
flags.push(format!("--maxrregcount={c}"));
}
if let Some(s) = self.cpp_std {
flags.push(s.nvrtc_flag().to_string());
}
if self.lto {
flags.push("-dlto".into());
}
if let Some(a) = self.arch {
flags.push(format!("--gpu-architecture={}", a.nvrtc_flag()));
}
for path in &self.include_paths {
flags.push(format!("--include-path={path}"));
}
for opt in &self.extra_options {
flags.push(opt.clone());
}
flags
}
fn into_cudarc(self) -> CompileOptions {
let arch_flag = self.arch.map(|a| a.nvrtc_flag());
let mut extra: Vec<String> = Vec::new();
if let Some(s) = self.cpp_std {
extra.push(s.nvrtc_flag().to_string());
}
if self.lto {
extra.push("-dlto".into());
}
for opt in self.extra_options {
extra.push(opt);
}
CompileOptions {
ftz: self.ftz,
maxrregcount: self.maxrregcount,
name: self.name,
use_fast_math: self.use_fast_math,
include_paths: self.include_paths,
arch: arch_flag,
options: extra,
..Default::default()
}
}
}
#[derive(Clone)]
pub struct KernelHandle {
func: Arc<CudaFunction>,
generation: u64,
#[allow(dead_code)]
src_hash: u64,
pub name: String,
lowered_names: Arc<HashMap<String, String>>,
ptx: Option<Arc<Vec<u8>>>,
cubin: Option<Arc<Vec<u8>>>,
}
impl KernelHandle {
pub fn generation(&self) -> u64 {
self.generation
}
pub fn lowered_name(&self, expr: &str) -> Option<&str> {
self.lowered_names.get(expr).map(|s| s.as_str())
}
pub fn ptx_bytes(&self) -> Option<&[u8]> {
self.ptx.as_deref().map(|v| v.as_slice())
}
pub fn cubin_bytes(&self) -> Option<&[u8]> {
self.cubin.as_deref().map(|v| v.as_slice())
}
}
pub enum KernelArg {
DevSlice(Box<dyn DevSliceArg>),
Scalar(Box<dyn ScalarArg>),
Usize(usize),
#[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
DevSliceF32(GpuRef<f32>),
#[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
DevSliceF64(GpuRef<f64>),
#[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
DevSliceI32(GpuRef<i32>),
#[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
DevSliceU32(GpuRef<u32>),
#[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
DevSliceU8(GpuRef<u8>),
#[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
ScalarF32(f32),
#[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
ScalarF64(f64),
#[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
ScalarI32(i32),
#[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
ScalarU32(u32),
#[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
ScalarU64(u64),
}
impl KernelArg {
#[allow(deprecated)]
pub fn canonicalize(self) -> KernelArg {
match self {
KernelArg::DevSlice(_) | KernelArg::Scalar(_) | KernelArg::Usize(_) => self,
KernelArg::DevSliceF32(g) => KernelArg::DevSlice(Box::new(g)),
KernelArg::DevSliceF64(g) => KernelArg::DevSlice(Box::new(g)),
KernelArg::DevSliceI32(g) => KernelArg::DevSlice(Box::new(g)),
KernelArg::DevSliceU32(g) => KernelArg::DevSlice(Box::new(g)),
KernelArg::DevSliceU8(g) => KernelArg::DevSlice(Box::new(g)),
KernelArg::ScalarF32(v) => KernelArg::Scalar(Box::new(v)),
KernelArg::ScalarF64(v) => KernelArg::Scalar(Box::new(v)),
KernelArg::ScalarI32(v) => KernelArg::Scalar(Box::new(v)),
KernelArg::ScalarU32(v) => KernelArg::Scalar(Box::new(v)),
KernelArg::ScalarU64(v) => KernelArg::Scalar(Box::new(v)),
}
}
}
pub enum NvrtcMsg {
Compile {
src: String,
kernel_name: String,
opts: NvrtcOpts,
reply: oneshot::Sender<Result<KernelHandle, GpuError>>,
},
CompileAsync {
src: String,
kernel_name: String,
opts: NvrtcOpts,
reply: oneshot::Sender<Result<KernelHandle, GpuError>>,
},
Launch {
kernel: KernelHandle,
args: Vec<KernelArg>,
cfg: LaunchConfig,
reply: oneshot::Sender<Result<(), GpuError>>,
},
}
pub struct NvrtcActor {
inner: NvrtcInner,
}
struct SendModule(Arc<CudaModule>);
unsafe impl Send for SendModule {}
unsafe impl Sync for SendModule {}
impl Clone for SendModule {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
enum NvrtcInner {
Real {
ctx: Arc<cudarc::driver::CudaContext>,
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
modules: Mutex<HashMap<u64, SendModule>>,
disk_cache: Option<Arc<NvrtcCache>>,
},
Mock,
}
impl NvrtcActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
_allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
ctx: Arc<cudarc::driver::CudaContext>,
) -> Props<Self> {
let disk_cache = NvrtcCache::new().ok().map(Arc::new);
Self::props_with_cache(stream, completion, state, ctx, disk_cache)
}
pub fn props_with_cache(
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
ctx: Arc<cudarc::driver::CudaContext>,
disk_cache: Option<Arc<NvrtcCache>>,
) -> Props<Self> {
Props::create(move || NvrtcActor {
inner: NvrtcInner::Real {
ctx: ctx.clone(),
stream: stream.clone(),
completion: completion.clone(),
state: state.clone(),
modules: Mutex::new(HashMap::new()),
disk_cache: disk_cache.clone(),
},
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| NvrtcActor {
inner: NvrtcInner::Mock,
})
}
}
#[async_trait]
impl Actor for NvrtcActor {
type Msg = NvrtcMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: NvrtcMsg) {
match &self.inner {
NvrtcInner::Mock => match msg {
NvrtcMsg::Compile { reply, .. } | NvrtcMsg::CompileAsync { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"NvrtcActor in mock mode".into(),
)));
}
NvrtcMsg::Launch { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"NvrtcActor in mock mode".into(),
)));
}
},
NvrtcInner::Real {
ctx,
stream,
completion,
state,
modules,
disk_cache,
} => match msg {
NvrtcMsg::Compile {
src,
kernel_name,
opts,
reply,
} => {
let _ = reply.send(handle_compile(
ctx,
state,
modules,
disk_cache.as_ref(),
src,
kernel_name,
opts,
));
}
NvrtcMsg::CompileAsync {
src,
kernel_name,
opts,
reply,
} => {
let ctx_c = ctx.clone();
let state_c = state.clone();
let cache_c = disk_cache.clone();
tokio::task::spawn_blocking(move || {
let local: Mutex<HashMap<u64, SendModule>> = Mutex::new(HashMap::new());
let res = handle_compile(
&ctx_c,
&state_c,
&local,
cache_c.as_ref(),
src,
kernel_name,
opts,
);
let _ = reply.send(res);
});
}
NvrtcMsg::Launch {
kernel,
args,
cfg,
reply,
} => {
handle_launch(stream, completion, state, kernel, args, cfg, reply);
}
},
}
}
}
fn hash_src(src: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
src.hash(&mut h);
h.finish()
}
fn handle_compile(
ctx: &Arc<cudarc::driver::CudaContext>,
state: &Arc<DeviceState>,
modules: &Mutex<HashMap<u64, SendModule>>,
disk_cache: Option<&Arc<NvrtcCache>>,
src: String,
kernel_name: String,
opts: NvrtcOpts,
) -> Result<KernelHandle, GpuError> {
let src_hash = hash_src(&src);
let opts_flags = opts.build_flags();
let arch = opts.arch.map(|a| a.compute_capability()).unwrap_or(0);
let cache_key = NvrtcCacheKey {
source_hash: hash_source(&src),
arch,
options_hash: hash_options(&opts_flags),
};
let lowered_names = build_lowered_names(&opts.name_expressions);
if let Some(m) = modules.lock().get(&src_hash).cloned() {
let func =
m.0.load_function(&kernel_name)
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("load_function {kernel_name}: {e}"),
})?;
return Ok(KernelHandle {
func: Arc::new(func),
generation: state.generation(),
src_hash,
name: kernel_name,
lowered_names: Arc::new(lowered_names),
ptx: None,
cubin: None,
});
}
let mut ptx_bytes: Option<Vec<u8>> = None;
let mut cubin_bytes: Option<Vec<u8>> = None;
if let Some(cache) = disk_cache {
if let Some(entry) = cache.get(cache_key) {
ptx_bytes = Some(entry.ptx.clone());
cubin_bytes = entry.cubin.clone();
}
}
let ptx: Ptx = if let Some(bytes) = &ptx_bytes {
let s = String::from_utf8(bytes.clone()).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("nvrtc cache: invalid UTF-8 PTX: {e}"),
})?;
Ptx::from_src(s)
} else {
let compiled = compile_ptx_with_opts(&src, opts.into_cudarc()).map_err(|e| {
GpuError::LibraryError {
lib: LIB,
msg: format!("compile_ptx: {e}"),
}
})?;
let bytes_v = compiled.to_src().into_bytes();
ptx_bytes = Some(bytes_v.clone());
if let Some(cache) = disk_cache {
let cached = CachedKernel::new(bytes_v, cubin_bytes.clone());
if let Err(e) = cache.insert(cache_key, cached) {
tracing::debug!(?e, "nvrtc disk cache insert failed (non-fatal)");
}
}
compiled
};
let module = ctx.load_module(ptx).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("load_module: {e}"),
})?;
let sm = SendModule(module.clone());
modules.lock().insert(src_hash, sm);
let func = module
.load_function(&kernel_name)
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("load_function {kernel_name}: {e}"),
})?;
Ok(KernelHandle {
func: Arc::new(func),
generation: state.generation(),
src_hash,
name: kernel_name,
lowered_names: Arc::new(lowered_names),
ptx: ptx_bytes.map(Arc::new),
cubin: cubin_bytes.map(Arc::new),
})
}
fn build_lowered_names(exprs: &[String]) -> HashMap<String, String> {
exprs.iter().map(|e| (e.clone(), e.clone())).collect()
}
pub fn compile_to_ptx(
src: &str,
opts: NvrtcOpts,
disk_cache: Option<&NvrtcCache>,
) -> Result<(Vec<u8>, Option<Vec<u8>>), GpuError> {
let opts_flags = opts.build_flags();
let arch = opts.arch.map(|a| a.compute_capability()).unwrap_or(0);
let cache_key = NvrtcCacheKey {
source_hash: hash_source(src),
arch,
options_hash: hash_options(&opts_flags),
};
if let Some(cache) = disk_cache {
if let Some(hit) = cache.get(cache_key) {
return Ok((hit.ptx.clone(), hit.cubin.clone()));
}
}
let compiled =
compile_ptx_with_opts(src, opts.into_cudarc()).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("compile_ptx: {e}"),
})?;
let ptx = compiled.to_src().into_bytes();
let cubin: Option<Vec<u8>> = None;
if let Some(cache) = disk_cache {
let cached = CachedKernel::new(ptx.clone(), cubin.clone());
let _ = cache.insert(cache_key, cached);
}
Ok((ptx, cubin))
}
pub fn default_disk_cache_path() -> Option<PathBuf> {
NvrtcCache::new().ok().map(|c| c.dir().to_path_buf())
}
fn handle_launch(
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
state: &Arc<DeviceState>,
kernel: KernelHandle,
args: Vec<KernelArg>,
cfg: LaunchConfig,
reply: oneshot::Sender<Result<(), GpuError>>,
) {
if kernel.generation != state.generation() {
let _ = reply.send(Err(GpuError::GpuRefStale(
"nvrtc kernel from prior context generation",
)));
return;
}
let args: Vec<KernelArg> = args.into_iter().map(KernelArg::canonicalize).collect();
let mut gpu_owners: Vec<Box<dyn std::any::Any + Send>> = Vec::new();
for arg in &args {
if let KernelArg::DevSlice(b) = arg {
match b.validate() {
Ok(owner) => gpu_owners.push(owner),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
}
}
}
let func = kernel.func.clone();
let stream_clone = stream.clone();
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let mut builder = stream_clone.launch_builder(&func);
for arg in args.iter() {
match arg {
KernelArg::DevSlice(b) => b.push(&mut builder)?,
KernelArg::Scalar(b) => {
b.push(&mut builder);
}
KernelArg::Usize(v) => {
builder.arg(v);
}
#[allow(deprecated)]
KernelArg::DevSliceF32(_)
| KernelArg::DevSliceF64(_)
| KernelArg::DevSliceI32(_)
| KernelArg::DevSliceU32(_)
| KernelArg::DevSliceU8(_)
| KernelArg::ScalarF32(_)
| KernelArg::ScalarF64(_)
| KernelArg::ScalarI32(_)
| KernelArg::ScalarU32(_)
| KernelArg::ScalarU64(_) => unreachable!("canonicalize() folds these arms"),
}
}
let res = unsafe { builder.launch(cfg) };
match res {
Ok(_) => Ok((gpu_owners, func, args)),
Err(e) => Err(GpuError::LibraryError {
lib: LIB,
msg: format!("launch: {e}"),
}),
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn launch_args_collapse_compile() {
let args: Vec<KernelArg> = vec![
KernelArg::Scalar(Box::new(1.0f32)),
KernelArg::Scalar(Box::new(42i32)),
KernelArg::Usize(128),
];
assert_eq!(args.len(), 3);
let canon: Vec<KernelArg> = args.into_iter().map(KernelArg::canonicalize).collect();
assert_eq!(canon.len(), 3);
let mut n_scalar = 0;
let mut n_usize = 0;
for a in &canon {
match a {
KernelArg::Scalar(_) => n_scalar += 1,
KernelArg::Usize(_) => n_usize += 1,
_ => panic!("unexpected variant"),
}
}
assert_eq!((n_scalar, n_usize), (2, 1));
}
#[test]
fn deprecated_aliases_still_construct() {
#[allow(deprecated)]
let aliases = vec![
KernelArg::ScalarF32(1.0),
KernelArg::ScalarF64(2.0),
KernelArg::ScalarI32(3),
KernelArg::ScalarU32(4),
KernelArg::ScalarU64(5),
];
for a in aliases {
let c = a.canonicalize();
assert!(matches!(c, KernelArg::Scalar(_)));
}
}
#[test]
fn lto_option_round_trip() {
let opts = NvrtcOpts::default().with_lto();
assert!(opts.lto, "with_lto sets the lto flag");
let flags = opts.build_flags();
assert!(
flags.iter().any(|f| f == "-dlto"),
"lto opt must emit `-dlto`, got {flags:?}"
);
let none = NvrtcOpts::default();
assert!(!none.build_flags().iter().any(|f| f == "-dlto"));
}
#[test]
fn name_expression_round_trip() {
let opts = NvrtcOpts::default()
.with_name_expression("my_kernel<float, 256>")
.with_name_expression("my_kernel<double, 128>");
assert_eq!(opts.name_expressions.len(), 2);
let lowered = build_lowered_names(&opts.name_expressions);
assert_eq!(lowered.len(), 2);
assert_eq!(
lowered.get("my_kernel<float, 256>").map(|s| s.as_str()),
Some("my_kernel<float, 256>")
);
assert_eq!(
lowered.get("my_kernel<double, 128>").map(|s| s.as_str()),
Some("my_kernel<double, 128>")
);
let arc = Arc::new(lowered);
assert_eq!(
arc.get("my_kernel<float, 256>").map(|s| s.as_str()),
Some("my_kernel<float, 256>")
);
assert!(arc.get("never_registered").is_none());
let empty = build_lowered_names(&[]);
assert!(empty.is_empty());
}
#[test]
fn async_compile_request_constructs() {
let (tx, _rx) = oneshot::channel::<Result<KernelHandle, GpuError>>();
let msg = NvrtcMsg::CompileAsync {
src: "extern \"C\" __global__ void k() {}".into(),
kernel_name: "k".into(),
opts: NvrtcOpts::default().with_lto().with_cpp_std(CppStd::Cpp17),
reply: tx,
};
match msg {
NvrtcMsg::CompileAsync {
src, kernel_name, ..
} => {
assert!(src.contains("__global__"));
assert_eq!(kernel_name, "k");
}
_ => panic!("expected CompileAsync variant"),
}
}
#[test]
fn arch_selection_emits_correct_flag() {
let cases = [
(SmArch::Sm80, "compute_80", 80),
(SmArch::Sm86, "compute_86", 86),
(SmArch::Sm89, "compute_89", 89),
(SmArch::Sm90, "compute_90", 90),
(SmArch::Sm90a, "compute_90a", 90),
(SmArch::Sm100, "compute_100", 100),
(SmArch::Sm120, "compute_120", 120),
];
for (arch, expect_flag, expect_cc) in cases {
assert_eq!(arch.nvrtc_flag(), expect_flag);
assert_eq!(arch.compute_capability(), expect_cc);
let opts = NvrtcOpts::for_arch(arch);
let flags = opts.build_flags();
let want = format!("--gpu-architecture={expect_flag}");
assert!(
flags.iter().any(|f| f == &want),
"arch {arch:?} must emit `{want}`, got {flags:?}"
);
}
}
#[test]
fn cpp_std_emits_flag() {
for (s, want) in [
(CppStd::Cpp14, "--std=c++14"),
(CppStd::Cpp17, "--std=c++17"),
(CppStd::Cpp20, "--std=c++20"),
] {
let opts = NvrtcOpts::default().with_cpp_std(s);
let flags = opts.build_flags();
assert!(
flags.iter().any(|f| f == want),
"{s:?} must emit `{want}`, got {flags:?}"
);
}
}
}