Skip to main content

baracuda_forge/
toolkit.rs

1//! CUDA toolkit auto-detection.
2//!
3//! Thin wrapper around [`baracuda_build::detect_cuda`] that adds nvcc-specific
4//! conveniences (architecture listing) the runtime `-sys` crates don't need.
5
6use crate::error::{Error, Result};
7use std::path::PathBuf;
8use std::process::Command;
9
10/// CUDA toolkit information.
11#[derive(Debug, Clone)]
12pub struct CudaToolkit {
13    /// Path to the `nvcc` binary.
14    pub nvcc_path: PathBuf,
15    /// CUDA include directory.
16    pub include_dir: PathBuf,
17    /// CUDA lib directory.
18    pub lib_dir: PathBuf,
19    /// CUDA version `(major, minor)` if detected (e.g., `(12, 6)`).
20    pub version: Option<(u32, u32)>,
21}
22
23impl CudaToolkit {
24    /// Auto-detect CUDA toolkit installation.
25    ///
26    /// Defers to [`baracuda_build::detect_cuda`] for install discovery, then
27    /// resolves nvcc via [`baracuda_build::find_nvcc`] (which adds `$NVCC` /
28    /// `$PATH` lookup as fallbacks).
29    pub fn detect() -> Result<Self> {
30        let install = baracuda_build::detect_cuda();
31        let nvcc_path = baracuda_build::find_nvcc().ok_or_else(|| {
32            Error::NvccNotFound(
33                "No nvcc found via $NVCC, CUDA install dirs, or $PATH".to_string(),
34            )
35        })?;
36
37        if let Some(install) = install {
38            let version = install.version;
39            return Ok(Self {
40                nvcc_path,
41                include_dir: install.include,
42                lib_dir: install.lib,
43                version,
44            });
45        }
46
47        Self::from_nvcc_path(nvcc_path)
48    }
49
50    /// Create toolkit from explicit nvcc path.
51    pub fn from_nvcc_path(nvcc_path: PathBuf) -> Result<Self> {
52        if !nvcc_path.exists() {
53            return Err(Error::NvccNotFound(nvcc_path.display().to_string()));
54        }
55
56        let cuda_root = nvcc_path
57            .parent()
58            .and_then(|p| p.parent())
59            .ok_or_else(|| Error::CudaToolkitNotFound(nvcc_path.clone()))?;
60
61        let include_dir = cuda_root.join("include");
62        let lib_dir = if cfg!(target_os = "windows") {
63            cuda_root.join("lib").join("x64")
64        } else {
65            cuda_root.join("lib64")
66        };
67
68        let version = nvcc_version(&nvcc_path);
69
70        Ok(Self {
71            nvcc_path,
72            include_dir,
73            lib_dir,
74            version,
75        })
76    }
77
78    /// Get supported GPU architectures by querying nvcc.
79    pub fn supported_architectures(&self) -> Vec<usize> {
80        let output = Command::new(&self.nvcc_path)
81            .arg("--list-gpu-code")
82            .output();
83
84        if let Ok(output) = output {
85            let stdout = String::from_utf8_lossy(&output.stdout);
86            parse_gpu_codes(&stdout)
87        } else {
88            Vec::new()
89        }
90    }
91}
92
93fn nvcc_version(nvcc_path: &PathBuf) -> Option<(u32, u32)> {
94    let output = Command::new(nvcc_path).arg("--version").output().ok()?;
95    let stdout = String::from_utf8_lossy(&output.stdout);
96    baracuda_build::parse_nvcc_version(&stdout)
97}
98
99fn parse_gpu_codes(output: &str) -> Vec<usize> {
100    let mut codes = Vec::new();
101    for line in output.lines() {
102        let parts: Vec<&str> = line.split('_').collect();
103        if parts.len() >= 2 && parts.contains(&"sm") {
104            if let Ok(code) = parts[1].parse::<usize>() {
105                codes.push(code);
106            }
107        }
108    }
109    codes.sort();
110    codes.dedup();
111    codes
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_parse_gpu_codes() {
120        let output = "sm_52\nsm_60\nsm_70\nsm_75\nsm_80\nsm_86\nsm_89\nsm_90";
121        let codes = parse_gpu_codes(output);
122        assert!(codes.contains(&80));
123        assert!(codes.contains(&90));
124    }
125}