Skip to main content

bindgen_cuda/
lib.rs

1#![deny(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
3use rayon::prelude::*;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::io::Write;
7use std::path::{Path, PathBuf};
8use std::str::FromStr;
9
10/// Error messages
11#[derive(Debug)]
12pub enum Error {}
13
14/// Core builder to setup the bindings options
15#[derive(Debug)]
16pub struct Builder {
17    cuda_root: Option<PathBuf>,
18    kernel_paths: Vec<PathBuf>,
19    watch: Vec<PathBuf>,
20    include_paths: Vec<PathBuf>,
21    compute_cap: Option<usize>,
22    out_dir: PathBuf,
23    extra_args: Vec<&'static str>,
24}
25
26impl Default for Builder {
27    fn default() -> Self {
28        // Use only physical cores for rayon.
29        // Builds can be super consuming and exhaust resources quite fast
30        // like when building flash attention kernels
31        let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
32            |_| num_cpus::get_physical(),
33            |s| usize::from_str(&s).expect("RAYON_NUM_THREADS is not set to a valid integer"),
34        );
35
36        rayon::ThreadPoolBuilder::new()
37            .num_threads(num_cpus)
38            .build_global()
39            .expect("build rayon global threadpool");
40
41        let out_dir = std::env::var("OUT_DIR").expect("Expected OUT_DIR environement variable to be present, is this running within `build.rs`?").into();
42
43        let cuda_root = cuda_include_dir();
44        let kernel_paths = default_kernels().unwrap_or_default();
45        let include_paths = default_include().unwrap_or_default();
46        let extra_args = vec![];
47        let watch = vec![];
48        let compute_cap = compute_cap().ok();
49        Self {
50            cuda_root,
51            kernel_paths,
52            watch,
53            include_paths,
54            extra_args,
55            compute_cap,
56            out_dir,
57        }
58    }
59}
60
61/// Helper struct to create a rust file when buildings PTX files.
62pub struct Bindings {
63    write: bool,
64    paths: Vec<PathBuf>,
65}
66
67fn default_kernels() -> Option<Vec<PathBuf>> {
68    Some(
69        glob::glob("src/**/*.cu")
70            .ok()?
71            .map(|p| p.expect("Invalid path"))
72            .collect(),
73    )
74}
75fn default_include() -> Option<Vec<PathBuf>> {
76    Some(
77        glob::glob("src/**/*.cuh")
78            .ok()?
79            .map(|p| p.expect("Invalid path"))
80            .collect(),
81    )
82}
83
84impl Builder {
85    /// Force to use a given compute capability
86    pub fn set_compute_cap(&mut self, cap: usize) {
87        self.compute_cap = Some(cap);
88    }
89
90    /// Returns the detected CUDA compute capability, if available.
91    pub fn get_compute_cap(&self) -> Option<usize> {
92        self.compute_cap
93    }
94    /// Setup the kernel paths. All path must be set at once and be valid files.
95    /// ```no_run
96    /// let builder = bindgen_cuda::Builder::default().kernel_paths(vec!["src/mykernel.cu"]);
97    /// ```
98    pub fn kernel_paths<P: Into<PathBuf>>(mut self, paths: Vec<P>) -> Self {
99        let paths: Vec<_> = paths.into_iter().map(|p| p.into()).collect();
100        let inexistent_paths: Vec<_> = paths.iter().filter(|f| !f.exists()).collect();
101        if !inexistent_paths.is_empty() {
102            panic!("Kernels paths do not exist {inexistent_paths:?}");
103        }
104        self.kernel_paths = paths;
105        self
106    }
107
108    /// Setup the paths that the lib depend on but does not need to build
109    /// ```no_run
110    /// let builder =
111    /// bindgen_cuda::Builder::default().watch(vec!["kernels/"]);
112    /// ```
113    pub fn watch<T, P>(mut self, paths: T) -> Self
114    where
115        T: IntoIterator<Item = P>,
116        P: Into<PathBuf>,
117    {
118        let paths: Vec<_> = paths.into_iter().map(|p| p.into()).collect();
119        let inexistent_paths: Vec<_> = paths.iter().filter(|f| !f.exists()).collect();
120        if !inexistent_paths.is_empty() {
121            panic!("Kernels paths do not exist {inexistent_paths:?}");
122        }
123        self.watch = paths;
124        self
125    }
126
127    /// Setup the kernel paths. All path must be set at once and be valid files.
128    /// ```no_run
129    /// let builder = bindgen_cuda::Builder::default().include_paths(vec!["src/mykernel.cuh"]);
130    /// ```
131    pub fn include_paths<P: Into<PathBuf>>(mut self, paths: Vec<P>) -> Self {
132        self.include_paths = paths.into_iter().map(|p| p.into()).collect();
133        self
134    }
135
136    /// Setup the kernels with a glob.
137    /// ```no_run
138    /// let builder = bindgen_cuda::Builder::default().kernel_paths_glob("src/**/*.cu");
139    /// ```
140    pub fn kernel_paths_glob(mut self, glob: &str) -> Self {
141        self.kernel_paths = glob::glob(glob)
142            .expect("Invalid blob")
143            .map(|p| p.expect("Invalid path"))
144            .collect();
145        self
146    }
147
148    /// Setup the include files with a glob.
149    /// ```no_run
150    /// let builder = bindgen_cuda::Builder::default().kernel_paths_glob("src/**/*.cuh");
151    /// ```
152    pub fn include_paths_glob(mut self, glob: &str) -> Self {
153        self.include_paths = glob::glob(glob)
154            .expect("Invalid blob")
155            .map(|p| p.expect("Invalid path"))
156            .collect();
157        self
158    }
159
160    /// Modifies the output directory.
161    /// By default this is
162    /// [OUT_DIR](https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-build-scripts)
163    /// ```no_run
164    /// let builder = bindgen_cuda::Builder::default().out_dir("out/");
165    /// ```
166    pub fn out_dir<P: Into<PathBuf>>(mut self, out_dir: P) -> Self {
167        self.out_dir = out_dir.into();
168        self
169    }
170
171    /// Sets up extra nvcc compile arguments.
172    /// ```no_run
173    /// let builder = bindgen_cuda::Builder::default().arg("--expt-relaxed-constexpr");
174    /// ```
175    pub fn arg(mut self, arg: &'static str) -> Self {
176        self.extra_args.push(arg);
177        self
178    }
179
180    /// Forces the cuda root to a specific directory.
181    /// By default all standard directories will be visited.
182    /// ```no_run
183    /// let builder = bindgen_cuda::Builder::default().cuda_root("/usr/local/cuda");
184    /// ```
185    pub fn cuda_root<P>(&mut self, path: P)
186    where
187        P: Into<PathBuf>,
188    {
189        self.cuda_root = Some(path.into());
190    }
191
192    /// Consumes the builder and create a lib in the out_dir.
193    /// It then needs to be linked against in your `build.rs`
194    /// ```no_run
195    /// let builder = bindgen_cuda::Builder::default().build_lib("libflash.a");
196    /// println!("cargo:rustc-link-lib=flash");
197    /// ```
198    pub fn build_lib<P>(&self, out_file: P)
199    where
200        P: Into<PathBuf>,
201    {
202        let out_file = out_file.into();
203        let compute_cap = self.compute_cap.expect("Failed to get compute_cap");
204        let out_dir = self.out_dir.clone();
205        for path in &self.watch {
206            println!("cargo:rerun-if-changed={}", path.display());
207        }
208        let cu_files: Vec<_> = self
209            .kernel_paths
210            .iter()
211            .map(|f| {
212                let mut s = DefaultHasher::new();
213                f.display().to_string().hash(&mut s);
214                let hash = s.finish();
215                let mut obj_file = out_dir.join(format!(
216                    "{}-{:x}",
217                    f.file_stem()
218                        .expect("kernels paths should include a filename")
219                        .to_string_lossy(),
220                    hash
221                ));
222                obj_file.set_extension("o");
223                (f, obj_file)
224            })
225            .collect();
226        let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
227        let should_compile = if let Ok(out_modified) = out_modified {
228            let kernel_modified = self.kernel_paths.iter().any(|entry| {
229                let in_modified = entry
230                    .metadata()
231                    .expect("kernel {entry} should exist")
232                    .modified()
233                    .expect("kernel modified to be accessible");
234                in_modified.duration_since(out_modified).is_ok()
235            });
236            let watch_modified = self.watch.iter().any(|entry| {
237                let in_modified = entry
238                    .metadata()
239                    .expect("watched file {entry} should exist")
240                    .modified()
241                    .expect("watch modified should be accessible");
242                in_modified.duration_since(out_modified).is_ok()
243            });
244            kernel_modified || watch_modified
245        } else {
246            true
247        };
248        let ccbin_env = std::env::var("NVCC_CCBIN");
249        let nvcc_binary = if std::path::Path::new("/usr/local/cuda/bin/nvcc").exists() {
250            "/usr/local/cuda/bin/nvcc"
251        } else {
252            "nvcc"
253        };
254        if should_compile {
255            cu_files
256            .par_iter()
257            .map(|(cu_file, obj_file)| {
258                let mut command = std::process::Command::new(nvcc_binary);
259                command
260                    .arg(format!("--gpu-architecture=sm_{compute_cap}"))
261                    .arg("-c")
262                    .args(["-o", obj_file.to_str().expect("valid outfile")])
263                    .args(["--default-stream", "per-thread"])
264                    .args(&self.extra_args);
265                if let Ok(ccbin_path) = &ccbin_env {
266                    command
267                        .arg("-allow-unsupported-compiler")
268                        .args(["-ccbin", ccbin_path]);
269                }
270                command.arg(cu_file);
271                let output = command
272                    .spawn()
273                    .expect("failed spawning nvcc")
274                    .wait_with_output().expect("capture nvcc output");
275                if !output.status.success() {
276                    panic!(
277                        "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
278                        &command,
279                        String::from_utf8_lossy(&output.stdout),
280                        String::from_utf8_lossy(&output.stderr)
281                    )
282                }
283                Ok(())
284            })
285            .collect::<Result<(), std::io::Error>>().expect("compile files correctly");
286            let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
287            let mut command = std::process::Command::new(nvcc_binary);
288            command
289                .arg("--lib")
290                .args([
291                    "-o",
292                    out_file.to_str().expect("library file {out_file} to exist"),
293                ])
294                .args(obj_files);
295            let output = command
296                .spawn()
297                .expect("failed spawning nvcc")
298                .wait_with_output()
299                .expect("Run nvcc");
300            if !output.status.success() {
301                panic!(
302                    "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
303                    &command,
304                    String::from_utf8_lossy(&output.stdout),
305                    String::from_utf8_lossy(&output.stderr)
306                )
307            }
308        }
309    }
310
311    /// Consumes the builder and outputs 1 ptx file for each kernels
312    /// found.
313    /// This function returns [`Bindings`] which can then be unused
314    /// to create a rust source file that will include those kernels.
315    /// ```no_run
316    /// let bindings = bindgen_cuda::Builder::default().build_ptx().unwrap();
317    /// bindings.write("src/lib.rs").unwrap();
318    /// ```
319    pub fn build_ptx(&self) -> Result<Bindings, Error> {
320        let mut cuda_include_dir = PathBuf::from("/usr/local/cuda/include");
321        if let Some(cuda_root) = &self.cuda_root {
322            cuda_include_dir = cuda_root.join("include");
323            println!(
324                "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
325                cuda_include_dir.display()
326            );
327        };
328        let compute_cap = self.compute_cap.expect("Could not find compute_cap");
329
330        let out_dir = self.out_dir.clone();
331
332        let mut include_paths = self.include_paths.clone();
333        for path in &mut include_paths {
334            println!("cargo:rerun-if-changed={}", path.display());
335            let destination =
336                out_dir.join(path.file_name().expect("include path to have filename"));
337            std::fs::copy(path.clone(), destination).expect("copy include headers");
338            // remove the filename from the path so it's just the directory
339            path.pop();
340        }
341
342        include_paths.sort();
343        include_paths.dedup();
344
345        #[allow(unused)]
346        let mut include_options: Vec<String> = include_paths
347            .into_iter()
348            .map(|s| {
349                "-I".to_string()
350                    + &s.into_os_string()
351                        .into_string()
352                        .expect("include option to be valid string")
353            })
354            .collect::<Vec<_>>();
355        include_options.push(format!("-I{}", cuda_include_dir.display()));
356
357        let ccbin_env = std::env::var("NVCC_CCBIN");
358        let nvcc_binary = if std::path::Path::new("/usr/local/cuda/bin/nvcc").exists() {
359            "/usr/local/cuda/bin/nvcc"
360        } else {
361            "nvcc"
362        };
363        println!("cargo:rerun-if-env-changed=NVCC_CCBIN");
364        for path in &self.watch {
365            println!("cargo:rerun-if-changed={}", path.display());
366        }
367        let children = self.kernel_paths
368            .par_iter()
369            .flat_map(|p| {
370                println!("cargo:rerun-if-changed={}", p.display());
371                let mut output = p.clone();
372                output.set_extension("ptx");
373                let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().expect("kernel to have a filename"));
374
375                let ignore = if let Ok(metadata) = output_filename.metadata() {
376                    let out_modified = metadata.modified().expect("modified to be accessible");
377                    let in_modified = p.metadata().expect("input to have metadata").modified().expect("input metadata to be accessible");
378                    out_modified.duration_since(in_modified).is_ok()
379                } else {
380                    false
381                };
382                if ignore {
383                    None
384                } else {
385                    let mut command = std::process::Command::new(nvcc_binary);
386                    command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
387                        .arg("--ptx")
388                        .args(["--default-stream", "per-thread"])
389                        .args(["--output-directory", &out_dir.display().to_string()])
390                        .args(&self.extra_args)
391                        .args(&include_options);
392                    if let Ok(ccbin_path) = &ccbin_env {
393                        command
394                            .arg("-allow-unsupported-compiler")
395                            .args(["-ccbin", ccbin_path]);
396                    }
397                    command.arg(p);
398                    Some((p, format!("{command:?}"), command.spawn()
399                        .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
400                }
401            })
402            .collect::<Vec<_>>();
403
404        let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{0}/**/*.ptx", out_dir.display()))
405            .expect("valid glob")
406            .map(|p| p.expect("valid path for PTX"))
407            .collect();
408        // We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
409        // some old ones
410        let write = !children.is_empty() || self.kernel_paths.len() < ptx_paths.len();
411        for (kernel_path, command, child) in children {
412            let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
413            assert!(
414                output.status.success(),
415                "nvcc error while compiling {kernel_path:?}:\n\n# CLI {command} \n\n# stdout\n{:#}\n\n# stderr\n{:#}",
416                String::from_utf8_lossy(&output.stdout),
417                String::from_utf8_lossy(&output.stderr)
418            );
419        }
420        Ok(Bindings {
421            write,
422            paths: self.kernel_paths.clone(),
423        })
424    }
425}
426
427impl Bindings {
428    /// Writes a helper rust file that will include the PTX sources as
429    /// `const KERNEL_NAME` making it easier to interact with the PTX sources.
430    pub fn write<P>(&self, out: P) -> Result<(), Error>
431    where
432        P: AsRef<Path>,
433    {
434        if self.write {
435            let mut file = std::fs::File::create(out).expect("Create lib in {out}");
436            for kernel_path in &self.paths {
437                let name = kernel_path
438                    .file_stem()
439                    .expect("kernel to have stem")
440                    .to_str()
441                    .expect("kernel path to be valid");
442                file.write_all(
443                format!(
444                    r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
445                    name.to_uppercase().replace('.', "_"),
446                    name
447                )
448                .as_bytes(),
449                )
450                .expect("write to {out}");
451                file.write_all(&[b'\n']).expect("write to {out}");
452            }
453        }
454        Ok(())
455    }
456}
457
458fn cuda_include_dir() -> Option<PathBuf> {
459    // NOTE: copied from cudarc build.rs.
460    let env_vars = [
461        "CUDA_PATH",
462        "CUDA_ROOT",
463        "CUDA_TOOLKIT_ROOT_DIR",
464        "CUDNN_LIB",
465    ];
466    #[allow(unused)]
467    let env_vars = env_vars
468        .into_iter()
469        .map(std::env::var)
470        .filter_map(Result::ok)
471        .map(Into::<PathBuf>::into);
472
473    let roots = [
474        "/usr",
475        "/usr/local/cuda",
476        "/opt/cuda",
477        "/usr/lib/cuda",
478        "C:/Program Files/NVIDIA GPU Computing Toolkit",
479        "C:/CUDA",
480    ];
481
482    println!("cargo:info={roots:?}");
483
484    #[allow(unused)]
485    let roots = roots.into_iter().map(Into::<PathBuf>::into);
486
487    #[cfg(feature = "ci-check")]
488    let root: PathBuf = "ci".into();
489
490    #[cfg(not(feature = "ci-check"))]
491    env_vars
492        .chain(roots)
493        .find(|path| path.join("include").join("cuda.h").is_file())
494}
495
496fn compute_cap() -> Result<usize, Error> {
497    println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
498
499    // Try to parse compute caps from env
500    let compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
501        println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
502        compute_cap_str
503            .parse::<usize>()
504            .expect("Could not parse code")
505    } else {
506        // Use nvidia-smi to get the current compute cap
507        let out = std::process::Command::new("nvidia-smi")
508                .arg("--query-gpu=compute_cap")
509                .arg("--format=csv")
510                .output()
511                .expect("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.");
512        let out = std::str::from_utf8(&out.stdout).expect("stdout is not a utf8 string");
513        let mut lines = out.lines();
514        assert_eq!(lines.next().expect("missing line in stdout"), "compute_cap");
515        let cap = lines
516            .next()
517            .expect("missing line in stdout")
518            .replace('.', "");
519        let cap = cap.parse::<usize>().expect("cannot parse as int {cap}");
520        println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
521        cap
522    };
523    let nvcc_binary = if std::path::Path::new("/usr/local/cuda/bin/nvcc").exists() {
524        "/usr/local/cuda/bin/nvcc"
525    } else {
526        "nvcc"
527    };
528    // Grab available GPU codes from nvcc and select the highest one
529    let (supported_nvcc_codes, max_nvcc_code) = {
530        let out = std::process::Command::new(nvcc_binary)
531                .arg("--list-gpu-code")
532                .output()
533                .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
534        let out = std::str::from_utf8(&out.stdout).expect("valid utf-8 nvcc output");
535
536        let out = out.lines().collect::<Vec<&str>>();
537        let mut codes = Vec::with_capacity(out.len());
538        for code in out {
539            let code = code.split('_').collect::<Vec<&str>>();
540            if !code.is_empty() && code.contains(&"sm") {
541                if let Ok(num) = code[1].parse::<usize>() {
542                    codes.push(num);
543                }
544            }
545        }
546        codes.sort();
547        let max_nvcc_code = *codes.last().expect("no gpu codes parsed from nvcc");
548        (codes, max_nvcc_code)
549    };
550
551    // Check that nvcc supports the asked compute caps
552    if !supported_nvcc_codes.contains(&compute_cap) {
553        panic!(
554            "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
555        );
556    }
557    if compute_cap > max_nvcc_code {
558        panic!(
559            "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
560        );
561    }
562
563    Ok(compute_cap)
564}