use std::env;
use std::path::{Path, PathBuf};
use std::process::Command;
#[derive(Clone, Debug)]
pub enum GpuTarget {
Nvidia,
Amd,
}
impl GpuTarget {
fn triple(&self) -> &str {
match self {
GpuTarget::Nvidia => "nvptx64-nvidia-cuda",
GpuTarget::Amd => "amdgcn-amd-amdhsa",
}
}
#[allow(dead_code)]
fn asm_extension(&self) -> &str {
match self {
GpuTarget::Nvidia => "s", GpuTarget::Amd => "s", }
}
}
pub struct WarpBuilder {
kernel_crate: PathBuf,
toolchain: String,
release: bool,
features: Vec<String>,
target: GpuTarget,
sm_arch: String,
}
impl WarpBuilder {
pub fn new(kernel_crate_path: impl Into<PathBuf>) -> Self {
WarpBuilder {
kernel_crate: kernel_crate_path.into(),
toolchain: "nightly".to_string(),
release: true,
features: Vec::new(),
target: GpuTarget::Nvidia,
sm_arch: "sm_70".to_string(),
}
}
pub fn toolchain(mut self, toolchain: impl Into<String>) -> Self {
self.toolchain = toolchain.into();
self
}
pub fn debug(mut self) -> Self {
self.release = false;
self
}
pub fn target(mut self, target: GpuTarget) -> Self {
self.target = target;
self
}
pub fn sm_arch(mut self, arch: impl Into<String>) -> Self {
self.sm_arch = arch.into();
self
}
pub fn feature(mut self, feature: impl Into<String>) -> Self {
self.features.push(feature.into());
self
}
pub fn build(self) -> Result<BuildResult, BuildError> {
let manifest_dir =
env::var("CARGO_MANIFEST_DIR").map_err(|_| BuildError::NotInBuildScript)?;
let out_dir = env::var("OUT_DIR").map_err(|_| BuildError::NotInBuildScript)?;
let kernel_dir = Path::new(&manifest_dir).join(&self.kernel_crate);
if !kernel_dir.exists() {
return Err(BuildError::KernelCrateNotFound(kernel_dir));
}
emit_rerun_if_changed(&kernel_dir);
let mut cmd = Command::new("cargo");
cmd.arg("rustc")
.arg("--target")
.arg(self.target.triple())
.arg("-Z")
.arg("build-std=core")
.current_dir(&kernel_dir);
if self.release {
cmd.arg("--release");
}
for feat in &self.features {
cmd.arg("--features").arg(feat);
}
cmd.arg("--")
.arg("--emit=asm")
.arg("-C")
.arg(format!("target-cpu={}", self.sm_arch));
cmd.env("RUSTUP_TOOLCHAIN", &self.toolchain);
cmd.env_remove("RUSTC");
cmd.env("RUSTFLAGS", "--cfg warp_kernel_build");
let output = cmd
.output()
.map_err(|e| BuildError::CargoFailed(format!("Failed to run cargo: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(BuildError::CompilationFailed(stderr.into_owned()));
}
let profile = if self.release { "release" } else { "debug" };
let target_dir = kernel_dir
.join("target")
.join(self.target.triple())
.join(profile);
let ptx_path = find_ptx_file(&target_dir, &kernel_dir)?;
let ptx_content = std::fs::read_to_string(&ptx_path)
.map_err(|e| BuildError::PtxReadFailed(format!("{}: {}", ptx_path.display(), e)))?;
let kernels = parse_kernel_entries(&ptx_content);
let out_ptx = Path::new(&out_dir).join("kernels.ptx");
std::fs::write(&out_ptx, &ptx_content)
.map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_ptx.display(), e)))?;
let out_rs = Path::new(&out_dir).join("kernels.rs");
let rs_content = generate_rust_module(&self.kernel_crate, &kernels);
std::fs::write(&out_rs, &rs_content)
.map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_rs.display(), e)))?;
Ok(BuildResult {
ptx_path: out_ptx,
module_path: out_rs,
kernel_names: kernels,
})
}
}
pub struct BuildResult {
pub ptx_path: PathBuf,
pub module_path: PathBuf,
pub kernel_names: Vec<String>,
}
#[derive(Debug)]
pub enum BuildError {
NotInBuildScript,
KernelCrateNotFound(PathBuf),
CargoFailed(String),
CompilationFailed(String),
PtxNotFound(String),
PtxReadFailed(String),
WriteFailed(String),
}
impl std::fmt::Display for BuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BuildError::NotInBuildScript => write!(
f,
"Not running in a build script (CARGO_MANIFEST_DIR not set)"
),
BuildError::KernelCrateNotFound(p) => {
write!(f, "Kernel crate not found: {}", p.display())
}
BuildError::CargoFailed(s) => write!(f, "Cargo invocation failed: {}", s),
BuildError::CompilationFailed(s) => write!(f, "Kernel compilation failed:\n{}", s),
BuildError::PtxNotFound(s) => write!(f, "PTX file not found: {}", s),
BuildError::PtxReadFailed(s) => write!(f, "Failed to read PTX: {}", s),
BuildError::WriteFailed(s) => write!(f, "Failed to write output: {}", s),
}
}
}
impl std::error::Error for BuildError {}
fn parse_kernel_entries(ptx: &str) -> Vec<String> {
ptx.lines()
.filter_map(|line| {
let trimmed = line.trim();
if trimmed.starts_with(".visible .entry ") {
let rest = trimmed.strip_prefix(".visible .entry ")?;
let name = rest.split('(').next()?.trim();
if !name.is_empty() {
Some(name.to_string())
} else {
None
}
} else {
None
}
})
.collect()
}
fn generate_rust_module(kernel_crate: &Path, kernels: &[String]) -> String {
let mut code = String::new();
code.push_str(&format!(
"// Auto-generated by warp-types-builder. Do not edit.\n\
// Kernel crate: {}\n\
// Kernel count: {}\n\n",
kernel_crate.display(),
kernels.len(),
));
code.push_str(
"/// Raw PTX assembly for all kernels in this module.\n\
pub const KERNEL_PTX: &str = include_str!(concat!(env!(\"OUT_DIR\"), \"/kernels.ptx\"));\n\n"
);
code.push_str(
"/// Loaded GPU kernel functions.\n\
///\n\
/// Created by [`Kernels::load`], which parses the PTX and extracts\n\
/// each kernel entry point as a ready-to-launch [`CudaFunction`].\n\
///\n\
/// # Available kernels\n",
);
for name in kernels {
code.push_str(&format!("/// - `{}` \n", name));
}
code.push_str(
"pub struct Kernels {\n\
_module: ::std::sync::Arc<::cudarc::driver::CudaModule>,\n",
);
for name in kernels {
code.push_str(&format!(
" /// Kernel: `{name}`\n\
pub {name}: ::cudarc::driver::CudaFunction,\n",
name = name,
));
}
code.push_str("}\n\n");
code.push_str(
"impl Kernels {\n\
/// Load all kernels from the compiled PTX.\n\
///\n\
/// Parses the embedded PTX assembly, loads it as a CUDA module,\n\
/// and extracts each kernel entry point by name.\n\
pub fn load(ctx: &::std::sync::Arc<::cudarc::driver::CudaContext>) -> \
::std::result::Result<Self, Box<dyn ::std::error::Error>> {\n\
let ptx = ::cudarc::nvrtc::Ptx::from_src(KERNEL_PTX.to_string());\n\
let module = ctx.load_module(ptx)?;\n",
);
for name in kernels {
code.push_str(&format!(
" let {name} = module.load_function(\"{name}\")?;\n",
name = name,
));
}
code.push_str(" let _module = module;\n");
code.push_str(" Ok(Kernels {\n _module,\n");
for name in kernels {
code.push_str(&format!(" {},\n", name));
}
code.push_str(" })\n }\n}\n");
code
}
fn emit_rerun_if_changed(kernel_dir: &Path) {
println!(
"cargo:rerun-if-changed={}",
kernel_dir.join("Cargo.toml").display()
);
let src_dir = kernel_dir.join("src");
if src_dir.exists() {
emit_rerun_recursive(&src_dir);
}
}
fn emit_rerun_recursive(dir: &Path) {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
emit_rerun_recursive(&path);
} else {
println!("cargo:rerun-if-changed={}", path.display());
}
}
}
}
fn find_ptx_file(target_dir: &Path, kernel_dir: &Path) -> Result<PathBuf, BuildError> {
let cargo_toml = kernel_dir.join("Cargo.toml");
let content = std::fs::read_to_string(&cargo_toml).map_err(|e| {
BuildError::PtxNotFound(format!("Can't read {}: {}", cargo_toml.display(), e))
})?;
let crate_name = {
let mut in_package = false;
content
.lines()
.find_map(|line| {
let line = line.trim();
if line.starts_with('[') {
in_package = line == "[package]";
return None;
}
if in_package && line.starts_with("name") {
let val = line.split('=').nth(1)?.trim().trim_matches('"');
return Some(val.replace('-', "_"));
}
None
})
.unwrap_or_else(|| "kernels".to_string())
};
let candidates = [
target_dir.join(format!("{}.s", crate_name)),
target_dir.join(format!("lib{}.s", crate_name)),
target_dir.join(format!("{}.ptx", crate_name)),
target_dir.join("deps").join(format!("{}.s", crate_name)),
target_dir.join("deps").join(format!("lib{}.s", crate_name)),
];
for path in &candidates {
if path.exists() {
return Ok(path.clone());
}
}
let deps = target_dir.join("deps");
if let Ok(entries) = std::fs::read_dir(&deps) {
for entry in entries.flatten() {
let p = entry.path();
if p.extension().is_some_and(|e| e == "s") {
let fname = p
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_default();
if fname.starts_with(&crate_name)
&& !fname.starts_with("core-")
&& !fname.starts_with("compiler_builtins-")
{
return Ok(p);
}
}
}
}
if let Ok(entries) = std::fs::read_dir(&deps) {
let mut candidates: Vec<PathBuf> = entries
.flatten()
.map(|e| e.path())
.filter(|p| {
p.extension().is_some_and(|e| e == "s") && {
let fname = p
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_default();
!fname.starts_with("core-")
&& !fname.starts_with("compiler_builtins-")
&& !fname.starts_with("warp_types-")
}
})
.collect();
candidates.sort();
if let Some(p) = candidates.into_iter().next() {
return Ok(p);
}
}
Err(BuildError::PtxNotFound(format!(
"No .s/.ptx file found in {}. Crate name: '{}'. Checked: {:?}",
target_dir.display(),
crate_name,
candidates
.iter()
.map(|c| c.display().to_string())
.collect::<Vec<_>>()
)))
}