risc0_build_kernel/
lib.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    env, fs,
17    path::{Path, PathBuf},
18    process::Command,
19};
20
21use rayon::prelude::*;
22use sha2::{Digest, Sha256};
23use tempfile::tempdir_in;
24
25const METAL_INCS: &[(&str, &str)] = &[
26    ("fp.h", include_str!("../kernels/metal/fp.h")),
27    ("fpext.h", include_str!("../kernels/metal/fpext.h")),
28];
29
30#[derive(Eq, PartialEq, Hash)]
31#[non_exhaustive]
32pub enum KernelType {
33    Cpp,
34    Cuda,
35    Metal,
36}
37
38pub struct KernelBuild {
39    kernel_type: KernelType,
40    flags: Vec<String>,
41    files: Vec<PathBuf>,
42    inc_dirs: Vec<PathBuf>,
43    deps: Vec<PathBuf>,
44}
45
46impl KernelBuild {
47    pub fn new(kernel_type: KernelType) -> Self {
48        Self {
49            kernel_type,
50            flags: Vec::new(),
51            files: Vec::new(),
52            inc_dirs: Vec::new(),
53            deps: Vec::new(),
54        }
55    }
56
57    /// Add a directory to the `-I` or include path for headers
58    pub fn include<P: AsRef<Path>>(&mut self, dir: P) -> &mut KernelBuild {
59        self.inc_dirs.push(dir.as_ref().to_path_buf());
60        self
61    }
62
63    /// Add an arbitrary flag to the invocation of the compiler
64    pub fn flag(&mut self, flag: &str) -> &mut KernelBuild {
65        self.flags.push(flag.to_string());
66        self
67    }
68
69    /// Add a file which will be compiled
70    pub fn file<P: AsRef<Path>>(&mut self, p: P) -> &mut KernelBuild {
71        self.files.push(p.as_ref().to_path_buf());
72        self
73    }
74
75    /// Add files which will be compiled
76    pub fn files<P>(&mut self, p: P) -> &mut KernelBuild
77    where
78        P: IntoIterator,
79        P::Item: AsRef<Path>,
80    {
81        for file in p.into_iter() {
82            self.file(file);
83        }
84        self
85    }
86
87    /// Add a file which will be compiled
88    pub fn file_opt<P: AsRef<Path>>(&mut self, _p: P, _opt: usize) -> &mut KernelBuild {
89        self
90    }
91
92    /// Add files which will be compiled
93    pub fn files_opt<P>(&mut self, _p: P, _opt: usize) -> &mut KernelBuild
94    where
95        P: IntoIterator,
96        P::Item: AsRef<Path>,
97    {
98        self
99    }
100
101    /// Add a dependency
102    pub fn dep<P: AsRef<Path>>(&mut self, p: P) -> &mut KernelBuild {
103        self.deps.push(p.as_ref().to_path_buf());
104        self
105    }
106
107    /// Add dependencies
108    pub fn deps<P>(&mut self, p: P) -> &mut KernelBuild
109    where
110        P: IntoIterator,
111        P::Item: AsRef<Path>,
112    {
113        for file in p.into_iter() {
114            self.dep(file);
115        }
116        self
117    }
118
119    pub fn compile(&mut self, output: &str) {
120        println!("cargo:rerun-if-env-changed=RISC0_SKIP_BUILD_KERNELS");
121        for src in self.files.iter() {
122            rerun_if_changed(src);
123        }
124        for dep in self.deps.iter() {
125            rerun_if_changed(dep);
126        }
127        match &self.kernel_type {
128            KernelType::Cpp => self.compile_cpp(output),
129            KernelType::Cuda => self.compile_cuda(output),
130            KernelType::Metal => self.compile_metal(output),
131        }
132    }
133
134    fn compile_cpp(&mut self, output: &str) {
135        if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
136            return;
137        }
138
139        // It's *highly* recommended to install `sccache` and use this combined with
140        // `RUSTC_WRAPPER=/path/to/sccache` to speed up rebuilds of C++ kernels
141        cc::Build::new()
142            .cpp(true)
143            .debug(false)
144            .files(&self.files)
145            .includes(&self.inc_dirs)
146            .flag_if_supported("/std:c++17")
147            .flag_if_supported("-std=c++17")
148            .flag_if_supported("-fno-var-tracking")
149            .flag_if_supported("-fno-var-tracking-assignments")
150            .flag_if_supported("-g0")
151            .compile(output);
152    }
153
154    fn compile_cuda(&mut self, output: &str) {
155        println!("cargo:rerun-if-env-changed=NVCC_APPEND_FLAGS");
156        println!("cargo:rerun-if-env-changed=NVCC_PREPEND_FLAGS");
157        println!("cargo:rerun-if-env-changed=RISC0_CUDART_LINKAGE");
158        println!("cargo:rerun-if-env-changed=NVCC_CCBIN");
159
160        for inc_dir in self.inc_dirs.iter() {
161            rerun_if_changed(inc_dir);
162        }
163
164        if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
165            let out_dir = env::var("OUT_DIR").map(PathBuf::from).unwrap();
166            let out_path = out_dir.join(format!("lib{output}-skip.a"));
167            fs::OpenOptions::new()
168                .create(true)
169                .truncate(true)
170                .write(true)
171                .open(&out_path)
172                .unwrap();
173            println!("cargo:{}={}", output, out_path.display());
174            return;
175        }
176
177        let mut build = cc::Build::new();
178
179        for file in self.files.iter() {
180            build.file(file);
181        }
182
183        for inc in self.inc_dirs.iter() {
184            build.include(inc);
185        }
186
187        for flag in self.flags.iter() {
188            build.flag(flag);
189        }
190
191        if env::var_os("NVCC_PREPEND_FLAGS").is_none() && env::var_os("NVCC_APPEND_FLAGS").is_none()
192        {
193            build.flag("-arch=native");
194        }
195
196        let cudart = env::var("RISC0_CUDART_LINKAGE").unwrap_or("static".to_string());
197
198        build
199            .cuda(true)
200            .cudart(&cudart)
201            .debug(false)
202            .ccbin(env::var("NVCC_CCBIN").is_err())
203            .flag("-diag-suppress=177")
204            .flag("-diag-suppress=2922")
205            .flag("-Xcudafe")
206            .flag("--display_error_number")
207            .flag("-Xcompiler")
208            .flag("-Wno-missing-braces,-Wno-unused-function")
209            .compile(output);
210    }
211
212    fn compile_metal(&mut self, output: &str) {
213        let target = env::var("TARGET").unwrap();
214        let sdk_name = if target.ends_with("ios") {
215            "iphoneos"
216        } else if target.ends_with("ios-sim") {
217            "iphonesimulator"
218        } else if target.ends_with("darwin") {
219            "macosx"
220        } else {
221            panic!("unsupported target: {target}")
222        };
223
224        self.cached_compile(
225            output,
226            "metallib",
227            METAL_INCS,
228            &[],
229            &[sdk_name.to_string()],
230            |out_dir, out_path, sys_inc_dir, _flags| {
231                let files: Vec<_> = self.files.iter().map(|x| x.as_path()).collect();
232
233                let air_paths: Vec<_> = files
234                    .into_par_iter()
235                    .map(|src| {
236                        let air_path = out_dir.join(src).with_extension("").with_extension("air");
237                        if let Some(parent) = air_path.parent() {
238                            fs::create_dir_all(parent).unwrap();
239                        }
240                        let mut cmd = Command::new("xcrun");
241                        cmd.args(["--sdk", sdk_name]);
242                        cmd.arg("metal");
243                        cmd.arg("-o").arg(&air_path);
244                        cmd.arg("-c").arg(src);
245                        cmd.arg("-I").arg(sys_inc_dir);
246                        cmd.arg("-Wno-unused-variable");
247                        for inc_dir in self.inc_dirs.iter() {
248                            cmd.arg("-I").arg(inc_dir);
249                        }
250                        println!("Running: {cmd:?}");
251                        let status = cmd.status().unwrap();
252                        if !status.success() {
253                            panic!("Could not build metal kernels");
254                        }
255                        air_path
256                    })
257                    .collect();
258
259                let result = Command::new("xcrun")
260                    .args(["--sdk", sdk_name])
261                    .arg("metallib")
262                    .args(air_paths)
263                    .arg("-o")
264                    .arg(out_path)
265                    .status()
266                    .unwrap();
267                if !result.success() {
268                    panic!("Could not build metal kernels");
269                }
270            },
271        );
272    }
273
274    fn cached_compile<F: Fn(&Path, &Path, &Path, &[String])>(
275        &self,
276        output: &str,
277        extension: &str,
278        assets: &[(&str, &str)],
279        flags: &[String],
280        tags: &[String],
281        inner: F,
282    ) {
283        let out_dir = env::var("OUT_DIR").map(PathBuf::from).unwrap();
284        if env::var("RISC0_SKIP_BUILD_KERNELS").is_ok() {
285            let out_path = out_dir
286                .join("skip-".to_string() + output)
287                .with_extension(extension);
288            fs::OpenOptions::new()
289                .create(true)
290                .truncate(true)
291                .write(true)
292                .open(&out_path)
293                .unwrap();
294            println!("cargo:{}={}", output, out_path.display());
295            return;
296        }
297
298        let out_path = out_dir.join(output).with_extension(extension);
299        let sys_inc_dir = out_dir.join("_sys_");
300
301        let cache_dir = risc0_cache();
302        if !cache_dir.is_dir() {
303            fs::create_dir_all(&cache_dir).unwrap();
304        }
305
306        let temp_dir = tempdir_in(&cache_dir).unwrap();
307        let mut hasher = Hasher::new();
308        for flag in flags {
309            hasher.add_flag(flag);
310        }
311        for tag in tags {
312            hasher.add_flag(tag);
313        }
314        for src in self.files.iter() {
315            hasher.add_file(src);
316        }
317        for (name, contents) in assets {
318            let path = sys_inc_dir.join(name);
319            if let Some(parent) = path.parent() {
320                fs::create_dir_all(parent).unwrap();
321            }
322            fs::write(&path, contents).unwrap();
323            hasher.add_file(path);
324        }
325        for dep in self.deps.iter() {
326            hasher.add_file(dep);
327        }
328        let digest = hasher.finalize();
329        let cache_path = cache_dir.join(digest).with_extension(extension);
330        if !cache_path.is_file() {
331            let tmp_dir = temp_dir.path();
332            let tmp_path = tmp_dir.join(output).with_extension(extension);
333            inner(tmp_dir, &tmp_path, &sys_inc_dir, flags);
334            fs::rename(tmp_path, &cache_path).unwrap();
335        }
336        fs::copy(cache_path, &out_path).unwrap();
337
338        println!("cargo:{}={}", output, out_path.display());
339    }
340}
341
342fn risc0_cache() -> PathBuf {
343    directories::ProjectDirs::from("com.risczero", "RISC Zero", "risc0")
344        .unwrap()
345        .cache_dir()
346        .into()
347}
348
349struct Hasher {
350    sha: Sha256,
351}
352
353impl Hasher {
354    pub fn new() -> Self {
355        Self { sha: Sha256::new() }
356    }
357
358    pub fn add_flag(&mut self, flag: &str) {
359        self.sha.update(flag);
360    }
361
362    pub fn add_file<P: AsRef<Path>>(&mut self, path: P) {
363        let bytes = fs::read(path).unwrap();
364        self.sha.update(bytes);
365    }
366
367    pub fn finalize(self) -> String {
368        hex::encode(self.sha.finalize())
369    }
370}
371
372fn rerun_if_changed<P: AsRef<Path>>(path: P) {
373    println!("cargo:rerun-if-changed={}", path.as_ref().display());
374}