use anyhow::{Context, Result};
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
const ORT_VERSION: &str = "1.23.0";
#[cfg(target_os = "macos")]
const ORT_LIB_NAME: &str = "libonnxruntime.dylib";
#[cfg(target_os = "linux")]
const ORT_LIB_NAME: &str = "libonnxruntime.so";
#[cfg(target_os = "windows")]
const ORT_LIB_NAME: &str = "onnxruntime.dll";
pub fn ensure_onnx_runtime() -> Result<PathBuf> {
if let Ok(path) = env::var("ORT_DYLIB_PATH") {
let path = PathBuf::from(&path);
if path.exists() {
return Ok(path);
}
}
if let Some(path) = find_onnx_runtime() {
env::set_var("ORT_DYLIB_PATH", &path);
return Ok(path);
}
let path = download_onnx_runtime()?;
env::set_var("ORT_DYLIB_PATH", &path);
Ok(path)
}
fn find_onnx_runtime() -> Option<PathBuf> {
let search_paths = get_search_paths();
for base_path in search_paths {
let lib_path = base_path.join(ORT_LIB_NAME);
if lib_path.exists() {
return Some(lib_path);
}
if let Ok(entries) = fs::read_dir(&base_path) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("libonnxruntime")
&& (name_str.ends_with(".dylib") || name_str.ends_with(".so"))
{
return Some(entry.path());
}
}
}
let lib_subdir = base_path.join("lib").join(ORT_LIB_NAME);
if lib_subdir.exists() {
return Some(lib_subdir);
}
}
None
}
fn get_search_paths() -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Some(home) = dirs::home_dir() {
paths.push(home.join(".cache").join("onnxruntime").join(ORT_VERSION));
if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
let conda_path = PathBuf::from(&conda_prefix);
paths.push(conda_path.join("lib"));
for entry in [
"lib/python3.12",
"lib/python3.11",
"lib/python3.10",
"lib/python3.9",
] {
paths.push(
conda_path
.join(entry)
.join("site-packages/onnxruntime/capi"),
);
}
}
for venv_name in [".venv", "venv", ".env", "env"] {
let venv_path = std::env::current_dir()
.map(|cwd| cwd.join(venv_name))
.unwrap_or_default();
#[cfg(target_os = "windows")]
paths.push(venv_path.join("Lib/site-packages/onnxruntime/capi"));
#[cfg(not(target_os = "windows"))]
for py in ["python3.12", "python3.11", "python3.10", "python3.9"] {
paths.push(
venv_path
.join("lib")
.join(py)
.join("site-packages/onnxruntime/capi"),
);
}
}
paths.push(home.join(".cache/uv"));
#[cfg(target_os = "macos")]
{
paths.push(PathBuf::from("/opt/homebrew/lib"));
paths.push(PathBuf::from("/usr/local/lib"));
}
#[cfg(target_os = "linux")]
{
paths.push(PathBuf::from("/usr/lib"));
paths.push(PathBuf::from("/usr/local/lib"));
paths.push(PathBuf::from("/usr/lib/x86_64-linux-gnu"));
}
}
paths
}
fn download_onnx_runtime() -> Result<PathBuf> {
let cache_dir = dirs::home_dir()
.context("Could not find home directory")?
.join(".cache")
.join("onnxruntime")
.join(ORT_VERSION);
let lib_path = cache_dir.join(ORT_LIB_NAME);
if lib_path.exists() {
return Ok(lib_path);
}
fs::create_dir_all(&cache_dir)?;
let (url, archive_lib_path) = get_download_url()?;
eprintln!("⚙️ Runtime: ONNX {}", ORT_VERSION);
let response = ureq::get(&url)
.call()
.context("Failed to download ONNX Runtime")?;
let mut archive_data = Vec::new();
response.into_reader().read_to_end(&mut archive_data)?;
extract_library(&archive_data, &archive_lib_path, &lib_path)?;
Ok(lib_path)
}
fn get_download_url() -> Result<(String, String)> {
let base = format!(
"https://github.com/microsoft/onnxruntime/releases/download/v{}",
ORT_VERSION
);
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let (archive, lib_path) = (
format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION),
format!(
"onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
);
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
let (archive, lib_path) = (
format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION),
format!(
"onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
);
#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
let (archive, lib_path) = (
format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION),
format!(
"onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
);
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
let (archive, lib_path) = (
format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION),
format!(
"onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
);
#[cfg(all(target_os = "windows", target_arch = "x86_64"))]
let (archive, lib_path) = (
format!("onnxruntime-win-x64-{}.zip", ORT_VERSION),
format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION),
);
#[cfg(not(any(
all(target_os = "macos", target_arch = "aarch64"),
all(target_os = "macos", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "aarch64"),
all(target_os = "windows", target_arch = "x86_64"),
)))]
return Err(anyhow::anyhow!(
"Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH."
));
Ok((format!("{}/{}", base, archive), lib_path))
}
#[cfg(not(target_os = "windows"))]
fn extract_library(archive_data: &[u8], lib_path_in_archive: &str, dest: &Path) -> Result<()> {
use flate2::read::GzDecoder;
use std::io::Read;
let decoder = GzDecoder::new(archive_data);
let mut archive = tar::Archive::new(decoder);
for entry in archive.entries()? {
let mut entry = entry?;
let path = entry.path()?;
let path_str = path.to_string_lossy();
let normalized_path = path_str.strip_prefix("./").unwrap_or(&path_str);
if normalized_path == lib_path_in_archive {
let mut lib_data = Vec::new();
entry.read_to_end(&mut lib_data)?;
fs::write(dest, lib_data)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(dest, fs::Permissions::from_mode(0o755))?;
}
return Ok(());
}
}
Err(anyhow::anyhow!(
"Library not found in archive: {}",
lib_path_in_archive
))
}
#[cfg(target_os = "windows")]
fn extract_library(archive_data: &[u8], lib_path_in_archive: &str, dest: &Path) -> Result<()> {
use std::io::{Cursor, Read};
let cursor = Cursor::new(archive_data);
let mut archive = zip::ZipArchive::new(cursor)?;
for i in 0..archive.len() {
let mut file = archive.by_index(i)?;
let path = file.name();
let normalized_path = path.strip_prefix("./").unwrap_or(path);
if normalized_path == lib_path_in_archive {
let mut lib_data = Vec::new();
file.read_to_end(&mut lib_data)?;
fs::write(dest, lib_data)?;
return Ok(());
}
}
Err(anyhow::anyhow!(
"Library not found in archive: {}",
lib_path_in_archive
))
}