use std::env;
use std::path::PathBuf;
fn main() {
println!("cargo::rerun-if-changed=build.rs");
println!("cargo::rerun-if-changed=src/compatibility.cuh");
println!("cargo::rerun-if-changed=src/cuda_utils.cuh");
println!("cargo::rerun-if-changed=src/binary_op_macros.cuh");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let ptx_path = out_dir.join("ptx.rs");
let builder = bindgen_cuda::Builder::default()
.arg("--expt-relaxed-constexpr")
.arg("-std=c++17")
.arg("-O3");
let bindings = builder.build_ptx().unwrap();
bindings.write(&ptx_path).unwrap();
remove_lines(&ptx_path, &["MOE_GGUF", "MOE_WMMA", "MOE_WMMA_GGUF"]);
let mut moe_builder = bindgen_cuda::Builder::default()
.arg("--expt-relaxed-constexpr")
.arg("-std=c++17")
.arg("-O3");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let mut is_target_msvc = false;
if let Ok(target) = std::env::var("TARGET") {
if target.contains("msvc") {
is_target_msvc = true;
moe_builder = moe_builder.arg("-D_USE_MATH_DEFINES");
}
}
if !is_target_msvc {
moe_builder = moe_builder.arg("-Xcompiler").arg("-fPIC");
}
let moe_builder = moe_builder.kernel_paths(vec![
"src/moe/moe_gguf.cu",
"src/moe/moe_wmma.cu",
"src/moe/moe_wmma_gguf.cu",
]);
moe_builder.build_lib(out_dir.join("libmoe.a"));
println!("cargo:rustc-link-search={}", out_dir.display());
println!("cargo:rustc-link-lib=moe");
println!("cargo:rustc-link-lib=dylib=cudart");
if !is_target_msvc {
println!("cargo:rustc-link-lib=stdc++");
}
}
fn remove_lines<P: AsRef<std::path::Path>>(file: P, patterns: &[&str]) {
let content = std::fs::read_to_string(&file).unwrap();
let filtered = content
.lines()
.filter(|line| !patterns.iter().any(|p| line.contains(p)))
.collect::<Vec<_>>()
.join("\n");
std::fs::write(file, filtered).unwrap();
}