use crate::error::{Error, Result};
use std::path::PathBuf;
use std::process::Command;
const CUDA_SEARCH_PATHS: &[&str] = &[
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"/usr",
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA",
"C:/CUDA",
];
#[derive(Debug, Clone)]
pub struct CudaToolkit {
pub nvcc_path: PathBuf,
pub include_dir: PathBuf,
pub lib_dir: PathBuf,
pub version: Option<String>,
}
impl CudaToolkit {
pub fn detect() -> Result<Self> {
let nvcc_path = find_nvcc()?;
let cuda_root = nvcc_path
.parent()
.and_then(|p| p.parent())
.ok_or_else(|| Error::CudaToolkitNotFound(nvcc_path.clone()))?;
let include_dir = cuda_root.join("include");
let lib_dir = if cfg!(target_os = "windows") {
cuda_root.join("lib").join("x64")
} else {
cuda_root.join("lib64")
};
let version = detect_cuda_version(&nvcc_path);
Ok(Self {
nvcc_path,
include_dir,
lib_dir,
version,
})
}
pub fn from_nvcc_path(nvcc_path: PathBuf) -> Result<Self> {
if !nvcc_path.exists() {
return Err(Error::NvccNotFound(nvcc_path.display().to_string()));
}
let cuda_root = nvcc_path
.parent()
.and_then(|p| p.parent())
.ok_or_else(|| Error::CudaToolkitNotFound(nvcc_path.clone()))?;
let include_dir = cuda_root.join("include");
let lib_dir = if cfg!(target_os = "windows") {
cuda_root.join("lib").join("x64")
} else {
cuda_root.join("lib64")
};
let version = detect_cuda_version(&nvcc_path);
Ok(Self {
nvcc_path,
include_dir,
lib_dir,
version,
})
}
pub fn supported_architectures(&self) -> Vec<usize> {
let output = Command::new(&self.nvcc_path)
.arg("--list-gpu-code")
.output();
if let Ok(output) = output {
let stdout = String::from_utf8_lossy(&output.stdout);
parse_gpu_codes(&stdout)
} else {
Vec::new()
}
}
}
fn find_nvcc() -> Result<PathBuf> {
if let Ok(nvcc) = std::env::var("NVCC") {
let path = PathBuf::from(&nvcc);
if path.exists() {
return Ok(path);
}
}
if let Ok(path) = which::which("nvcc") {
return Ok(path);
}
if let Ok(cuda_home) = std::env::var("CUDA_HOME") {
let nvcc = PathBuf::from(&cuda_home).join("bin").join("nvcc");
if nvcc.exists() {
return Ok(nvcc);
}
}
if let Ok(cuda_path) = std::env::var("CUDA_PATH") {
let nvcc = PathBuf::from(&cuda_path).join("bin").join("nvcc.exe");
if nvcc.exists() {
return Ok(nvcc);
}
}
for base_path in CUDA_SEARCH_PATHS {
let base = PathBuf::from(base_path);
let nvcc = base.join("bin").join(if cfg!(target_os = "windows") {
"nvcc.exe"
} else {
"nvcc"
});
if nvcc.exists() {
return Ok(nvcc);
}
if cfg!(target_os = "windows") && base.exists() {
if let Ok(entries) = std::fs::read_dir(&base) {
for entry in entries.flatten() {
let nvcc = entry.path().join("bin").join("nvcc.exe");
if nvcc.exists() {
return Ok(nvcc);
}
}
}
}
}
Err(Error::NvccNotFound(
"No nvcc found in PATH or standard locations".to_string(),
))
}
fn parse_gpu_codes(output: &str) -> Vec<usize> {
let mut codes = Vec::new();
for line in output.lines() {
let parts: Vec<&str> = line.split('_').collect();
if parts.len() >= 2 && parts.contains(&"sm") {
if let Ok(code) = parts[1].parse::<usize>() {
codes.push(code);
}
}
}
codes.sort();
codes.dedup();
codes
}
fn detect_cuda_version(nvcc_path: &PathBuf) -> Option<String> {
let output = Command::new(nvcc_path).arg("--version").output().ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
for line in stdout.lines() {
if line.contains("release") {
if let Some(version_part) = line.split("release").nth(1) {
let version = version_part.trim().split(',').next()?;
return Some(version.trim().to_string());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_gpu_codes() {
let output = "sm_52\nsm_60\nsm_70\nsm_75\nsm_80\nsm_86\nsm_89\nsm_90";
let codes = parse_gpu_codes(output);
assert!(codes.contains(&80));
assert!(codes.contains(&90));
}
}