1use crate::error::{Error, Result};
4use std::collections::HashMap;
5use std::process::Command;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct GpuArch {
12 pub base: usize,
14 pub suffix: Option<String>,
16}
17
18impl GpuArch {
19 pub fn new(base: usize) -> Self {
21 Self { base, suffix: None }
22 }
23
24 pub fn with_suffix(base: usize, suffix: &str) -> Self {
26 Self {
27 base,
28 suffix: Some(suffix.to_string()),
29 }
30 }
31
32 pub fn parse(s: &str) -> Result<Self> {
37 let s = s.trim().to_lowercase();
38 let s = s.strip_prefix("sm_").unwrap_or(&s);
39
40 let (num_part, explicit_suffix) = if let Some(stripped) = s.strip_suffix('f') {
41 (stripped, Some("f".to_string()))
42 } else if let Some(stripped) = s.strip_suffix('a') {
43 (stripped, Some("a".to_string()))
44 } else {
45 (s, None)
46 };
47
48 let base = num_part.parse::<usize>().map_err(|_| {
49 Error::ComputeCapDetectionFailed(format!("Invalid compute capability: {}", s))
50 })?;
51
52 let base = if base < 20 { base * 10 } else { base };
53
54 if explicit_suffix.is_some() {
55 Ok(Self {
56 base,
57 suffix: explicit_suffix,
58 })
59 } else {
60 Ok(Self::auto_suffix(base))
61 }
62 }
63
64 pub fn auto_suffix(base: usize) -> Self {
69 match base {
70 b if b >= 120 => Self::with_suffix(b, "f"),
71 b if b >= 90 => Self::with_suffix(b, "a"),
72 b => Self::new(b),
73 }
74 }
75
76 pub fn to_nvcc_arch(&self) -> String {
78 match &self.suffix {
79 Some(s) => format!("sm_{}{}", self.base, s),
80 None => format!("sm_{}", self.base),
81 }
82 }
83
84 pub fn to_gencode_arg(&self) -> String {
87 let compute = match &self.suffix {
88 Some(s) => format!("compute_{}{}", self.base, s),
89 None => format!("compute_{}", self.base),
90 };
91 let sm = self.to_nvcc_arch();
92 format!("-gencode=arch={},code={}", compute, sm)
93 }
94
95 pub fn base(&self) -> usize {
97 self.base
98 }
99}
100
101impl std::fmt::Display for GpuArch {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match &self.suffix {
104 Some(s) => write!(f, "{}{}", self.base, s),
105 None => write!(f, "{}", self.base),
106 }
107 }
108}
109
110impl From<usize> for GpuArch {
111 fn from(base: usize) -> Self {
112 Self::auto_suffix(base)
113 }
114}
115
116#[derive(Debug, Clone, Default)]
118pub struct ComputeCapability {
119 default_cap: Option<GpuArch>,
120 overrides: HashMap<String, GpuArch>,
121}
122
123impl ComputeCapability {
124 pub fn new() -> Self {
126 Self::default()
127 }
128
129 pub fn with_default(mut self, cap: usize) -> Self {
131 self.default_cap = Some(GpuArch::auto_suffix(cap));
132 self
133 }
134
135 pub fn with_default_arch(mut self, arch: &str) -> Self {
137 if let Ok(gpu_arch) = GpuArch::parse(arch) {
138 self.default_cap = Some(gpu_arch);
139 }
140 self
141 }
142
143 pub fn with_override(mut self, pattern: &str, cap: usize) -> Self {
145 self.overrides
146 .insert(pattern.to_string(), GpuArch::auto_suffix(cap));
147 self
148 }
149
150 pub fn with_override_arch(mut self, pattern: &str, arch: &str) -> Self {
152 if let Ok(gpu_arch) = GpuArch::parse(arch) {
153 self.overrides.insert(pattern.to_string(), gpu_arch);
154 }
155 self
156 }
157
158 pub fn get_for_file(&self, filename: &str) -> Result<GpuArch> {
166 for (pattern, arch) in &self.overrides {
167 if matches_pattern(filename, pattern) {
168 return Ok(arch.clone());
169 }
170 }
171
172 if let Some(arch) = &self.default_cap {
173 return Ok(arch.clone());
174 }
175
176 detect_compute_cap()
177 }
178
179 pub fn get_default(&self) -> Result<GpuArch> {
181 if let Some(arch) = &self.default_cap {
182 return Ok(arch.clone());
183 }
184 detect_compute_cap()
185 }
186
187 pub fn has_overrides(&self) -> bool {
189 !self.overrides.is_empty()
190 }
191}
192
193pub fn detect_compute_cap() -> Result<GpuArch> {
199 if let Ok(cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
200 return GpuArch::parse(&cap_str);
201 }
202 detect_from_nvidia_smi()
203}
204
205fn detect_from_nvidia_smi() -> Result<GpuArch> {
206 let output = Command::new("nvidia-smi")
207 .args(["--query-gpu=compute_cap", "--format=csv"])
208 .output();
209
210 match output {
211 Ok(output) if output.status.success() => {
212 let stdout = String::from_utf8_lossy(&output.stdout);
213 parse_nvidia_smi_output(&stdout)
214 }
215 Ok(output) => Err(Error::ComputeCapDetectionFailed(format!(
216 "nvidia-smi failed: {}. \
217 If building in Docker, set CUDA_COMPUTE_CAP environment variable (e.g., CUDA_COMPUTE_CAP=90).",
218 String::from_utf8_lossy(&output.stderr)
219 ))),
220 Err(e) => Err(Error::ComputeCapDetectionFailed(format!(
221 "Failed to run nvidia-smi: {}. \
222 If building in Docker, set CUDA_COMPUTE_CAP environment variable (e.g., CUDA_COMPUTE_CAP=90). \
223 GPU is not accessible during 'docker build' - only during 'docker run --gpus all'.",
224 e
225 ))),
226 }
227}
228
229fn parse_nvidia_smi_output(output: &str) -> Result<GpuArch> {
230 let line = output.lines().nth(1).ok_or_else(|| {
231 Error::ComputeCapDetectionFailed("Unexpected nvidia-smi output".to_string())
232 })?;
233
234 let cap = line.trim().parse::<f32>().map_err(|_| {
235 Error::ComputeCapDetectionFailed(format!("Failed to parse compute_cap: {}", line))
236 })?;
237
238 let base = (cap * 10.0) as usize;
239 Ok(GpuArch::auto_suffix(base))
240}
241
242fn matches_pattern(filename: &str, pattern: &str) -> bool {
243 if filename == pattern {
244 return true;
245 }
246
247 if pattern.contains('*') {
248 let parts: Vec<&str> = pattern.split('*').collect();
249
250 if parts.len() == 2 {
251 let (prefix, suffix) = (parts[0], parts[1]);
252 return filename.starts_with(prefix) && filename.ends_with(suffix);
253 }
254
255 if let Some(stripped) = pattern.strip_prefix('*') {
256 return filename.ends_with(stripped);
257 }
258 if let Some(stripped) = pattern.strip_suffix('*') {
259 return filename.starts_with(stripped);
260 }
261 }
262
263 false
264}
265
266pub fn get_gpu_arch_string(compute_cap: usize) -> String {
270 GpuArch::auto_suffix(compute_cap).to_nvcc_arch()
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_matches_pattern() {
279 assert!(matches_pattern("kernel.cu", "kernel.cu"));
280 assert!(matches_pattern("sm90_kernel.cu", "sm90_*.cu"));
281 assert!(matches_pattern("kernel_hopper.cu", "*_hopper.cu"));
282 assert!(matches_pattern("prefix_middle_suffix.cu", "prefix_*.cu"));
283 assert!(!matches_pattern("other.cu", "sm90_*.cu"));
284 }
285
286 #[test]
287 fn test_gpu_arch_string() {
288 assert_eq!(get_gpu_arch_string(80), "sm_80");
289 assert_eq!(get_gpu_arch_string(90), "sm_90a");
290 assert_eq!(get_gpu_arch_string(100), "sm_100a");
291 assert_eq!(get_gpu_arch_string(120), "sm_120f");
292 }
293
294 #[test]
295 fn test_gpu_arch_parse() {
296 let arch = GpuArch::parse("90a").unwrap();
297 assert_eq!(arch.base, 90);
298 assert_eq!(arch.suffix, Some("a".to_string()));
299 assert_eq!(arch.to_nvcc_arch(), "sm_90a");
300
301 let arch = GpuArch::parse("100a").unwrap();
302 assert_eq!(arch.base, 100);
303 assert_eq!(arch.to_nvcc_arch(), "sm_100a");
304
305 let arch = GpuArch::parse("sm_120f").unwrap();
306 assert_eq!(arch.base, 120);
307 assert_eq!(arch.to_nvcc_arch(), "sm_120f");
308
309 let arch = GpuArch::parse("80").unwrap();
310 assert_eq!(arch.base, 80);
311 assert_eq!(arch.suffix, None);
312 assert_eq!(arch.to_nvcc_arch(), "sm_80");
313 }
314
315 #[test]
316 fn test_gpu_arch_auto_suffix() {
317 assert_eq!(GpuArch::auto_suffix(80).to_nvcc_arch(), "sm_80");
318 assert_eq!(GpuArch::auto_suffix(89).to_nvcc_arch(), "sm_89");
319 assert_eq!(GpuArch::auto_suffix(90).to_nvcc_arch(), "sm_90a");
320 assert_eq!(GpuArch::auto_suffix(100).to_nvcc_arch(), "sm_100a");
321 assert_eq!(GpuArch::auto_suffix(120).to_nvcc_arch(), "sm_120f");
322 }
323
324 #[test]
325 fn test_gpu_arch_gencode() {
326 assert_eq!(
327 GpuArch::auto_suffix(75).to_gencode_arg(),
328 "-gencode=arch=compute_75,code=sm_75"
329 );
330 assert_eq!(
331 GpuArch::auto_suffix(90).to_gencode_arg(),
332 "-gencode=arch=compute_90a,code=sm_90a"
333 );
334 assert_eq!(
335 GpuArch::auto_suffix(120).to_gencode_arg(),
336 "-gencode=arch=compute_120f,code=sm_120f"
337 );
338 }
339}