extern crate bindgen;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::env;
use std::fs;
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use std::thread;
use std::time::Duration;
fn get_xgboost_version() -> String {
env::var("XGBOOST_VERSION").unwrap_or_else(|_| "3.1.1".to_string())
}
fn get_header_checksums() -> HashMap<&'static str, (&'static str, &'static str)> {
let mut checksums = HashMap::new();
checksums.insert(
"3.1.1",
(
"c0f0a98eb36fb5e451fdd3e9ead2d185f4c61be2a6997fc295e5d1a94f3096e2",
"8d771fb20e03f3443e21cfdcd26ac5cd880be585b8817f2e0d146e7c5c7bb63a",
),
);
checksums.insert(
"3.0.5",
(
"2ccec6e5301fa5a1324f60af48b9c6be5879e590ed583ec9d74297e6018860bc",
"47f0148706907ccecb72b8484687524bc36d58b4c6fe5e7b81e59de157261ea7",
),
);
checksums.insert(
"2.1.4",
(
"b804850ec6c7a00f8e36f139dfce7fe348fc9ad066ff4cb7ac44a4f5420ec1dd",
"525c4a2ba2f6bd9b17a299978e16f91897d497d6ae0ae5df2335dd059f00d0ce",
),
);
checksums.insert(
"1.7.6",
(
"145ed1df652937122b6f6bc31331051eabc02226a0b62349ea593cdbe841c20d",
"b26e17eadbcc6350dc900b35d164eedc02b1cd2a64913c560d4d416c81a68935",
),
);
checksums.insert(
"1.4.2",
(
"3f5de5d046a3c9576e0c560abe5fa1e889f72b4b18ff2bf73e5f98290d47d0dc",
"e3abfcc730eee86acf44124d5496a2b41413f963c4bbf560513eeae0b7d12fb7",
),
);
checksums
}
fn compute_sha256(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
fn verify_checksum(
data: &[u8],
expected: &str,
filename: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let actual = compute_sha256(data);
if actual != expected {
return Err(format!(
"SHA256 checksum mismatch for {}:\n Expected: {}\n Got: {}",
filename, expected, actual
)
.into());
}
println!("cargo:warning=✓ Verified SHA256 for {}", filename);
Ok(())
}
fn parse_version(version: &str) -> (u32, u32, u32) {
let parts: Vec<&str> = version.split('.').collect();
let major = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0);
let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
let patch = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0);
(major, minor, patch)
}
fn emit_version_cfg_flags(version: &str) {
let (major, minor, _patch) = parse_version(version);
if major > 1 || (major == 1 && minor >= 4) {
println!("cargo:rustc-cfg=xgboost_thread_safe");
println!(
"cargo:warning=XGBoost version {} supports thread-safe predictions",
version
);
} else {
println!(
"cargo:warning=XGBoost version {} does NOT support thread-safe predictions",
version
);
}
}
fn get_platform_info() -> (String, String) {
let target = env::var("TARGET").unwrap();
let os = if target.contains("apple-darwin") {
"darwin"
} else if target.contains("linux") {
"linux"
} else if target.contains("windows") {
"windows"
} else {
panic!("Unsupported target: {}", target);
};
let arch = if target.contains("x86_64") {
"x86_64"
} else if target.contains("aarch64") || target.contains("arm64") {
"aarch64"
} else if target.contains("i686") || target.contains("i586") {
"i686"
} else {
panic!("Unsupported architecture for target: {}", target);
};
(os.to_string(), arch.to_string())
}
fn download_xgboost_headers(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
let version = get_xgboost_version();
let checksums = get_header_checksums();
let (c_api_expected, base_expected) = checksums.get(version.as_str()).ok_or_else(|| {
format!(
"No known SHA256 checksums for XGBoost version {}. \
Please verify this version manually or add checksums to build.rs",
version
)
})?;
let include_dir = out_dir.join("include/xgboost");
fs::create_dir_all(&include_dir)?;
let c_api_path = include_dir.join("c_api.h");
download_and_verify_file(
&format!(
"https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/c_api.h",
version
),
&c_api_path,
c_api_expected,
"c_api.h",
)?;
let base_path = include_dir.join("base.h");
download_and_verify_file(
&format!(
"https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/base.h",
version
),
&base_path,
base_expected,
"base.h",
)?;
Ok(())
}
fn download_and_verify_file(
url: &str,
dest_path: &Path,
expected_sha256: &str,
filename: &str,
) -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:warning=Downloading {} from: {}", filename, url);
let response = ureq::get(url).call()?;
let status = response.status();
if !(200..300).contains(&status) {
return Err(format!("Failed to download {}: HTTP {}", filename, status).into());
}
let mut buffer = Vec::new();
response.into_reader().read_to_end(&mut buffer)?;
verify_checksum(&buffer, expected_sha256, filename)?;
let mut file = fs::File::create(dest_path)?;
file.write_all(&buffer)?;
Ok(())
}
fn download_with_retry(url: &str, max_retries: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let mut last_error = None;
for attempt in 0..max_retries {
if attempt > 0 {
let backoff = Duration::from_millis(100 * 2_u64.pow(attempt));
println!(
"cargo:warning=Retry attempt {} after {:?}",
attempt + 1,
backoff
);
thread::sleep(backoff);
}
match ureq::get(url).call() {
Ok(response) => {
let status = response.status();
if !(200..300).contains(&status) {
last_error = Some(format!("HTTP {}", status));
continue;
}
let mut buffer = Vec::new();
if let Err(e) = response.into_reader().read_to_end(&mut buffer) {
last_error = Some(e.to_string());
continue;
}
return Ok(buffer);
}
Err(e) => {
last_error = Some(e.to_string());
}
}
}
Err(format!(
"Failed to download after {} attempts. Last error: {}",
max_retries,
last_error.unwrap_or_else(|| "Unknown error".to_string())
)
.into())
}
fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
let (os, arch) = get_platform_info();
let version = get_xgboost_version();
let (major, minor, _patch) = parse_version(&version);
let wheel_filename = match (os.as_str(), arch.as_str()) {
("linux", "x86_64") => {
let manylinux_tag = if major >= 3 {
"manylinux_2_28"
} else if major == 1 && minor == 4 {
"manylinux2010"
} else {
"manylinux2014"
};
format!("xgboost-{}-py3-none-{}_x86_64.whl", version, manylinux_tag)
}
("linux", "aarch64") => {
let manylinux_tag = if major >= 3 {
"manylinux_2_28"
} else {
"manylinux2014"
};
format!("xgboost-{}-py3-none-{}_aarch64.whl", version, manylinux_tag)
}
("darwin", "x86_64") => {
if major >= 3 {
format!("xgboost-{}-py3-none-macosx_10_15_x86_64.whl", version)
} else if major == 1 && minor == 4 {
format!("xgboost-{}-py3-none-macosx_10_14_x86_64.macosx_10_15_x86_64.macosx_11_0_x86_64.whl", version)
} else {
format!("xgboost-{}-py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64.whl", version)
}
}
("darwin", "aarch64") => {
if major == 1 && minor < 5 {
return Err(format!(
"XGBoost {} does not have macOS arm64 support. Minimum version for arm64 is 1.5.0",
version
).into());
}
format!("xgboost-{}-py3-none-macosx_12_0_arm64.whl", version)
}
("windows", "x86_64") => format!("xgboost-{}-py3-none-win_amd64.whl", version),
_ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()),
};
let lib_filename = match os.as_str() {
"windows" => "xgboost.dll",
"darwin" => "libxgboost.dylib",
_ => "libxgboost.so",
};
let wheel_dir = out_dir.join("wheel");
let lib_dir = out_dir.join("libs");
fs::create_dir_all(&wheel_dir)?;
fs::create_dir_all(&lib_dir)?;
let wheel_path = wheel_dir.join(&wheel_filename);
let lib_dest_path = lib_dir.join(lib_filename);
if lib_dest_path.exists() {
println!(
"cargo:warning=Using cached XGBoost library at: {}",
lib_dest_path.display()
);
return Ok(());
}
let wheel_buffer = if wheel_path.exists() {
println!(
"cargo:warning=Using cached wheel at: {}",
wheel_path.display()
);
fs::read(&wheel_path)?
} else {
let download_url = format!(
"https://files.pythonhosted.org/packages/py3/x/xgboost/{}",
wheel_filename
);
println!(
"cargo:warning=Downloading XGBoost wheel from: {}",
download_url
);
let buffer = download_with_retry(&download_url, 3)?;
let temp_path = wheel_path.with_extension("tmp");
{
let mut temp_file = fs::File::create(&temp_path)?;
temp_file.write_all(&buffer)?;
temp_file.sync_all()?;
}
fs::rename(&temp_path, &wheel_path)?;
println!("cargo:warning=✓ Downloaded and cached wheel");
buffer
};
println!("cargo:warning=Extracting library from wheel");
let cursor = io::Cursor::new(wheel_buffer);
let mut archive = zip::ZipArchive::new(cursor)?;
let mut found = false;
let mut searched_paths = Vec::new();
for i in 0..archive.len() {
let mut file = archive.by_index(i)?;
let file_path = file.name().to_string();
if file_path.contains(".dll") || file_path.contains(".so") || file_path.contains(".dylib") {
searched_paths.push(file_path.clone());
}
if file_path.ends_with(lib_filename)
|| file_path.ends_with(&format!("/{}", lib_filename))
|| file_path.ends_with(&format!("\\{}", lib_filename))
{
println!("cargo:warning=Found library at: {}", file_path);
let temp_dest_path = lib_dest_path.with_extension("tmp");
{
let mut dest = fs::File::create(&temp_dest_path)?;
io::copy(&mut file, &mut dest)?;
dest.sync_all()?;
}
fs::rename(&temp_dest_path, &lib_dest_path)?;
found = true;
break;
}
}
if !found {
let error_msg = if searched_paths.is_empty() {
format!(
"Library file {} not found in wheel. No library files found at all.",
lib_filename
)
} else {
format!(
"Library file {} not found in wheel. Found these library files instead: {:?}",
lib_filename, searched_paths
)
};
return Err(error_msg.into());
}
println!(
"cargo:warning=✓ Successfully extracted XGBoost library to: {}",
lib_dir.display()
);
Ok(())
}
fn main() {
println!("cargo:rustc-check-cfg=cfg(xgboost_thread_safe)");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let xgb_include_root = out_dir.join("include");
let version = get_xgboost_version();
emit_version_cfg_flags(&version);
if let Err(e) = download_xgboost_headers(&out_dir) {
eprintln!("Failed to download XGBoost headers: {}", e);
panic!("Cannot proceed without headers");
}
if let Err(e) = download_and_extract_wheel(&out_dir) {
eprintln!("Failed to download and extract wheel: {}", e);
panic!("Cannot proceed without compiled library");
}
let bindings = bindgen::Builder::default()
.header("wrapper.h")
.clang_arg(format!("-I{}", xgb_include_root.display()))
.allowlist_function("XGB.*")
.allowlist_function("XGD.*")
.allowlist_type("BoosterHandle")
.allowlist_type("DMatrixHandle")
.allowlist_type("bst_ulong")
.size_t_is_usize(true)
.generate_comments(false)
.generate()
.expect("Unable to generate bindings.");
bindings
.write_to_file(out_dir.join("bindings.rs"))
.expect("Couldn't write bindings.");
let (os, _arch) = get_platform_info();
let lib_filename = match os.as_str() {
"windows" => "xgboost.dll",
"darwin" => "libxgboost.dylib",
_ => "libxgboost.so",
};
let lib_source_path = out_dir.join("libs").join(lib_filename);
let target_dir = out_dir
.ancestors()
.find(|p| p.ends_with("target"))
.unwrap()
.join(env::var("PROFILE").unwrap());
let lib_dest_path = target_dir.join(lib_filename);
fs::copy(&lib_source_path, &lib_dest_path).expect("Failed to copy library to target directory");
if os == "darwin" {
use std::process::Command;
let _ = Command::new("install_name_tool")
.arg("-id")
.arg(format!("@loader_path/{}", lib_filename))
.arg(&lib_source_path)
.status();
let _ = Command::new("install_name_tool")
.arg("-id")
.arg(format!("@loader_path/{}", lib_filename))
.arg(&lib_dest_path)
.status();
} else if os == "linux" {
use std::process::Command;
let _ = Command::new("patchelf")
.arg("--set-soname")
.arg(lib_filename)
.arg(&lib_source_path)
.output();
let _ = Command::new("patchelf")
.arg("--set-soname")
.arg(lib_filename)
.arg(&lib_dest_path)
.output();
}
let lib_search_path = out_dir.join("libs");
println!(
"cargo:rustc-link-search=native={}",
lib_search_path.display()
);
match os.as_str() {
"darwin" => {
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../..");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path/../..");
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
lib_search_path.display()
);
if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) {
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/debug",
target_root.display()
);
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/release",
target_root.display()
);
}
}
"linux" => {
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN");
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../..");
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
lib_search_path.display()
);
if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) {
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/debug",
target_root.display()
);
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/release",
target_root.display()
);
}
}
_ => {} }
println!("cargo:rustc-link-lib=dylib=xgboost");
}