use ort::execution_providers::CPUExecutionProvider;
#[cfg(feature = "ort-cuda")]
use ort::execution_providers::CUDAExecutionProvider;
#[cfg(feature = "ort-directml")]
use ort::execution_providers::DirectMLExecutionProvider;
#[cfg(feature = "ort-rocm")]
use ort::execution_providers::ROCmExecutionProvider;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use std::path::Path;
use crate::accel::{get_ort_accelerator, OrtAccelerator};
fn execution_providers() -> Vec<ort::execution_providers::ExecutionProviderDispatch> {
let pref = get_ort_accelerator();
let mut eps = Vec::new();
match pref {
OrtAccelerator::CpuOnly => {
}
OrtAccelerator::Cuda => {
#[cfg(feature = "ort-cuda")]
eps.push(CUDAExecutionProvider::default().build());
#[cfg(not(feature = "ort-cuda"))]
log::warn!(
"Accelerator set to CUDA but ort-cuda feature is not enabled; falling back to CPU"
);
}
OrtAccelerator::DirectMl => {
#[cfg(feature = "ort-directml")]
eps.push(DirectMLExecutionProvider::default().build());
#[cfg(not(feature = "ort-directml"))]
log::warn!("Accelerator set to DirectML but ort-directml feature is not enabled; falling back to CPU");
}
OrtAccelerator::Rocm => {
#[cfg(feature = "ort-rocm")]
eps.push(ROCmExecutionProvider::default().build());
#[cfg(not(feature = "ort-rocm"))]
log::warn!(
"Accelerator set to ROCm but ort-rocm feature is not enabled; falling back to CPU"
);
}
OrtAccelerator::Auto => {
#[cfg(feature = "ort-cuda")]
eps.push(CUDAExecutionProvider::default().build());
#[cfg(feature = "ort-rocm")]
eps.push(ROCmExecutionProvider::default().build());
}
}
eps.push(CPUExecutionProvider::default().build());
eps
}
fn directml_active() -> bool {
get_ort_accelerator() == OrtAccelerator::DirectMl && cfg!(feature = "ort-directml")
}
fn build_session(
path: &Path,
intra_threads: Option<usize>,
parallel_execution: bool,
) -> Result<Session, ort::Error> {
let mut builder =
Session::builder()?.with_optimization_level(GraphOptimizationLevel::Level3)?;
if let Some(n) = intra_threads {
if n > 0 {
builder = builder.with_intra_threads(n)?;
}
}
let use_parallel = if directml_active() {
false
} else {
parallel_execution
};
builder = builder.with_parallel_execution(use_parallel)?;
if directml_active() {
builder = builder.with_memory_pattern(false)?;
}
let session = builder
.with_execution_providers(execution_providers())?
.commit_from_file(path)?;
for input in session.inputs() {
log::info!(
"Model input: name={}, type={:?}",
input.name(),
input.dtype()
);
}
for output in session.outputs() {
log::info!(
"Model output: name={}, type={:?}",
output.name(),
output.dtype()
);
}
Ok(session)
}
pub fn create_session(path: &Path) -> Result<Session, ort::Error> {
build_session(path, None, true)
}
pub fn create_session_with_threads(path: &Path, num_threads: usize) -> Result<Session, ort::Error> {
build_session(path, Some(num_threads), true)
}
pub fn resolve_model_path(
dir: &Path,
name: &str,
quantization: &super::Quantization,
) -> std::path::PathBuf {
let suffix = match quantization {
super::Quantization::FP32 => None,
super::Quantization::FP16 => Some("fp16"),
super::Quantization::Int8 => Some("int8"),
};
if let Some(suffix) = suffix {
let path = dir.join(format!("{}.{}.onnx", name, suffix));
if path.exists() {
log::info!("Loading {} model: {}", suffix, path.display());
return path;
}
log::warn!(
"{} model not found at {}, falling back to {}.onnx",
suffix,
path.display(),
name
);
}
dir.join(format!("{}.onnx", name))
}
pub fn read_metadata_str(session: &Session, key: &str) -> Result<Option<String>, ort::Error> {
let meta = session.metadata()?;
Ok(meta.custom(key).filter(|s| !s.is_empty()))
}
pub fn read_metadata_i32(
session: &Session,
key: &str,
default: Option<i32>,
) -> Result<Option<i32>, crate::TranscribeError> {
let str_val = read_metadata_str(session, key).map_err(|e| {
crate::TranscribeError::Config(format!("failed to read metadata '{}': {}", key, e))
})?;
match str_val {
Some(v) => Ok(Some(v.parse::<i32>().map_err(|e| {
crate::TranscribeError::Config(format!("failed to parse '{}': {}", key, e))
})?)),
None => Ok(default),
}
}
pub fn read_metadata_float_vec(
session: &Session,
key: &str,
) -> Result<Option<Vec<f32>>, crate::TranscribeError> {
let str_val = read_metadata_str(session, key).map_err(|e| {
crate::TranscribeError::Config(format!("failed to read metadata '{}': {}", key, e))
})?;
match str_val {
Some(v) => {
let floats: Result<Vec<f32>, _> =
v.split(',').map(|s| s.trim().parse::<f32>()).collect();
Ok(Some(floats.map_err(|e| {
crate::TranscribeError::Config(format!(
"failed to parse floats in '{}': {}",
key, e
))
})?))
}
None => Ok(None),
}
}