onnx_ir/
pipeline.rs

1//! ONNX to IR conversion pipeline orchestrator
2//!
3//! This module provides the high-level orchestration of the ONNX conversion process.
4//! It clearly shows the entire conversion flow from start to finish.
5//!
6//! # Zero-Copy Loading
7//!
8//! When the `mmap` feature is enabled (default), files are memory-mapped for zero-copy
9//! tensor loading. This significantly reduces memory usage for large models.
10//!
11//! # Usage
12//!
13//! ```ignore
14//! use onnx_ir::OnnxGraphBuilder;
15//!
16//! // Build from file
17//! let graph = OnnxGraphBuilder::new().parse_file("model.onnx")?;
18//!
19//! // Build from bytes
20//! let graph = OnnxGraphBuilder::new().parse_bytes(&bytes)?;
21//!
22//! // Build from reader
23//! let graph = OnnxGraphBuilder::new().parse_reader(file)?;
24//! ```
25
26use std::io::Read;
27use std::{fmt, fs::File, path::Path};
28
29use protobuf::Message;
30
31use crate::{ir::OnnxGraph, processor::ProcessError, protos::ModelProto};
32
33use super::phases::{
34    finalization, initialization, node_conversion, post_processing, type_inference,
35};
36
37/// Errors that can occur when parsing ONNX models
38#[derive(Debug)]
39pub enum Error {
40    /// Failed to open or read the ONNX file
41    Io { path: String, error: std::io::Error },
42
43    /// Failed to parse ONNX protobuf format
44    InvalidFormat { path: Option<String>, error: String },
45
46    /// Model graph nodes are not topologically sorted (ONNX spec violation)
47    InvalidGraphStructure { reason: String },
48
49    /// Missing required opset version for default domain
50    MissingOpsetVersion,
51
52    /// Type inference failed during IR conversion
53    TypeInference(ProcessError),
54
55    /// Generic processing error
56    Processing(ProcessError),
57}
58
59impl fmt::Display for Error {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            Error::Io { path, error } => {
63                write!(f, "Failed to open ONNX file '{}': {}", path, error)
64            }
65            Error::InvalidFormat { path, error } => {
66                if let Some(p) = path {
67                    write!(f, "Invalid ONNX format in '{}': {}", p, error)
68                } else {
69                    write!(f, "Invalid ONNX format: {}", error)
70                }
71            }
72            Error::InvalidGraphStructure { reason } => {
73                write!(f, "Invalid ONNX graph structure: {}", reason)
74            }
75            Error::MissingOpsetVersion => {
76                write!(
77                    f,
78                    "ONNX model must specify opset version for default domain"
79                )
80            }
81            Error::TypeInference(e) => {
82                write!(f, "Type inference failed: {:?}", e)
83            }
84            Error::Processing(e) => {
85                write!(f, "Processing error: {:?}", e)
86            }
87        }
88    }
89}
90
91impl std::error::Error for Error {
92    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
93        match self {
94            Error::Io { error, .. } => Some(error),
95            _ => None,
96        }
97    }
98}
99
100impl From<ProcessError> for Error {
101    fn from(error: ProcessError) -> Self {
102        Error::Processing(error)
103    }
104}
105
106/// ONNX IR builder with fluent API
107///
108/// Builds ONNX IR graphs from various sources (files, bytes, readers).
109/// Future configuration options can be added without breaking changes.
110///
111/// # Examples
112///
113/// ```ignore
114/// use onnx_ir::OnnxGraphBuilder;
115///
116/// // Build from file (uses mmap when feature is enabled)
117/// let graph = OnnxGraphBuilder::new().parse_file("model.onnx")?;
118///
119/// // Build from bytes
120/// let graph = OnnxGraphBuilder::new().parse_bytes(&model_bytes)?;
121///
122/// // Build from reader
123/// let graph = OnnxGraphBuilder::new().parse_reader(std::io::Cursor::new(data))?;
124/// ```
125#[derive(Debug, Clone, Default)]
126pub struct OnnxGraphBuilder {
127    // Future options can be added here without breaking changes
128    // e.g., strict_mode: bool, min_opset_version: Option<usize>
129}
130
131impl OnnxGraphBuilder {
132    /// Create a new ONNX graph builder with default settings
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Parse an ONNX model from a file path
138    ///
139    /// When the `mmap` feature is enabled (default), the file is memory-mapped
140    /// for zero-copy tensor loading, significantly reducing memory usage.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if:
145    /// - File cannot be opened or read
146    /// - File is not valid ONNX protobuf format
147    /// - Graph nodes are not topologically sorted
148    /// - Type inference fails
149    pub fn parse_file(self, path: impl AsRef<Path>) -> Result<OnnxGraph, Error> {
150        let path = path.as_ref();
151        log::info!("Parsing ONNX file: {}", path.display());
152
153        // Load file contents - mmap when feature is enabled
154        #[cfg(feature = "mmap")]
155        let buffer = {
156            let file = File::open(path).map_err(|error| Error::Io {
157                path: path.display().to_string(),
158                error,
159            })?;
160            // SAFETY: We're mapping a read-only file. The bytes::Bytes keeps
161            // the mmap alive for as long as tensor data references it.
162            let mmap = unsafe { memmap2::Mmap::map(&file) }.map_err(|error| Error::Io {
163                path: path.display().to_string(),
164                error,
165            })?;
166            log::debug!("Memory-mapped ONNX file ({} bytes)", mmap.len());
167            bytes::Bytes::from_owner(mmap)
168        };
169
170        #[cfg(not(feature = "mmap"))]
171        let buffer = {
172            let mut file = File::open(path).map_err(|error| Error::Io {
173                path: path.display().to_string(),
174                error,
175            })?;
176            let mut buf = Vec::new();
177            file.read_to_end(&mut buf).map_err(|error| Error::Io {
178                path: path.display().to_string(),
179                error,
180            })?;
181            log::debug!("Read ONNX file into memory ({} bytes)", buf.len());
182            bytes::Bytes::from(buf)
183        };
184
185        self.parse_buffer(buffer, Some(path))
186    }
187
188    /// Parse an ONNX model from a byte slice
189    ///
190    /// Note: This copies the data internally. For large models already in memory
191    /// as `bytes::Bytes`, consider using the internal buffer directly.
192    ///
193    /// # Errors
194    ///
195    /// Returns an error if:
196    /// - Data is not valid ONNX protobuf format
197    /// - Graph nodes are not topologically sorted
198    /// - Type inference fails
199    pub fn parse_bytes(self, data: &[u8]) -> Result<OnnxGraph, Error> {
200        let buffer = bytes::Bytes::copy_from_slice(data);
201        self.parse_buffer(buffer, None)
202    }
203
204    /// Parse an ONNX model from a reader
205    ///
206    /// Reads all data into memory before parsing.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if:
211    /// - Reading from the reader fails
212    /// - Data is not valid ONNX protobuf format
213    /// - Graph nodes are not topologically sorted
214    /// - Type inference fails
215    pub fn parse_reader<R: Read>(self, mut reader: R) -> Result<OnnxGraph, Error> {
216        let mut buf = Vec::new();
217        reader.read_to_end(&mut buf).map_err(|error| Error::Io {
218            path: "<reader>".to_string(),
219            error,
220        })?;
221        log::debug!("Read ONNX from reader ({} bytes)", buf.len());
222        let buffer = bytes::Bytes::from(buf);
223        self.parse_buffer(buffer, None)
224    }
225
226    /// Internal: Parse from a bytes::Bytes buffer
227    fn parse_buffer(
228        self,
229        buffer: bytes::Bytes,
230        source_path: Option<&Path>,
231    ) -> Result<OnnxGraph, Error> {
232        let path_str = source_path.map(|p| p.display().to_string());
233
234        // Get the base directory for external data resolution
235        let base_path = source_path.and_then(|p| p.parent());
236
237        let model: ModelProto =
238            Message::parse_from_tokio_bytes(&buffer).map_err(|e| Error::InvalidFormat {
239                path: path_str.clone(),
240                error: e.to_string(),
241            })?;
242
243        // ONNX nodes must be topologically sorted per spec:
244        // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs
245        if !model.graph.node.is_top_sorted() {
246            return Err(Error::InvalidGraphStructure {
247                reason: "Nodes are not topologically sorted (ONNX spec violation)".to_string(),
248            });
249        }
250
251        let graph = build_graph_with_base_path(&model, base_path)?;
252
253        if let Some(path) = path_str {
254            log::info!("Finished parsing ONNX file: {}", path);
255        } else {
256            log::info!("Finished parsing ONNX from bytes");
257        }
258        Ok(graph)
259    }
260}
261
262/// Build IR graph from ONNX model with base path for external data support
263///
264/// The `base_path` is the directory containing the ONNX file, used for resolving
265/// external tensor data paths (for models >2GB).
266///
267/// # Errors
268///
269/// Returns an error if:
270/// - Missing opset version for default domain
271/// - Type inference fails
272pub fn build_graph_with_base_path(
273    model: &ModelProto,
274    base_path: Option<&Path>,
275) -> Result<OnnxGraph, Error> {
276    let opset_version = extract_opset_version(model)?;
277    build_graph_from_proto_with_base_path(&model.graph, opset_version, base_path)
278}
279
280/// Build IR graph from ONNX GraphProto with base path for external data
281///
282/// The `base_path` is used for resolving external tensor data paths (for models >2GB).
283/// Subgraphs that need a shared name registry should use `build_graph_builder_from_proto`
284/// directly (see `DeferredGraph::build_with_outer_scope`).
285///
286/// # Errors
287///
288/// Returns an error if node conversion or type inference fails
289pub fn build_graph_from_proto_with_base_path(
290    graph: &crate::protos::GraphProto,
291    opset_version: usize,
292    base_path: Option<&Path>,
293) -> Result<OnnxGraph, Error> {
294    let graph_builder = build_graph_builder_from_proto(graph, opset_version, None, base_path)?;
295
296    log::debug!(" PHASE 6: Node Conversion (RawNode -> Node) ");
297    Ok(graph_builder.convert_to_graph(opset_version))
298}
299
300/// Build IR graph as OnnxGraphBuilder (for subgraphs during processing)
301///
302/// This returns OnnxGraphBuilder which still contains RawNode instances.
303/// Call convert_to_graph() to get the final OnnxGraph with Node enum instances.
304///
305/// # Errors
306///
307/// Returns an error if node conversion or type inference fails
308pub(crate) fn build_graph_builder_from_proto(
309    graph: &crate::protos::GraphProto,
310    opset_version: usize,
311    name_registry: Option<crate::graph_state::NameRegistry>,
312    base_path: Option<&Path>,
313) -> Result<crate::ir::OnnxGraphBuilder, Error> {
314    build_graph_builder_from_proto_with_outer_scope(
315        graph,
316        opset_version,
317        name_registry,
318        crate::ir::OuterScopeTypes::new(),
319        base_path,
320    )
321}
322
323/// Build IR graph as OnnxGraphBuilder with access to outer scope types
324///
325/// This is used for building subgraphs that reference values from parent graphs.
326/// The `outer_scope` map provides types for values that the subgraph references
327/// but doesn't define internally.
328///
329/// The `base_path` is the directory containing the ONNX file, used for resolving
330/// external tensor data paths (for models >2GB).
331///
332/// # Errors
333///
334/// Returns an error if node conversion or type inference fails
335pub(crate) fn build_graph_builder_from_proto_with_outer_scope(
336    graph: &crate::protos::GraphProto,
337    opset_version: usize,
338    name_registry: Option<crate::graph_state::NameRegistry>,
339    outer_scope: crate::ir::OuterScopeTypes,
340    base_path: Option<&Path>,
341) -> Result<crate::ir::OnnxGraphBuilder, Error> {
342    log::debug!(" PHASE 1: Initialization ");
343    let state_rc = initialization::initialize_from_graph_with_registry_and_outer_scope(
344        graph,
345        name_registry,
346        outer_scope,
347        base_path,
348    );
349
350    log::debug!(" PHASE 2: Node Conversion (Proto -> RawNode) ");
351    node_conversion::convert_nodes_from_graph(graph, &state_rc, opset_version)?;
352
353    log::debug!(" PHASE 3: Type Inference ");
354    type_inference::infer_types(&state_rc, opset_version).map_err(Error::TypeInference)?;
355
356    log::debug!(" PHASE 4: Post-processing ");
357    let (mut nodes, inputs, mut outputs) = post_processing::post_process(&state_rc);
358
359    log::debug!(" PHASE 5: Finalization ");
360    Ok(finalization::finalize(
361        &mut nodes,
362        inputs,
363        &mut outputs,
364        state_rc,
365    ))
366}
367
368/// Extract opset version from model (default ONNX domain)
369fn extract_opset_version(model: &ModelProto) -> Result<usize, Error> {
370    model
371        .opset_import
372        .iter()
373        .find(|opset| opset.domain.is_empty())
374        .map(|opset| opset.version as usize)
375        .ok_or(Error::MissingOpsetVersion)
376}
377
378/// Trait for checking if a list of nodes is topologically sorted
379pub(crate) trait TopologicalSortable {
380    fn is_top_sorted(&self) -> bool;
381}
382
383use crate::protos::NodeProto;
384
385impl TopologicalSortable for Vec<NodeProto> {
386    fn is_top_sorted(&self) -> bool {
387        // Iterate over each node in the vector
388        for (node_position, node) in self.iter().enumerate() {
389            // Iterate over each output of the node
390            for output in &node.output {
391                // If the output is empty, we don't want to check the rest of the graph, inputs and outputs that are optional
392                // can end up as empty strings, so we can't use that as a reason to count the graph as not sorted
393                if output.is_empty() {
394                    continue;
395                }
396                // Iterate over each other node in the vector
397                for (other_node_position, other_node) in self.iter().enumerate() {
398                    // If the other node has an input that matches the current output
399                    if other_node.input.contains(output) {
400                        // If the position of the current node is greater than the position of the other node
401                        if node_position > other_node_position {
402                            // The vector is not topologically sorted
403                            return false;
404                        }
405                    }
406                }
407            }
408        }
409
410        // The vector is topologically sorted
411        true
412    }
413}