use crate::error::{Error, Result};
use std::collections::HashMap;
use std::process::Command;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GpuArch {
pub base: usize,
pub suffix: Option<String>,
}
impl GpuArch {
pub fn new(base: usize) -> Self {
Self { base, suffix: None }
}
pub fn with_suffix(base: usize, suffix: &str) -> Self {
Self {
base,
suffix: Some(suffix.to_string()),
}
}
pub fn parse(s: &str) -> Result<Self> {
let s = s.trim().to_lowercase();
let s = s.strip_prefix("sm_").unwrap_or(&s);
let (num_part, explicit_suffix) = if s.ends_with('a') {
(&s[..s.len() - 1], Some("a".to_string()))
} else {
(s.as_ref(), None)
};
let base = num_part.parse::<usize>().map_err(|_| {
Error::ComputeCapDetectionFailed(format!("Invalid compute capability: {}", s))
})?;
let base = if base < 100 && base < 20 {
base * 10
} else {
base
};
if explicit_suffix.is_some() {
Ok(Self {
base,
suffix: explicit_suffix,
})
} else {
Ok(Self::auto_suffix(base))
}
}
pub fn auto_suffix(base: usize) -> Self {
match base {
b if b >= 90 => Self::with_suffix(b, "a"),
b => Self::new(b),
}
}
pub fn to_nvcc_arch(&self) -> String {
match &self.suffix {
Some(s) => format!("sm_{}{}", self.base, s),
None => format!("sm_{}", self.base),
}
}
pub fn to_gencode_arg(&self) -> String {
let compute = match &self.suffix {
Some(s) => format!("compute_{}{}", self.base, s),
None => format!("compute_{}", self.base),
};
let sm = self.to_nvcc_arch();
format!("-gencode=arch={},code={}", compute, sm)
}
pub fn base(&self) -> usize {
self.base
}
}
impl std::fmt::Display for GpuArch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.suffix {
Some(s) => write!(f, "{}{}", self.base, s),
None => write!(f, "{}", self.base),
}
}
}
impl From<usize> for GpuArch {
fn from(base: usize) -> Self {
Self::auto_suffix(base)
}
}
#[derive(Debug, Clone)]
pub struct ComputeCapability {
default_cap: Option<GpuArch>,
overrides: HashMap<String, GpuArch>,
}
impl Default for ComputeCapability {
fn default() -> Self {
Self {
default_cap: None,
overrides: HashMap::new(),
}
}
}
impl ComputeCapability {
pub fn new() -> Self {
Self::default()
}
pub fn with_default(mut self, cap: usize) -> Self {
self.default_cap = Some(GpuArch::auto_suffix(cap));
self
}
pub fn with_default_arch(mut self, arch: &str) -> Self {
if let Ok(gpu_arch) = GpuArch::parse(arch) {
self.default_cap = Some(gpu_arch);
}
self
}
pub fn with_override(mut self, pattern: &str, cap: usize) -> Self {
self.overrides
.insert(pattern.to_string(), GpuArch::auto_suffix(cap));
self
}
pub fn with_override_arch(mut self, pattern: &str, arch: &str) -> Self {
if let Ok(gpu_arch) = GpuArch::parse(arch) {
self.overrides.insert(pattern.to_string(), gpu_arch);
}
self
}
pub fn get_for_file(&self, filename: &str) -> Result<GpuArch> {
for (pattern, arch) in &self.overrides {
if matches_pattern(filename, pattern) {
return Ok(arch.clone());
}
}
if let Some(arch) = &self.default_cap {
return Ok(arch.clone());
}
detect_compute_cap()
}
pub fn get_default(&self) -> Result<GpuArch> {
if let Some(arch) = &self.default_cap {
return Ok(arch.clone());
}
detect_compute_cap()
}
pub fn has_overrides(&self) -> bool {
!self.overrides.is_empty()
}
}
pub fn detect_compute_cap() -> Result<GpuArch> {
if let Ok(cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
return GpuArch::parse(&cap_str);
}
detect_from_nvidia_smi()
}
fn detect_from_nvidia_smi() -> Result<GpuArch> {
let output = Command::new("nvidia-smi")
.args(["--query-gpu=compute_cap", "--format=csv"])
.output();
match output {
Ok(output) if output.status.success() => {
let stdout = String::from_utf8_lossy(&output.stdout);
parse_nvidia_smi_output(&stdout)
}
Ok(output) => Err(Error::ComputeCapDetectionFailed(format!(
"nvidia-smi failed: {}. \
If building in Docker, set CUDA_COMPUTE_CAP environment variable (e.g., CUDA_COMPUTE_CAP=90).",
String::from_utf8_lossy(&output.stderr)
))),
Err(e) => Err(Error::ComputeCapDetectionFailed(format!(
"Failed to run nvidia-smi: {}. \
If building in Docker, set CUDA_COMPUTE_CAP environment variable (e.g., CUDA_COMPUTE_CAP=90). \
GPU is not accessible during 'docker build' - only during 'docker run --gpus all'.",
e
))),
}
}
fn parse_nvidia_smi_output(output: &str) -> Result<GpuArch> {
let line = output.lines().nth(1).ok_or_else(|| {
Error::ComputeCapDetectionFailed("Unexpected nvidia-smi output".to_string())
})?;
let cap = line.trim().parse::<f32>().map_err(|_| {
Error::ComputeCapDetectionFailed(format!("Failed to parse compute_cap: {}", line))
})?;
let base = (cap * 10.0) as usize;
Ok(GpuArch::auto_suffix(base))
}
fn matches_pattern(filename: &str, pattern: &str) -> bool {
if filename == pattern {
return true;
}
if pattern.contains('*') {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 2 {
let (prefix, suffix) = (parts[0], parts[1]);
return filename.starts_with(prefix) && filename.ends_with(suffix);
}
if pattern.starts_with('*') {
return filename.ends_with(&pattern[1..]);
}
if pattern.ends_with('*') {
return filename.starts_with(&pattern[..pattern.len() - 1]);
}
}
false
}
pub fn get_gpu_arch_string(compute_cap: usize) -> String {
GpuArch::auto_suffix(compute_cap).to_nvcc_arch()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matches_pattern() {
assert!(matches_pattern("kernel.cu", "kernel.cu"));
assert!(matches_pattern("sm90_kernel.cu", "sm90_*.cu"));
assert!(matches_pattern("kernel_hopper.cu", "*_hopper.cu"));
assert!(matches_pattern("prefix_middle_suffix.cu", "prefix_*.cu"));
assert!(!matches_pattern("other.cu", "sm90_*.cu"));
}
#[test]
fn test_gpu_arch_string() {
assert_eq!(get_gpu_arch_string(80), "sm_80");
assert_eq!(get_gpu_arch_string(90), "sm_90a");
assert_eq!(get_gpu_arch_string(100), "sm_100a");
assert_eq!(get_gpu_arch_string(120), "sm_120a");
}
#[test]
fn test_gpu_arch_parse() {
let arch = GpuArch::parse("90a").unwrap();
assert_eq!(arch.base, 90);
assert_eq!(arch.suffix, Some("a".to_string()));
assert_eq!(arch.to_nvcc_arch(), "sm_90a");
let arch = GpuArch::parse("100a").unwrap();
assert_eq!(arch.base, 100);
assert_eq!(arch.to_nvcc_arch(), "sm_100a");
let arch = GpuArch::parse("sm_120a").unwrap();
assert_eq!(arch.base, 120);
assert_eq!(arch.to_nvcc_arch(), "sm_120a");
let arch = GpuArch::parse("80").unwrap();
assert_eq!(arch.base, 80);
assert_eq!(arch.suffix, None);
assert_eq!(arch.to_nvcc_arch(), "sm_80");
}
#[test]
fn test_gpu_arch_auto_suffix() {
assert_eq!(GpuArch::auto_suffix(80).to_nvcc_arch(), "sm_80");
assert_eq!(GpuArch::auto_suffix(89).to_nvcc_arch(), "sm_89");
assert_eq!(GpuArch::auto_suffix(90).to_nvcc_arch(), "sm_90a");
assert_eq!(GpuArch::auto_suffix(100).to_nvcc_arch(), "sm_100a");
}
#[test]
fn test_gpu_arch_gencode() {
assert_eq!(
GpuArch::auto_suffix(75).to_gencode_arg(),
"-gencode=arch=compute_75,code=sm_75"
);
assert_eq!(
GpuArch::auto_suffix(80).to_gencode_arg(),
"-gencode=arch=compute_80,code=sm_80"
);
assert_eq!(
GpuArch::auto_suffix(89).to_gencode_arg(),
"-gencode=arch=compute_89,code=sm_89"
);
assert_eq!(
GpuArch::auto_suffix(90).to_gencode_arg(),
"-gencode=arch=compute_90a,code=sm_90a"
);
assert_eq!(
GpuArch::auto_suffix(100).to_gencode_arg(),
"-gencode=arch=compute_100a,code=sm_100a"
);
assert_eq!(
GpuArch::auto_suffix(120).to_gencode_arg(),
"-gencode=arch=compute_120a,code=sm_120a"
);
}
}