use std::env;
fn main() -> Result<(), String> {
let rust_toolchain = env::var("RUSTUP_TOOLCHAIN")
.or_else(|e| match e {
env::VarError::NotPresent => Ok("stable".into()),
e => Err(e),
})
.unwrap();
if rust_toolchain.starts_with("nightly") {
println!("cargo:rustc-cfg=feature=\"nightly\"");
}
println!("cargo::rustc-check-cfg=cfg(kernel_support, values(\"avx512\"))");
println!("cargo:rerun-if-changed=src/simd/f16.c");
println!("cargo:rerun-if-changed=src/simd/dist_table.c");
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap();
if target_os == "windows" {
println!(
"cargo:warning=fp16 kernels are not supported on Windows. Skipping compilation of kernels."
);
return Ok(());
}
if target_arch == "aarch64" && target_os == "macos" {
build_f16_with_flags("neon", &["-mtune=apple-m1"]).unwrap();
} else if target_arch == "aarch64" && target_os == "linux" {
build_f16_with_flags("neon", &["-march=armv8.2-a+fp16"]).unwrap();
} else if target_arch == "x86_64" {
if let Err(err) = build_f16_with_flags("avx512", &["-march=sapphirerapids", "-mavx512fp16"])
{
println!(
"cargo:warning=Skipping build of AVX-512 fp16 kernels. Error: {}",
err
);
} else {
println!("cargo:rustc-cfg=kernel_support=\"avx512\"");
};
if let Err(err) = build_dist_table_with_flags("avx512", &["-march=native"]) {
println!(
"cargo:warning=Skipping build of AVX-512 dist_table. Error: {}",
err
);
} else {
println!("cargo:rustc-cfg=kernel_support=\"avx512\"");
};
if let Err(err) = build_f16_with_flags("avx2", &["-march=haswell"]) {
return Err(format!("Unable to build AVX2 f16 kernels. Please use Clang >= 6 or GCC >= 12 or remove the fp16kernels feature. Received error: {}", err));
};
} else if target_arch == "loongarch64" {
build_f16_with_flags("lsx", &["-mlsx"]).unwrap();
build_f16_with_flags("lasx", &["-mlasx"]).unwrap();
} else {
return Err("Unable to build f16 kernels on given target_arch. Please use x86_64 or aarch64 or remove the fp16kernels feature".to_string());
}
Ok(())
}
fn build_f16_with_flags(suffix: &str, flags: &[&str]) -> Result<(), cc::Error> {
if cfg!(not(feature = "fp16kernels")) {
println!(
"cargo:warning=fp16kernels feature is not enabled, skipping build of fp16 kernels"
);
return Ok(());
}
let mut builder = cc::Build::new();
builder
.std("c17")
.file("src/simd/f16.c")
.flag("-ffast-math")
.flag("-funroll-loops")
.flag("-O3")
.flag("-Wall")
.flag("-Wextra")
.flag(format!("-DSUFFIX=_{}", suffix).as_str());
for flag in flags {
builder.flag(flag);
}
builder.try_compile(&format!("f16_{}", suffix))
}
fn build_dist_table_with_flags(suffix: &str, flags: &[&str]) -> Result<(), cc::Error> {
let mut builder = cc::Build::new();
builder
.std("c17")
.file("src/simd/dist_table.c")
.flag("-funroll-loops")
.flag("-O3")
.flag("-Wall")
.flag("-Wextra")
.flag("-mavx512bw")
.flag(format!("-DSUFFIX=_{}", suffix).as_str());
for flag in flags {
builder.flag(flag);
}
builder.try_compile(&format!("dist_table_{}", suffix))
}