use std::process::Command;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AptDistro {
pub id: String,
pub version_id: String,
pub arch: String,
}
impl AptDistro {
pub fn repo_segment(&self) -> String {
format!("{}{}", self.id, self.version_id)
}
}
pub fn detect_apt_distro() -> Option<AptDistro> {
if !cfg!(target_os = "linux") {
return None;
}
let release = std::fs::read_to_string("/etc/os-release").ok()?;
let mut id = None;
let mut id_like = String::new();
let mut version_id = None;
for line in release.lines() {
let (k, v) = line.split_once('=')?;
let v = v.trim_matches('"').trim();
match k {
"ID" => id = Some(v.to_string()),
"ID_LIKE" => id_like = v.to_string(),
"VERSION_ID" => version_id = Some(v.replace('.', "")),
_ => {}
}
}
let id = id?;
let version_id = version_id?;
let supported = matches!(id.as_str(), "ubuntu" | "debian")
|| id_like
.split_whitespace()
.any(|t| t == "ubuntu" || t == "debian");
if !supported {
return None;
}
let dpkg = Command::new("dpkg")
.arg("--print-architecture")
.output()
.ok()?;
let arch = String::from_utf8(dpkg.stdout).ok()?;
let arch = match arch.trim() {
"amd64" => "x86_64",
"arm64" => "sbsa",
other => other, }
.to_string();
Some(AptDistro {
id,
version_id,
arch,
})
}
pub fn driver_cuda_major() -> Option<u32> {
let out = Command::new("nvidia-smi")
.arg("--query")
.arg("--display=COMPUTE")
.output()
.ok()?;
if !out.status.success() {
let plain = Command::new("nvidia-smi").output().ok()?;
let s = String::from_utf8_lossy(&plain.stdout);
return parse_smi_cuda_version(&s);
}
let s = String::from_utf8_lossy(&out.stdout);
parse_smi_cuda_version(&s)
}
fn parse_smi_cuda_version(s: &str) -> Option<u32> {
let needle = "CUDA Version";
let pos = s.find(needle)?;
let rest = &s[pos + needle.len()..];
let rest = rest.trim_start_matches(|c: char| c == ':' || c.is_whitespace());
let major: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
major.parse().ok()
}
pub fn nvrtc_present() -> bool {
if !cfg!(target_os = "linux") {
return false;
}
let out = Command::new("ldconfig").arg("-p").output();
let Ok(out) = out else {
return false;
};
String::from_utf8_lossy(&out.stdout).contains("libnvrtc.so")
}
pub fn nvrtc_install_command(driver_major: u32) -> String {
let primary = match driver_major {
13 => "cuda-nvrtc-13-0",
12 => "cuda-nvrtc-12-6",
_ => "cuda-nvrtc-13-0",
};
format!(
"apt-get update -qq && \
DEBIAN_FRONTEND=noninteractive apt-get install -yqq {primary} 2>/dev/null || \
DEBIAN_FRONTEND=noninteractive apt-get install -yqq cuda-nvrtc-13-0 2>/dev/null || \
DEBIAN_FRONTEND=noninteractive apt-get install -yqq cuda-nvrtc-12-6 2>/dev/null || \
DEBIAN_FRONTEND=noninteractive apt-get install -yqq cuda-nvrtc-12-4 2>/dev/null || \
DEBIAN_FRONTEND=noninteractive apt-get install -yqq cuda-nvrtc-12-0 2>/dev/null || \
DEBIAN_FRONTEND=noninteractive apt-get install -yqq libnvrtc12 2>/dev/null || true"
)
}
#[cfg(target_os = "linux")]
pub fn install_nvrtc(distro: &AptDistro, driver_major: u32) -> anyhow::Result<()> {
use anyhow::{anyhow, Context};
let keyring_installed = std::path::Path::new("/usr/share/keyrings/cuda-archive-keyring.gpg")
.exists()
|| std::fs::read_dir("/etc/apt/sources.list.d")
.map(|d| {
d.flatten()
.any(|e| e.file_name().to_string_lossy().contains("cuda"))
})
.unwrap_or(false);
let segment = distro.repo_segment();
let arch = &distro.arch;
let keyring_url = format!(
"https://developer.download.nvidia.com/compute/cuda/repos/{segment}/{arch}/cuda-keyring_1.1-1_all.deb"
);
let install_tail = nvrtc_install_command(driver_major);
let script = if keyring_installed {
install_tail
} else {
format!(
"set -e; \
tmp=$(mktemp -d); \
curl -fsSL --retry 3 -o \"$tmp/cuda-keyring.deb\" '{keyring_url}'; \
dpkg -i \"$tmp/cuda-keyring.deb\"; \
rm -rf \"$tmp\"; \
{install_tail}"
)
};
run_privileged_script(&script).context("CUDA NVRTC install failed")?;
if !nvrtc_present() {
return Err(anyhow!(
"apt completed but libnvrtc.so still not visible to ldconfig — \
the package may not match your driver's CUDA major version ({driver_major}). \
See https://developer.nvidia.com/cuda-toolkit-archive"
));
}
Ok(())
}
#[cfg(target_os = "linux")]
fn run_privileged_script(script: &str) -> anyhow::Result<()> {
use anyhow::bail;
let is_root = unsafe { libc::geteuid() } == 0;
let mut cmd = if is_root {
let mut c = Command::new("sh");
c.arg("-c").arg(script);
c
} else {
let mut c = Command::new("sudo");
c.arg("-E").arg("sh").arg("-c").arg(script);
c
};
let status = cmd.status()?;
if !status.success() {
bail!("privileged command exited with {status}");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_smi_header_format() {
let s = "+-----------------------------------------------------------------------------+\n\
| NVIDIA-SMI 550.127.05 Driver Version: 550.127.05 CUDA Version: 13.0 |\n\
|-------------------------------+----------------------+----------------------+\n";
assert_eq!(parse_smi_cuda_version(s), Some(13));
}
#[test]
fn parses_smi_query_format() {
let s = "==============NVSMI LOG==============\n\nCUDA Version : 12.6\n";
assert_eq!(parse_smi_cuda_version(s), Some(12));
}
#[test]
fn ignores_garbage() {
assert_eq!(parse_smi_cuda_version("nothing relevant here"), None);
}
#[test]
fn install_command_picks_matching_package_first() {
let cmd = nvrtc_install_command(13);
let pos_13 = cmd.find("cuda-nvrtc-13-0").unwrap();
let pos_12 = cmd.find("cuda-nvrtc-12-6").unwrap();
assert!(pos_13 < pos_12);
}
#[test]
fn distro_repo_segment_is_concatenated() {
let d = AptDistro {
id: "ubuntu".to_string(),
version_id: "2404".to_string(),
arch: "x86_64".to_string(),
};
assert_eq!(d.repo_segment(), "ubuntu2404");
}
}