onnx_ir/pipeline.rs
1//! ONNX to IR conversion pipeline orchestrator
2//!
3//! This module provides the high-level orchestration of the ONNX conversion process.
4//! It clearly shows the entire conversion flow from start to finish.
5//!
6//! # Zero-Copy Loading
7//!
8//! When the `mmap` feature is enabled (default), files are memory-mapped for zero-copy
9//! tensor loading. This significantly reduces memory usage for large models.
10//!
11//! # Usage
12//!
13//! ```ignore
14//! use onnx_ir::OnnxGraphBuilder;
15//!
16//! // Build from file
17//! let graph = OnnxGraphBuilder::new().parse_file("model.onnx")?;
18//!
19//! // Build from bytes
20//! let graph = OnnxGraphBuilder::new().parse_bytes(&bytes)?;
21//!
22//! // Build from reader
23//! let graph = OnnxGraphBuilder::new().parse_reader(file)?;
24//! ```
25
26use std::io::Read;
27use std::{fmt, fs::File, path::Path};
28
29use protobuf::Message;
30
31use crate::{ir::OnnxGraph, processor::ProcessError, protos::ModelProto};
32
33use super::phases::{
34 finalization, initialization, node_conversion, post_processing, type_inference,
35};
36
37/// Errors that can occur when parsing ONNX models
38#[derive(Debug)]
39pub enum Error {
40 /// Failed to open or read the ONNX file
41 Io { path: String, error: std::io::Error },
42
43 /// Failed to parse ONNX protobuf format
44 InvalidFormat { path: Option<String>, error: String },
45
46 /// Model graph nodes are not topologically sorted (ONNX spec violation)
47 InvalidGraphStructure { reason: String },
48
49 /// Missing required opset version for default domain
50 MissingOpsetVersion,
51
52 /// Type inference failed during IR conversion
53 TypeInference(ProcessError),
54
55 /// Generic processing error
56 Processing(ProcessError),
57}
58
59impl fmt::Display for Error {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 Error::Io { path, error } => {
63 write!(f, "Failed to open ONNX file '{}': {}", path, error)
64 }
65 Error::InvalidFormat { path, error } => {
66 if let Some(p) = path {
67 write!(f, "Invalid ONNX format in '{}': {}", p, error)
68 } else {
69 write!(f, "Invalid ONNX format: {}", error)
70 }
71 }
72 Error::InvalidGraphStructure { reason } => {
73 write!(f, "Invalid ONNX graph structure: {}", reason)
74 }
75 Error::MissingOpsetVersion => {
76 write!(
77 f,
78 "ONNX model must specify opset version for default domain"
79 )
80 }
81 Error::TypeInference(e) => {
82 write!(f, "Type inference failed: {:?}", e)
83 }
84 Error::Processing(e) => {
85 write!(f, "Processing error: {:?}", e)
86 }
87 }
88 }
89}
90
91impl std::error::Error for Error {
92 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
93 match self {
94 Error::Io { error, .. } => Some(error),
95 _ => None,
96 }
97 }
98}
99
100impl From<ProcessError> for Error {
101 fn from(error: ProcessError) -> Self {
102 Error::Processing(error)
103 }
104}
105
106/// ONNX IR builder with fluent API
107///
108/// Builds ONNX IR graphs from various sources (files, bytes, readers).
109/// Future configuration options can be added without breaking changes.
110///
111/// # Examples
112///
113/// ```ignore
114/// use onnx_ir::OnnxGraphBuilder;
115///
116/// // Build from file (uses mmap when feature is enabled)
117/// let graph = OnnxGraphBuilder::new().parse_file("model.onnx")?;
118///
119/// // Build from bytes
120/// let graph = OnnxGraphBuilder::new().parse_bytes(&model_bytes)?;
121///
122/// // Build from reader
123/// let graph = OnnxGraphBuilder::new().parse_reader(std::io::Cursor::new(data))?;
124/// ```
125#[derive(Debug, Clone, Default)]
126pub struct OnnxGraphBuilder {
127 // Future options can be added here without breaking changes
128 // e.g., strict_mode: bool, min_opset_version: Option<usize>
129}
130
131impl OnnxGraphBuilder {
132 /// Create a new ONNX graph builder with default settings
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 /// Parse an ONNX model from a file path
138 ///
139 /// When the `mmap` feature is enabled (default), the file is memory-mapped
140 /// for zero-copy tensor loading, significantly reducing memory usage.
141 ///
142 /// # Errors
143 ///
144 /// Returns an error if:
145 /// - File cannot be opened or read
146 /// - File is not valid ONNX protobuf format
147 /// - Graph nodes are not topologically sorted
148 /// - Type inference fails
149 pub fn parse_file(self, path: impl AsRef<Path>) -> Result<OnnxGraph, Error> {
150 let path = path.as_ref();
151 log::info!("Parsing ONNX file: {}", path.display());
152
153 // Load file contents - mmap when feature is enabled
154 #[cfg(feature = "mmap")]
155 let buffer = {
156 let file = File::open(path).map_err(|error| Error::Io {
157 path: path.display().to_string(),
158 error,
159 })?;
160 // SAFETY: We're mapping a read-only file. The bytes::Bytes keeps
161 // the mmap alive for as long as tensor data references it.
162 let mmap = unsafe { memmap2::Mmap::map(&file) }.map_err(|error| Error::Io {
163 path: path.display().to_string(),
164 error,
165 })?;
166 log::debug!("Memory-mapped ONNX file ({} bytes)", mmap.len());
167 bytes::Bytes::from_owner(mmap)
168 };
169
170 #[cfg(not(feature = "mmap"))]
171 let buffer = {
172 let mut file = File::open(path).map_err(|error| Error::Io {
173 path: path.display().to_string(),
174 error,
175 })?;
176 let mut buf = Vec::new();
177 file.read_to_end(&mut buf).map_err(|error| Error::Io {
178 path: path.display().to_string(),
179 error,
180 })?;
181 log::debug!("Read ONNX file into memory ({} bytes)", buf.len());
182 bytes::Bytes::from(buf)
183 };
184
185 self.parse_buffer(buffer, Some(path))
186 }
187
188 /// Parse an ONNX model from a byte slice
189 ///
190 /// Note: This copies the data internally. For large models already in memory
191 /// as `bytes::Bytes`, consider using the internal buffer directly.
192 ///
193 /// # Errors
194 ///
195 /// Returns an error if:
196 /// - Data is not valid ONNX protobuf format
197 /// - Graph nodes are not topologically sorted
198 /// - Type inference fails
199 pub fn parse_bytes(self, data: &[u8]) -> Result<OnnxGraph, Error> {
200 let buffer = bytes::Bytes::copy_from_slice(data);
201 self.parse_buffer(buffer, None)
202 }
203
204 /// Parse an ONNX model from a reader
205 ///
206 /// Reads all data into memory before parsing.
207 ///
208 /// # Errors
209 ///
210 /// Returns an error if:
211 /// - Reading from the reader fails
212 /// - Data is not valid ONNX protobuf format
213 /// - Graph nodes are not topologically sorted
214 /// - Type inference fails
215 pub fn parse_reader<R: Read>(self, mut reader: R) -> Result<OnnxGraph, Error> {
216 let mut buf = Vec::new();
217 reader.read_to_end(&mut buf).map_err(|error| Error::Io {
218 path: "<reader>".to_string(),
219 error,
220 })?;
221 log::debug!("Read ONNX from reader ({} bytes)", buf.len());
222 let buffer = bytes::Bytes::from(buf);
223 self.parse_buffer(buffer, None)
224 }
225
226 /// Internal: Parse from a bytes::Bytes buffer
227 fn parse_buffer(
228 self,
229 buffer: bytes::Bytes,
230 source_path: Option<&Path>,
231 ) -> Result<OnnxGraph, Error> {
232 let path_str = source_path.map(|p| p.display().to_string());
233
234 // Get the base directory for external data resolution
235 let base_path = source_path.and_then(|p| p.parent());
236
237 let model: ModelProto =
238 Message::parse_from_tokio_bytes(&buffer).map_err(|e| Error::InvalidFormat {
239 path: path_str.clone(),
240 error: e.to_string(),
241 })?;
242
243 // ONNX nodes must be topologically sorted per spec:
244 // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs
245 if !model.graph.node.is_top_sorted() {
246 return Err(Error::InvalidGraphStructure {
247 reason: "Nodes are not topologically sorted (ONNX spec violation)".to_string(),
248 });
249 }
250
251 let graph = build_graph_with_base_path(&model, base_path)?;
252
253 if let Some(path) = path_str {
254 log::info!("Finished parsing ONNX file: {}", path);
255 } else {
256 log::info!("Finished parsing ONNX from bytes");
257 }
258 Ok(graph)
259 }
260}
261
262/// Build IR graph from ONNX model with base path for external data support
263///
264/// The `base_path` is the directory containing the ONNX file, used for resolving
265/// external tensor data paths (for models >2GB).
266///
267/// # Errors
268///
269/// Returns an error if:
270/// - Missing opset version for default domain
271/// - Type inference fails
272pub fn build_graph_with_base_path(
273 model: &ModelProto,
274 base_path: Option<&Path>,
275) -> Result<OnnxGraph, Error> {
276 let opset_version = extract_opset_version(model)?;
277 build_graph_from_proto_with_base_path(&model.graph, opset_version, base_path)
278}
279
280/// Build IR graph from ONNX GraphProto with base path for external data
281///
282/// The `base_path` is used for resolving external tensor data paths (for models >2GB).
283/// Subgraphs that need a shared name registry should use `build_graph_builder_from_proto`
284/// directly (see `DeferredGraph::build_with_outer_scope`).
285///
286/// # Errors
287///
288/// Returns an error if node conversion or type inference fails
289pub fn build_graph_from_proto_with_base_path(
290 graph: &crate::protos::GraphProto,
291 opset_version: usize,
292 base_path: Option<&Path>,
293) -> Result<OnnxGraph, Error> {
294 let graph_builder = build_graph_builder_from_proto(graph, opset_version, None, base_path)?;
295
296 log::debug!(" PHASE 6: Node Conversion (RawNode -> Node) ");
297 Ok(graph_builder.convert_to_graph(opset_version))
298}
299
300/// Build IR graph as OnnxGraphBuilder (for subgraphs during processing)
301///
302/// This returns OnnxGraphBuilder which still contains RawNode instances.
303/// Call convert_to_graph() to get the final OnnxGraph with Node enum instances.
304///
305/// # Errors
306///
307/// Returns an error if node conversion or type inference fails
308pub(crate) fn build_graph_builder_from_proto(
309 graph: &crate::protos::GraphProto,
310 opset_version: usize,
311 name_registry: Option<crate::graph_state::NameRegistry>,
312 base_path: Option<&Path>,
313) -> Result<crate::ir::OnnxGraphBuilder, Error> {
314 build_graph_builder_from_proto_with_outer_scope(
315 graph,
316 opset_version,
317 name_registry,
318 crate::ir::OuterScopeTypes::new(),
319 base_path,
320 )
321}
322
323/// Build IR graph as OnnxGraphBuilder with access to outer scope types
324///
325/// This is used for building subgraphs that reference values from parent graphs.
326/// The `outer_scope` map provides types for values that the subgraph references
327/// but doesn't define internally.
328///
329/// The `base_path` is the directory containing the ONNX file, used for resolving
330/// external tensor data paths (for models >2GB).
331///
332/// # Errors
333///
334/// Returns an error if node conversion or type inference fails
335pub(crate) fn build_graph_builder_from_proto_with_outer_scope(
336 graph: &crate::protos::GraphProto,
337 opset_version: usize,
338 name_registry: Option<crate::graph_state::NameRegistry>,
339 outer_scope: crate::ir::OuterScopeTypes,
340 base_path: Option<&Path>,
341) -> Result<crate::ir::OnnxGraphBuilder, Error> {
342 log::debug!(" PHASE 1: Initialization ");
343 let state_rc = initialization::initialize_from_graph_with_registry_and_outer_scope(
344 graph,
345 name_registry,
346 outer_scope,
347 base_path,
348 );
349
350 log::debug!(" PHASE 2: Node Conversion (Proto -> RawNode) ");
351 node_conversion::convert_nodes_from_graph(graph, &state_rc, opset_version)?;
352
353 log::debug!(" PHASE 3: Type Inference ");
354 type_inference::infer_types(&state_rc, opset_version).map_err(Error::TypeInference)?;
355
356 log::debug!(" PHASE 4: Post-processing ");
357 let (mut nodes, inputs, mut outputs) = post_processing::post_process(&state_rc);
358
359 log::debug!(" PHASE 5: Finalization ");
360 Ok(finalization::finalize(
361 &mut nodes,
362 inputs,
363 &mut outputs,
364 state_rc,
365 ))
366}
367
368/// Extract opset version from model (default ONNX domain)
369fn extract_opset_version(model: &ModelProto) -> Result<usize, Error> {
370 model
371 .opset_import
372 .iter()
373 .find(|opset| opset.domain.is_empty())
374 .map(|opset| opset.version as usize)
375 .ok_or(Error::MissingOpsetVersion)
376}
377
378/// Trait for checking if a list of nodes is topologically sorted
379pub(crate) trait TopologicalSortable {
380 fn is_top_sorted(&self) -> bool;
381}
382
383use crate::protos::NodeProto;
384
385impl TopologicalSortable for Vec<NodeProto> {
386 fn is_top_sorted(&self) -> bool {
387 // Iterate over each node in the vector
388 for (node_position, node) in self.iter().enumerate() {
389 // Iterate over each output of the node
390 for output in &node.output {
391 // If the output is empty, we don't want to check the rest of the graph, inputs and outputs that are optional
392 // can end up as empty strings, so we can't use that as a reason to count the graph as not sorted
393 if output.is_empty() {
394 continue;
395 }
396 // Iterate over each other node in the vector
397 for (other_node_position, other_node) in self.iter().enumerate() {
398 // If the other node has an input that matches the current output
399 if other_node.input.contains(output) {
400 // If the position of the current node is greater than the position of the other node
401 if node_position > other_node_position {
402 // The vector is not topologically sorted
403 return false;
404 }
405 }
406 }
407 }
408 }
409
410 // The vector is topologically sorted
411 true
412 }
413}