use std::env;
use std::path::PathBuf;
use cc::Build;
fn compile_bindings(out_path: &PathBuf) {
let bindings = bindgen::Builder::default()
.header("./binding.h")
.blocklist_function("tokenCallback")
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.generate()
.expect("Unable to generate bindings");
bindings
.write_to_file(&out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
fn compile_opencl(cx: &mut Build, cxx: &mut Build) {
cx.flag("-DGGML_USE_CLBLAST");
cxx.flag("-DGGML_USE_CLBLAST");
if cfg!(target_os = "linux") {
println!("cargo:rustc-link-lib=OpenCL");
println!("cargo:rustc-link-lib=clblast");
} else if cfg!(target_os = "macos") {
println!("cargo:rustc-link-lib=framework=OpenCL");
println!("cargo:rustc-link-lib=clblast");
}
cxx.file("./llama.cpp/ggml-opencl.cpp");
}
fn compile_openblas(cx: &mut Build) {
cx.flag("-DGGML_USE_OPENBLAS")
.include("/usr/local/include/openblas")
.include("/usr/local/include/openblas");
println!("cargo:rustc-link-lib=openblas");
}
fn compile_blis(cx: &mut Build) {
cx.flag("-DGGML_USE_OPENBLAS")
.include("/usr/local/include/blis")
.include("/usr/local/include/blis");
println!("cargo:rustc-link-search=native=/usr/local/lib");
println!("cargo:rustc-link-lib=blis");
}
fn compile_cuda(cxx_flags: &str) {
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=native=/opt/cuda/lib64");
if let Ok(cuda_path) = std::env::var("CUDA_PATH") {
println!(
"cargo:rustc-link-search=native={}/targets/x86_64-linux/lib",
cuda_path
);
}
let libs = "cublas culibos cudart cublasLt pthread dl rt";
for lib in libs.split_whitespace() {
println!("cargo:rustc-link-lib={}", lib);
}
let mut nvcc = cc::Build::new();
let env_flags = vec![
("LLAMA_CUDA_DMMV_X=32", "-DGGML_CUDA_DMMV_X"),
("LLAMA_CUDA_DMMV_Y=1", "-DGGML_CUDA_DMMV_Y"),
("LLAMA_CUDA_KQUANTS_ITER=2", "-DK_QUANTS_PER_ITERATION"),
];
let nvcc_flags = "--forward-unknown-to-host-compiler -arch=native ";
for nvcc_flag in nvcc_flags.split_whitespace() {
nvcc.flag(nvcc_flag);
}
for cxx_flag in cxx_flags.split_whitespace() {
nvcc.flag(cxx_flag);
}
for env_flag in env_flags {
let mut flag_split = env_flag.0.split("=");
if let Ok(val) = std::env::var(flag_split.next().unwrap()) {
nvcc.flag(&format!("{}={}", env_flag.1, val));
} else {
nvcc.flag(&format!("{}={}", env_flag.1, flag_split.next().unwrap()));
}
}
nvcc.compiler("nvcc")
.file("./llama.cpp/ggml-cuda.cu")
.flag("-Wno-pedantic")
.include("./llama.cpp/ggml-cuda.h")
.compile("ggml-cuda");
}
fn compile_ggml(cx: &mut Build, cx_flags: &str) {
for cx_flag in cx_flags.split_whitespace() {
cx.flag(cx_flag);
}
cx.include("./llama.cpp")
.file("./llama.cpp/ggml.c")
.cpp(false)
.compile("ggml");
}
fn compile_llama(cxx: &mut Build, cxx_flags: &str, out_path: &PathBuf, ggml_type: &str) {
for cxx_flag in cxx_flags.split_whitespace() {
cxx.flag(cxx_flag);
}
let ggml_obj = PathBuf::from(&out_path).join("llama.cpp/ggml.o");
cxx.object(ggml_obj);
if !ggml_type.is_empty() {
let ggml_feature_obj =
PathBuf::from(&out_path).join(format!("llama.cpp/ggml-{}.o", ggml_type));
cxx.object(ggml_feature_obj);
}
cxx.shared_flag(true)
.file("./llama.cpp/examples/common.cpp")
.file("./llama.cpp/llama.cpp")
.file("./binding.cpp")
.cpp(true)
.compile("binding");
}
fn main() {
let out_path = PathBuf::from(env::var("OUT_DIR").expect("No out dir found"));
compile_bindings(&out_path);
let mut cx_flags = String::from("-Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -march=native -mtune=native");
let mut cxx_flags = String::from("-Wall -Wdeprecated-declarations -Wunused-but-set-variable -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -march=native -mtune=native");
if cfg!(target_os = "linux") {
cx_flags.push_str(" -pthread");
cxx_flags.push_str(" -fPIC -pthread");
}
let mut cx = cc::Build::new();
let mut cxx = cc::Build::new();
let mut ggml_type = String::new();
cxx.include("./llama.cpp/examples").include("./llama.cpp");
if cfg!(feature = "opencl") {
compile_opencl(&mut cx, &mut cxx);
ggml_type = "opencl".to_string();
} else if cfg!(feature = "openblas") {
compile_openblas(&mut cx);
} else if cfg!(feature = "blis") {
compile_blis(&mut cx);
}
if cfg!(feature = "cuda") {
cx_flags.push_str(" -DGGML_USE_CUBLAS");
cxx_flags.push_str(" -DGGML_USE_CUBLAS");
cx.include("/usr/local/cuda/include")
.include("/opt/cuda/include");
cxx.include("/usr/local/cuda/include")
.include("/opt/cuda/include");
if let Ok(cuda_path) = std::env::var("CUDA_PATH") {
cx.include(format!("{}/targets/x86_64-linux/include", cuda_path));
cxx.include(format!("{}/targets/x86_64-linux/include", cuda_path));
}
compile_ggml(&mut cx, &cx_flags);
compile_cuda(&cxx_flags);
compile_llama(&mut cxx, &cxx_flags, &out_path, "cuda");
} else {
compile_ggml(&mut cx, &cx_flags);
compile_llama(&mut cxx, &cxx_flags, &out_path, &ggml_type);
}
}