use cudaforge::{KernelBuilder, Result};
use std::env;
use std::path::PathBuf;
fn main() -> Result<()> {
println!("cargo::rerun-if-changed=build.rs");
println!("cargo::rerun-if-changed=src/compatibility.cuh");
println!("cargo::rerun-if-changed=src/cuda_utils.cuh");
println!("cargo::rerun-if-changed=src/binary_op_macros.cuh");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let ptx_path = out_dir.join("ptx.rs");
let bindings = KernelBuilder::new()
.source_dir("src") .exclude(&["moe_*.cu"]) .arg("--expt-relaxed-constexpr")
.arg("-std=c++17")
.arg("-O3")
.build_ptx()?;
bindings.write(&ptx_path)?;
let mut moe_builder = KernelBuilder::default()
.source_files(vec![
"src/moe/moe_gguf.cu",
"src/moe/moe_wmma.cu",
"src/moe/moe_wmma_gguf.cu",
])
.arg("--expt-relaxed-constexpr")
.arg("-std=c++17")
.arg("-O3");
let compute_cap = cudaforge::detect_compute_cap()
.map(|arch| arch.base())
.unwrap_or(80);
if compute_cap < 80 {
moe_builder = moe_builder.arg("-DNO_BF16_KERNEL");
}
let mut is_target_msvc = false;
if let Ok(target) = std::env::var("TARGET") {
if target.contains("msvc") {
is_target_msvc = true;
moe_builder = moe_builder.arg("-D_USE_MATH_DEFINES");
}
}
if !is_target_msvc {
moe_builder = moe_builder.arg("-Xcompiler").arg("-fPIC");
}
moe_builder.build_lib(out_dir.join("libmoe.a"))?;
println!("cargo:rustc-link-search={}", out_dir.display());
println!("cargo:rustc-link-lib=moe");
println!("cargo:rustc-link-lib=dylib=cudart");
if !is_target_msvc {
println!("cargo:rustc-link-lib=stdc++");
}
Ok(())
}