1use crate::compute_cap::{ComputeCapability, GpuArch};
4use crate::dependency::DependencyManager;
5use crate::error::{Error, Result};
6use crate::hash::{hash_args, hash_paths, BuildCache};
7use crate::parallel::ParallelConfig;
8use crate::source::SourceSelector;
9use crate::toolkit::CudaToolkit;
10
11use rayon::prelude::*;
12use std::collections::hash_map::DefaultHasher;
13use std::hash::{Hash, Hasher};
14use std::io::Write;
15use std::path::{Path, PathBuf};
16use std::process::Command;
17use std::sync::atomic::{AtomicBool, Ordering};
18
19#[derive(Debug)]
21pub struct KernelBuilder {
22 toolkit: Option<CudaToolkit>,
23 compute_cap: ComputeCapability,
24 sources: SourceSelector,
25 dependencies: DependencyManager,
26 parallel: ParallelConfig,
27 out_dir: PathBuf,
28 extra_args: Vec<String>,
29 incremental: bool,
30 cpp_std: Option<String>,
33}
34
35impl Default for KernelBuilder {
36 fn default() -> Self {
37 let out_dir = std::env::var("OUT_DIR")
38 .map(PathBuf::from)
39 .unwrap_or_else(|_| PathBuf::from("target/debug"));
40
41 Self {
42 toolkit: None,
43 compute_cap: ComputeCapability::default(),
44 sources: SourceSelector::default(),
45 dependencies: DependencyManager::default(),
46 parallel: ParallelConfig::default(),
47 out_dir,
48 extra_args: Vec::new(),
49 incremental: true,
50 cpp_std: None,
51 }
52 }
53}
54
55impl KernelBuilder {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn source_dir<P: AsRef<Path>>(mut self, dir: P) -> Self {
70 self.sources = self.sources.add_directory(dir);
71 self
72 }
73
74 pub fn source_files<I, P>(mut self, files: I) -> Self
81 where
82 I: IntoIterator<Item = P>,
83 P: AsRef<Path>,
84 {
85 self.sources = self.sources.add_files(files);
86 self
87 }
88
89 pub fn source_glob(mut self, pattern: &str) -> Self {
96 self.sources = self.sources.add_glob(pattern);
97 self
98 }
99
100 pub fn exclude(mut self, patterns: &[&str]) -> Self {
102 self.sources = self.sources.exclude(patterns);
103 self
104 }
105
106 pub fn watch<I, P>(mut self, paths: I) -> Self
108 where
109 I: IntoIterator<Item = P>,
110 P: AsRef<Path>,
111 {
112 self.sources = self.sources.watch(paths);
113 self
114 }
115
116 pub fn compute_cap(mut self, cap: usize) -> Self {
120 self.compute_cap = self.compute_cap.with_default(cap);
121 self
122 }
123
124 pub fn compute_cap_arch(mut self, arch: &str) -> Self {
126 self.compute_cap = self.compute_cap.with_default_arch(arch);
127 self
128 }
129
130 pub fn with_compute_override(mut self, pattern: &str, cap: usize) -> Self {
142 self.compute_cap = self.compute_cap.with_override(pattern, cap);
143 self
144 }
145
146 pub fn with_compute_override_arch(mut self, pattern: &str, arch: &str) -> Self {
148 self.compute_cap = self.compute_cap.with_override_arch(pattern, arch);
149 self
150 }
151
152 pub fn get_compute_cap(&self) -> Option<usize> {
154 self.compute_cap.get_default().ok().map(|a| a.base)
155 }
156
157 pub fn set_compute_cap(&mut self, cap: usize) {
159 self.compute_cap = ComputeCapability::new().with_default(cap);
160 }
161
162 pub fn require_explicit_compute_cap(self) -> Result<Self> {
180 if self.compute_cap.get_default().is_ok() {
181 return Ok(self);
182 }
183
184 if std::env::var("CUDA_COMPUTE_CAP").is_ok() {
185 return Ok(self);
186 }
187
188 Err(Error::ComputeCapDetectionFailed(
189 "Explicit compute capability required but not set. \
190 Either call .compute_cap(N) on the builder or set CUDA_COMPUTE_CAP environment variable. \
191 This is required for Docker builds where nvidia-smi is unavailable.".to_string()
192 ))
193 }
194
195 pub fn with_cutlass(mut self, commit: Option<&str>) -> Self {
215 self.dependencies = self.dependencies.with_cutlass(commit);
216 self
217 }
218
219 pub fn with_git_dependency(
223 mut self,
224 name: &str,
225 repo: &str,
226 commit: &str,
227 include_paths: Vec<&str>,
228 extra_paths: Vec<&str>,
229 recurse_submodules: bool,
230 ) -> Self {
231 self.dependencies = self.dependencies.with_git_dependency(
232 name,
233 repo,
234 commit,
235 include_paths,
236 extra_paths,
237 recurse_submodules,
238 );
239 self
240 }
241
242 pub fn fetch_git_dependency(&self, name: &str) -> Result<PathBuf> {
244 self.dependencies.fetch_dependency(name, &self.out_dir)
245 }
246
247 pub fn include_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
249 self.dependencies = self.dependencies.with_local_include(path);
250 self
251 }
252
253 pub fn thread_percentage(mut self, percentage: f32) -> Self {
257 self.parallel = self.parallel.with_percentage(percentage);
258 self
259 }
260
261 pub fn max_threads(mut self, max: usize) -> Self {
263 self.parallel = self.parallel.with_max_threads(max);
264 self
265 }
266
267 pub fn nvcc_thread_patterns<S: AsRef<str>>(
269 mut self,
270 patterns: &[S],
271 num_nvcc_threads: usize,
272 ) -> Self {
273 self.parallel = self
274 .parallel
275 .with_nvcc_thread_patterns(patterns, num_nvcc_threads);
276 self
277 }
278
279 pub fn out_dir<P: Into<PathBuf>>(mut self, dir: P) -> Self {
283 self.out_dir = dir.into();
284 self
285 }
286
287 pub fn arg(mut self, arg: &str) -> Self {
289 self.extra_args.push(arg.to_string());
290 self
291 }
292
293 pub fn args<I, S>(mut self, args: I) -> Self
295 where
296 I: IntoIterator<Item = S>,
297 S: AsRef<str>,
298 {
299 for arg in args {
300 self.extra_args.push(arg.as_ref().to_string());
301 }
302 self
303 }
304
305 pub fn no_incremental(mut self) -> Self {
307 self.incremental = false;
308 self
309 }
310
311 pub fn cuda_root<P: AsRef<Path>>(mut self, path: P) -> Self {
313 if let Ok(toolkit) = CudaToolkit::from_nvcc_path(path.as_ref().join("bin").join("nvcc")) {
314 self.toolkit = Some(toolkit);
315 }
316 self
317 }
318
319 pub fn cpp_std(mut self, standard: &str) -> Self {
335 self.cpp_std = Some(standard.to_string());
336 self
337 }
338
339 pub fn build_lib<P: Into<PathBuf>>(&self, out_file: P) -> Result<()> {
359 let out_file = out_file.into();
360
361 let toolkit = match &self.toolkit {
362 Some(t) => t.clone(),
363 None => CudaToolkit::detect()?,
364 };
365
366 let _ = self.parallel.init_thread_pool();
367
368 println!(
369 "cargo:warning=Using {} threads for compilation",
370 self.parallel.thread_count()
371 );
372
373 std::fs::create_dir_all(&self.out_dir)?;
374
375 let kernel_files = self.sources.resolve()?;
376 if kernel_files.is_empty() {
377 println!("cargo:warning=No kernel files found");
378 return Ok(());
379 }
380
381 for file in &kernel_files {
382 println!("cargo:rerun-if-changed={}", file.display());
383 }
384 for path in self.sources.watch_paths() {
385 println!("cargo:rerun-if-changed={}", path.display());
386 }
387 println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
388 println!("cargo:rerun-if-env-changed=NVCC");
389 println!("cargo:rerun-if-env-changed=NVCC_CCBIN");
390
391 let dep_args = self.dependencies.fetch_all(&self.out_dir)?;
392
393 let mut cache = if self.incremental {
394 BuildCache::load(&self.out_dir)
395 } else {
396 BuildCache::default()
397 };
398
399 let cpp_std_arg = self.resolve_cpp_std_arg(&toolkit);
400
401 let mut all_args = Vec::new();
402 if let Some(std_arg) = &cpp_std_arg {
403 all_args.push(std_arg.clone());
404 }
405 all_args.extend(self.extra_args.iter().cloned());
406 all_args.extend(dep_args.clone());
407 let args_hash = hash_args(&all_args);
408
409 let watch_hash = hash_paths(self.sources.watch_paths());
410
411 let mut compile_jobs: Vec<(PathBuf, PathBuf, GpuArch)> = Vec::new();
412 let mut all_obj_files: Vec<PathBuf> = Vec::new();
413
414 for kernel_file in &kernel_files {
415 let filename = kernel_file
416 .file_name()
417 .and_then(|n| n.to_str())
418 .unwrap_or("");
419 let gpu_arch = self.compute_cap.get_for_file(filename)?;
420
421 let obj_file = self.object_file_path(kernel_file);
422 all_obj_files.push(obj_file.clone());
423
424 if self.incremental
425 && !cache.needs_rebuild(
426 kernel_file,
427 &obj_file,
428 &gpu_arch.to_nvcc_arch(),
429 &args_hash,
430 &watch_hash,
431 )
432 {
433 continue;
434 }
435
436 compile_jobs.push((kernel_file.clone(), obj_file, gpu_arch));
437 }
438
439 if compile_jobs.is_empty() && out_file.exists() {
440 println!("cargo:warning=All library kernels up-to-date, skipping compilation");
441 return Ok(());
442 }
443
444 println!(
445 "cargo:warning=Compiling {} of {} kernels",
446 compile_jobs.len(),
447 kernel_files.len()
448 );
449
450 let target = std::env::var("TARGET").ok();
451 let is_msvc = target.as_ref().is_some_and(|t| t.contains("msvc"));
452 let ccbin_env = std::env::var("NVCC_CCBIN").ok();
453 let nvcc_threads = self.parallel.nvcc_threads();
454
455 let had_error = AtomicBool::new(false);
456
457 compile_jobs.par_iter().try_for_each(
458 |(kernel_file, obj_file, gpu_arch)| -> Result<()> {
459 if had_error.load(Ordering::Relaxed) {
460 return Ok(());
461 }
462
463 let gencode_arg = gpu_arch.to_gencode_arg();
464
465 let mut command = Command::new(&toolkit.nvcc_path);
466 command
467 .arg(&gencode_arg)
468 .arg("-c")
469 .arg("-o")
470 .arg(obj_file)
471 .args(["--default-stream", "per-thread"]);
472
473 if let Some(std_arg) = &cpp_std_arg {
474 command.arg(std_arg);
475 }
476
477 for arg in &self.extra_args {
478 command.arg(arg);
479 }
480
481 for arg in &dep_args {
482 command.arg(arg);
483 }
484
485 if self.dependencies.has_cutlass() {
486 command.arg("-DUSE_CUTLASS");
487 }
488
489 if let Some(ccbin) = &ccbin_env {
490 command
491 .arg("-allow-unsupported-compiler")
492 .args(["-ccbin", ccbin]);
493 }
494
495 if !is_msvc {
496 command.arg("-Xcompiler").arg("-fPIC");
497 } else {
498 command.arg("-D_USE_MATH_DEFINES");
499 msvc_cccl_args(&mut command);
500 }
501
502 if let Some(threads) = nvcc_threads {
503 let filename = kernel_file.to_string_lossy();
504 if self.parallel.should_use_nvcc_threads(&filename) {
505 command.arg(format!("--threads={}", threads));
506 }
507 }
508
509 command.arg(kernel_file);
510
511 let output = command
512 .spawn()
513 .map_err(|e| Error::NvccNotFound(format!("Failed to spawn nvcc: {}", e)))?
514 .wait_with_output()
515 .map_err(|e| Error::CompilationFailed {
516 path: kernel_file.clone(),
517 message: e.to_string(),
518 })?;
519
520 if !output.status.success() {
521 had_error.store(true, Ordering::Relaxed);
522 return Err(Error::CompilationFailed {
523 path: kernel_file.clone(),
524 message: format!(
525 "nvcc error:\n{}\n{}",
526 String::from_utf8_lossy(&output.stdout),
527 String::from_utf8_lossy(&output.stderr)
528 ),
529 });
530 }
531
532 Ok(())
533 },
534 )?;
535
536 if self.incremental {
537 for (kernel_file, obj_file, gpu_arch) in &compile_jobs {
538 cache.update(
539 kernel_file,
540 obj_file,
541 &gpu_arch.to_nvcc_arch(),
542 &args_hash,
543 &watch_hash,
544 )?;
545 }
546 cache.save(&self.out_dir)?;
547 }
548
549 if is_msvc {
567 archive_with_msvc_lib(&out_file, &all_obj_files, &self.out_dir)?;
568 } else {
569 let mut command = Command::new(&toolkit.nvcc_path);
570 command
571 .arg("--lib")
572 .arg("-o")
573 .arg(&out_file)
574 .args(&all_obj_files);
575
576 let output = command
577 .spawn()
578 .map_err(|e| Error::NvccNotFound(format!("Failed to spawn nvcc for linking: {}", e)))?
579 .wait_with_output()
580 .map_err(|e| Error::LinkingFailed(e.to_string()))?;
581
582 if !output.status.success() {
583 return Err(Error::LinkingFailed(format!(
584 "nvcc linking error:\n{}\n{}",
585 String::from_utf8_lossy(&output.stdout),
586 String::from_utf8_lossy(&output.stderr)
587 )));
588 }
589 }
590
591 Ok(())
592 }
593
594 pub fn build_ptx(&self) -> Result<PtxOutput> {
611 let toolkit = match &self.toolkit {
612 Some(t) => t.clone(),
613 None => CudaToolkit::detect()?,
614 };
615
616 let _ = self.parallel.init_thread_pool();
617 std::fs::create_dir_all(&self.out_dir)?;
618
619 let kernel_files = self.sources.resolve()?;
620
621 println!(
622 "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
623 toolkit.include_dir.display()
624 );
625
626 for file in &kernel_files {
627 println!("cargo:rerun-if-changed={}", file.display());
628 }
629 for path in self.sources.watch_paths() {
630 println!("cargo:rerun-if-changed={}", path.display());
631 }
632 println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
633 println!("cargo:rerun-if-env-changed=NVCC_CCBIN");
634
635 let dep_args = self.dependencies.fetch_all(&self.out_dir)?;
636 let ccbin_env = std::env::var("NVCC_CCBIN").ok();
637 let is_msvc = std::env::var("TARGET").ok().is_some_and(|t| t.contains("msvc"));
638 let nvcc_threads = self.parallel.nvcc_threads();
639 let watch_hash = hash_paths(self.sources.watch_paths());
640 let mut cache = BuildCache::load(&self.out_dir);
641
642 let cpp_std_arg = self.resolve_cpp_std_arg(&toolkit);
643
644 let mut all_args = Vec::new();
645 if let Some(std_arg) = &cpp_std_arg {
646 all_args.push(std_arg.clone());
647 }
648 all_args.extend(self.extra_args.iter().cloned());
649 all_args.extend(dep_args.clone());
650 let args_hash = hash_args(&all_args);
651
652 let mut compile_jobs = Vec::new();
653 for kernel_file in &kernel_files {
654 let filename = kernel_file
655 .file_name()
656 .and_then(|n| n.to_str())
657 .unwrap_or("");
658 let gpu_arch = self.compute_cap.get_for_file(filename)?;
659
660 let output_file = self
661 .out_dir
662 .join(kernel_file.with_extension("ptx").file_name().unwrap());
663
664 if self.incremental
665 && !cache.needs_rebuild(
666 kernel_file,
667 &output_file,
668 &gpu_arch.to_nvcc_arch(),
669 &args_hash,
670 &watch_hash,
671 )
672 {
673 continue;
674 }
675
676 compile_jobs.push((kernel_file, output_file, gpu_arch));
677 }
678
679 if compile_jobs.is_empty() {
680 println!("cargo:warning=All PTX kernels up-to-date, skipping compilation");
681 return Ok(PtxOutput {
682 paths: kernel_files,
683 out_dir: self.out_dir.clone(),
684 });
685 }
686
687 println!(
688 "cargo:warning=Compiling {} of {} PTX kernels",
689 compile_jobs.len(),
690 kernel_files.len()
691 );
692
693 compile_jobs.par_iter().try_for_each(
694 |(kernel_file, _output_file, gpu_arch)| -> Result<()> {
695 let gencode_arg = gpu_arch.to_gencode_arg();
696
697 let mut command = Command::new(&toolkit.nvcc_path);
698 command
699 .arg(&gencode_arg)
700 .arg("--ptx")
701 .args(["--default-stream", "per-thread"])
702 .args(["--output-directory", &self.out_dir.to_string_lossy()]);
703
704 if let Some(std_arg) = &cpp_std_arg {
705 command.arg(std_arg);
706 }
707
708 for arg in &self.extra_args {
709 command.arg(arg);
710 }
711 for arg in &dep_args {
712 command.arg(arg);
713 }
714 if let Some(ccbin) = &ccbin_env {
715 command
716 .arg("-allow-unsupported-compiler")
717 .args(["-ccbin", ccbin]);
718 }
719
720 if is_msvc {
721 msvc_cccl_args(&mut command);
722 }
723
724 if let Some(threads) = nvcc_threads {
725 let file_path = kernel_file.to_string_lossy();
726 if self.parallel.should_use_nvcc_threads(&file_path) {
727 command.arg(format!("--threads={}", threads));
728 }
729 }
730
731 command.arg(kernel_file);
732
733 let output = command
734 .spawn()
735 .map_err(|e| Error::NvccNotFound(format!("Failed to spawn nvcc: {}", e)))?
736 .wait_with_output()
737 .map_err(|e| Error::CompilationFailed {
738 path: kernel_file.to_path_buf(),
739 message: e.to_string(),
740 })?;
741
742 if !output.status.success() {
743 return Err(Error::CompilationFailed {
744 path: kernel_file.to_path_buf(),
745 message: format!(
746 "nvcc error:\n{}\n{}",
747 String::from_utf8_lossy(&output.stdout),
748 String::from_utf8_lossy(&output.stderr)
749 ),
750 });
751 }
752
753 Ok(())
754 },
755 )?;
756
757 if self.incremental {
758 for kernel_file in &kernel_files {
759 let filename = kernel_file
760 .file_name()
761 .and_then(|n| n.to_str())
762 .unwrap_or("");
763 let gpu_arch = self.compute_cap.get_for_file(filename)?;
764 let output_file = self
765 .out_dir
766 .join(kernel_file.with_extension("ptx").file_name().unwrap());
767
768 cache.update(
769 kernel_file,
770 &output_file,
771 &gpu_arch.to_nvcc_arch(),
772 &args_hash,
773 &watch_hash,
774 )?;
775 }
776 cache.save(&self.out_dir)?;
777 }
778
779 Ok(PtxOutput {
780 paths: kernel_files,
781 out_dir: self.out_dir.clone(),
782 })
783 }
784
785 fn resolve_cpp_std_arg(&self, toolkit: &CudaToolkit) -> Option<String> {
789 if self.extra_args.iter().any(|a| a.starts_with("-std=")) {
790 return None;
791 }
792
793 if let Some(s) = &self.cpp_std {
794 return Some(format!("-std={s}"));
795 }
796
797 let standard = match toolkit.version {
798 Some((major, _)) if major >= 12 => "c++20",
799 _ => "c++17",
800 };
801 Some(format!("-std={standard}"))
802 }
803
804 fn object_file_path(&self, kernel_file: &Path) -> PathBuf {
805 let mut hasher = DefaultHasher::new();
806 kernel_file.display().to_string().hash(&mut hasher);
807 let hash = hasher.finish();
808
809 let stem = kernel_file
810 .file_stem()
811 .and_then(|s| s.to_str())
812 .unwrap_or("kernel");
813
814 self.out_dir.join(format!("{}-{:x}.o", stem, hash))
815 }
816}
817
818fn msvc_cccl_args(command: &mut Command) {
847 command.arg("-Xcompiler").arg("/Zc:preprocessor");
848}
849
850fn find_msvc_lib_exe() -> Result<PathBuf> {
863 if let Ok(p) = std::env::var("BARACUDA_FORGE_LIB_EXE") {
864 let pb = PathBuf::from(&p);
865 if pb.exists() {
866 return Ok(pb);
867 }
868 return Err(Error::LinkingFailed(format!(
869 "BARACUDA_FORGE_LIB_EXE points to a non-existent file: {}",
870 p
871 )));
872 }
873
874 let vswhere = PathBuf::from(
876 r"C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe",
877 );
878 if vswhere.exists() {
879 let output = Command::new(&vswhere)
880 .args([
881 "-latest",
882 "-products",
883 "*",
884 "-requires",
885 "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
886 "-property",
887 "installationPath",
888 ])
889 .output()
890 .ok();
891 if let Some(out) = output {
892 if out.status.success() {
893 let install_path = String::from_utf8_lossy(&out.stdout).trim().to_string();
894 if !install_path.is_empty() {
895 let install = PathBuf::from(&install_path);
896 let version_file = install.join(
897 r"VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt",
898 );
899 if let Ok(ver) = std::fs::read_to_string(&version_file) {
900 let ver = ver.trim();
901 let lib = install
904 .join("VC")
905 .join("Tools")
906 .join("MSVC")
907 .join(ver)
908 .join("bin")
909 .join("Hostx64")
910 .join("x64")
911 .join("lib.exe");
912 if lib.exists() {
913 return Ok(lib);
914 }
915 }
916 }
917 }
918 }
919 }
920
921 if Command::new("lib.exe").arg("/?").output().is_ok() {
923 return Ok(PathBuf::from("lib.exe"));
924 }
925
926 Err(Error::LinkingFailed(
927 "could not locate MSVC `lib.exe`. Set `BARACUDA_FORGE_LIB_EXE` to its \
928 full path, or run from a Visual Studio Developer Command Prompt so \
929 `lib.exe` is on PATH."
930 .to_string(),
931 ))
932}
933
934fn archive_with_msvc_lib(
938 out_file: &Path,
939 obj_files: &[PathBuf],
940 out_dir: &Path,
941) -> Result<()> {
942 let lib_exe = find_msvc_lib_exe()?;
943
944 let response_file = out_dir.join(".lib_response.txt");
945 {
946 let mut f = std::fs::File::create(&response_file).map_err(|e| {
947 Error::LinkingFailed(format!(
948 "failed to create lib response file {}: {}",
949 response_file.display(),
950 e
951 ))
952 })?;
953 for obj in obj_files {
954 writeln!(f, "\"{}\"", obj.display()).map_err(|e| {
957 Error::LinkingFailed(format!(
958 "failed to write lib response file: {}",
959 e
960 ))
961 })?;
962 }
963 }
964
965 let mut command = Command::new(&lib_exe);
966 command
967 .arg("/NOLOGO")
968 .arg(format!("/OUT:{}", out_file.display()))
969 .arg(format!("@{}", response_file.display()));
970
971 let output = command
972 .spawn()
973 .map_err(|e| {
974 Error::NvccNotFound(format!(
975 "Failed to spawn {} for linking: {}",
976 lib_exe.display(),
977 e
978 ))
979 })?
980 .wait_with_output()
981 .map_err(|e| Error::LinkingFailed(e.to_string()))?;
982
983 if !output.status.success() {
984 return Err(Error::LinkingFailed(format!(
985 "{} archiving error:\n{}\n{}",
986 lib_exe.display(),
987 String::from_utf8_lossy(&output.stdout),
988 String::from_utf8_lossy(&output.stderr)
989 )));
990 }
991
992 Ok(())
993}
994
995pub struct PtxOutput {
997 paths: Vec<PathBuf>,
998 #[allow(dead_code)]
999 out_dir: PathBuf,
1000}
1001
1002impl PtxOutput {
1003 pub fn write<P: AsRef<Path>>(&self, out: P) -> Result<()> {
1005 let mut file = std::fs::File::create(out.as_ref())?;
1006
1007 for kernel_path in &self.paths {
1008 let name = kernel_path
1009 .file_stem()
1010 .and_then(|s| s.to_str())
1011 .unwrap_or("KERNEL");
1012
1013 writeln!(
1014 file,
1015 r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
1016 name.to_uppercase().replace(['.', '-'], "_"),
1017 name
1018 )?;
1019 }
1020
1021 Ok(())
1022 }
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027 use super::*;
1028
1029 fn toolkit_with_version(version: Option<(u32, u32)>) -> CudaToolkit {
1030 CudaToolkit {
1031 nvcc_path: PathBuf::from("/dev/null"),
1032 include_dir: PathBuf::from("/dev/null"),
1033 lib_dir: PathBuf::from("/dev/null"),
1034 version,
1035 }
1036 }
1037
1038 #[test]
1039 fn cpp_std_auto_selects_cpp20_for_cuda_12() {
1040 let b = KernelBuilder::new();
1041 let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((12, 6))));
1042 assert_eq!(arg.as_deref(), Some("-std=c++20"));
1043 }
1044
1045 #[test]
1046 fn cpp_std_auto_selects_cpp17_for_cuda_11() {
1047 let b = KernelBuilder::new();
1048 let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((11, 8))));
1049 assert_eq!(arg.as_deref(), Some("-std=c++17"));
1050 }
1051
1052 #[test]
1053 fn cpp_std_auto_selects_cpp17_when_version_unknown() {
1054 let b = KernelBuilder::new();
1055 let arg = b.resolve_cpp_std_arg(&toolkit_with_version(None));
1056 assert_eq!(arg.as_deref(), Some("-std=c++17"));
1057 }
1058
1059 #[test]
1060 fn cpp_std_explicit_override_wins() {
1061 let b = KernelBuilder::new().cpp_std("c++17");
1062 let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((12, 6))));
1063 assert_eq!(arg.as_deref(), Some("-std=c++17"));
1064 }
1065
1066 #[test]
1067 fn cpp_std_extra_arg_disables_auto() {
1068 let b = KernelBuilder::new().arg("-std=c++14");
1069 let arg = b.resolve_cpp_std_arg(&toolkit_with_version(Some((12, 6))));
1070 assert_eq!(arg, None);
1071 }
1072}