use once_cell::sync::OnceCell;
use ort::ep::ExecutionProvider as OrtExecutionProvider;
use ort::session::Session;
use std::path::{Path, PathBuf};
use super::{EmbedderError, ExecutionProvider};
pub(super) fn ort_err<T>(e: ort::Error<T>) -> EmbedderError {
EmbedderError::InferenceFailed(e.to_string())
}
#[cfg(target_os = "linux")]
fn ensure_ort_provider_libs() {
let ort_lib_dir = match find_ort_provider_dir() {
Some(d) => d,
None => return,
};
let provider_libs = [
"libonnxruntime_providers_shared.so",
"libonnxruntime_providers_cuda.so",
"libonnxruntime_providers_tensorrt.so",
];
let ort_search_dir = match ort_runtime_search_dir() {
Some(d) => d,
None => return,
};
symlink_providers(&ort_lib_dir, &ort_search_dir, &provider_libs);
if let Some(ld_dir) = find_ld_library_dir(&ort_lib_dir) {
symlink_providers(&ort_lib_dir, &ld_dir, &provider_libs);
}
}
#[cfg(target_os = "linux")]
fn ort_runtime_search_dir() -> Option<PathBuf> {
let cmdline = std::fs::read("/proc/self/cmdline").ok()?;
let argv0_end = cmdline.iter().position(|&b| b == 0)?;
let argv0 = std::str::from_utf8(&cmdline[..argv0_end]).ok()?;
let abs_path = if argv0.starts_with('/') {
PathBuf::from(argv0)
} else {
std::env::current_dir().ok()?.join(argv0)
};
abs_path.parent().map(|p| p.to_path_buf())
}
#[cfg(target_os = "linux")]
fn find_ort_provider_dir() -> Option<PathBuf> {
let cache_dir = dirs::cache_dir()?;
let triplet = match std::env::consts::ARCH {
"x86_64" => "x86_64-unknown-linux-gnu",
"aarch64" => "aarch64-unknown-linux-gnu",
_ => return None,
};
let ort_cache = cache_dir.join(format!("ort.pyke.io/dfbin/{triplet}"));
match std::fs::read_dir(&ort_cache) {
Ok(entries) => {
let mut dirs: Vec<PathBuf> = entries
.filter_map(|e| e.ok())
.filter(|e| e.path().is_dir())
.map(|e| e.path())
.collect();
dirs.sort_by(|a, b| b.file_name().cmp(&a.file_name()));
dirs.into_iter().next()
}
Err(e) => {
tracing::debug!(path = %ort_cache.display(), error = %e, "ORT cache not found");
None
}
}
}
#[cfg(target_os = "linux")]
fn find_ld_library_dir(ort_lib_dir: &Path) -> Option<PathBuf> {
let ld_path = std::env::var("LD_LIBRARY_PATH").unwrap_or_default();
let ort_cache_str = ort_lib_dir.to_string_lossy();
ld_path
.split(':')
.find(|p| !p.is_empty() && Path::new(p).is_dir() && !ort_cache_str.starts_with(p))
.map(PathBuf::from)
}
#[cfg(target_os = "linux")]
fn symlink_providers(src_dir: &Path, target_dir: &Path, libs: &[&str]) {
for lib in libs {
let src = src_dir.join(lib);
let dst = target_dir.join(lib);
if !src.exists() {
continue;
}
if let Ok(existing) = std::fs::read_link(&dst) {
let existing_canon = dunce::canonicalize(&existing).unwrap_or(existing);
let src_canon = dunce::canonicalize(&src).unwrap_or_else(|_| src.clone());
if existing_canon == src_canon {
continue;
}
let _ = std::fs::remove_file(&dst);
}
if let Err(e) = std::os::unix::fs::symlink(&src, &dst) {
tracing::debug!("Failed to symlink {}: {}", lib, e);
}
}
}
#[cfg(not(target_os = "linux"))]
fn ensure_ort_provider_libs() {
tracing::debug!(
"Provider library setup not implemented for this platform — GPU may not activate"
);
}
static CACHED_PROVIDER: OnceCell<ExecutionProvider> = OnceCell::new();
pub(crate) fn select_provider() -> ExecutionProvider {
*CACHED_PROVIDER.get_or_init(detect_provider)
}
fn detect_provider() -> ExecutionProvider {
let _span = tracing::info_span!("detect_provider").entered();
use ort::ep::{TensorRT, CUDA};
ensure_ort_provider_libs();
let cuda = CUDA::default();
if cuda.is_available().unwrap_or(false) {
let provider = ExecutionProvider::CUDA { device_id: 0 };
tracing::info!(provider = ?provider, "Execution provider selected");
return provider;
}
let tensorrt = TensorRT::default();
if tensorrt.is_available().unwrap_or(false) {
let provider = ExecutionProvider::TensorRT { device_id: 0 };
tracing::info!(provider = ?provider, "Execution provider selected");
return provider;
}
let provider = ExecutionProvider::CPU;
tracing::info!(provider = ?provider, "Execution provider selected");
provider
}
pub(crate) fn create_session(
model_path: &Path,
provider: ExecutionProvider,
) -> Result<Session, EmbedderError> {
let _span = tracing::info_span!("create_session", provider = ?provider).entered();
use ort::ep::{TensorRT, CUDA};
tracing::info!(provider = ?provider, model_path = %model_path.display(), "Creating ONNX session");
let mut builder = Session::builder().map_err(ort_err)?;
let session = match provider {
ExecutionProvider::CUDA { device_id } => builder
.with_execution_providers([CUDA::default().with_device_id(device_id).build()])
.map_err(ort_err)?
.commit_from_file(model_path)
.map_err(ort_err)?,
ExecutionProvider::TensorRT { device_id } => {
builder
.with_execution_providers([
TensorRT::default().with_device_id(device_id).build(),
CUDA::default().with_device_id(device_id).build(),
])
.map_err(ort_err)?
.commit_from_file(model_path)
.map_err(ort_err)?
}
ExecutionProvider::CPU => builder.commit_from_file(model_path).map_err(ort_err)?,
};
Ok(session)
}