use std::env;
use std::path::{Path, PathBuf};
fn main() {
println!("cargo:rerun-if-changed=build.rs");
#[cfg(feature = "vendored")]
let include_path = build_vendored();
#[cfg(not(feature = "vendored"))]
let include_path = find_system_library();
generate_bindings(&include_path);
}
#[cfg(feature = "vendored")]
fn build_vendored() -> PathBuf {
let project_root = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
.parent()
.unwrap()
.parent()
.unwrap()
.parent()
.unwrap()
.to_path_buf();
println!("cargo:rerun-if-changed=../../..");
let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string());
let cmake_build_type = match profile.as_str() {
"debug" => "Debug",
"release" => "Release",
"bench" => "Release",
_ => "RelWithDebInfo",
};
let mut cmake_config = cmake::Config::new(&project_root);
cmake_config.profile(cmake_build_type);
cmake_config
.define("HMLL_BUILD_STATIC", "ON")
.define("HMLL_BUILD_EXAMPLES", "OFF")
.define("HMLL_BUILD_TESTS", "OFF")
.define("HMLL_ENABLE_PYTHON", "OFF")
.build_target("libhmll");
#[cfg(all(target_os = "linux", feature = "io_uring"))]
cmake_config.define("HMLL_ENABLE_IO_URING", "ON");
#[cfg(not(all(target_os = "linux", feature = "io_uring")))]
cmake_config.define("HMLL_ENABLE_IO_URING", "OFF");
#[cfg(feature = "safetensors")]
cmake_config.define("HMLL_ENABLE_SAFETENSORS", "ON");
#[cfg(not(feature = "safetensors"))]
cmake_config.define("HMLL_ENABLE_SAFETENSORS", "OFF");
#[cfg(feature = "cuda")]
cmake_config.define("HMLL_ENABLE_CUDA", "ON");
#[cfg(not(feature = "cuda"))]
cmake_config.define("HMLL_ENABLE_CUDA", "OFF");
let dst = cmake_config.build();
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=static=libhmll");
#[cfg(all(target_os = "linux", feature = "io_uring"))]
{
println!(
"cargo:rustc-link-search=native={}/build/_deps/liburing-src/src",
dst.display()
);
println!("cargo:rustc-link-lib=static=uring");
}
#[cfg(feature = "safetensors")]
{
println!(
"cargo:rustc-link-search=native={}/build/_deps/yyjson-build",
dst.display()
);
println!("cargo:rustc-link-lib=static=yyjson");
}
#[cfg(feature = "cuda")]
link_cuda();
project_root.join("include")
}
#[cfg(not(feature = "vendored"))]
fn find_system_library() -> PathBuf {
let library = pkg_config::Config::new()
.atleast_version("0.1.0")
.probe("hmll")
.expect(
"Could not find system hmll library. \
Either install libhmll or enable the 'vendored' feature to build from source.",
);
#[cfg(all(target_os = "linux", feature = "io_uring"))]
{
if pkg_config::probe_library("liburing").is_err() {
println!("cargo:rustc-link-lib=uring");
}
}
#[cfg(feature = "safetensors")]
{
let _ = pkg_config::probe_library("yyjson");
}
#[cfg(feature = "cuda")]
link_cuda();
library
.include_paths
.first()
.cloned()
.unwrap_or_else(|| PathBuf::from("/usr/include"))
}
#[cfg(feature = "cuda")]
fn link_cuda() {
if let Ok(cuda_path) = env::var("CUDA_PATH") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_path);
println!("cargo:rustc-link-search=native={}/lib", cuda_path);
} else if let Ok(cuda_home) = env::var("CUDA_HOME") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_home);
println!("cargo:rustc-link-search=native={}/lib", cuda_home);
} else {
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib");
println!("cargo:rustc-link-search=native=/opt/cuda/lib64");
println!("cargo:rustc-link-search=native=/opt/cuda/lib");
}
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
}
fn generate_bindings(include_path: &Path) {
let builder = bindgen::Builder::default()
.header(include_path.join("hmll/hmll.h").to_str().unwrap())
.clang_arg(format!("-I{}", include_path.display()))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.allowlist_function("hmll_.*")
.allowlist_type("hmll_.*")
.allowlist_var("HMLL_.*")
.derive_debug(true)
.derive_default(true)
.derive_copy(true)
.derive_eq(true)
.derive_hash(true)
.no_partialeq("hmll_loader")
.no_partialeq("__sFILE")
.impl_debug(true)
.prepend_enum_name(false)
.size_t_is_usize(true)
.layout_tests(false)
.rustified_enum("hmll_status_code")
.rustified_enum("hmll_device")
.rustified_enum("hmll_loader_kind")
.rustified_enum("hmll_dtype");
#[cfg(feature = "safetensors")]
let builder = builder
.clang_arg("-D__HMLL_SAFETENSORS_ENABLED__=1")
.clang_arg("-D__HMLL_TENSORS_ENABLED__=1");
#[cfg(feature = "cuda")]
let builder = builder.clang_arg("-D__HMLL_CUDA_ENABLED__=1");
let bindings = builder.generate().expect("Unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}