use std::path::{Path, PathBuf};
fn main() {
build_c_bridge();
#[cfg(feature = "std")]
build_cxx_bridge();
generate_bindings();
link_executorch();
println!("cargo::rerun-if-changed={}", cpp_dir().to_str().unwrap());
println!(
"cargo::rerun-if-changed={}",
third_party_dir().to_str().unwrap()
);
let check_cfg = rustc_version().map(|v| v >= 80).unwrap_or(false);
println!("cargo::rerun-if-env-changed=EXECUTORCH_RS_DENY_WARNINGS");
let deny_warnings = std::env::var("EXECUTORCH_RS_DENY_WARNINGS").as_deref() == Ok("1");
if check_cfg {
println!("cargo:rustc-check-cfg=cfg(deny_warnings)");
}
if deny_warnings {
println!("cargo:rustc-cfg=deny_warnings");
}
}
fn build_c_bridge() {
let sources_dir = cpp_dir().join("executorch_rs");
let mut builder = cc::Build::new();
common_cc(&mut builder);
builder
.files([sources_dir.join("c_bridge.cpp")])
.includes(cpp_includes());
builder.compile(&format!(
"executorch_rs_c_bridge_{}",
env!("CARGO_PKG_VERSION")
));
}
#[cfg(feature = "std")]
fn build_cxx_bridge() {
let sources_dir = cpp_dir().join("executorch_rs");
let mut bridges = Vec::new();
bridges.push("src/cxx_bridge/core.rs");
if cfg!(feature = "module") {
bridges.push("src/cxx_bridge/module.rs");
}
if cfg!(feature = "tensor-ptr") {
bridges.push("src/cxx_bridge/tensor_ptr.rs");
}
let mut builder = cxx_build::bridges(bridges);
common_cc(&mut builder);
builder
.files([sources_dir.join("cxx_bridge.cpp")])
.includes(cpp_includes());
builder.compile(&format!(
"executorch_rs_cxx_bridge_{}",
env!("CARGO_PKG_VERSION")
));
}
fn common_cc(builder: &mut cc::Build) {
builder.cpp(true).std("c++17").cpp_link_stdlib(None); if !cfg!(feature = "std") {
}
for define in cpp_defines() {
builder.define(define, None);
}
}
fn generate_bindings() {
let bindings_h = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("c_bindings.h");
let mut bindings_h_content = String::from("#pragma once\n");
for define in cpp_defines() {
bindings_h_content.push_str(&format!("#define {define}\n"));
}
bindings_h_content.push_str("#include \"executorch_rs/c_bridge.h\"\n");
std::fs::write(&bindings_h, bindings_h_content).expect("Unable to write bindings.h");
let builder = bindgen::Builder::default()
.clang_arg(format!("-I{}", cpp_dir().to_str().unwrap()))
.use_core()
.generate_cstr(true)
.header(bindings_h.as_os_str().to_str().unwrap())
.allowlist_file(format!(
"{}/executorch_rs/c_bridge.h",
cpp_dir().to_str().unwrap()
))
.default_enum_style(bindgen::EnumVariation::Rust {
non_exhaustive: false,
})
.no_copy(".*")
.manually_drop_union(".*")
.opaque_type("EValueStorage")
.opaque_type("TensorStorage")
.opaque_type("TensorImpl")
.opaque_type("Program")
.opaque_type("TensorInfo")
.opaque_type("MethodMeta")
.opaque_type("Method")
.opaque_type("BufferDataLoader")
.opaque_type("FileDataLoader")
.opaque_type("MmapDataLoader")
.opaque_type("MemoryAllocator")
.opaque_type("HierarchicalAllocator")
.opaque_type("MemoryManager")
.opaque_type("OptionalTensorStorage")
.opaque_type("ETDumpGen")
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()));
let bindings = builder.generate().expect("Unable to generate bindings");
let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("executorch_bindings.rs"))
.expect("Couldn't write bindings!");
}
fn link_executorch() {
println!("cargo::rerun-if-env-changed=EXECUTORCH_RS_EXECUTORCH_LIB_DIR");
println!("cargo::rerun-if-env-changed=EXECUTORCH_RS_LINK");
let link_enabled = std::env::var("EXECUTORCH_RS_LINK").as_deref() != Ok("0");
let check_cfg = rustc_version().map(|v| v >= 80).unwrap_or(false);
if check_cfg {
println!("cargo::rustc-check-cfg=cfg(link_cxx)");
}
if link_enabled {
println!("cargo::rustc-cfg=link_cxx");
}
if std::env::var("DOCS_RS").is_ok() || !link_enabled {
return;
}
let libs_dir = std::env::var("EXECUTORCH_RS_EXECUTORCH_LIB_DIR").ok();
if libs_dir.is_none() {
println!("cargo::warning=EXECUTORCH_RS_EXECUTORCH_LIB_DIR is not set, can't locate executorch static libs");
}
if let Some(libs_dir) = &libs_dir {
println!("cargo::rustc-link-search=native={libs_dir}");
}
println!("cargo::rustc-link-lib=static:+whole-archive=executorch");
println!("cargo::rustc-link-lib=static:+whole-archive=executorch_core");
if cfg!(feature = "data-loader") {
if let Some(libs_dir) = &libs_dir {
println!("cargo::rustc-link-search=native={libs_dir}/extension/data_loader/");
}
println!("cargo::rustc-link-lib=static:+whole-archive=extension_data_loader");
}
if cfg!(feature = "module") {
if let Some(libs_dir) = &libs_dir {
println!("cargo::rustc-link-search=native={libs_dir}/extension/module/");
}
println!("cargo::rustc-link-lib=static:+whole-archive=extension_module_static");
}
if cfg!(feature = "tensor-ptr") {
if let Some(libs_dir) = &libs_dir {
println!("cargo::rustc-link-search=native={libs_dir}/extension/tensor/");
}
println!("cargo::rustc-link-lib=static:+whole-archive=extension_tensor");
}
if cfg!(feature = "etdump") {
if let Some(libs_dir) = &libs_dir {
println!("cargo::rustc-link-search=native={libs_dir}/devtools/");
}
println!("cargo::rustc-link-lib=static:+whole-archive=etdump");
}
}
fn cpp_dir() -> PathBuf {
Path::new(&env!("CARGO_MANIFEST_DIR")).join("cpp")
}
fn third_party_dir() -> PathBuf {
Path::new(&env!("CARGO_MANIFEST_DIR")).join("third-party")
}
fn cpp_includes() -> Vec<PathBuf> {
let third_party_dir = third_party_dir();
let c10_dir = std::env::var_os("EXECUTORCH_RS_C10_HEADERS_DIR")
.map(PathBuf::from)
.unwrap_or_else(|| third_party_dir.join("executorch/runtime/core/portable_type/c10"));
assert!(
c10_dir.exists(),
"C10 directory does not exist: {}",
c10_dir.display()
);
vec![cpp_dir(), third_party_dir.clone(), c10_dir]
}
fn cpp_defines() -> Vec<&'static str> {
let mut defines = vec!["C10_USING_CUSTOM_GENERATED_MACROS"];
if cfg!(feature = "data-loader") {
defines.push("EXECUTORCH_RS_DATA_LOADER");
}
if cfg!(feature = "module") {
defines.push("EXECUTORCH_RS_MODULE");
}
if cfg!(feature = "tensor-ptr") {
defines.push("EXECUTORCH_RS_TENSOR_PTR");
}
if cfg!(feature = "etdump") {
defines.push("EXECUTORCH_RS_ETDUMP");
}
if cfg!(feature = "std") {
defines.push("EXECUTORCH_RS_STD");
}
defines
}
fn rustc_version() -> Option<u32> {
let rustc = std::env::var_os("RUSTC")?;
let output = std::process::Command::new(rustc)
.arg("--version")
.output()
.ok()?;
let version = String::from_utf8(output.stdout).ok()?;
let mut pieces = version.split('.');
if pieces.next() != Some("rustc 1") {
return None;
}
let minor = pieces.next()?.parse().ok()?;
Some(minor)
}