use anyhow::{Context, Result};
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
#[cfg(all(feature = "cuda", target_os = "linux"))]
static CUDNN_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
#[cfg(all(feature = "cuda", target_os = "linux"))]
pub fn is_cudnn_available() -> bool {
*CUDNN_AVAILABLE.get().unwrap_or(&false)
}
#[cfg(all(feature = "cuda", not(target_os = "linux")))]
pub fn is_cudnn_available() -> bool {
true
}
#[cfg(not(feature = "cuda"))]
pub fn is_cudnn_available() -> bool {
false }
const ORT_VERSION: &str = "1.23.0";
#[cfg(target_os = "macos")]
const ORT_LIB_NAME: &str = "libonnxruntime.dylib";
#[cfg(target_os = "linux")]
const ORT_LIB_NAME: &str = "libonnxruntime.so";
#[cfg(target_os = "windows")]
const ORT_LIB_NAME: &str = "onnxruntime.dll";
#[cfg(feature = "cuda")]
const USE_GPU: bool = true;
#[cfg(not(feature = "cuda"))]
const USE_GPU: bool = false;
#[cfg(feature = "cuda")]
const ORT_CACHE_SUBDIR: &str = "gpu";
#[cfg(not(feature = "cuda"))]
const ORT_CACHE_SUBDIR: &str = "cpu";
pub fn ensure_onnx_runtime() -> Result<PathBuf> {
#[cfg(all(target_os = "linux", feature = "cuda"))]
if crate::acceleration::env_acceleration_mode_lossy()
!= crate::acceleration::AccelerationMode::ForceCpu
{
if env::var("_COLGREP_CUDA_SETUP").is_err() {
if let Some(cudnn_dir) = find_cudnn_directory() {
let current_ld = env::var("LD_LIBRARY_PATH").unwrap_or_default();
let cudnn_str = cudnn_dir.to_string_lossy();
if !current_ld.contains(&*cudnn_str) {
let new_ld = if current_ld.is_empty() {
cudnn_str.to_string()
} else {
format!("{}:{}", cudnn_str, current_ld)
};
let ort_gpu_dir = dirs::home_dir()
.map(|h| {
h.join(".cache")
.join("colgrep")
.join("onnxruntime")
.join(ORT_VERSION)
.join("gpu")
})
.filter(|p| p.exists());
let final_ld = if let Some(ort_dir) = ort_gpu_dir {
let ort_str = ort_dir.to_string_lossy();
if new_ld.contains(&*ort_str) {
new_ld
} else {
format!("{}:{}", ort_str, new_ld)
}
} else {
new_ld
};
env::set_var("LD_LIBRARY_PATH", &final_ld);
env::set_var("_COLGREP_CUDA_SETUP", "1");
let exe = env::current_exe().context("Failed to get current executable")?;
let args: Vec<String> = env::args().collect();
let err = exec::execvp(&exe, &args);
return Err(anyhow::anyhow!(
"Failed to re-exec with CUDA environment: {}",
err
));
}
}
env::set_var("_COLGREP_CUDA_SETUP", "1");
}
}
if let Ok(path) = env::var("ORT_DYLIB_PATH") {
let path = PathBuf::from(&path);
if path.exists() && is_valid_ort_dylib(&path) {
pin_runtime_library(&path);
return Ok(path);
}
eprintln!(
"⚠️ ORT_DYLIB_PATH={} is not a loadable ONNX Runtime dylib; ignoring.",
path.display()
);
env::remove_var("ORT_DYLIB_PATH");
}
#[cfg(not(feature = "cuda"))]
if let Some(path) = find_onnx_runtime() {
pin_runtime_library(&path);
return Ok(path);
}
let path = download_onnx_runtime()?;
pin_runtime_library(&path);
Ok(path)
}
fn pin_runtime_library(path: &Path) {
env::set_var("ORT_DYLIB_PATH", path);
#[cfg(target_os = "linux")]
if let Some(parent) = path.parent() {
prepend_ld_library_path(parent);
}
#[cfg(all(target_os = "linux", feature = "cuda"))]
{
let _ = check_cudnn_available();
}
}
#[cfg(all(target_os = "linux", feature = "cuda"))]
fn find_cudnn_directory() -> Option<PathBuf> {
let search_dirs = get_cudnn_search_dirs();
let cudnn_lib_names = ["libcudnn.so.9", "libcudnn.so.8", "libcudnn.so"];
for dir in &search_dirs {
for lib_name in &cudnn_lib_names {
let cudnn_path = dir.join(lib_name);
if cudnn_path.exists() {
return Some(dir.clone());
}
}
if dir.exists() {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("libcudnn") && name_str.contains(".so") {
return Some(dir.clone());
}
}
}
}
}
None
}
#[cfg(target_os = "linux")]
fn prepend_ld_library_path(dir: &Path) {
let dir_str = dir.to_string_lossy();
let current = env::var("LD_LIBRARY_PATH").unwrap_or_default();
if !current.contains(&*dir_str) {
let new_path = if current.is_empty() {
dir_str.to_string()
} else {
format!("{}:{}", dir_str, current)
};
env::set_var("LD_LIBRARY_PATH", &new_path);
}
}
#[cfg(all(target_os = "linux", feature = "cuda"))]
fn get_cudnn_search_dirs() -> Vec<PathBuf> {
let mut dirs = Vec::new();
if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
dirs.push(PathBuf::from(&conda_prefix).join("lib"));
dirs.push(PathBuf::from(&conda_prefix).join("lib64"));
let site_packages = PathBuf::from(&conda_prefix).join("lib");
if let Ok(entries) = std::fs::read_dir(&site_packages) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("python") {
let cudnn_lib = entry
.path()
.join("site-packages")
.join("nvidia")
.join("cudnn")
.join("lib");
dirs.push(cudnn_lib);
}
}
}
}
for var in ["CUDA_HOME", "CUDA_PATH", "CUDNN_PATH", "CUDNN_HOME"] {
if let Ok(path) = env::var(var) {
let base = PathBuf::from(&path);
dirs.push(base.join("lib"));
dirs.push(base.join("lib64"));
dirs.push(base.clone());
}
}
if let Ok(ld_path) = env::var("LD_LIBRARY_PATH") {
for dir in ld_path.split(':') {
if !dir.is_empty() {
dirs.push(PathBuf::from(dir));
}
}
}
if let Ok(lib_path) = env::var("LIBRARY_PATH") {
for dir in lib_path.split(':') {
if !dir.is_empty() {
dirs.push(PathBuf::from(dir));
}
}
}
dirs.extend([
PathBuf::from("/usr/local/cuda/lib64"),
PathBuf::from("/usr/local/cuda/lib"),
PathBuf::from("/usr/lib/x86_64-linux-gnu"),
PathBuf::from("/usr/lib64"),
PathBuf::from("/usr/lib"),
PathBuf::from("/opt/cuda/lib64"),
PathBuf::from("/opt/cuda/lib"),
]);
if let Ok(nvhpc) = env::var("NVHPC_ROOT") {
dirs.push(PathBuf::from(&nvhpc).join("cuda/lib64"));
}
if let Some(home) = dirs::home_dir() {
dirs.push(home.join(".local/lib"));
dirs.push(home.join(".local/lib64"));
}
dirs
}
#[cfg(all(target_os = "linux", feature = "cuda"))]
fn check_cudnn_available() -> bool {
let cudnn_lib_names = [
"libcudnn.so.9",
"libcudnn.so.8",
"libcudnn.so",
"libcudnn.so.9.0.0",
"libcudnn.so.8.0.0",
];
let search_dirs = get_cudnn_search_dirs();
for dir in &search_dirs {
for lib_name in &cudnn_lib_names {
let cudnn_path = dir.join(lib_name);
if cudnn_path.exists() {
prepend_ld_library_path(dir);
let _ = CUDNN_AVAILABLE.set(true);
return true;
}
}
if dir.exists() {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("libcudnn") && name_str.contains(".so") {
prepend_ld_library_path(dir);
let _ = CUDNN_AVAILABLE.set(true);
return true;
}
}
}
}
}
let _ = CUDNN_AVAILABLE.set(false);
false
}
fn is_valid_ort_dylib(path: &Path) -> bool {
unsafe {
match libloading::Library::new(path) {
Ok(lib) => lib
.get::<unsafe extern "C" fn() -> *const std::ffi::c_void>(b"OrtGetApiBase\0")
.is_ok(),
Err(_) => false,
}
}
}
#[cfg(not(feature = "cuda"))]
fn find_onnx_runtime() -> Option<PathBuf> {
let search_paths = get_search_paths();
let mut rejected: Vec<PathBuf> = Vec::new();
let try_candidate = |candidate: PathBuf, rejected: &mut Vec<PathBuf>| -> Option<PathBuf> {
if !candidate.exists() {
return None;
}
if is_valid_ort_dylib(&candidate) {
Some(candidate)
} else {
rejected.push(candidate);
None
}
};
for base_path in search_paths {
if let Some(p) = try_candidate(base_path.join(ORT_LIB_NAME), &mut rejected) {
return Some(p);
}
if let Ok(entries) = fs::read_dir(&base_path) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("libonnxruntime.so")
|| name_str.starts_with("libonnxruntime.dylib")
|| (name_str.starts_with("libonnxruntime.") && name_str.ends_with(".dylib"))
{
if let Some(p) = try_candidate(entry.path(), &mut rejected) {
return Some(p);
}
}
}
}
if let Some(p) = try_candidate(base_path.join("lib").join(ORT_LIB_NAME), &mut rejected) {
return Some(p);
}
}
if !rejected.is_empty() {
use std::sync::atomic::{AtomicBool, Ordering};
static WARNED: AtomicBool = AtomicBool::new(false);
if !WARNED.swap(true, Ordering::Relaxed) {
let mut seen: std::collections::HashSet<PathBuf> = std::collections::HashSet::new();
let unique: Vec<&PathBuf> = rejected
.iter()
.filter(|p| {
let canon = p.canonicalize().unwrap_or_else(|_| (*p).clone());
seen.insert(canon)
})
.collect();
eprintln!(
"⚠️ Found {} ONNX Runtime candidate(s) that failed to load (wrong arch, broken \
signature, or companion library); downloading a managed copy instead:",
unique.len()
);
for p in unique {
eprintln!(" - {}", p.display());
}
}
}
None
}
#[cfg(not(feature = "cuda"))]
fn get_search_paths() -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Some(home) = dirs::home_dir() {
paths.push(
home.join(".cache")
.join("colgrep")
.join("onnxruntime")
.join(ORT_VERSION)
.join(ORT_CACHE_SUBDIR),
);
paths.push(home.join(".cache").join("onnxruntime").join(ORT_VERSION));
if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
let conda_path = PathBuf::from(&conda_prefix);
paths.push(conda_path.join("lib"));
for entry in [
"lib/python3.12",
"lib/python3.11",
"lib/python3.10",
"lib/python3.9",
] {
paths.push(
conda_path
.join(entry)
.join("site-packages/onnxruntime/capi"),
);
}
}
for venv_name in [".venv", "venv", ".env", "env"] {
let venv_path = std::env::current_dir()
.map(|cwd| cwd.join(venv_name))
.unwrap_or_default();
#[cfg(target_os = "windows")]
paths.push(venv_path.join("Lib/site-packages/onnxruntime/capi"));
#[cfg(not(target_os = "windows"))]
for py in ["python3.12", "python3.11", "python3.10", "python3.9"] {
paths.push(
venv_path
.join("lib")
.join(py)
.join("site-packages/onnxruntime/capi"),
);
}
}
paths.push(home.join(".cache/uv"));
#[cfg(target_os = "macos")]
{
paths.push(PathBuf::from("/opt/homebrew/lib"));
paths.push(PathBuf::from("/usr/local/lib"));
}
#[cfg(target_os = "linux")]
{
}
}
paths
}
fn download_onnx_runtime() -> Result<PathBuf> {
let cache_dir = dirs::home_dir()
.context("Could not find home directory")?
.join(".cache")
.join("colgrep")
.join("onnxruntime")
.join(ORT_VERSION)
.join(ORT_CACHE_SUBDIR);
let lib_path = cache_dir.join(ORT_LIB_NAME);
#[cfg(all(feature = "cuda", target_os = "linux"))]
let already_cached = lib_path.exists()
&& cache_dir
.join("libonnxruntime_providers_shared.so")
.exists()
&& cache_dir.join("libonnxruntime_providers_cuda.so").exists();
#[cfg(all(feature = "cuda", target_os = "windows"))]
let already_cached = lib_path.exists()
&& cache_dir.join("onnxruntime_providers_shared.dll").exists()
&& cache_dir.join("onnxruntime_providers_cuda.dll").exists();
#[cfg(not(feature = "cuda"))]
let already_cached = lib_path.exists();
if already_cached {
return Ok(lib_path);
}
fs::create_dir_all(&cache_dir)?;
let (url, files_to_extract) = get_download_info()?;
if USE_GPU {
eprintln!("⚙️ Runtime: ONNX {} (GPU/CUDA)", ORT_VERSION);
} else {
eprintln!("⚙️ Runtime: ONNX {} (CPU)", ORT_VERSION);
}
let response = ureq::get(&url)
.call()
.context("Failed to download ONNX Runtime")?;
let mut archive_data = Vec::new();
response.into_reader().read_to_end(&mut archive_data)?;
extract_libraries(&archive_data, &files_to_extract, &cache_dir)?;
Ok(lib_path)
}
type FileToExtract = (String, String);
fn get_download_info() -> Result<(String, Vec<FileToExtract>)> {
let base = format!(
"https://github.com/microsoft/onnxruntime/releases/download/v{}",
ORT_VERSION
);
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let (archive, files) = (
format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.dylib".to_string(),
)],
);
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
let (archive, files) = (
format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.dylib".to_string(),
)],
);
#[cfg(all(target_os = "linux", target_arch = "x86_64", feature = "cuda"))]
let (archive, files) = {
let archive_name = format!("onnxruntime-linux-x64-gpu-{}", ORT_VERSION);
(
format!("{}.tgz", archive_name),
vec![
(
format!("{}/lib/libonnxruntime.so.{}", archive_name, ORT_VERSION),
"libonnxruntime.so".to_string(),
),
(
format!("{}/lib/libonnxruntime_providers_shared.so", archive_name),
"libonnxruntime_providers_shared.so".to_string(),
),
(
format!("{}/lib/libonnxruntime_providers_cuda.so", archive_name),
"libonnxruntime_providers_cuda.so".to_string(),
),
],
)
};
#[cfg(all(target_os = "linux", target_arch = "x86_64", not(feature = "cuda")))]
let (archive, files) = (
format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.so".to_string(),
)],
);
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
let (archive, files) = (
format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.so".to_string(),
)],
);
#[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "cuda"))]
let (archive, files) = {
let archive_name = format!("onnxruntime-win-x64-gpu-{}", ORT_VERSION);
(
format!("{}.zip", archive_name),
vec![
(
format!("{}/lib/onnxruntime.dll", archive_name),
"onnxruntime.dll".to_string(),
),
(
format!("{}/lib/onnxruntime_providers_shared.dll", archive_name),
"onnxruntime_providers_shared.dll".to_string(),
),
(
format!("{}/lib/onnxruntime_providers_cuda.dll", archive_name),
"onnxruntime_providers_cuda.dll".to_string(),
),
],
)
};
#[cfg(all(target_os = "windows", target_arch = "x86_64", not(feature = "cuda")))]
let (archive, files) = (
format!("onnxruntime-win-x64-{}.zip", ORT_VERSION),
vec![(
format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION),
"onnxruntime.dll".to_string(),
)],
);
#[cfg(not(any(
all(target_os = "macos", target_arch = "aarch64"),
all(target_os = "macos", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "aarch64"),
all(target_os = "windows", target_arch = "x86_64"),
)))]
return Err(anyhow::anyhow!(
"Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH."
));
Ok((format!("{}/{}", base, archive), files))
}
#[cfg(not(target_os = "windows"))]
fn extract_libraries(
archive_data: &[u8],
files_to_extract: &[FileToExtract],
dest_dir: &Path,
) -> Result<()> {
use flate2::read::GzDecoder;
use std::collections::HashSet;
use std::io::Read;
let decoder = GzDecoder::new(archive_data);
let mut archive = tar::Archive::new(decoder);
let files_map: std::collections::HashMap<&str, &str> = files_to_extract
.iter()
.map(|(src, dst)| (src.as_str(), dst.as_str()))
.collect();
let mut extracted: HashSet<String> = HashSet::new();
for entry in archive.entries()? {
let mut entry = entry?;
let path = entry.path()?;
let path_str = path.to_string_lossy().to_string();
let normalized_path = path_str.strip_prefix("./").unwrap_or(&path_str);
if let Some(&dest_name) = files_map.get(normalized_path) {
let dest_path = dest_dir.join(dest_name);
let mut lib_data = Vec::new();
entry.read_to_end(&mut lib_data)?;
fs::write(&dest_path, lib_data)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(&dest_path, fs::Permissions::from_mode(0o755))?;
}
extracted.insert(normalized_path.to_string());
}
}
for (src, _) in files_to_extract {
if !extracted.contains(src.as_str()) {
return Err(anyhow::anyhow!("Library not found in archive: {}", src));
}
}
Ok(())
}
#[cfg(target_os = "windows")]
fn extract_libraries(
archive_data: &[u8],
files_to_extract: &[FileToExtract],
dest_dir: &Path,
) -> Result<()> {
use std::collections::HashSet;
use std::io::{Cursor, Read};
let cursor = Cursor::new(archive_data);
let mut archive = zip::ZipArchive::new(cursor)?;
let files_map: std::collections::HashMap<&str, &str> = files_to_extract
.iter()
.map(|(src, dst)| (src.as_str(), dst.as_str()))
.collect();
let mut extracted: HashSet<String> = HashSet::new();
for i in 0..archive.len() {
let mut file = archive.by_index(i)?;
let path = file.name().to_string();
let normalized_path = path.strip_prefix("./").unwrap_or(&path);
if let Some(&dest_name) = files_map.get(normalized_path) {
let dest_path = dest_dir.join(dest_name);
let mut lib_data = Vec::new();
file.read_to_end(&mut lib_data)?;
fs::write(&dest_path, lib_data)?;
extracted.insert(normalized_path.to_string());
}
}
for (src, _) in files_to_extract {
if !extracted.contains(src.as_str()) {
return Err(anyhow::anyhow!("Library not found in archive: {}", src));
}
}
Ok(())
}