extern crate bindgen;
use std::env;
use std::path::{Path, PathBuf};
use std::fs;
use std::io;
fn get_lightgbm_version() -> String {
env::var("LIGHTGBM_VERSION").unwrap_or_else(|_| "4.6.0".to_string())
}
fn get_platform_info() -> (String, String) {
let target = env::var("TARGET").unwrap();
let os = if target.contains("apple-darwin") {
"darwin"
} else if target.contains("linux") {
"linux"
} else if target.contains("windows") {
"windows"
} else {
panic!("Unsupported target: {}", target);
};
let arch = if target.contains("x86_64") {
"x86_64"
} else if target.contains("aarch64") || target.contains("arm64") {
"aarch64"
} else if target.contains("i686") || target.contains("i586") {
"i686"
} else {
panic!("Unsupported architecture for target: {}", target);
};
(os.to_string(), arch.to_string())
}
fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
let version = get_lightgbm_version();
let include_dir = out_dir.join("include/LightGBM");
fs::create_dir_all(&include_dir)?;
let c_api_url = format!(
"https://raw.githubusercontent.com/microsoft/LightGBM/v{}/include/LightGBM/c_api.h",
version
);
println!("cargo:warning=Downloading c_api.h from: {}", c_api_url);
let response = ureq::get(&c_api_url).call()?;
let status = response.status();
if status < 200 || status >= 300 {
return Err(format!("Failed to download c_api.h: HTTP {}", status).into());
}
let c_api_path = include_dir.join("c_api.h");
let mut file = fs::File::create(&c_api_path)?;
io::copy(&mut response.into_reader(), &mut file)?;
let export_url = format!(
"https://raw.githubusercontent.com/microsoft/LightGBM/v{}/include/LightGBM/export.h",
version
);
println!("cargo:warning=Downloading export.h from: {}", export_url);
let response = ureq::get(&export_url).call()?;
let status = response.status();
if status < 200 || status >= 300 {
return Err(format!("Failed to download export.h: HTTP {}", status).into());
}
let export_path = include_dir.join("export.h");
let mut file = fs::File::create(&export_path)?;
io::copy(&mut response.into_reader(), &mut file)?;
let arrow_url = format!(
"https://raw.githubusercontent.com/microsoft/LightGBM/v{}/include/LightGBM/arrow.h",
version
);
println!("cargo:warning=Attempting to download arrow.h from: {}", arrow_url);
match ureq::get(&arrow_url).call() {
Ok(response) if response.status() >= 200 && response.status() < 300 => {
let arrow_path = include_dir.join("arrow.h");
let mut file = fs::File::create(&arrow_path)?;
io::copy(&mut response.into_reader(), &mut file)?;
println!("cargo:warning=Successfully downloaded arrow.h");
let arrow_tpp_url = format!(
"https://raw.githubusercontent.com/microsoft/LightGBM/v{}/include/LightGBM/arrow.tpp",
version
);
println!("cargo:warning=Attempting to download arrow.tpp from: {}", arrow_tpp_url);
match ureq::get(&arrow_tpp_url).call() {
Ok(resp) if resp.status() >= 200 && resp.status() < 300 => {
let arrow_tpp_path = include_dir.join("arrow.tpp");
let mut file = fs::File::create(&arrow_tpp_path)?;
io::copy(&mut resp.into_reader(), &mut file)?;
println!("cargo:warning=Successfully downloaded arrow.tpp");
}
_ => {
println!("cargo:warning=arrow.tpp not available for this version (optional)");
}
}
}
_ => {
println!("cargo:warning=arrow.h not available for this version (optional, only in v4.2.0+)");
}
}
Ok(())
}
fn download_compiled_library(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
let (os, _arch) = get_platform_info();
let version = get_lightgbm_version();
let (lib_filename, download_url) = match os.as_str() {
"linux" => (
"lib_lightgbm.so".to_string(),
format!(
"https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.so",
version
),
),
"darwin" => (
"lib_lightgbm.dylib".to_string(),
format!(
"https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.dylib",
version
),
),
"windows" => (
"lib_lightgbm.dll".to_string(),
format!(
"https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.dll",
version
),
),
_ => return Err(format!("Unsupported platform: {}", os).into()),
};
println!(
"cargo:warning=Downloading LightGBM v{} library from: {}",
version, download_url
);
let lib_dir = out_dir.join("libs");
fs::create_dir_all(&lib_dir)?;
let lib_path = lib_dir.join(&lib_filename);
let mut dest = fs::File::create(&lib_path)?;
let response = ureq::get(&download_url).call()?;
let status = response.status();
if status < 200 || status >= 300 {
return Err(format!("Failed to download library: HTTP {}", status).into());
}
io::copy(&mut response.into_reader(), &mut dest)?;
println!(
"cargo:warning=Downloaded LightGBM library to: {}",
lib_path.display()
);
Ok(())
}
fn main() {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let lgbm_include_root = out_dir.join("include");
if let Err(e) = download_lightgbm_headers(&out_dir) {
eprintln!("Failed to download LightGBM headers: {}", e);
panic!("Cannot proceed without headers");
}
if let Err(e) = download_compiled_library(&out_dir) {
eprintln!("Failed to download compiled library: {}", e);
panic!("Cannot proceed without compiled library");
}
let bindings = bindgen::Builder::default()
.header("wrapper.h")
.clang_arg(format!("-I{}", lgbm_include_root.display()))
.clang_arg("-xc++")
.clang_arg("-std=c++11")
.allowlist_function("LGBM_.*")
.allowlist_type("BoosterHandle")
.allowlist_type("DatasetHandle")
.allowlist_type("FastConfigHandle")
.allowlist_type("ArrowArray")
.allowlist_type("ArrowSchema")
.allowlist_var("C_API_DTYPE_.*")
.opaque_type("ArrowArray")
.opaque_type("ArrowSchema")
.blocklist_type("std::.*")
.blocklist_type("ArrowTable")
.blocklist_type("ArrowChunkedArray")
.blocklist_type(".*_Tp.*")
.blocklist_type(".*_Pred.*")
.size_t_is_usize(true)
.rustfmt_bindings(true)
.generate()
.expect("Unable to generate bindings.");
bindings
.write_to_file(out_dir.join("bindings.rs"))
.expect("Couldn't write bindings.");
let (os, _arch) = get_platform_info();
let lib_filename = match os.as_str() {
"windows" => "lib_lightgbm.dll",
"darwin" => "lib_lightgbm.dylib",
_ => "lib_lightgbm.so", };
let lib_source_path = out_dir.join("libs").join(lib_filename);
let target_dir = out_dir
.ancestors()
.find(|p| p.ends_with("target"))
.unwrap()
.join(env::var("PROFILE").unwrap());
let lib_dest_path = target_dir.join(lib_filename);
fs::copy(&lib_source_path, &lib_dest_path)
.expect("Failed to copy library to target directory");
let lib_search_path = out_dir.join("libs");
println!(
"cargo:rustc-link-search=native={}",
lib_search_path.display()
);
match os.as_str() {
"darwin" => {
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../..");
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display());
if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) {
println!("cargo:rustc-link-arg=-Wl,-rpath,{}/debug", target_root.display());
println!("cargo:rustc-link-arg=-Wl,-rpath,{}/release", target_root.display());
}
},
"linux" => {
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN");
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../..");
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display());
},
_ => {} }
println!("cargo:rustc-link-lib=dylib=lib_lightgbm");
}