use anyhow::Result;
use std::{env, fs, io, path::Path};
const PACKAGES: &[&str] = &[
"nvidia-cuda-runtime-cu12",
"nvidia-cudnn-cu12",
"nvidia-cublas-cu12",
"nvidia-cufft-cu12",
];
fn main() {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=CUDA_LIBS_DIR");
if let Err(err) = run() {
eprintln!("cuda-rt build script failed: {err:?}");
panic!("Failed to prepare CUDA runtime libraries");
}
}
fn run() -> Result<()> {
let workspace = env::var("CARGO_WORKSPACE_DIR")?;
let profile = env::var("PROFILE")?;
let out_dir = Path::new(&workspace).join("target").join(profile);
fs::create_dir_all(&out_dir)?;
let platform_tag = current_platform_tag()?;
for pkg in PACKAGES {
fetch_and_extract(pkg, platform_tag, &out_dir)?;
}
println!("cargo:rustc-link-search=native={}", out_dir.display());
Ok(())
}
fn current_platform_tag() -> Result<&'static str> {
if cfg!(target_os = "windows") {
Ok("win_amd64")
} else if cfg!(all(target_os = "linux", target_arch = "x86_64")) {
Ok("manylinux")
} else {
anyhow::bail!("unsupported platform for CUDA runtime bundling");
}
}
fn fetch_and_extract(pkg: &str, platform_tag: &str, out_dir: &Path) -> Result<()> {
let meta_url = format!("https://pypi.org/pypi/{pkg}/json");
let mut resp = ureq::get(&meta_url).call()?;
let json: serde_json::Value = resp.body_mut().with_config().read_json()?;
let files = json
.get("urls")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow::anyhow!("bad json: urls"))?;
let mut chosen: Option<(String, String)> = None; for f in files {
let filename = f.get("filename").and_then(|v| v.as_str()).unwrap_or("");
let file_url = f.get("url").and_then(|v| v.as_str()).unwrap_or("");
if !filename.ends_with(".whl") {
continue;
}
if filename.contains(platform_tag)
|| (platform_tag == "manylinux" && filename.contains("x86_64"))
{
chosen = Some((file_url.to_string(), filename.to_string()));
break;
}
}
let (wheel_url, wheel_name) = chosen.ok_or_else(|| anyhow::anyhow!("no suitable wheel"))?;
println!("Fetching {wheel_name}...");
let mut resp = ureq::get(&wheel_url).call()?;
let bytes = resp
.body_mut()
.with_config()
.limit(1 * 1024 * 1024 * 1024)
.read_to_vec()?;
extract_from_wheel(&bytes, out_dir)
}
fn extract_from_wheel(bytes: &[u8], out_dir: &Path) -> Result<()> {
let reader = std::io::Cursor::new(bytes);
let mut zip = zip::ZipArchive::new(reader)?;
let mut copied = 0usize;
for i in 0..zip.len() {
let mut file = zip.by_index(i)?;
let name = file.name().to_string();
let lname = name.to_ascii_lowercase();
let is_target = if cfg!(target_os = "windows") {
lname.ends_with(".dll") && lname.contains("nvidia")
} else {
lname.contains(".so") && lname.contains("nvidia")
};
if !is_target {
continue;
}
let fname = std::path::Path::new(&name)
.file_name()
.and_then(|s| s.to_str())
.ok_or_else(|| anyhow::anyhow!("bad filename"))?;
let mut out = fs::File::create(out_dir.join(fname))?;
io::copy(&mut file, &mut out)?;
println!("Copied {fname}");
copied += 1;
}
if copied == 0 {
anyhow::bail!("no CUDA libraries found in wheel");
}
Ok(())
}