use crate::{env::TARGET, utils::IteratorExt as _};
use anyhow::{bail, Result};
use cfg_if::cfg_if;
use itertools::chain;
use std::{
iter,
path::{Path, PathBuf},
str,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Library {
pub libtorch_dir: &'static Path,
pub api: Api,
pub use_cxx11_abi: bool,
}
impl Library {
pub fn include_paths(
&self,
use_cuda_api: impl Into<Option<bool>>,
) -> Result<impl Iterator<Item = PathBuf>> {
let Self {
libtorch_dir, api, ..
} = self;
let include_dir = libtorch_dir.join("include");
let use_cuda_api = use_cuda_api
.into()
.unwrap_or_else(|| self.is_cuda_api_available());
let base_includes = [
include_dir.clone(),
include_dir
.join("torch")
.join("csrc")
.join("api")
.join("include"),
include_dir.join("TH"),
include_dir.join("THC"),
];
let extra_includes = if use_cuda_api {
match api {
Api::Hip(HipApi {
rocm_home,
miopen_home,
..
}) => {
let thh_include = include_dir.join("THH");
let rocm_include = rocm_home.join("include");
let miopen_include = miopen_home.join("include");
[thh_include, rocm_include, miopen_include]
.into_iter()
.boxed()
}
Api::Cuda(CudaApi {
cuda_home,
cudnn_home,
})
| Api::CudaSplit(CudaSplitApi {
cuda_home,
cudnn_home,
}) => {
let cuda_include = cuda_home.join("include");
let cudnn_include = cudnn_home.map(|path| path.join("include"));
chain!([cuda_include], cudnn_include).boxed()
}
Api::None => bail!("CUDA runtime is available"),
}
} else {
iter::empty().boxed()
};
let all_includes = chain!(base_includes, extra_includes);
#[cfg(target_os = "linux")]
let all_includes = all_includes.filter(|path| path != Path::new("/usr/include"));
Ok(all_includes)
}
pub fn link_paths(
&self,
use_cuda_api: impl Into<Option<bool>>,
) -> Result<impl Iterator<Item = PathBuf>> {
let Self {
libtorch_dir, api, ..
} = self;
let use_cuda_api = use_cuda_api
.into()
.unwrap_or_else(|| self.is_cuda_api_available());
let lib_dir = libtorch_dir.join("lib");
let extra_dirs = if use_cuda_api {
match api {
Api::Hip(HipApi { rocm_home, .. }) => iter::once(rocm_home.join("lib")).boxed(),
Api::Cuda(CudaApi {
cuda_home,
cudnn_home,
})
| Api::CudaSplit(CudaSplitApi {
cuda_home,
cudnn_home,
}) => {
cfg_if! {
if #[cfg(target_os = "windows")] {
let cuda_lib_dir = cuda_home.un.join("lib").join("x64");
iter::once(cuda_lib_dir).boxed()
}
else if #[cfg(any(target_os = "linux", target_os = "macos"))] {
let cuda_lib_dir = {
let guess1 = cuda_home.join("lib64");
let guess2 = cuda_home.join("lib");
match (guess1.exists(), guess2.exists()) {
(true, _) => guess1,
(false, true) => guess2,
(false, false) => bail!("TODO"),
}
};
let cudnn_lib_dir = if let Some(cudnn_home) = cudnn_home {
let guess1 = cudnn_home.join("lib64");
let guess2 = cudnn_home.join("lib");
let dir = match (guess1.exists(), guess2.exists()) {
(true, _) => guess1,
(false, true) => guess2,
(false, false) => bail!("TODO"),
};
Some(dir)
} else {
None
};
chain!([cuda_lib_dir], cudnn_lib_dir).boxed()
}
else {
bail!("Unsupported OS");
}
}
}
Api::None => bail!("CUDA runtime is available"),
}
} else {
iter::empty().boxed()
};
let all_paths = chain!([lib_dir], extra_dirs);
Ok(all_paths)
}
pub fn libraries(
&self,
use_cuda_api: impl Into<Option<bool>>,
use_python: bool,
) -> Result<impl Iterator<Item = &'static str>> {
let Self { api, .. } = self;
let use_cuda_api = use_cuda_api
.into()
.unwrap_or_else(|| self.is_cuda_api_available());
let base_libraries = ["c10", "torch_cpu", "torch"];
let python_library = use_python.then(|| "torch_python");
let base_cuda_libraries = ["cudart", "c10_cuda"];
let cuda_libraries = if use_cuda_api {
match api {
Api::None => bail!("CUDA runtime is available"),
Api::Hip(_) => {
[
"amdhip64", "c10_hip",
"torch_hip",
]
.into_iter()
.boxed()
}
Api::Cuda(_) => chain!(base_cuda_libraries, ["torch_cuda"]).boxed(),
Api::CudaSplit(_) => {
chain!(base_cuda_libraries, ["torch_cuda_cu", "torch_cuda_cpp"]).boxed()
}
}
} else {
iter::empty().boxed()
};
let gomp = TARGET.as_ref().and_then(|target| {
let ok = !target.contains("msvc") && !target.contains("apple");
ok.then(|| "gomp")
});
Ok(chain!(base_libraries, python_library, cuda_libraries, gomp))
}
pub fn is_cuda_api_available(&self) -> bool {
self.api.is_cuda_api_available()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Api {
None,
Hip(HipApi),
Cuda(CudaApi),
CudaSplit(CudaSplitApi),
}
impl Api {
pub fn is_cuda_api_available(&self) -> bool {
!matches!(self, Self::None)
}
}
impl From<HipApi> for Api {
fn from(from: HipApi) -> Self {
Self::Hip(from)
}
}
impl From<CudaApi> for Api {
fn from(from: CudaApi) -> Self {
Self::Cuda(from)
}
}
impl From<CudaSplitApi> for Api {
fn from(from: CudaSplitApi) -> Self {
Self::CudaSplit(from)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct HipApi {
pub rocm_home: &'static Path,
pub miopen_home: &'static Path,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CudaApi {
pub cuda_home: &'static Path,
pub cudnn_home: Option<&'static Path>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CudaSplitApi {
pub cuda_home: &'static Path,
pub cudnn_home: Option<&'static Path>,
}