use std::path::Path;
use crate::{
error::{Error, Result},
options::Options,
};
pub(crate) fn build_session(graph: &Path, opts: Options) -> Result<ort::session::Session> {
use ort::session::Session;
let level = opts.optimization_level();
let mut builder = Session::builder()
.map_err(|source| Error::LoadGraph {
path: graph.to_path_buf(),
source,
})?
.with_optimization_level(level)
.map_err(|source| Error::LoadGraph {
path: graph.to_path_buf(),
source: ort::Error::from(source),
})?
.with_intra_threads(opts.threads().intra_threads())
.map_err(|source| Error::LoadGraph {
path: graph.to_path_buf(),
source: ort::Error::from(source),
})?
.with_inter_threads(opts.threads().inter_threads())
.map_err(|source| Error::LoadGraph {
path: graph.to_path_buf(),
source: ort::Error::from(source),
})?
.with_parallel_execution(opts.threads().parallel_execution())
.map_err(|source| Error::LoadGraph {
path: graph.to_path_buf(),
source: ort::Error::from(source),
})?;
let providers = collect_execution_providers();
if !providers.is_empty() {
builder = builder
.with_execution_providers(providers)
.map_err(|source| Error::LoadGraph {
path: graph.to_path_buf(),
source: ort::Error::from(source),
})?;
}
builder
.commit_from_file(graph)
.map_err(|source| Error::LoadGraph {
path: graph.to_path_buf(),
source,
})
}
fn collect_execution_providers() -> Vec<ort::ep::ExecutionProviderDispatch> {
#[allow(unused_mut)]
let mut providers: Vec<ort::ep::ExecutionProviderDispatch> = Vec::new();
#[cfg(feature = "tensorrt")]
{
providers.push(ort::ep::TensorRT::default().build());
}
#[cfg(feature = "cuda")]
{
providers.push(ort::ep::CUDA::default().build());
}
#[cfg(feature = "directml")]
{
providers.push(ort::ep::DirectML::default().build());
}
#[cfg(feature = "rocm")]
{
providers.push(ort::ep::ROCm::default().build());
}
#[cfg(feature = "coreml")]
{
providers.push(ort::ep::CoreML::default().build());
}
providers
}