use crate::{
env::{CUDA_HOME, CUDNN_HOME, LIBTORCH, LIBTORCH_CXX11_ABI, OUT_DIR, ROCM_HOME},
library::{Api, CudaApi, CudaSplitApi, HipApi, Library},
};
use anyhow::{Context as _, Result};
use cfg_if::cfg_if;
use log::warn;
use once_cell::sync::OnceCell;
use std::{
path::{Path, PathBuf},
str,
};
pub fn probe_libtorch() -> Result<&'static Library> {
static PROBE: OnceCell<Library> = OnceCell::new();
PROBE.get_or_try_init(|| -> Result<_> {
let libtorch_dir = find_or_download_libtorch_dir()?;
let lib_dir = libtorch_dir.join("lib");
let probe_file = |name: &str| -> bool {
cfg_if! {
if #[cfg(target_os = "linux")] {
lib_dir.join(format!("lib{}.so", name)).exists()
}
else if #[cfg(target_os = "windows")] {
lib_dir.join(format!("{}.dll", name)).exists()
}
else { false }
}
};
let api = if let (Some(rocm_home), true) = (&*ROCM_HOME, probe_file("torch_hip")) {
static MIOPEN_HOME: OnceCell<PathBuf> = OnceCell::new();
let miopen_home = MIOPEN_HOME.get_or_init(|| rocm_home.join("miopen"));
HipApi {
rocm_home,
miopen_home,
}
.into()
} else if let Some(cuda_home) = &*CUDA_HOME {
if probe_file("torch_cuda_cu") && probe_file("torch_cuda_cpp") {
CudaSplitApi {
cuda_home,
cudnn_home: CUDNN_HOME.as_deref(),
}
.into()
} else if probe_file("torch_cuda") {
CudaApi {
cuda_home,
cudnn_home: CUDNN_HOME.as_deref(),
}
.into()
} else {
warn!(
r#"CUDA_HOME is set to "{}", but no CUDA runtime found for libtorch"#,
cuda_home.display()
);
Api::None
}
} else {
Api::None
};
let use_cxx11_abi = check_cxx11_abi();
Ok(Library {
libtorch_dir,
api,
use_cxx11_abi,
})
})
}
pub fn find_or_download_libtorch_dir() -> Result<&'static Path> {
static LIBTORCH_DIR: OnceCell<PathBuf> = OnceCell::new();
LIBTORCH_DIR.get_or_try_init(|| {
let guess = LIBTORCH.to_owned();
#[cfg(target_os = "linux")]
let guess = guess.or_else(|| {
Path::new("/usr/lib/libtorch.so")
.exists()
.then(|| PathBuf::from("/usr"))
});
cfg_if! {
if #[cfg(feature = "download-libtorch")] {
match guess {
Some(dir) => Ok(dir),
None => {
crate::download::download_libtorch().with_context(|| "unable to download libtorch")
}
}
} else {
guess.ok_or_else(|| anyhow!("unable to find libtorch"))?;
}
}
})
.map(|path| path.as_ref())
}
pub fn check_cxx11_abi() -> bool {
static CHECK: OnceCell<bool> = OnceCell::new();
*CHECK.get_or_init(|| {
if let Some(val) = *LIBTORCH_CXX11_ABI {
return val;
}
cfg_if! {
if #[cfg(target_os = "macos")] {
true
} else if #[cfg(target_os = "linux")] {
Path::new(OUT_DIR)
.join("use_cxx11_abi")
.exists()
} else if #[cfg(target_os = "window")] {
true
} else {
true
}
}
})
}