torch-build 0.1.0

Utilities to link libtorch FFI interface.
Documentation
use crate::{
    env::{CUDA_HOME, CUDNN_HOME, LIBTORCH, LIBTORCH_CXX11_ABI, OUT_DIR, ROCM_HOME},
    library::{Api, CudaApi, CudaSplitApi, HipApi, Library},
};
use anyhow::{Context as _, Result};
use cfg_if::cfg_if;
use log::warn;
use once_cell::sync::OnceCell;
use std::{
    path::{Path, PathBuf},
    str,
};

/// Probe the installation directory of libtorch and its capabilities.
pub fn probe_libtorch() -> Result<&'static Library> {
    static PROBE: OnceCell<Library> = OnceCell::new();

    PROBE.get_or_try_init(|| -> Result<_> {
        let libtorch_dir = find_or_download_libtorch_dir()?;
        let lib_dir = libtorch_dir.join("lib");

        let probe_file = |name: &str| -> bool {
            cfg_if! {
                if #[cfg(target_os = "linux")] {
                    lib_dir.join(format!("lib{}.so", name)).exists()
                }
                else if #[cfg(target_os = "windows")] {
                    lib_dir.join(format!("{}.dll", name)).exists()
                }
                else { false }
            }
        };

        let api = if let (Some(rocm_home), true) = (&*ROCM_HOME, probe_file("torch_hip")) {
            static MIOPEN_HOME: OnceCell<PathBuf> = OnceCell::new();
            let miopen_home = MIOPEN_HOME.get_or_init(|| rocm_home.join("miopen"));

            HipApi {
                rocm_home,
                miopen_home,
            }
            .into()
        } else if let Some(cuda_home) = &*CUDA_HOME {
            if probe_file("torch_cuda_cu") && probe_file("torch_cuda_cpp") {
                CudaSplitApi {
                    cuda_home,
                    cudnn_home: CUDNN_HOME.as_deref(),
                }
                .into()
            } else if probe_file("torch_cuda") {
                CudaApi {
                    cuda_home,
                    cudnn_home: CUDNN_HOME.as_deref(),
                }
                .into()
            } else {
                warn!(
                    r#"CUDA_HOME is set to "{}", but no CUDA runtime found for libtorch"#,
                    cuda_home.display()
                );
                Api::None
            }
        } else {
            Api::None
        };

        let use_cxx11_abi = check_cxx11_abi();

        Ok(Library {
            libtorch_dir,
            api,
            use_cxx11_abi,
        })
    })
}

/// Locate the libtorch directory, or try to download libtorch if it does not exist.
///
/// It finds the directory in the following order. If none of them suceeds, the function returns error.
/// 1. `LIBTORCH` environment variable if it is set.
/// 2. `/usr` if host system is Linux and `/usr/lib/libtorch.so` exists.
/// 3. If `download-libtorch` feature is set, download from the URL generated by
///   [libtorch_url()](crate::download::libtorch_url) and returns the extracted directory.
///
/// The function is idempotent. Downloading only run once even when the function is called more than once.
pub fn find_or_download_libtorch_dir() -> Result<&'static Path> {
    static LIBTORCH_DIR: OnceCell<PathBuf> = OnceCell::new();

    LIBTORCH_DIR.get_or_try_init(|| {
        let guess = LIBTORCH.to_owned();

        #[cfg(target_os = "linux")]
        let guess = guess.or_else(|| {
            Path::new("/usr/lib/libtorch.so")
                .exists()
                .then(|| PathBuf::from("/usr"))
        });

        cfg_if! {
            if #[cfg(feature = "download-libtorch")] {
                match guess {
                    Some(dir) => Ok(dir),
                    None => {
                        crate::download::download_libtorch().with_context(|| "unable to download libtorch")
                    }
                }
            } else {
                guess.ok_or_else(|| anyhow!("unable to find libtorch"))?;
            }
        }
    })
        .map(|path| path.as_ref())
}

/// Return true of host system uses C++11 ABI. It is used to set the `_GLIBCXX_USE_CXX11_ABI` macro.
pub fn check_cxx11_abi() -> bool {
    static CHECK: OnceCell<bool> = OnceCell::new();

    *CHECK.get_or_init(|| {
        if let Some(val) = *LIBTORCH_CXX11_ABI {
            return val;
        }

        cfg_if! {
            if #[cfg(target_os = "macos")] {
                true
            } else if #[cfg(target_os = "linux")] {
                Path::new(OUT_DIR)
                    .join("use_cxx11_abi")
                    .exists()
            } else if #[cfg(target_os = "window")] {
                // TODO: check _MSVC_LANG
                true
            } else {
                true
            }
        }
    })
}