extern crate curl;
extern crate flate2;
extern crate pkg_config;
extern crate semver;
extern crate tar;
use std::env;
use std::error::Error;
use std::fs::{self, File};
use std::io;
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::process::{self, Command};
use curl::easy::Easy;
use flate2::read::GzDecoder;
use semver::Version;
use tar::Archive;
use zip::ZipArchive;
const FRAMEWORK_LIBRARY: &'static str = "tensorflow_framework";
const LIBRARY: &'static str = "tensorflow";
const REPOSITORY: &'static str = "https://github.com/tensorflow/tensorflow.git";
const FRAMEWORK_TARGET: &'static str = "tensorflow:libtensorflow_framework";
const TARGET: &'static str = "tensorflow:libtensorflow";
const VERSION: &'static str = "1.15.0";
const TAG: &'static str = "v1.15.0";
const MIN_BAZEL: &'static str = "0.5.4";
macro_rules! get(($name:expr) => (ok!(env::var($name))));
macro_rules! ok(($expression:expr) => ($expression.unwrap()));
macro_rules! log {
($fmt:expr) => (println!(concat!("libtensorflow-sys/build.rs:{}: ", $fmt), line!()));
($fmt:expr, $($arg:tt)*) => (println!(concat!("libtensorflow-sys/build.rs:{}: ", $fmt),
line!(), $($arg)*));
}
macro_rules! log_var(($var:ident) => (log!(concat!(stringify!($var), " = {:?}"), $var)));
fn main() {
if cfg!(feature = "private-docs-rs") {
log!("Returning early because private-docs-rs feature was enabled");
return;
}
if check_windows_lib() {
log!("Returning early because {} was already found", LIBRARY);
return;
}
if pkg_config::probe_library(LIBRARY).is_ok() {
log!("Returning early because {} was already found", LIBRARY);
return;
}
let force_src = match env::var("TF_RUST_BUILD_FROM_SRC") {
Ok(s) => s == "true",
Err(_) => false,
};
let target_os = target_os();
if !force_src
&& target_arch() == "x86_64"
&& (target_os == "linux" || target_os == "macos" || target_os == "windows")
{
install_prebuilt();
} else {
build_from_src();
}
}
fn target_arch() -> String {
get!("CARGO_CFG_TARGET_ARCH")
}
fn target_os() -> String {
get!("CARGO_CFG_TARGET_OS")
}
fn dll_prefix() -> &'static str {
match &target_os() as &str {
"windows" => "",
_ => "lib",
}
}
fn dll_suffix() -> &'static str {
match &target_os() as &str {
"windows" => ".dll",
"macos" => ".dylib",
_ => ".so",
}
}
fn check_windows_lib() -> bool {
if target_os() != "windows" {
return false;
}
let windows_lib: &str = &format!("{}.lib", LIBRARY);
if let Ok(path) = env::var("PATH") {
for p in path.split(";") {
let path = Path::new(p).join(windows_lib);
if path.exists() {
println!("cargo:rustc-link-lib=dylib={}", LIBRARY);
println!("cargo:rustc-link-search=native={}", p);
return true;
}
}
}
false
}
fn remove_suffix(value: &mut String, suffix: &str) {
if value.ends_with(suffix) {
let n = value.len();
value.truncate(n - suffix.len());
}
}
fn has_extension<P: AsRef<Path>>(path: P, extension: &str) -> bool {
if let Some(os_ext) = path.as_ref().extension() {
if let Some(ext) = os_ext.to_str() {
ext == extension
} else {
false
}
} else {
false
}
}
fn extract_tar_gz<P: AsRef<Path>, P2: AsRef<Path>>(archive_path: P, extract_to: P2) {
let file = File::open(archive_path).unwrap();
let unzipped = GzDecoder::new(file);
let mut a = Archive::new(unzipped);
a.unpack(extract_to).unwrap();
}
fn extract_zip<P: AsRef<Path>, P2: AsRef<Path>>(archive_path: P, extract_to: P2) {
fs::create_dir_all(&extract_to).expect("Failed to create output path for zip archive.");
let file = File::open(archive_path).expect("Unable to open libtensorflow zip archive.");
let mut archive = ZipArchive::new(file).unwrap();
for i in 0..archive.len() {
let mut zipfile = archive.by_index(i).unwrap();
let output_path = extract_to.as_ref().join(zipfile.sanitized_name());
if zipfile.name().starts_with("lib") {
if zipfile.is_dir() {
fs::create_dir_all(&output_path)
.expect("Failed to create output directory when unpacking archive.");
} else {
if let Some(parent) = output_path.parent() {
if !parent.exists() {
fs::create_dir_all(&parent)
.expect("Failed to create parent directory for extracted file.");
}
}
let mut outfile = File::create(&output_path).unwrap();
io::copy(&mut zipfile, &mut outfile).unwrap();
}
}
}
}
fn extract<P: AsRef<Path>, P2: AsRef<Path>>(archive_path: P, extract_to: P2) {
if has_extension(&archive_path, "zip") {
extract_zip(archive_path, extract_to);
} else {
extract_tar_gz(archive_path, extract_to);
}
}
fn install_prebuilt() {
let os = match &target_os() as &str {
"macos" => "darwin".to_string(),
x => x.to_string(),
};
let proc_type = if cfg!(feature = "tensorflow_gpu") {
"gpu"
} else {
"cpu"
};
let windows = target_os() == "windows";
let ext = if windows { ".zip" } else { ".tar.gz" };
let binary_url = format!(
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-{}-{}-{}-{}{}",
proc_type,
os,
target_arch(),
VERSION,
ext
);
log_var!(binary_url);
let short_file_name = binary_url.split("/").last().unwrap();
let mut base_name = short_file_name.to_string();
remove_suffix(&mut base_name, ext);
log_var!(base_name);
let download_dir = match env::var("TF_RUST_DOWNLOAD_DIR") {
Ok(s) => PathBuf::from(s),
Err(_) => PathBuf::from(&get!("OUT_DIR")),
};
if !download_dir.exists() {
fs::create_dir(&download_dir).unwrap();
}
let file_name = download_dir.join(short_file_name);
log_var!(file_name);
if !file_name.exists() {
let f = File::create(&file_name).unwrap();
let mut writer = BufWriter::new(f);
let mut easy = Easy::new();
easy.url(&binary_url).unwrap();
easy.write_function(move |data| Ok(writer.write(data).unwrap()))
.unwrap();
easy.perform().unwrap();
let response_code = easy.response_code().unwrap();
if response_code != 200 {
panic!(
"Unexpected response code {} for {}",
response_code, binary_url
);
}
}
let unpacked_dir = download_dir.join(base_name);
let lib_dir = unpacked_dir.join("lib");
let framework_library_file = format!("{}{}{}", dll_prefix(), FRAMEWORK_LIBRARY, dll_suffix());
let library_file = format!("{}{}{}", dll_prefix(), LIBRARY, dll_suffix());
let framework_library_full_path = lib_dir.join(&framework_library_file);
let library_full_path = lib_dir.join(&library_file);
let download_required =
(!windows && !framework_library_full_path.exists()) || !library_full_path.exists();
if download_required {
extract(file_name, &unpacked_dir);
}
if target_os() != "windows" {
println!("cargo:rustc-link-lib=dylib={}", FRAMEWORK_LIBRARY);
}
println!("cargo:rustc-link-lib=dylib={}", LIBRARY);
let output = PathBuf::from(&get!("OUT_DIR"));
let framework_files = std::fs::read_dir(lib_dir).unwrap();
for library_entry in framework_files.filter_map(Result::ok) {
let library_full_path = library_entry.path();
let new_library_full_path = output.join(&library_full_path.file_name().unwrap());
if new_library_full_path.exists() {
log!(
"{} already exists. Removing",
new_library_full_path.display()
);
fs::remove_file(&new_library_full_path).unwrap();
}
log!(
"Copying {} to {}...",
library_full_path.display(),
new_library_full_path.display()
);
fs::copy(&library_full_path, &new_library_full_path).unwrap();
}
println!("cargo:rustc-link-search={}", output.display());
}
fn build_from_src() {
let dll_suffix = dll_suffix();
let framework_target = FRAMEWORK_TARGET.to_string() + &dll_suffix;
let target = TARGET.to_string() + &dll_suffix;
let output = PathBuf::from(&get!("OUT_DIR"));
log_var!(output);
let source = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join(format!("target/source-{}", TAG));
log_var!(source);
let lib_dir = output.join(format!("lib-{}", TAG));
log_var!(lib_dir);
if lib_dir.exists() {
log!("Directory {:?} already exists", lib_dir);
} else {
log!("Creating directory {:?}", lib_dir);
fs::create_dir(lib_dir.clone()).unwrap();
}
let framework_library_path = lib_dir.join(format!("lib{}.so", FRAMEWORK_LIBRARY));
log_var!(framework_library_path);
let library_path = lib_dir.join(format!("lib{}.so", LIBRARY));
log_var!(library_path);
if library_path.exists() && framework_library_path.exists() {
log!(
"{:?} and {:?} already exist, not building",
library_path,
framework_library_path
);
} else {
if let Err(e) = check_bazel() {
println!(
"cargo:error=Bazel must be installed at version {} or greater. (Error: {})",
MIN_BAZEL, e
);
process::exit(1);
}
let framework_target_path = &framework_target.replace(":", "/");
log_var!(framework_target_path);
let target_path = &TARGET.replace(":", "/");
log_var!(target_path);
if !Path::new(&source.join(".git")).exists() {
run("git", |command| {
command
.arg("clone")
.arg(format!("--branch={}", TAG))
.arg("--recursive")
.arg(REPOSITORY)
.arg(&source)
});
}
let configure_hint_file_pb = source.join(".rust-configured");
let configure_hint_file = Path::new(&configure_hint_file_pb);
if !configure_hint_file.exists() {
run("bash", |command| {
command
.current_dir(&source)
.env(
"TF_NEED_CUDA",
if cfg!(feature = "tensorflow_gpu") {
"1"
} else {
"0"
},
)
.arg("-c")
.arg("yes ''|./configure")
});
File::create(configure_hint_file).unwrap();
}
let bazel_args_string = if let Ok(args) = env::var("TF_RUST_BAZEL_OPTS") {
args
} else {
"".to_string()
};
run("bazel", |command| {
command
.current_dir(&source)
.arg("build")
.arg(format!("--jobs={}", get!("NUM_JOBS")))
.arg("--compilation_mode=opt")
.arg("--copt=-march=native")
.args(bazel_args_string.split_whitespace())
.arg(&target)
});
let framework_target_bazel_bin = source.join("bazel-bin").join(framework_target_path);
log!(
"Copying {:?} to {:?}",
framework_target_bazel_bin,
framework_library_path
);
fs::copy(framework_target_bazel_bin, framework_library_path).unwrap();
let target_bazel_bin = source.join("bazel-bin").join(target_path);
log!("Copying {:?} to {:?}", target_bazel_bin, library_path);
fs::copy(target_bazel_bin, library_path).unwrap();
}
println!("cargo:rustc-link-lib=dylib={}", FRAMEWORK_LIBRARY);
println!("cargo:rustc-link-lib=dylib={}", LIBRARY);
println!("cargo:rustc-link-search={}", lib_dir.display());
}
fn run<F>(name: &str, mut configure: F)
where
F: FnMut(&mut Command) -> &mut Command,
{
let mut command = Command::new(name);
let configured = configure(&mut command);
log!("Executing {:?}", configured);
if !ok!(configured.status()).success() {
panic!("failed to execute {:?}", configured);
}
log!("Command {:?} finished successfully", configured);
}
fn check_bazel() -> Result<(), Box<dyn Error>> {
let mut command = Command::new("bazel");
command.arg("version");
log!("Executing {:?}", command);
let out = command.output()?;
log!("Command {:?} finished successfully", command);
let stdout = String::from_utf8(out.stdout)?;
let mut found_version = false;
for line in stdout.lines() {
if line.starts_with("Build label:") {
found_version = true;
let mut version_str = line
.split(":")
.nth(1)
.unwrap()
.split(" ")
.nth(1)
.unwrap()
.trim();
if version_str.ends_with('-') {
version_str = &version_str[..version_str.len() - 1];
}
let version = Version::parse(version_str)?;
let want = Version::parse(MIN_BAZEL)?;
if version < want {
return Err(format!(
"Installed version {} is less than required version {}",
version_str, MIN_BAZEL
)
.into());
}
}
}
if !found_version {
return Err("Did not find version number in `bazel version` output.".into());
}
Ok(())
}