Skip to main content

baracuda_forge/
compute_cap.rs

1//! Compute capability detection and management.
2
3use crate::error::{Error, Result};
4use std::collections::HashMap;
5use std::process::Command;
6
7/// GPU architecture specification.
8///
9/// Supports both numeric (80, 90) and string-based (90a, 100a) formats.
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct GpuArch {
12    /// Base compute capability number (e.g., 90, 100, 120).
13    pub base: usize,
14    /// Optional suffix for accelerated variants (e.g., "a" for async).
15    pub suffix: Option<String>,
16}
17
18impl GpuArch {
19    /// Create a new GPU architecture from base number.
20    pub fn new(base: usize) -> Self {
21        Self { base, suffix: None }
22    }
23
24    /// Create a new GPU architecture with suffix (e.g., 90a, 100a).
25    pub fn with_suffix(base: usize, suffix: &str) -> Self {
26        Self {
27            base,
28            suffix: Some(suffix.to_string()),
29        }
30    }
31
32    /// Parse from string like "90", "90a", "100a", "sm_90a".
33    ///
34    /// If no suffix is provided (e.g., "90"), auto-suffix is applied for sm_90+.
35    /// To explicitly disable the suffix, use the numeric API directly.
36    pub fn parse(s: &str) -> Result<Self> {
37        let s = s.trim().to_lowercase();
38        let s = s.strip_prefix("sm_").unwrap_or(&s);
39
40        let (num_part, explicit_suffix) = if let Some(stripped) = s.strip_suffix('f') {
41            (stripped, Some("f".to_string()))
42        } else if let Some(stripped) = s.strip_suffix('a') {
43            (stripped, Some("a".to_string()))
44        } else {
45            (s, None)
46        };
47
48        let base = num_part.parse::<usize>().map_err(|_| {
49            Error::ComputeCapDetectionFailed(format!("Invalid compute capability: {}", s))
50        })?;
51
52        let base = if base < 20 { base * 10 } else { base };
53
54        if explicit_suffix.is_some() {
55            Ok(Self {
56                base,
57                suffix: explicit_suffix,
58            })
59        } else {
60            Ok(Self::auto_suffix(base))
61        }
62    }
63
64    /// Create GPU arch with auto-detected suffix for newer architectures.
65    ///
66    /// - `>= sm_120` gets "f" suffix for f8f6f4.mma instructions (CUDA 12.9+).
67    /// - `>= sm_90` gets "a" suffix for async / accelerated features.
68    pub fn auto_suffix(base: usize) -> Self {
69        match base {
70            b if b >= 120 => Self::with_suffix(b, "f"),
71            b if b >= 90 => Self::with_suffix(b, "a"),
72            b => Self::new(b),
73        }
74    }
75
76    /// Get the nvcc `--gpu-architecture` string (e.g., "sm_90a", "sm_80").
77    pub fn to_nvcc_arch(&self) -> String {
78        match &self.suffix {
79            Some(s) => format!("sm_{}{}", self.base, s),
80            None => format!("sm_{}", self.base),
81        }
82    }
83
84    /// Get the nvcc `-gencode` argument
85    /// (e.g., `-gencode=arch=compute_90a,code=sm_90a`).
86    pub fn to_gencode_arg(&self) -> String {
87        let compute = match &self.suffix {
88            Some(s) => format!("compute_{}{}", self.base, s),
89            None => format!("compute_{}", self.base),
90        };
91        let sm = self.to_nvcc_arch();
92        format!("-gencode=arch={},code={}", compute, sm)
93    }
94
95    /// Get the base compute capability number.
96    pub fn base(&self) -> usize {
97        self.base
98    }
99}
100
101impl std::fmt::Display for GpuArch {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match &self.suffix {
104            Some(s) => write!(f, "{}{}", self.base, s),
105            None => write!(f, "{}", self.base),
106        }
107    }
108}
109
110impl From<usize> for GpuArch {
111    fn from(base: usize) -> Self {
112        Self::auto_suffix(base)
113    }
114}
115
116/// Compute capability configuration.
117#[derive(Debug, Clone, Default)]
118pub struct ComputeCapability {
119    default_cap: Option<GpuArch>,
120    overrides: HashMap<String, GpuArch>,
121}
122
123impl ComputeCapability {
124    /// Create new compute capability config with auto-detection.
125    pub fn new() -> Self {
126        Self::default()
127    }
128
129    /// Set default compute capability (numeric, auto-selects suffix).
130    pub fn with_default(mut self, cap: usize) -> Self {
131        self.default_cap = Some(GpuArch::auto_suffix(cap));
132        self
133    }
134
135    /// Set default compute capability with explicit arch string (e.g., "90a", "100a").
136    pub fn with_default_arch(mut self, arch: &str) -> Self {
137        if let Ok(gpu_arch) = GpuArch::parse(arch) {
138            self.default_cap = Some(gpu_arch);
139        }
140        self
141    }
142
143    /// Add compute cap override for files matching pattern (numeric).
144    pub fn with_override(mut self, pattern: &str, cap: usize) -> Self {
145        self.overrides
146            .insert(pattern.to_string(), GpuArch::auto_suffix(cap));
147        self
148    }
149
150    /// Add compute cap override with explicit arch string.
151    pub fn with_override_arch(mut self, pattern: &str, arch: &str) -> Self {
152        if let Ok(gpu_arch) = GpuArch::parse(arch) {
153            self.overrides.insert(pattern.to_string(), gpu_arch);
154        }
155        self
156    }
157
158    /// Get GPU arch for a specific file.
159    ///
160    /// Priority:
161    /// 1. Per-file override matching pattern.
162    /// 2. Default compute cap.
163    /// 3. Auto-detected from `CUDA_COMPUTE_CAP` env var.
164    /// 4. Auto-detected from `nvidia-smi`.
165    pub fn get_for_file(&self, filename: &str) -> Result<GpuArch> {
166        for (pattern, arch) in &self.overrides {
167            if matches_pattern(filename, pattern) {
168                return Ok(arch.clone());
169            }
170        }
171
172        if let Some(arch) = &self.default_cap {
173            return Ok(arch.clone());
174        }
175
176        detect_compute_cap()
177    }
178
179    /// Get the default GPU architecture.
180    pub fn get_default(&self) -> Result<GpuArch> {
181        if let Some(arch) = &self.default_cap {
182            return Ok(arch.clone());
183        }
184        detect_compute_cap()
185    }
186
187    /// Check if any overrides are configured.
188    pub fn has_overrides(&self) -> bool {
189        !self.overrides.is_empty()
190    }
191}
192
193/// Detect compute capability from system.
194///
195/// Priority:
196/// 1. `CUDA_COMPUTE_CAP` environment variable (supports "90", "90a", "100a").
197/// 2. `nvidia-smi` query.
198pub fn detect_compute_cap() -> Result<GpuArch> {
199    if let Ok(cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
200        return GpuArch::parse(&cap_str);
201    }
202    detect_from_nvidia_smi()
203}
204
205fn detect_from_nvidia_smi() -> Result<GpuArch> {
206    let output = Command::new("nvidia-smi")
207        .args(["--query-gpu=compute_cap", "--format=csv"])
208        .output();
209
210    match output {
211        Ok(output) if output.status.success() => {
212            let stdout = String::from_utf8_lossy(&output.stdout);
213            parse_nvidia_smi_output(&stdout)
214        }
215        Ok(output) => Err(Error::ComputeCapDetectionFailed(format!(
216            "nvidia-smi failed: {}. \
217            If building in Docker, set CUDA_COMPUTE_CAP environment variable (e.g., CUDA_COMPUTE_CAP=90).",
218            String::from_utf8_lossy(&output.stderr)
219        ))),
220        Err(e) => Err(Error::ComputeCapDetectionFailed(format!(
221            "Failed to run nvidia-smi: {}. \
222            If building in Docker, set CUDA_COMPUTE_CAP environment variable (e.g., CUDA_COMPUTE_CAP=90). \
223            GPU is not accessible during 'docker build' - only during 'docker run --gpus all'.",
224            e
225        ))),
226    }
227}
228
229fn parse_nvidia_smi_output(output: &str) -> Result<GpuArch> {
230    let line = output.lines().nth(1).ok_or_else(|| {
231        Error::ComputeCapDetectionFailed("Unexpected nvidia-smi output".to_string())
232    })?;
233
234    let cap = line.trim().parse::<f32>().map_err(|_| {
235        Error::ComputeCapDetectionFailed(format!("Failed to parse compute_cap: {}", line))
236    })?;
237
238    let base = (cap * 10.0) as usize;
239    Ok(GpuArch::auto_suffix(base))
240}
241
242fn matches_pattern(filename: &str, pattern: &str) -> bool {
243    if filename == pattern {
244        return true;
245    }
246
247    if pattern.contains('*') {
248        let parts: Vec<&str> = pattern.split('*').collect();
249
250        if parts.len() == 2 {
251            let (prefix, suffix) = (parts[0], parts[1]);
252            return filename.starts_with(prefix) && filename.ends_with(suffix);
253        }
254
255        if let Some(stripped) = pattern.strip_prefix('*') {
256            return filename.ends_with(stripped);
257        }
258        if let Some(stripped) = pattern.strip_suffix('*') {
259            return filename.starts_with(stripped);
260        }
261    }
262
263    false
264}
265
266/// Get GPU architecture string for nvcc (e.g., "sm_90a" or "sm_80").
267///
268/// Convenience function. For more control, use [`GpuArch`] directly.
269pub fn get_gpu_arch_string(compute_cap: usize) -> String {
270    GpuArch::auto_suffix(compute_cap).to_nvcc_arch()
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_matches_pattern() {
279        assert!(matches_pattern("kernel.cu", "kernel.cu"));
280        assert!(matches_pattern("sm90_kernel.cu", "sm90_*.cu"));
281        assert!(matches_pattern("kernel_hopper.cu", "*_hopper.cu"));
282        assert!(matches_pattern("prefix_middle_suffix.cu", "prefix_*.cu"));
283        assert!(!matches_pattern("other.cu", "sm90_*.cu"));
284    }
285
286    #[test]
287    fn test_gpu_arch_string() {
288        assert_eq!(get_gpu_arch_string(80), "sm_80");
289        assert_eq!(get_gpu_arch_string(90), "sm_90a");
290        assert_eq!(get_gpu_arch_string(100), "sm_100a");
291        assert_eq!(get_gpu_arch_string(120), "sm_120f");
292    }
293
294    #[test]
295    fn test_gpu_arch_parse() {
296        let arch = GpuArch::parse("90a").unwrap();
297        assert_eq!(arch.base, 90);
298        assert_eq!(arch.suffix, Some("a".to_string()));
299        assert_eq!(arch.to_nvcc_arch(), "sm_90a");
300
301        let arch = GpuArch::parse("100a").unwrap();
302        assert_eq!(arch.base, 100);
303        assert_eq!(arch.to_nvcc_arch(), "sm_100a");
304
305        let arch = GpuArch::parse("sm_120f").unwrap();
306        assert_eq!(arch.base, 120);
307        assert_eq!(arch.to_nvcc_arch(), "sm_120f");
308
309        let arch = GpuArch::parse("80").unwrap();
310        assert_eq!(arch.base, 80);
311        assert_eq!(arch.suffix, None);
312        assert_eq!(arch.to_nvcc_arch(), "sm_80");
313    }
314
315    #[test]
316    fn test_gpu_arch_auto_suffix() {
317        assert_eq!(GpuArch::auto_suffix(80).to_nvcc_arch(), "sm_80");
318        assert_eq!(GpuArch::auto_suffix(89).to_nvcc_arch(), "sm_89");
319        assert_eq!(GpuArch::auto_suffix(90).to_nvcc_arch(), "sm_90a");
320        assert_eq!(GpuArch::auto_suffix(100).to_nvcc_arch(), "sm_100a");
321        assert_eq!(GpuArch::auto_suffix(120).to_nvcc_arch(), "sm_120f");
322    }
323
324    #[test]
325    fn test_gpu_arch_gencode() {
326        assert_eq!(
327            GpuArch::auto_suffix(75).to_gencode_arg(),
328            "-gencode=arch=compute_75,code=sm_75"
329        );
330        assert_eq!(
331            GpuArch::auto_suffix(90).to_gencode_arg(),
332            "-gencode=arch=compute_90a,code=sm_90a"
333        );
334        assert_eq!(
335            GpuArch::auto_suffix(120).to_gencode_arg(),
336            "-gencode=arch=compute_120f,code=sm_120f"
337        );
338    }
339}