Skip to main content

baracuda_forge/
builder.rs

1//! Main kernel builder implementation.
2
3use crate::compute_cap::{ComputeCapability, GpuArch};
4use crate::dependency::DependencyManager;
5use crate::error::{Error, Result};
6use crate::hash::{hash_args, hash_paths, BuildCache};
7use crate::parallel::ParallelConfig;
8use crate::source::SourceSelector;
9use crate::toolkit::CudaToolkit;
10
11use rayon::prelude::*;
12use std::collections::hash_map::DefaultHasher;
13use std::hash::{Hash, Hasher};
14use std::io::Write;
15use std::path::{Path, PathBuf};
16use std::process::Command;
17use std::sync::atomic::{AtomicBool, Ordering};
18
19/// Main builder for CUDA kernel compilation.
20#[derive(Debug)]
21pub struct KernelBuilder {
22    toolkit: Option<CudaToolkit>,
23    compute_cap: ComputeCapability,
24    sources: SourceSelector,
25    dependencies: DependencyManager,
26    parallel: ParallelConfig,
27    out_dir: PathBuf,
28    extra_args: Vec<String>,
29    incremental: bool,
30    /// Explicit C++ standard for `-std=`. `None` means auto-select from the
31    /// detected toolkit version: c++20 for CUDA >= 12.0, c++17 otherwise.
32    cpp_std: Option<String>,
33}
34
35impl Default for KernelBuilder {
36    fn default() -> Self {
37        let out_dir = std::env::var("OUT_DIR")
38            .map(PathBuf::from)
39            .unwrap_or_else(|_| PathBuf::from("target/debug"));
40
41        Self {
42            toolkit: None,
43            compute_cap: ComputeCapability::default(),
44            sources: SourceSelector::default(),
45            dependencies: DependencyManager::default(),
46            parallel: ParallelConfig::default(),
47            out_dir,
48            extra_args: Vec::new(),
49            incremental: true,
50            cpp_std: None,
51        }
52    }
53}
54
55impl KernelBuilder {
56    /// Create a new kernel builder with default settings.
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    // ========== Source Selection ==========
62
63    /// Add a directory to search for `.cu` files (recursive).
64    ///
65    /// ```no_run
66    /// # use baracuda_forge::KernelBuilder;
67    /// KernelBuilder::new().source_dir("src/kernels");
68    /// ```
69    pub fn source_dir<P: AsRef<Path>>(mut self, dir: P) -> Self {
70        self.sources = self.sources.add_directory(dir);
71        self
72    }
73
74    /// Add specific kernel files.
75    ///
76    /// ```no_run
77    /// # use baracuda_forge::KernelBuilder;
78    /// KernelBuilder::new().source_files(["src/kernels/hello.cu", "src/kernels/world.cu"]);
79    /// ```
80    pub fn source_files<I, P>(mut self, files: I) -> Self
81    where
82        I: IntoIterator<Item = P>,
83        P: AsRef<Path>,
84    {
85        self.sources = self.sources.add_files(files);
86        self
87    }
88
89    /// Add kernel files matching a glob pattern.
90    ///
91    /// ```no_run
92    /// # use baracuda_forge::KernelBuilder;
93    /// KernelBuilder::new().source_glob("src/**/*.cu");
94    /// ```
95    pub fn source_glob(mut self, pattern: &str) -> Self {
96        self.sources = self.sources.add_glob(pattern);
97        self
98    }
99
100    /// Exclude files matching patterns.
101    pub fn exclude(mut self, patterns: &[&str]) -> Self {
102        self.sources = self.sources.exclude(patterns);
103        self
104    }
105
106    /// Add paths to watch for changes (headers, etc.).
107    pub fn watch<I, P>(mut self, paths: I) -> Self
108    where
109        I: IntoIterator<Item = P>,
110        P: AsRef<Path>,
111    {
112        self.sources = self.sources.watch(paths);
113        self
114    }
115
116    // ========== Compute Capability ==========
117
118    /// Set the default compute capability (numeric, auto-selects suffix for sm_90+).
119    pub fn compute_cap(mut self, cap: usize) -> Self {
120        self.compute_cap = self.compute_cap.with_default(cap);
121        self
122    }
123
124    /// Set the default compute capability with explicit arch string (e.g., `"90a"`, `"100a"`).
125    pub fn compute_cap_arch(mut self, arch: &str) -> Self {
126        self.compute_cap = self.compute_cap.with_default_arch(arch);
127        self
128    }
129
130    /// Set compute cap override for specific kernels (numeric).
131    ///
132    /// Pattern can use wildcards: `"sm90_*.cu"`, `"*_hopper.cu"`.
133    ///
134    /// ```no_run
135    /// # use baracuda_forge::KernelBuilder;
136    /// KernelBuilder::new()
137    ///     .source_glob("src/**/*.cu")
138    ///     .with_compute_override("sm90_*.cu", 90)   // Hopper kernels
139    ///     .with_compute_override("sm80_*.cu", 80);  // Ampere kernels
140    /// ```
141    pub fn with_compute_override(mut self, pattern: &str, cap: usize) -> Self {
142        self.compute_cap = self.compute_cap.with_override(pattern, cap);
143        self
144    }
145
146    /// Set compute cap override with explicit arch string.
147    pub fn with_compute_override_arch(mut self, pattern: &str, arch: &str) -> Self {
148        self.compute_cap = self.compute_cap.with_override_arch(pattern, arch);
149        self
150    }
151
152    /// Get the current default compute capability (base number only).
153    pub fn get_compute_cap(&self) -> Option<usize> {
154        self.compute_cap.get_default().ok().map(|a| a.base)
155    }
156
157    /// Set compute capability (mutable reference version).
158    pub fn set_compute_cap(&mut self, cap: usize) {
159        self.compute_cap = ComputeCapability::new().with_default(cap);
160    }
161
162    /// Require explicit compute capability (fail fast if not set).
163    ///
164    /// Use this for Docker builds or CI environments where `nvidia-smi` is
165    /// unavailable. The build fails immediately if `CUDA_COMPUTE_CAP` is not
166    /// set and no compute capability was explicitly configured.
167    ///
168    /// ```no_run
169    /// # use baracuda_forge::KernelBuilder;
170    /// # fn build() -> Result<(), baracuda_forge::Error> {
171    /// // In a Docker build, fail at build time if CUDA_COMPUTE_CAP wasn't
172    /// // baked into the image:
173    /// KernelBuilder::new()
174    ///     .require_explicit_compute_cap()?
175    ///     .source_dir("src/kernels")
176    ///     .build_lib("libkernels.a")?;
177    /// # Ok(()) }
178    /// ```
179    pub fn require_explicit_compute_cap(self) -> Result<Self> {
180        if self.compute_cap.get_default().is_ok() {
181            return Ok(self);
182        }
183
184        if std::env::var("CUDA_COMPUTE_CAP").is_ok() {
185            return Ok(self);
186        }
187
188        Err(Error::ComputeCapDetectionFailed(
189            "Explicit compute capability required but not set. \
190            Either call .compute_cap(N) on the builder or set CUDA_COMPUTE_CAP environment variable. \
191            This is required for Docker builds where nvidia-smi is unavailable.".to_string()
192        ))
193    }
194
195    // ========== External Dependencies ==========
196
197    /// Add CUTLASS dependency.
198    ///
199    /// `commit` pins a specific CUTLASS commit hash. Pass `None` to use the
200    /// built-in default. When the consuming crate also depends on
201    /// `baracuda-cutlass-sys`, that crate's pinned version wins automatically
202    /// via cargo's `links` mechanism — forge then skips its own git fetch.
203    ///
204    /// ```no_run
205    /// # use baracuda_forge::KernelBuilder;
206    /// # fn build() -> Result<(), baracuda_forge::Error> {
207    /// KernelBuilder::new()
208    ///     .source_dir("src/kernels")
209    ///     .with_cutlass(None)
210    ///     .arg("-DUSE_CUTLASS")
211    ///     .build_lib("libkernels.a")?;
212    /// # Ok(()) }
213    /// ```
214    pub fn with_cutlass(mut self, commit: Option<&str>) -> Self {
215        self.dependencies = self.dependencies.with_cutlass(commit);
216        self
217    }
218
219    /// Add a custom git dependency.
220    ///
221    /// If `recurse_submodules` is false, clone/fetch adds `--no-recurse-submodules`.
222    pub fn with_git_dependency(
223        mut self,
224        name: &str,
225        repo: &str,
226        commit: &str,
227        include_paths: Vec<&str>,
228        extra_paths: Vec<&str>,
229        recurse_submodules: bool,
230    ) -> Self {
231        self.dependencies = self.dependencies.with_git_dependency(
232            name,
233            repo,
234            commit,
235            include_paths,
236            extra_paths,
237            recurse_submodules,
238        );
239        self
240    }
241
242    /// Fetch a configured git dependency and return its checkout root.
243    pub fn fetch_git_dependency(&self, name: &str) -> Result<PathBuf> {
244        self.dependencies.fetch_dependency(name, &self.out_dir)
245    }
246
247    /// Add a local include path.
248    pub fn include_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
249        self.dependencies = self.dependencies.with_local_include(path);
250        self
251    }
252
253    // ========== Parallel Configuration ==========
254
255    /// Set the percentage of available threads to use (0.0 - 1.0).
256    pub fn thread_percentage(mut self, percentage: f32) -> Self {
257        self.parallel = self.parallel.with_percentage(percentage);
258        self
259    }
260
261    /// Set the maximum number of threads.
262    pub fn max_threads(mut self, max: usize) -> Self {
263        self.parallel = self.parallel.with_max_threads(max);
264        self
265    }
266
267    /// Set patterns for files that should use nvcc's `--threads=N` flag.
268    pub fn nvcc_thread_patterns<S: AsRef<str>>(
269        mut self,
270        patterns: &[S],
271        num_nvcc_threads: usize,
272    ) -> Self {
273        self.parallel = self
274            .parallel
275            .with_nvcc_thread_patterns(patterns, num_nvcc_threads);
276        self
277    }
278
279    // ========== Build Configuration ==========
280
281    /// Set the output directory.
282    pub fn out_dir<P: Into<PathBuf>>(mut self, dir: P) -> Self {
283        self.out_dir = dir.into();
284        self
285    }
286
287    /// Add an extra nvcc argument.
288    pub fn arg(mut self, arg: &str) -> Self {
289        self.extra_args.push(arg.to_string());
290        self
291    }
292
293    /// Add multiple extra nvcc arguments.
294    pub fn args<I, S>(mut self, args: I) -> Self
295    where
296        I: IntoIterator<Item = S>,
297        S: AsRef<str>,
298    {
299        for arg in args {
300            self.extra_args.push(arg.as_ref().to_string());
301        }
302        self
303    }
304
305    /// Disable incremental builds.
306    pub fn no_incremental(mut self) -> Self {
307        self.incremental = false;
308        self
309    }
310
311    /// Set explicit CUDA toolkit path.
312    pub fn cuda_root<P: AsRef<Path>>(mut self, path: P) -> Self {
313        if let Ok(toolkit) = CudaToolkit::from_nvcc_path(path.as_ref().join("bin").join("nvcc")) {
314            self.toolkit = Some(toolkit);
315        }
316        self
317    }
318
319    /// Set the C++ standard passed to nvcc as `-std=<standard>`.
320    ///
321    /// Pass values like `"c++17"`, `"c++20"`. When unset (the default), the
322    /// builder selects automatically from the detected toolkit version:
323    /// `c++20` for CUDA >= 12.0, `c++17` otherwise.
324    ///
325    /// If your `extra_args` already contains a `-std=` argument, this method's
326    /// value is ignored (your explicit `-std=` wins).
327    ///
328    /// ```no_run
329    /// # use baracuda_forge::KernelBuilder;
330    /// // Force c++17 even on CUDA 12+, e.g. for code that must compile
331    /// // against both 11.x and 12.x toolkits:
332    /// KernelBuilder::new().cpp_std("c++17");
333    /// ```
334    pub fn cpp_std(mut self, standard: &str) -> Self {
335        self.cpp_std = Some(standard.to_string());
336        self
337    }
338
339    // ========== Build Methods ==========
340
341    /// Build a static library from all kernel sources.
342    ///
343    /// `out_file` is typically `format!("{}/libkernels.a", env!("OUT_DIR"))`.
344    /// Pair with `cargo:rustc-link-search` and `cargo:rustc-link-lib` to wire
345    /// the library into the resulting Rust binary.
346    ///
347    /// ```no_run
348    /// # use baracuda_forge::KernelBuilder;
349    /// let out_dir = std::env::var("OUT_DIR").unwrap();
350    /// KernelBuilder::new()
351    ///     .source_dir("src/kernels")
352    ///     .arg("-O3")
353    ///     .build_lib(format!("{out_dir}/libkernels.a"))
354    ///     .unwrap();
355    /// println!("cargo:rustc-link-search={out_dir}");
356    /// println!("cargo:rustc-link-lib=kernels");
357    /// ```
358    pub fn build_lib<P: Into<PathBuf>>(&self, out_file: P) -> Result<()> {
359        let out_file = out_file.into();
360
361        let toolkit = match &self.toolkit {
362            Some(t) => t.clone(),
363            None => CudaToolkit::detect()?,
364        };
365
366        let _ = self.parallel.init_thread_pool();
367
368        println!(
369            "cargo:warning=Using {} threads for compilation",
370            self.parallel.thread_count()
371        );
372
373        std::fs::create_dir_all(&self.out_dir)?;
374
375        let kernel_files = self.sources.resolve()?;
376        if kernel_files.is_empty() {
377            println!("cargo:warning=No kernel files found");
378            return Ok(());
379        }
380
381        for file in &kernel_files {
382            println!("cargo:rerun-if-changed={}", file.display());
383        }
384        for path in self.sources.watch_paths() {
385            println!("cargo:rerun-if-changed={}", path.display());
386        }
387        println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
388        println!("cargo:rerun-if-env-changed=NVCC");
389        println!("cargo:rerun-if-env-changed=NVCC_CCBIN");
390
391        let dep_args = self.dependencies.fetch_all(&self.out_dir)?;
392
393        let mut cache = if self.incremental {
394            BuildCache::load(&self.out_dir)
395        } else {
396            BuildCache::default()
397        };
398
399        let cpp_std_arg = self.resolve_cpp_std_arg(&toolkit);
400
401        let mut all_args = Vec::new();
402        if let Some(std_arg) = &cpp_std_arg {
403            all_args.push(std_arg.clone());
404        }
405        all_args.extend(self.extra_args.iter().cloned());
406        all_args.extend(dep_args.clone());
407        let args_hash = hash_args(&all_args);
408
409        let watch_hash = hash_paths(self.sources.watch_paths());
410
411        let mut compile_jobs: Vec<(PathBuf, PathBuf, GpuArch)> = Vec::new();
412        let mut all_obj_files: Vec<PathBuf> = Vec::new();
413
414        for kernel_file in &kernel_files {
415            let filename = kernel_file
416                .file_name()
417                .and_then(|n| n.to_str())
418                .unwrap_or("");
419            let gpu_arch = self.compute_cap.get_for_file(filename)?;
420
421            let obj_file = self.object_file_path(kernel_file);
422            all_obj_files.push(obj_file.clone());
423
424            if self.incremental
425                && !cache.needs_rebuild(
426                    kernel_file,
427                    &obj_file,
428                    &gpu_arch.to_nvcc_arch(),
429                    &args_hash,
430                    &watch_hash,
431                )
432            {
433                continue;
434            }
435
436            compile_jobs.push((kernel_file.clone(), obj_file, gpu_arch));
437        }
438
439        if compile_jobs.is_empty() && out_file.exists() {
440            println!("cargo:warning=All library kernels up-to-date, skipping compilation");
441            return Ok(());
442        }
443
444        println!(
445            "cargo:warning=Compiling {} of {} kernels",
446            compile_jobs.len(),
447            kernel_files.len()
448        );
449
450        let target = std::env::var("TARGET").ok();
451        let is_msvc = target.as_ref().is_some_and(|t| t.contains("msvc"));
452        let ccbin_env = std::env::var("NVCC_CCBIN").ok();
453        let nvcc_threads = self.parallel.nvcc_threads();
454
455        let had_error = AtomicBool::new(false);
456
457        compile_jobs.par_iter().try_for_each(
458            |(kernel_file, obj_file, gpu_arch)| -> Result<()> {
459                if had_error.load(Ordering::Relaxed) {
460                    return Ok(());
461                }
462
463                let gencode_arg = gpu_arch.to_gencode_arg();
464
465                let mut command = Command::new(&toolkit.nvcc_path);
466                command
467                    .arg(&gencode_arg)
468                    .arg("-c")
469                    .arg("-o")
470                    .arg(obj_file)
471                    .args(["--default-stream", "per-thread"]);
472
473                if let Some(std_arg) = &cpp_std_arg {
474                    command.arg(std_arg);
475                }
476
477                for arg in &self.extra_args {
478                    command.arg(arg);
479                }
480
481                for arg in &dep_args {
482                    command.arg(arg);
483                }
484
485                if self.dependencies.has_cutlass() {
486                    command.arg("-DUSE_CUTLASS");
487                }
488
489                if let Some(ccbin) = &ccbin_env {
490                    command
491                        .arg("-allow-unsupported-compiler")
492                        .args(["-ccbin", ccbin]);
493                }
494
495                if !is_msvc {
496                    command.arg("-Xcompiler").arg("-fPIC");
497                } else {
498                    command.arg("-D_USE_MATH_DEFINES");
499                    msvc_cccl_args(&mut command);
500                }
501
502                if let Some(threads) = nvcc_threads {
503                    let filename = kernel_file.to_string_lossy();
504                    if self.parallel.should_use_nvcc_threads(&filename) {
505                        command.arg(format!("--threads={}", threads));
506                    }
507                }
508
509                command.arg(kernel_file);
510
511                let output = command
512                    .spawn()
513                    .map_err(|e| Error::NvccNotFound(format!("Failed to spawn nvcc: {}", e)))?
514                    .wait_with_output()
515                    .map_err(|e| Error::CompilationFailed {
516                        path: kernel_file.clone(),
517                        message: e.to_string(),
518                    })?;
519
520                if !output.status.success() {
521                    had_error.store(true, Ordering::Relaxed);
522                    return Err(Error::CompilationFailed {
523                        path: kernel_file.clone(),
524                        message: format!(
525                            "nvcc error:\n{}\n{}",
526                            String::from_utf8_lossy(&output.stdout),
527                            String::from_utf8_lossy(&output.stderr)
528                        ),
529                    });
530                }
531
532                Ok(())
533            },
534        )?;
535
536        if self.incremental {
537            for (kernel_file, obj_file, gpu_arch) in &compile_jobs {
538                cache.update(
539                    kernel_file,
540                    obj_file,
541                    &gpu_arch.to_nvcc_arch(),
542                    &args_hash,
543                    &watch_hash,
544                )?;
545            }
546            cache.save(&self.out_dir)?;
547        }
548
549        // Linking with many objects can exceed Windows' 32 KiB
550        // command-line limit (one large baracuda-kernels-sys build
551        // pushed past this with 217 .cu files, and `nvcc --lib`
552        // doesn't accept response files — it errors with
553        // `Don't know what to do with '@file'`).
554        //
555        // On MSVC hosts we work around this by invoking the MSVC
556        // archiver (`lib.exe`) directly with the object list passed
557        // via a response file (`@file`), which `lib.exe` natively
558        // supports for arguments of arbitrary length. nvcc's `--lib`
559        // would have shelled out to `lib.exe` anyway, so this
560        // preserves the same archive format. We discover `lib.exe`
561        // by querying `vswhere.exe` for the most recent MSVC install
562        // (matching the layout the Rust MSVC toolchain target uses).
563        //
564        // On non-MSVC hosts argv limits are much higher (~2 MiB on
565        // Linux) so we keep the simpler nvcc --lib path there.
566        if is_msvc {
567            archive_with_msvc_lib(&out_file, &all_obj_files, &self.out_dir)?;
568        } else {
569            let mut command = Command::new(&toolkit.nvcc_path);
570            command
571                .arg("--lib")
572                .arg("-o")
573                .arg(&out_file)
574                .args(&all_obj_files);
575
576            let output = command
577                .spawn()
578                .map_err(|e| Error::NvccNotFound(format!("Failed to spawn nvcc for linking: {}", e)))?
579                .wait_with_output()
580                .map_err(|e| Error::LinkingFailed(e.to_string()))?;
581
582            if !output.status.success() {
583                return Err(Error::LinkingFailed(format!(
584                    "nvcc linking error:\n{}\n{}",
585                    String::from_utf8_lossy(&output.stdout),
586                    String::from_utf8_lossy(&output.stderr)
587                )));
588            }
589        }
590
591        Ok(())
592    }
593
594    /// Build PTX files from all kernel sources.
595    ///
596    /// Each `.cu` source produces a `<stem>.ptx` text file in the configured
597    /// `out_dir`. The returned [`PtxOutput`] can write a Rust source file
598    /// that exposes each PTX as a `pub const &str` for runtime loading via
599    /// `baracuda-driver`'s `Module::load_ptx`.
600    ///
601    /// ```no_run
602    /// # use baracuda_forge::KernelBuilder;
603    /// # fn build() -> Result<(), baracuda_forge::Error> {
604    /// let output = KernelBuilder::new()
605    ///     .source_glob("src/**/*.cu")
606    ///     .build_ptx()?;
607    /// output.write("src/kernels.rs")?;
608    /// # Ok(()) }
609    /// ```
610    pub fn build_ptx(&self) -> Result<PtxOutput> {
611        let toolkit = match &self.toolkit {
612            Some(t) => t.clone(),
613            None => CudaToolkit::detect()?,
614        };
615
616        let _ = self.parallel.init_thread_pool();
617        std::fs::create_dir_all(&self.out_dir)?;
618
619        let kernel_files = self.sources.resolve()?;
620
621        println!(
622            "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
623            toolkit.include_dir.display()
624        );
625
626        for file in &kernel_files {
627            println!("cargo:rerun-if-changed={}", file.display());
628        }
629        for path in self.sources.watch_paths() {
630            println!("cargo:rerun-if-changed={}", path.display());
631        }
632        println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
633        println!("cargo:rerun-if-env-changed=NVCC_CCBIN");
634
635        let dep_args = self.dependencies.fetch_all(&self.out_dir)?;
636        let ccbin_env = std::env::var("NVCC_CCBIN").ok();
637        let is_msvc = std::env::var("TARGET").ok().is_some_and(|t| t.contains("msvc"));
638        let nvcc_threads = self.parallel.nvcc_threads();
639        let watch_hash = hash_paths(self.sources.watch_paths());
640        let mut cache = BuildCache::load(&self.out_dir);
641
642        let cpp_std_arg = self.resolve_cpp_std_arg(&toolkit);
643
644        let mut all_args = Vec::new();
645        if let Some(std_arg) = &cpp_std_arg {
646            all_args.push(std_arg.clone());
647        }
648        all_args.extend(self.extra_args.iter().cloned());
649        all_args.extend(dep_args.clone());
650        let args_hash = hash_args(&all_args);
651
652        let mut compile_jobs = Vec::new();
653        for kernel_file in &kernel_files {
654            let filename = kernel_file
655                .file_name()
656                .and_then(|n| n.to_str())
657                .unwrap_or("");
658            let gpu_arch = self.compute_cap.get_for_file(filename)?;
659
660            let output_file = self
661                .out_dir
662                .join(kernel_file.with_extension("ptx").file_name().unwrap());
663
664            if self.incremental
665                && !cache.needs_rebuild(
666                    kernel_file,
667                    &output_file,
668                    &gpu_arch.to_nvcc_arch(),
669                    &args_hash,
670                    &watch_hash,
671                )
672            {
673                continue;
674            }
675
676            compile_jobs.push((kernel_file, output_file, gpu_arch));
677        }
678
679        if compile_jobs.is_empty() {
680            println!("cargo:warning=All PTX kernels up-to-date, skipping compilation");
681            return Ok(PtxOutput {
682                paths: kernel_files,
683                out_dir: self.out_dir.clone(),
684            });
685        }
686
687        println!(
688            "cargo:warning=Compiling {} of {} PTX kernels",
689            compile_jobs.len(),
690            kernel_files.len()
691        );
692
693        compile_jobs.par_iter().try_for_each(
694            |(kernel_file, _output_file, gpu_arch)| -> Result<()> {
695                let gencode_arg = gpu_arch.to_gencode_arg();
696
697                let mut command = Command::new(&toolkit.nvcc_path);
698                command
699                    .arg(&gencode_arg)
700                    .arg("--ptx")
701                    .args(["--default-stream", "per-thread"])
702                    .args(["--output-directory", &self.out_dir.to_string_lossy()]);
703
704                if let Some(std_arg) = &cpp_std_arg {
705                    command.arg(std_arg);
706                }
707
708                for arg in &self.extra_args {
709                    command.arg(arg);
710                }
711                for arg in &dep_args {
712                    command.arg(arg);
713                }
714                if let Some(ccbin) = &ccbin_env {
715                    command
716                        .arg("-allow-unsupported-compiler")
717                        .args(["-ccbin", ccbin]);
718                }
719
720                if is_msvc {
721                    msvc_cccl_args(&mut command);
722                }
723
724                if let Some(threads) = nvcc_threads {
725                    let file_path = kernel_file.to_string_lossy();
726                    if self.parallel.should_use_nvcc_threads(&file_path) {
727                        command.arg(format!("--threads={}", threads));
728                    }
729                }
730
731                command.arg(kernel_file);
732
733                let output = command
734                    .spawn()
735                    .map_err(|e| Error::NvccNotFound(format!("Failed to spawn nvcc: {}", e)))?
736                    .wait_with_output()
737                    .map_err(|e| Error::CompilationFailed {
738                        path: kernel_file.to_path_buf(),
739                        message: e.to_string(),
740                    })?;
741
742                if !output.status.success() {
743                    return Err(Error::CompilationFailed {
744                        path: kernel_file.to_path_buf(),
745                        message: format!(
746                            "nvcc error:\n{}\n{}",
747                            String::from_utf8_lossy(&output.stdout),
748                            String::from_utf8_lossy(&output.stderr)
749                        ),
750                    });
751                }
752
753                Ok(())
754            },
755        )?;
756
757        if self.incremental {
758            for kernel_file in &kernel_files {
759                let filename = kernel_file
760                    .file_name()
761                    .and_then(|n| n.to_str())
762                    .unwrap_or("");
763                let gpu_arch = self.compute_cap.get_for_file(filename)?;
764                let output_file = self
765                    .out_dir
766                    .join(kernel_file.with_extension("ptx").file_name().unwrap());
767
768                cache.update(
769                    kernel_file,
770                    &output_file,
771                    &gpu_arch.to_nvcc_arch(),
772                    &args_hash,
773                    &watch_hash,
774                )?;
775            }
776            cache.save(&self.out_dir)?;
777        }
778
779        Ok(PtxOutput {
780            paths: kernel_files,
781            out_dir: self.out_dir.clone(),
782        })
783    }
784
785    /// Resolve the `-std=<standard>` argument: explicit override, then auto from
786    /// toolkit version, then `None` if neither is available *and* the user
787    /// already supplied a `-std=` in `extra_args`.
788    fn resolve_cpp_std_arg(&self, toolkit: &CudaToolkit) -> Option<String> {
789        if self.extra_args.iter().any(|a| a.starts_with("-std=")) {
790            return None;
791        }
792
793        if let Some(s) = &self.cpp_std {
794            return Some(format!("-std={s}"));
795        }
796
797        let standard = match toolkit.version {
798            Some((major, _)) if major >= 12 => "c++20",
799            _ => "c++17",
800        };
801        Some(format!("-std={standard}"))
802    }
803
804    fn object_file_path(&self, kernel_file: &Path) -> PathBuf {
805        let mut hasher = DefaultHasher::new();
806        kernel_file.display().to_string().hash(&mut hasher);
807        let hash = hasher.finish();
808
809        let stem = kernel_file
810            .file_stem()
811            .and_then(|s| s.to_str())
812            .unwrap_or("kernel");
813
814        self.out_dir.join(format!("{}-{:x}.o", stem, hash))
815    }
816}
817
818/// Append the nvcc flags CCCL-heavy translation units require on MSVC hosts.
819///
820/// CUDA 12.5+ (and every CUDA 13.x, including the 13.3 the Fuel team hit)
821/// bundles a CCCL whose `<cuda/std/__cccl/preprocessor.h>` opens with a hard
822/// `#error` (MSVC `fatal error C1189`) when the host `cl.exe` is driving its
823/// legacy *traditional* preprocessor:
824///
825/// ```text
826/// #if defined(_MSC_VER) && !defined(__clang__)
827/// #  if (!defined(_MSVC_TRADITIONAL) || _MSVC_TRADITIONAL == 1) \
828///     && !defined(CCCL_IGNORE_MSVC_TRADITIONAL_PREPROCESSOR_WARNING)
829/// #    error MSVC/cl.exe with traditional preprocessor is used ...
830/// ```
831///
832/// CUTLASS (here) and cub/thrust (in `baracuda-kernels-sys`) pull this header
833/// in transitively, so *every* CCCL-touching `.cu` fails to compile. We pass
834/// `-Xcompiler /Zc:preprocessor`, which flips `cl.exe` to its standard-
835/// conforming preprocessor (defining `_MSVC_TRADITIONAL=0`). That is both the
836/// fix CCCL's own message recommends and the one CUTLASS's variadic-macro-heavy
837/// headers actually need — unlike defining
838/// `CCCL_IGNORE_MSVC_TRADITIONAL_PREPROCESSOR_WARNING`, which only silences the
839/// guard while leaving the non-conformant preprocessor (and its latent macro-
840/// expansion bugs) in place. Verified against nvcc 13.3 + MSVC 19.5x: the flag
841/// clears the error in both the host and device front-end passes.
842///
843/// No-op gate: callers invoke this only on MSVC targets. `/Zc:preprocessor`
844/// needs VS 2019 16.5+, which every CUDA-12/13-supported MSVC comfortably
845/// exceeds, so it is always safe to pass here.
846fn msvc_cccl_args(command: &mut Command) {
847    command.arg("-Xcompiler").arg("/Zc:preprocessor");
848}
849
850/// Locate the MSVC archiver (`lib.exe`) at build time.
851///
852/// We prefer the host-architecture build that matches `target_arch`
853/// (since that's what nvcc would have picked). The discovery walks
854/// the same paths cc-rs / cargo's MSVC setup walks, in priority order:
855///
856/// 1. `BARACUDA_FORGE_LIB_EXE` env var (escape hatch).
857/// 2. `vswhere.exe` (the canonical MS-supplied locator).
858/// 3. PATH (`lib.exe` may already be on PATH in dev-shell builds).
859///
860/// Returns the absolute path on success; an `Err` describing what
861/// was tried otherwise.
862fn find_msvc_lib_exe() -> Result<PathBuf> {
863    if let Ok(p) = std::env::var("BARACUDA_FORGE_LIB_EXE") {
864        let pb = PathBuf::from(&p);
865        if pb.exists() {
866            return Ok(pb);
867        }
868        return Err(Error::LinkingFailed(format!(
869            "BARACUDA_FORGE_LIB_EXE points to a non-existent file: {}",
870            p
871        )));
872    }
873
874    // Probe vswhere at its fixed install location.
875    let vswhere = PathBuf::from(
876        r"C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe",
877    );
878    if vswhere.exists() {
879        let output = Command::new(&vswhere)
880            .args([
881                "-latest",
882                "-products",
883                "*",
884                "-requires",
885                "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
886                "-property",
887                "installationPath",
888            ])
889            .output()
890            .ok();
891        if let Some(out) = output {
892            if out.status.success() {
893                let install_path = String::from_utf8_lossy(&out.stdout).trim().to_string();
894                if !install_path.is_empty() {
895                    let install = PathBuf::from(&install_path);
896                    let version_file = install.join(
897                        r"VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt",
898                    );
899                    if let Ok(ver) = std::fs::read_to_string(&version_file) {
900                        let ver = ver.trim();
901                        // Match host arch — assume x64 (current target arch on
902                        // every CUDA-supported Windows host today).
903                        let lib = install
904                            .join("VC")
905                            .join("Tools")
906                            .join("MSVC")
907                            .join(ver)
908                            .join("bin")
909                            .join("Hostx64")
910                            .join("x64")
911                            .join("lib.exe");
912                        if lib.exists() {
913                            return Ok(lib);
914                        }
915                    }
916                }
917            }
918        }
919    }
920
921    // Fall back to PATH lookup.
922    if Command::new("lib.exe").arg("/?").output().is_ok() {
923        return Ok(PathBuf::from("lib.exe"));
924    }
925
926    Err(Error::LinkingFailed(
927        "could not locate MSVC `lib.exe`. Set `BARACUDA_FORGE_LIB_EXE` to its \
928         full path, or run from a Visual Studio Developer Command Prompt so \
929         `lib.exe` is on PATH."
930            .to_string(),
931    ))
932}
933
934/// Invoke `lib.exe` to assemble a static archive from `obj_files`,
935/// passing the object list via a response file (`@file`) to avoid
936/// Windows' command-line-length limit.
937fn archive_with_msvc_lib(
938    out_file: &Path,
939    obj_files: &[PathBuf],
940    out_dir: &Path,
941) -> Result<()> {
942    let lib_exe = find_msvc_lib_exe()?;
943
944    let response_file = out_dir.join(".lib_response.txt");
945    {
946        let mut f = std::fs::File::create(&response_file).map_err(|e| {
947            Error::LinkingFailed(format!(
948                "failed to create lib response file {}: {}",
949                response_file.display(),
950                e
951            ))
952        })?;
953        for obj in obj_files {
954            // Each path on its own line, surrounded by quotes so any
955            // embedded spaces survive lib.exe's response-file parser.
956            writeln!(f, "\"{}\"", obj.display()).map_err(|e| {
957                Error::LinkingFailed(format!(
958                    "failed to write lib response file: {}",
959                    e
960                ))
961            })?;
962        }
963    }
964
965    let mut command = Command::new(&lib_exe);
966    command
967        .arg("/NOLOGO")
968        .arg(format!("/OUT:{}", out_file.display()))
969        .arg(format!("@{}", response_file.display()));
970
971    let output = command
972        .spawn()
973        .map_err(|e| {
974            Error::NvccNotFound(format!(
975                "Failed to spawn {} for linking: {}",
976                lib_exe.display(),
977                e
978            ))
979        })?
980        .wait_with_output()
981        .map_err(|e| Error::LinkingFailed(e.to_string()))?;
982
983    if !output.status.success() {
984        return Err(Error::LinkingFailed(format!(
985            "{} archiving error:\n{}\n{}",
986            lib_exe.display(),
987            String::from_utf8_lossy(&output.stdout),
988            String::from_utf8_lossy(&output.stderr)
989        )));
990    }
991
992    Ok(())
993}
994
995/// Output from PTX compilation.
996pub struct PtxOutput {
997    paths: Vec<PathBuf>,
998    #[allow(dead_code)]
999    out_dir: PathBuf,
1000}
1001
1002impl PtxOutput {
1003    /// Write a Rust source file with `const` declarations for each PTX file.
1004    pub fn write<P: AsRef<Path>>(&self, out: P) -> Result<()> {
1005        let mut file = std::fs::File::create(out.as_ref())?;
1006
1007        for kernel_path in &self.paths {
1008            let name = kernel_path
1009                .file_stem()
1010                .and_then(|s| s.to_str())
1011                .unwrap_or("KERNEL");
1012
1013            writeln!(
1014                file,
1015                r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
1016                name.to_uppercase().replace(['.', '-'], "_"),
1017                name
1018            )?;
1019        }
1020
1021        Ok(())
1022    }
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027    use super::*;
1028
1029    fn toolkit_with_version(version: Option<(u32, u32)>) -> CudaToolkit {
1030        CudaToolkit {
1031            nvcc_path: PathBuf::from("/dev/null"),
1032            include_dir: PathBuf::from("/dev/null"),
1033            lib_dir: PathBuf::from("/dev/null"),
1034            version,
1035        }
1036    }
1037
1038    #[test]
1039    fn cpp_std_auto_selects_cpp20_for_cuda_12() {
1040        let b = KernelBuilder::new();
1041        let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((12, 6))));
1042        assert_eq!(arg.as_deref(), Some("-std=c++20"));
1043    }
1044
1045    #[test]
1046    fn cpp_std_auto_selects_cpp17_for_cuda_11() {
1047        let b = KernelBuilder::new();
1048        let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((11, 8))));
1049        assert_eq!(arg.as_deref(), Some("-std=c++17"));
1050    }
1051
1052    #[test]
1053    fn cpp_std_auto_selects_cpp17_when_version_unknown() {
1054        let b = KernelBuilder::new();
1055        let arg = b.resolve_cpp_std_arg(&toolkit_with_version(None));
1056        assert_eq!(arg.as_deref(), Some("-std=c++17"));
1057    }
1058
1059    #[test]
1060    fn cpp_std_explicit_override_wins() {
1061        let b = KernelBuilder::new().cpp_std("c++17");
1062        let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((12, 6))));
1063        assert_eq!(arg.as_deref(), Some("-std=c++17"));
1064    }
1065
1066    #[test]
1067    fn cpp_std_extra_arg_disables_auto() {
1068        let b = KernelBuilder::new().arg("-std=c++14");
1069        let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((12, 6))));
1070        assert_eq!(arg, None);
1071    }
1072}