use std::{
env, fs,
path::{Path, PathBuf},
process::Command,
};
use sha2::{Digest, Sha256};
use tempfile::tempdir_in;
const CUDA_INCS: &[(&str, &str)] = &[
("fp.h", include_str!("../kernels/cuda/fp.h")),
("fp4.h", include_str!("../kernels/cuda/fp4.h")),
];
const METAL_INCS: &[(&str, &str)] = &[
("fp.h", include_str!("../kernels/metal/fp.h")),
("fp4.h", include_str!("../kernels/metal/fp4.h")),
];
#[derive(Eq, PartialEq, Hash)]
pub enum KernelType {
Cpp,
Cuda,
Metal,
}
pub struct KernelBuild {
kernel_type: KernelType,
flags: Vec<String>,
files: Vec<PathBuf>,
inc_dirs: Vec<PathBuf>,
deps: Vec<PathBuf>,
}
impl KernelBuild {
pub fn new(kernel_type: KernelType) -> Self {
Self {
kernel_type,
flags: Vec::new(),
files: Vec::new(),
inc_dirs: Vec::new(),
deps: Vec::new(),
}
}
pub fn include<P: AsRef<Path>>(&mut self, dir: P) -> &mut KernelBuild {
self.inc_dirs.push(dir.as_ref().to_path_buf());
self
}
pub fn flag(&mut self, flag: &str) -> &mut KernelBuild {
self.flags.push(flag.to_string());
self
}
pub fn file<P: AsRef<Path>>(&mut self, p: P) -> &mut KernelBuild {
self.files.push(p.as_ref().to_path_buf());
self
}
pub fn files<P>(&mut self, p: P) -> &mut KernelBuild
where
P: IntoIterator,
P::Item: AsRef<Path>,
{
for file in p.into_iter() {
self.file(file);
}
self
}
pub fn dep<P: AsRef<Path>>(&mut self, p: P) -> &mut KernelBuild {
self.deps.push(p.as_ref().to_path_buf());
self
}
pub fn deps<P>(&mut self, p: P) -> &mut KernelBuild
where
P: IntoIterator,
P::Item: AsRef<Path>,
{
for file in p.into_iter() {
self.dep(file);
}
self
}
pub fn compile(&mut self, output: &str) {
for src in self.files.iter() {
println!("cargo:rerun-if-changed={}", src.display());
}
for dep in self.deps.iter() {
println!("cargo:rerun-if-changed={}", dep.display());
}
match &self.kernel_type {
KernelType::Cpp => self.compile_cpp(output),
KernelType::Cuda => self.compile_cuda(output),
KernelType::Metal => self.compile_metal(output),
}
}
fn compile_cpp(&mut self, output: &str) {
cc::Build::new()
.cpp(true)
.debug(false)
.files(&self.files)
.flag_if_supported("/std:c++17")
.flag_if_supported("-std=c++17")
.flag_if_supported("-fno-var-tracking")
.flag_if_supported("-fno-var-tracking-assignments")
.flag_if_supported("-g0")
.compile(output);
}
fn compile_cuda(&mut self, output: &str) {
self.cached_compile(
output,
"fatbin",
CUDA_INCS,
|_out_dir, out_path, sys_inc_dir| {
println!("cargo:rerun-if-env-changed=RISC0_CUDA_OPT");
println!("cargo:rerun-if-env-changed=NVCC_PREPEND_FLAGS");
println!("cargo:rerun-if-env-changed=NVCC_APPEND_FLAGS");
let mut cmd = Command::new("nvcc");
cmd.arg("--fatbin");
cmd.arg("-o").arg(out_path);
cmd.args(self.files.iter());
cmd.arg("-I").arg(sys_inc_dir);
let ptx_opt_level = env::var("RISC0_CUDA_OPT").unwrap_or_else(|_| "1".to_string());
cmd.arg(format!("--ptxas-options=-O{ptx_opt_level}"));
for inc_dir in self.inc_dirs.iter() {
cmd.arg("-I").arg(inc_dir);
}
let status = cmd
.status()
.expect("Failed to run 'nvcc', do you have the CUDA toolkit installed?");
if !status.success() {
panic!("Failed to build CUDA kernel: {}", output);
}
},
);
}
fn compile_metal(&mut self, output: &str) {
self.cached_compile(
output,
"metallib",
METAL_INCS,
|out_dir, out_path, sys_inc_dir| {
let mut air_paths = vec![];
for src in self.files.iter() {
let out_path = out_dir.join(src).with_extension("").with_extension("air");
if let Some(parent) = out_path.parent() {
fs::create_dir_all(parent).unwrap();
}
let mut cmd = Command::new("xcrun");
cmd.args(["--sdk", "macosx"]);
cmd.arg("metal");
cmd.arg("-o").arg(&out_path);
cmd.arg("-c").arg(src);
cmd.arg("-I").arg(sys_inc_dir);
for inc_dir in self.inc_dirs.iter() {
cmd.arg("-I").arg(inc_dir);
}
println!("Running: {:?}", cmd);
let status = cmd.status().unwrap();
if !status.success() {
panic!("Could not build metal kernels");
}
air_paths.push(out_path);
}
let result = Command::new("xcrun")
.args(["--sdk", "macosx"])
.arg("metallib")
.args(air_paths)
.arg("-o")
.arg(out_path)
.status()
.unwrap();
if !result.success() {
panic!("Could not build metal kernels");
}
},
);
}
fn cached_compile<F: Fn(&Path, &Path, &Path)>(
&self,
output: &str,
extension: &str,
assets: &[(&str, &str)],
inner: F,
) {
let out_dir = env::var("OUT_DIR").map(PathBuf::from).unwrap();
let out_path = out_dir.join(output).with_extension(extension);
let sys_inc_dir = out_dir.join("_sys_");
let cache_dir = risc0_cache();
if !cache_dir.is_dir() {
fs::create_dir_all(&cache_dir).unwrap();
}
let temp_dir = tempdir_in(&cache_dir).unwrap();
let mut hasher = Hasher::new();
for src in self.files.iter() {
hasher.add_file(src);
}
for (name, contents) in assets {
let path = sys_inc_dir.join(name);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).unwrap();
}
fs::write(&path, contents).unwrap();
hasher.add_file(path);
}
for dep in self.deps.iter() {
hasher.add_file(dep);
}
let digest = hasher.finalize();
let cache_path = cache_dir.join(digest).with_extension(extension);
if !cache_path.is_file() {
let tmp_dir = temp_dir.path();
let tmp_path = tmp_dir.join(output).with_extension(extension);
inner(tmp_dir, &tmp_path, &sys_inc_dir);
fs::rename(tmp_path, &cache_path).unwrap();
}
fs::copy(cache_path, &out_path).unwrap();
println!("cargo:{}={}", output, out_path.display());
}
}
fn risc0_cache() -> PathBuf {
directories::ProjectDirs::from("com.risczero", "RISC Zero", "risc0")
.unwrap()
.cache_dir()
.into()
}
struct Hasher {
sha: Sha256,
}
impl Hasher {
pub fn new() -> Self {
Self { sha: Sha256::new() }
}
pub fn add_file<P: AsRef<Path>>(&mut self, path: P) {
let bytes = fs::read(path).unwrap();
self.sha.update(bytes);
}
pub fn finalize(self) -> String {
hex::encode(self.sha.finalize())
}
}