use crate::error::{Error, Result};
use std::path::PathBuf;
use std::process::Command;
#[derive(Debug, Clone)]
pub struct CudaToolkit {
pub nvcc_path: PathBuf,
pub include_dir: PathBuf,
pub lib_dir: PathBuf,
pub version: Option<(u32, u32)>,
}
impl CudaToolkit {
pub fn detect() -> Result<Self> {
let install = baracuda_build::detect_cuda();
let nvcc_path = baracuda_build::find_nvcc().ok_or_else(|| {
Error::NvccNotFound(
"No nvcc found via $NVCC, CUDA install dirs, or $PATH".to_string(),
)
})?;
if let Some(install) = install {
let version = install.version;
return Ok(Self {
nvcc_path,
include_dir: install.include,
lib_dir: install.lib,
version,
});
}
Self::from_nvcc_path(nvcc_path)
}
pub fn from_nvcc_path(nvcc_path: PathBuf) -> Result<Self> {
if !nvcc_path.exists() {
return Err(Error::NvccNotFound(nvcc_path.display().to_string()));
}
let cuda_root = nvcc_path
.parent()
.and_then(|p| p.parent())
.ok_or_else(|| Error::CudaToolkitNotFound(nvcc_path.clone()))?;
let include_dir = cuda_root.join("include");
let lib_dir = if cfg!(target_os = "windows") {
cuda_root.join("lib").join("x64")
} else {
cuda_root.join("lib64")
};
let version = nvcc_version(&nvcc_path);
Ok(Self {
nvcc_path,
include_dir,
lib_dir,
version,
})
}
pub fn supported_architectures(&self) -> Vec<usize> {
let output = Command::new(&self.nvcc_path)
.arg("--list-gpu-code")
.output();
if let Ok(output) = output {
let stdout = String::from_utf8_lossy(&output.stdout);
parse_gpu_codes(&stdout)
} else {
Vec::new()
}
}
}
fn nvcc_version(nvcc_path: &PathBuf) -> Option<(u32, u32)> {
let output = Command::new(nvcc_path).arg("--version").output().ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
baracuda_build::parse_nvcc_version(&stdout)
}
fn parse_gpu_codes(output: &str) -> Vec<usize> {
let mut codes = Vec::new();
for line in output.lines() {
let parts: Vec<&str> = line.split('_').collect();
if parts.len() >= 2 && parts.contains(&"sm") {
if let Ok(code) = parts[1].parse::<usize>() {
codes.push(code);
}
}
}
codes.sort();
codes.dedup();
codes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_gpu_codes() {
let output = "sm_52\nsm_60\nsm_70\nsm_75\nsm_80\nsm_86\nsm_89\nsm_90";
let codes = parse_gpu_codes(output);
assert!(codes.contains(&80));
assert!(codes.contains(&90));
}
}