1use std::{
16 env, fs,
17 path::{Path, PathBuf},
18 process::Command,
19};
20
21use rayon::prelude::*;
22use sha2::{Digest, Sha256};
23use tempfile::tempdir_in;
24
25const METAL_INCS: &[(&str, &str)] = &[
26 ("fp.h", include_str!("../kernels/metal/fp.h")),
27 ("fpext.h", include_str!("../kernels/metal/fpext.h")),
28];
29
30#[derive(Eq, PartialEq, Hash)]
31#[non_exhaustive]
32pub enum KernelType {
33 Cpp,
34 Cuda,
35 Metal,
36}
37
38pub struct KernelBuild {
39 kernel_type: KernelType,
40 flags: Vec<String>,
41 files: Vec<PathBuf>,
42 inc_dirs: Vec<PathBuf>,
43 deps: Vec<PathBuf>,
44}
45
46impl KernelBuild {
47 pub fn new(kernel_type: KernelType) -> Self {
48 Self {
49 kernel_type,
50 flags: Vec::new(),
51 files: Vec::new(),
52 inc_dirs: Vec::new(),
53 deps: Vec::new(),
54 }
55 }
56
57 pub fn include<P: AsRef<Path>>(&mut self, dir: P) -> &mut KernelBuild {
59 self.inc_dirs.push(dir.as_ref().to_path_buf());
60 self
61 }
62
63 pub fn flag(&mut self, flag: &str) -> &mut KernelBuild {
65 self.flags.push(flag.to_string());
66 self
67 }
68
69 pub fn file<P: AsRef<Path>>(&mut self, p: P) -> &mut KernelBuild {
71 self.files.push(p.as_ref().to_path_buf());
72 self
73 }
74
75 pub fn files<P>(&mut self, p: P) -> &mut KernelBuild
77 where
78 P: IntoIterator,
79 P::Item: AsRef<Path>,
80 {
81 for file in p.into_iter() {
82 self.file(file);
83 }
84 self
85 }
86
87 pub fn file_opt<P: AsRef<Path>>(&mut self, _p: P, _opt: usize) -> &mut KernelBuild {
89 self
90 }
91
92 pub fn files_opt<P>(&mut self, _p: P, _opt: usize) -> &mut KernelBuild
94 where
95 P: IntoIterator,
96 P::Item: AsRef<Path>,
97 {
98 self
99 }
100
101 pub fn dep<P: AsRef<Path>>(&mut self, p: P) -> &mut KernelBuild {
103 self.deps.push(p.as_ref().to_path_buf());
104 self
105 }
106
107 pub fn deps<P>(&mut self, p: P) -> &mut KernelBuild
109 where
110 P: IntoIterator,
111 P::Item: AsRef<Path>,
112 {
113 for file in p.into_iter() {
114 self.dep(file);
115 }
116 self
117 }
118
119 pub fn compile(&mut self, output: &str) {
120 println!("cargo:rerun-if-env-changed=RISC0_SKIP_BUILD_KERNELS");
121 for src in self.files.iter() {
122 rerun_if_changed(src);
123 }
124 for dep in self.deps.iter() {
125 rerun_if_changed(dep);
126 }
127 match &self.kernel_type {
128 KernelType::Cpp => self.compile_cpp(output),
129 KernelType::Cuda => self.compile_cuda(output),
130 KernelType::Metal => self.compile_metal(output),
131 }
132 }
133
134 fn compile_cpp(&mut self, output: &str) {
135 if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
136 return;
137 }
138
139 cc::Build::new()
142 .cpp(true)
143 .debug(false)
144 .files(&self.files)
145 .includes(&self.inc_dirs)
146 .flag_if_supported("/std:c++17")
147 .flag_if_supported("-std=c++17")
148 .flag_if_supported("-fno-var-tracking")
149 .flag_if_supported("-fno-var-tracking-assignments")
150 .flag_if_supported("-g0")
151 .compile(output);
152 }
153
154 fn compile_cuda(&mut self, output: &str) {
155 println!("cargo:rerun-if-env-changed=NVCC_APPEND_FLAGS");
156 println!("cargo:rerun-if-env-changed=NVCC_PREPEND_FLAGS");
157 println!("cargo:rerun-if-env-changed=RISC0_CUDART_LINKAGE");
158 println!("cargo:rerun-if-env-changed=NVCC_CCBIN");
159
160 for inc_dir in self.inc_dirs.iter() {
161 rerun_if_changed(inc_dir);
162 }
163
164 if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
165 let out_dir = env::var("OUT_DIR").map(PathBuf::from).unwrap();
166 let out_path = out_dir.join(format!("lib{output}-skip.a"));
167 fs::OpenOptions::new()
168 .create(true)
169 .truncate(true)
170 .write(true)
171 .open(&out_path)
172 .unwrap();
173 println!("cargo:{}={}", output, out_path.display());
174 return;
175 }
176
177 let mut build = cc::Build::new();
178
179 for file in self.files.iter() {
180 build.file(file);
181 }
182
183 for inc in self.inc_dirs.iter() {
184 build.include(inc);
185 }
186
187 for flag in self.flags.iter() {
188 build.flag(flag);
189 }
190
191 if env::var_os("NVCC_PREPEND_FLAGS").is_none() && env::var_os("NVCC_APPEND_FLAGS").is_none()
192 {
193 build.flag("-arch=native");
194 }
195
196 let cudart = env::var("RISC0_CUDART_LINKAGE").unwrap_or("static".to_string());
197
198 build
199 .cuda(true)
200 .cudart(&cudart)
201 .debug(false)
202 .ccbin(env::var("NVCC_CCBIN").is_err())
203 .flag("-diag-suppress=177")
204 .flag("-diag-suppress=2922")
205 .flag("-Xcudafe")
206 .flag("--display_error_number")
207 .flag("-Xcompiler")
208 .flag("-Wno-missing-braces,-Wno-unused-function")
209 .compile(output);
210 }
211
212 fn compile_metal(&mut self, output: &str) {
213 let target = env::var("TARGET").unwrap();
214 let sdk_name = if target.ends_with("ios") {
215 "iphoneos"
216 } else if target.ends_with("ios-sim") {
217 "iphonesimulator"
218 } else if target.ends_with("darwin") {
219 "macosx"
220 } else {
221 panic!("unsupported target: {target}")
222 };
223
224 self.cached_compile(
225 output,
226 "metallib",
227 METAL_INCS,
228 &[],
229 &[sdk_name.to_string()],
230 |out_dir, out_path, sys_inc_dir, _flags| {
231 let files: Vec<_> = self.files.iter().map(|x| x.as_path()).collect();
232
233 let air_paths: Vec<_> = files
234 .into_par_iter()
235 .map(|src| {
236 let air_path = out_dir.join(src).with_extension("").with_extension("air");
237 if let Some(parent) = air_path.parent() {
238 fs::create_dir_all(parent).unwrap();
239 }
240 let mut cmd = Command::new("xcrun");
241 cmd.args(["--sdk", sdk_name]);
242 cmd.arg("metal");
243 cmd.arg("-o").arg(&air_path);
244 cmd.arg("-c").arg(src);
245 cmd.arg("-I").arg(sys_inc_dir);
246 cmd.arg("-Wno-unused-variable");
247 for inc_dir in self.inc_dirs.iter() {
248 cmd.arg("-I").arg(inc_dir);
249 }
250 println!("Running: {cmd:?}");
251 let status = cmd.status().unwrap();
252 if !status.success() {
253 panic!("Could not build metal kernels");
254 }
255 air_path
256 })
257 .collect();
258
259 let result = Command::new("xcrun")
260 .args(["--sdk", sdk_name])
261 .arg("metallib")
262 .args(air_paths)
263 .arg("-o")
264 .arg(out_path)
265 .status()
266 .unwrap();
267 if !result.success() {
268 panic!("Could not build metal kernels");
269 }
270 },
271 );
272 }
273
274 fn cached_compile<F: Fn(&Path, &Path, &Path, &[String])>(
275 &self,
276 output: &str,
277 extension: &str,
278 assets: &[(&str, &str)],
279 flags: &[String],
280 tags: &[String],
281 inner: F,
282 ) {
283 let out_dir = env::var("OUT_DIR").map(PathBuf::from).unwrap();
284 if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
285 let out_path = out_dir
286 .join("skip-".to_string() + output)
287 .with_extension(extension);
288 fs::OpenOptions::new()
289 .create(true)
290 .truncate(true)
291 .write(true)
292 .open(&out_path)
293 .unwrap();
294 println!("cargo:{}={}", output, out_path.display());
295 return;
296 }
297
298 let out_path = out_dir.join(output).with_extension(extension);
299 let sys_inc_dir = out_dir.join("_sys_");
300
301 let cache_dir = risc0_cache();
302 if !cache_dir.is_dir() {
303 fs::create_dir_all(&cache_dir).unwrap();
304 }
305
306 let temp_dir = tempdir_in(&cache_dir).unwrap();
307 let mut hasher = Hasher::new();
308 for flag in flags {
309 hasher.add_flag(flag);
310 }
311 for tag in tags {
312 hasher.add_flag(tag);
313 }
314 for src in self.files.iter() {
315 hasher.add_file(src);
316 }
317 for (name, contents) in assets {
318 let path = sys_inc_dir.join(name);
319 if let Some(parent) = path.parent() {
320 fs::create_dir_all(parent).unwrap();
321 }
322 fs::write(&path, contents).unwrap();
323 hasher.add_file(path);
324 }
325 for dep in self.deps.iter() {
326 hasher.add_file(dep);
327 }
328 let digest = hasher.finalize();
329 let cache_path = cache_dir.join(digest).with_extension(extension);
330 if !cache_path.is_file() {
331 let tmp_dir = temp_dir.path();
332 let tmp_path = tmp_dir.join(output).with_extension(extension);
333 inner(tmp_dir, &tmp_path, &sys_inc_dir, flags);
334 fs::rename(tmp_path, &cache_path).unwrap();
335 }
336 fs::copy(cache_path, &out_path).unwrap();
337
338 println!("cargo:{}={}", output, out_path.display());
339 }
340}
341
342fn risc0_cache() -> PathBuf {
343 directories::ProjectDirs::from("com.risczero", "RISC Zero", "risc0")
344 .unwrap()
345 .cache_dir()
346 .into()
347}
348
349struct Hasher {
350 sha: Sha256,
351}
352
353impl Hasher {
354 pub fn new() -> Self {
355 Self { sha: Sha256::new() }
356 }
357
358 pub fn add_flag(&mut self, flag: &str) {
359 self.sha.update(flag);
360 }
361
362 pub fn add_file<P: AsRef<Path>>(&mut self, path: P) {
363 let bytes = fs::read(path).unwrap();
364 self.sha.update(bytes);
365 }
366
367 pub fn finalize(self) -> String {
368 hex::encode(self.sha.finalize())
369 }
370}
371
372fn rerun_if_changed<P: AsRef<Path>>(path: P) {
373 println!("cargo:rerun-if-changed={}", path.as_ref().display());
374}