baracuda_forge/
toolkit.rs1use crate::error::{Error, Result};
7use std::path::PathBuf;
8use std::process::Command;
9
10#[derive(Debug, Clone)]
12pub struct CudaToolkit {
13 pub nvcc_path: PathBuf,
15 pub include_dir: PathBuf,
17 pub lib_dir: PathBuf,
19 pub version: Option<(u32, u32)>,
21}
22
23impl CudaToolkit {
24 pub fn detect() -> Result<Self> {
30 let install = baracuda_build::detect_cuda();
31 let nvcc_path = baracuda_build::find_nvcc().ok_or_else(|| {
32 Error::NvccNotFound(
33 "No nvcc found via $NVCC, CUDA install dirs, or $PATH".to_string(),
34 )
35 })?;
36
37 if let Some(install) = install {
38 let version = install.version;
39 return Ok(Self {
40 nvcc_path,
41 include_dir: install.include,
42 lib_dir: install.lib,
43 version,
44 });
45 }
46
47 Self::from_nvcc_path(nvcc_path)
48 }
49
50 pub fn from_nvcc_path(nvcc_path: PathBuf) -> Result<Self> {
52 if !nvcc_path.exists() {
53 return Err(Error::NvccNotFound(nvcc_path.display().to_string()));
54 }
55
56 let cuda_root = nvcc_path
57 .parent()
58 .and_then(|p| p.parent())
59 .ok_or_else(|| Error::CudaToolkitNotFound(nvcc_path.clone()))?;
60
61 let include_dir = cuda_root.join("include");
62 let lib_dir = if cfg!(target_os = "windows") {
63 cuda_root.join("lib").join("x64")
64 } else {
65 cuda_root.join("lib64")
66 };
67
68 let version = nvcc_version(&nvcc_path);
69
70 Ok(Self {
71 nvcc_path,
72 include_dir,
73 lib_dir,
74 version,
75 })
76 }
77
78 pub fn supported_architectures(&self) -> Vec<usize> {
80 let output = Command::new(&self.nvcc_path)
81 .arg("--list-gpu-code")
82 .output();
83
84 if let Ok(output) = output {
85 let stdout = String::from_utf8_lossy(&output.stdout);
86 parse_gpu_codes(&stdout)
87 } else {
88 Vec::new()
89 }
90 }
91}
92
93fn nvcc_version(nvcc_path: &PathBuf) -> Option<(u32, u32)> {
94 let output = Command::new(nvcc_path).arg("--version").output().ok()?;
95 let stdout = String::from_utf8_lossy(&output.stdout);
96 baracuda_build::parse_nvcc_version(&stdout)
97}
98
99fn parse_gpu_codes(output: &str) -> Vec<usize> {
100 let mut codes = Vec::new();
101 for line in output.lines() {
102 let parts: Vec<&str> = line.split('_').collect();
103 if parts.len() >= 2 && parts.contains(&"sm") {
104 if let Ok(code) = parts[1].parse::<usize>() {
105 codes.push(code);
106 }
107 }
108 }
109 codes.sort();
110 codes.dedup();
111 codes
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_parse_gpu_codes() {
120 let output = "sm_52\nsm_60\nsm_70\nsm_75\nsm_80\nsm_86\nsm_89\nsm_90";
121 let codes = parse_gpu_codes(output);
122 assert!(codes.contains(&80));
123 assert!(codes.contains(&90));
124 }
125}