use crate::execution_providers::OpPlacement;
use crate::graph::Graph;
use crate::tensor::Tensor;
use crate::OnnxError;
use oxionnx_core::OperatorRegistry;
use std::collections::HashMap;
use std::path::Path;
use super::types::{raw_meta_to_model_metadata, ModelMetadata, OptLevel};
use super::Session;
impl Session {
pub fn from_file(path: &Path) -> Result<Self, OnnxError> {
let bytes = std::fs::read(path).map_err(|e| {
OnnxError::Parse(format!("Cannot read ONNX file {}: {e}", path.display()))
})?;
let base_path = path.parent().unwrap_or_else(|| Path::new("."));
let registry = oxionnx_ops::default_registry();
let (raw_meta, graph, weights) =
crate::model::load_with_metadata_and_path(&bytes, base_path)
.map_err(OnnxError::Parse)?;
let metadata = raw_meta_to_model_metadata(raw_meta);
Self::build_from_graph(
graph,
weights,
metadata,
registry,
OptLevel::All,
false,
false,
false,
false,
None,
OpPlacement::default(),
)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, OnnxError> {
Self::from_bytes_with_registry(bytes, oxionnx_ops::default_registry())
}
pub fn from_file_with_registry(
path: &Path,
registry: OperatorRegistry,
) -> Result<Self, OnnxError> {
let bytes = std::fs::read(path).map_err(|e| {
OnnxError::Parse(format!("Cannot read ONNX file {}: {e}", path.display()))
})?;
let base_path = path.parent().unwrap_or_else(|| Path::new("."));
let (raw_meta, graph, weights) =
crate::model::load_with_metadata_and_path(&bytes, base_path)
.map_err(OnnxError::Parse)?;
let metadata = raw_meta_to_model_metadata(raw_meta);
Self::build_from_graph(
graph,
weights,
metadata,
registry,
OptLevel::All,
false,
false,
false,
false,
None,
OpPlacement::default(),
)
}
pub fn from_bytes_with_registry(
bytes: &[u8],
registry: OperatorRegistry,
) -> Result<Self, OnnxError> {
let (raw_meta, graph, weights) =
crate::model::load_with_metadata(bytes).map_err(OnnxError::Parse)?;
let metadata = raw_meta_to_model_metadata(raw_meta);
Self::build_from_graph(
graph,
weights,
metadata,
registry,
OptLevel::All,
false,
false,
false,
false,
None,
OpPlacement::default(),
)
}
pub fn from_graph(graph: Graph, weights: HashMap<String, Tensor>) -> Result<Self, OnnxError> {
Self::from_graph_with_registry(graph, weights, oxionnx_ops::default_registry())
}
pub fn from_graph_with_registry(
graph: Graph,
weights: HashMap<String, Tensor>,
registry: OperatorRegistry,
) -> Result<Self, OnnxError> {
Self::build_from_graph(
graph,
weights,
ModelMetadata::default(),
registry,
OptLevel::All,
false,
false,
false,
false,
None,
OpPlacement::default(),
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn build_from_graph(
graph: Graph,
weights: HashMap<String, Tensor>,
metadata: ModelMetadata,
registry: OperatorRegistry,
opt_level: OptLevel,
enable_profiling: bool,
enable_memory_pool: bool,
parallel: bool,
mixed_precision: bool,
num_threads: Option<usize>,
op_placement: OpPlacement,
) -> Result<Self, OnnxError> {
use crate::memory::SizeClassPool;
use std::sync::Mutex;
let mut weights = weights;
let input_names = graph.input_names.clone();
let output_names = graph.output_names.clone();
let input_infos = graph.input_infos.clone();
let output_infos = graph.output_infos.clone();
let optimized_nodes = match opt_level {
OptLevel::None => graph.nodes,
OptLevel::Basic | OptLevel::Extended | OptLevel::All => crate::optimizer::optimize(
graph.nodes,
&mut weights,
&graph.output_names,
®istry,
),
};
let opt_graph = Graph {
nodes: optimized_nodes,
input_names: input_names.clone(),
output_names: output_names.clone(),
..Default::default()
};
let known: Vec<String> = weights
.keys()
.cloned()
.chain(input_names.iter().cloned())
.collect();
let order = opt_graph.topological_sort(&known);
let sorted_nodes: Vec<crate::graph::Node> =
order.iter().map(|&i| opt_graph.nodes[i].clone()).collect();
let profiling_data = if enable_profiling {
Some(Mutex::new(Vec::new()))
} else {
Option::None
};
let (pool, shape_cache) = if enable_memory_pool {
let input_shapes: HashMap<String, Vec<usize>> = HashMap::new();
let shapes = crate::optimizer::shape_inference::infer_shapes(
&sorted_nodes,
&weights,
&input_shapes,
);
(Some(Mutex::new(SizeClassPool::new())), Some(shapes))
} else {
(None, None)
};
#[cfg(feature = "cuda")]
let cuda = oxionnx_cuda::CudaContext::try_new();
#[cfg(feature = "directml")]
let dml = oxionnx_directml::DirectMLContext::try_new();
#[cfg(feature = "gpu")]
let gpu = crate::gpu::GpuContext::try_new();
if mixed_precision {
tracing::info!("Mixed-precision inference enabled (f16 activations, f32 accumulation)");
}
#[cfg(not(target_arch = "wasm32"))]
let thread_pool = num_threads
.map(|n| rayon::ThreadPoolBuilder::new().num_threads(n).build())
.transpose()
.map_err(|e| OnnxError::Internal(format!("thread pool: {e}")))?;
Ok(Self {
sorted_nodes,
weights,
input_names,
output_names,
input_infos,
output_infos,
metadata,
registry,
profiling_data,
pool,
shape_cache,
parallel,
mixed_precision,
op_placement,
dynamic_dims: Mutex::new(HashMap::new()),
resolved_shapes: Mutex::new(HashMap::new()),
#[cfg(not(target_arch = "wasm32"))]
thread_pool,
#[cfg(feature = "cuda")]
cuda,
#[cfg(feature = "directml")]
dml,
#[cfg(feature = "gpu")]
gpu,
})
}
pub fn builder() -> super::SessionBuilder {
super::SessionBuilder::new()
}
}