fn main() {
#[cfg(feature = "cuda")]
{
use std::path::PathBuf;
const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS");
println!("cargo:rerun-if-changed=build.rs");
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let lib_files = vec!["kernels/bitsandbytes/dequant.cu"];
for lib_file in lib_files.iter() {
println!("cargo:rerun-if-changed={lib_file}");
}
let mut builder = bindgen_cuda::Builder::default()
.kernel_paths(lib_files)
.out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS {
builder = builder.arg("--compiler-options");
builder = builder.arg(cuda_nvcc_flags_env);
}
let target = std::env::var("TARGET").unwrap();
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let out_file = if target.contains("msvc") {
build_dir.join("mistralrsquant.lib")
} else {
build_dir.join("libmistralrsquant.a")
};
builder.build_lib(out_file);
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=mistralrsquant");
println!("cargo:rustc-link-lib=dylib=cudart");
if target.contains("msvc") {
} else if target.contains("apple")
|| target.contains("freebsd")
|| target.contains("openbsd")
{
println!("cargo:rustc-link-lib=dylib=c++");
} else if target.contains("android") {
println!("cargo:rustc-link-lib=dylib=c++_shared");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
}
}