Skip to main content

burn_onnx/
model_gen.rs

1use std::{
2    env,
3    fs::{self, create_dir_all},
4    path::{Path, PathBuf},
5};
6
7use crate::{burn::graph::BurnGraph, format_tokens, logger::init_log};
8
9use onnx_ir::{OnnxGraphBuilder, ir::OnnxGraph};
10
11/// Controls how model weights are loaded at runtime.
12///
13/// This determines which constructors are generated on the `Model` struct.
14/// `from_bytes()` is generated for all variants except `None`.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum LoadStrategy {
17    /// Keep weights in a separate `.bpk` file. Generates `from_file()`, `from_bytes()`,
18    /// and a `Default` impl that calls `from_file()`.
19    #[default]
20    File,
21
22    /// Embed weights in the binary via `include_bytes!`. Generates `from_embedded()`,
23    /// `from_bytes()`, and a `Default` impl that calls `from_embedded()`.
24    Embedded,
25
26    /// No built-in file or embedded loading. Generates only `from_bytes()`.
27    /// The caller must provide weight bytes at runtime.
28    Bytes,
29
30    /// No weight-loading constructors at all.
31    None,
32}
33
34/// Builder for generating Burn model code from ONNX files.
35///
36/// `ModelGen` converts ONNX models into Burn-compatible Rust source code and model weights.
37/// It can be used from both build scripts and CLI applications.
38///
39/// # Conversion Process
40///
41/// 1. Parses ONNX model file(s)
42/// 2. Converts ONNX operations to Burn nodes using the node registry
43/// 3. Generates Rust source code with type-safe tensor operations
44/// 4. Saves model weights in BurnPack (.bpk) format
45///
46/// # Examples
47///
48/// ## Using in a build script (`build.rs`)
49///
50/// ```no_run
51/// use burn_onnx::ModelGen;
52///
53/// ModelGen::new()
54///     .input("path/to/model.onnx")
55///     .out_dir("model/")
56///     .run_from_script();
57/// ```
58///
59/// This generates code in `$OUT_DIR/model/model.rs` which can be included in your crate:
60///
61/// ```ignore
62/// include!(concat!(env!("OUT_DIR"), "/model/model.rs"));
63/// ```
64///
65/// ## Using from CLI
66///
67/// ```no_run
68/// use burn_onnx::ModelGen;
69///
70/// ModelGen::new()
71///     .input("path/to/model.onnx")
72///     .out_dir("src/model/")
73///     .run_from_cli();
74/// ```
75///
76/// ## Development mode for debugging
77///
78/// ```no_run
79/// use burn_onnx::ModelGen;
80///
81/// ModelGen::new()
82///     .input("path/to/model.onnx")
83///     .out_dir("model/")
84///     .development(true)  // Generates .onnx.txt and .graph.txt debug files
85///     .run_from_cli();
86/// ```
87#[derive(Debug)]
88pub struct ModelGen {
89    out_dir: Option<PathBuf>,
90    /// List of onnx files to generate source code from.
91    inputs: Vec<PathBuf>,
92    development: bool,
93    load_strategy: LoadStrategy,
94    /// Whether to run graph simplification passes (default: true)
95    simplify: bool,
96    /// Whether to partition large models into submodules (default: true)
97    partition: bool,
98}
99
100impl Default for ModelGen {
101    fn default() -> Self {
102        Self {
103            out_dir: None,
104            inputs: Vec::new(),
105            development: false,
106            load_strategy: LoadStrategy::default(),
107            simplify: true,
108            partition: true,
109        }
110    }
111}
112
113impl ModelGen {
114    /// Creates a new `ModelGen` builder with default settings.
115    ///
116    /// Default configuration:
117    /// - Development mode: off
118    /// - Load strategy: [`LoadStrategy::File`]
119    ///
120    /// # Examples
121    ///
122    /// ```no_run
123    /// use burn_onnx::ModelGen;
124    ///
125    /// ModelGen::new()
126    ///     .input("model.onnx")
127    ///     .out_dir("./out")
128    ///     .run_from_cli();
129    /// ```
130    pub fn new() -> Self {
131        init_log().ok(); // Error when init multiple times are ignored.
132        Self::default()
133    }
134
135    /// Sets the output directory for generated files.
136    ///
137    /// When used with [`run_from_script`](Self::run_from_script), this path is appended to
138    /// `$OUT_DIR`. When used with [`run_from_cli`](Self::run_from_cli), this is the absolute
139    /// or relative path where files will be written.
140    ///
141    /// # Arguments
142    ///
143    /// * `out_dir` - Directory path where generated `.rs` and record files will be saved
144    ///
145    /// # Examples
146    ///
147    /// ```no_run
148    /// use burn_onnx::ModelGen;
149    ///
150    /// ModelGen::new()
151    ///     .out_dir("model/")  // In build.rs: $OUT_DIR/model/
152    ///     .input("model.onnx")
153    ///     .run_from_script();
154    /// ```
155    pub fn out_dir(&mut self, out_dir: &str) -> &mut Self {
156        self.out_dir = Some(Path::new(out_dir).into());
157        self
158    }
159
160    /// Adds an ONNX model file to convert.
161    ///
162    /// Multiple input files can be added by calling this method multiple times.
163    /// Each input file will generate a separate `.rs` file with the same base name.
164    ///
165    /// # Arguments
166    ///
167    /// * `input` - Path to the ONNX model file (`.onnx`)
168    ///
169    /// # Examples
170    ///
171    /// ```no_run
172    /// use burn_onnx::ModelGen;
173    ///
174    /// ModelGen::new()
175    ///     .input("encoder.onnx")
176    ///     .input("decoder.onnx")  // Generate multiple models
177    ///     .out_dir("models/")
178    ///     .run_from_cli();
179    /// ```
180    pub fn input(&mut self, input: &str) -> &mut Self {
181        self.inputs.push(input.into());
182        self
183    }
184
185    /// Enables development mode for debugging.
186    ///
187    /// When enabled, generates additional debug files alongside the Rust source:
188    /// - `<model>.onnx.txt` - Debug representation of the parsed ONNX graph
189    /// - `<model>.graph.txt` - Debug representation of the converted Burn graph
190    ///
191    /// # Arguments
192    ///
193    /// * `development` - If `true`, generate debug files
194    ///
195    /// # Examples
196    ///
197    /// ```no_run
198    /// use burn_onnx::ModelGen;
199    ///
200    /// ModelGen::new()
201    ///     .input("model.onnx")
202    ///     .out_dir("debug/")
203    ///     .development(true)  // Generates model.onnx.txt and model.graph.txt
204    ///     .run_from_cli();
205    /// ```
206    pub fn development(&mut self, development: bool) -> &mut Self {
207        self.development = development;
208        self
209    }
210
211    /// Sets the weight loading strategy for the generated model.
212    ///
213    /// See [`LoadStrategy`] for available options.
214    ///
215    /// # Examples
216    ///
217    /// ```no_run
218    /// use burn_onnx::{ModelGen, LoadStrategy};
219    ///
220    /// // WASM or embedded: load weights from bytes at runtime
221    /// ModelGen::new()
222    ///     .input("model.onnx")
223    ///     .out_dir("model/")
224    ///     .load_strategy(LoadStrategy::Bytes)
225    ///     .run_from_script();
226    /// ```
227    pub fn load_strategy(&mut self, strategy: LoadStrategy) -> &mut Self {
228        self.load_strategy = strategy;
229        self
230    }
231
232    /// Enable or disable graph simplification passes (default: true).
233    ///
234    /// When enabled, optimization passes like dead node elimination, common
235    /// subexpression elimination (CSE), and pattern-based simplifications
236    /// are applied to the ONNX IR before code generation.
237    pub fn simplify(&mut self, simplify: bool) -> &mut Self {
238        self.simplify = simplify;
239        self
240    }
241
242    /// Enable or disable submodule partitioning for large models (default: true).
243    ///
244    /// When enabled, models with more than 200 nodes are automatically split into
245    /// smaller submodule structs to keep generated code compilable. Each submodule
246    /// gets its own `forward()` method, and the top-level `Model` delegates to them.
247    pub fn partition(&mut self, partition: bool) -> &mut Self {
248        self.partition = partition;
249        self
250    }
251
252    /// Runs code generation from a build script context.
253    ///
254    /// Use this method when calling from `build.rs`. The output directory will be
255    /// `$OUT_DIR/<out_dir>`, allowing the generated code to be included with:
256    ///
257    /// ```ignore
258    /// include!(concat!(env!("OUT_DIR"), "/<out_dir>/<model>.rs"));
259    /// ```
260    ///
261    /// # Panics
262    ///
263    /// Panics if `OUT_DIR` environment variable is not set (should be set by Cargo).
264    ///
265    /// # Examples
266    ///
267    /// In `build.rs`:
268    ///
269    /// ```no_run
270    /// use burn_onnx::ModelGen;
271    ///
272    /// ModelGen::new()
273    ///     .input("path/to/model.onnx")
274    ///     .out_dir("model/")
275    ///     .run_from_script();
276    /// ```
277    pub fn run_from_script(&self) {
278        self.run(true);
279    }
280
281    /// Runs code generation from a CLI or application context.
282    ///
283    /// Use this method when calling from a CLI tool or regular application.
284    /// The output directory is used as-is (relative or absolute path).
285    ///
286    /// # Panics
287    ///
288    /// Panics if `out_dir` was not set via [`out_dir`](Self::out_dir).
289    ///
290    /// # Examples
291    ///
292    /// ```no_run
293    /// use burn_onnx::ModelGen;
294    ///
295    /// ModelGen::new()
296    ///     .input("model.onnx")
297    ///     .out_dir("./generated/")
298    ///     .run_from_cli();
299    /// ```
300    pub fn run_from_cli(&self) {
301        self.run(false);
302    }
303
304    /// Run code generation.
305    fn run(&self, is_build_script: bool) {
306        log::info!("Starting to convert ONNX to Burn");
307
308        let out_dir = self.get_output_directory(is_build_script);
309        log::debug!("Output directory: {out_dir:?}");
310
311        create_dir_all(&out_dir).unwrap();
312
313        for input in self.inputs.iter() {
314            let file_name = input.file_stem().unwrap();
315            let out_file: PathBuf = out_dir.join(file_name);
316
317            log::info!("Converting {input:?}");
318            log::debug!("Input file name: {file_name:?}");
319            log::debug!("Output file: {out_file:?}");
320
321            self.generate_model(input, out_file);
322        }
323
324        log::info!("Finished converting ONNX to Burn");
325    }
326
327    /// Get the output directory path based on whether this is a build script or CLI invocation.
328    fn get_output_directory(&self, is_build_script: bool) -> PathBuf {
329        if is_build_script {
330            let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
331            let mut path = PathBuf::from(cargo_out_dir);
332            // Append the out_dir to the cargo_out_dir
333            path.push(self.out_dir.as_ref().unwrap());
334            path
335        } else {
336            self.out_dir.as_ref().expect("out_dir is not set").clone()
337        }
338    }
339
340    /// Generate model source code and model state.
341    fn generate_model(&self, input: &PathBuf, out_file: PathBuf) {
342        log::info!("Generating model from {input:?}");
343        log::debug!("Development mode: {:?}", self.development);
344        log::debug!("Output file: {out_file:?}");
345
346        let graph = OnnxGraphBuilder::new()
347            .simplify(self.simplify)
348            .parse_file(input)
349            .unwrap_or_else(|e| panic!("Failed to parse ONNX file '{}': {}", input.display(), e));
350
351        if self.development {
352            self.write_debug_file(&out_file, "onnx.txt", &graph);
353        }
354
355        let graph = ParsedOnnxGraph(graph);
356
357        let top_comment = Some(format!("Generated from ONNX {input:?} by burn-onnx"));
358
359        let code = self.generate_burn_graph(graph, &out_file, top_comment);
360
361        let code_str = format_tokens(code);
362        let source_code_file = out_file.with_extension("rs");
363        log::info!("Writing source code to {}", source_code_file.display());
364        fs::write(source_code_file, code_str).unwrap();
365
366        log::info!("Model generated");
367    }
368
369    /// Write debug file in development mode.
370    fn write_debug_file<T: std::fmt::Debug>(&self, out_file: &Path, extension: &str, content: &T) {
371        let debug_content = format!("{content:#?}");
372        let debug_file = out_file.with_extension(extension);
373        log::debug!("Writing debug file: {debug_file:?}");
374        fs::write(debug_file, debug_content).unwrap();
375    }
376
377    /// Generate BurnGraph and codegen.
378    fn generate_burn_graph(
379        &self,
380        graph: ParsedOnnxGraph,
381        out_file: &Path,
382        top_comment: Option<String>,
383    ) -> proc_macro2::TokenStream {
384        let bpk_file = out_file.with_extension("bpk");
385        graph
386            .into_burn()
387            .with_burnpack(bpk_file, self.load_strategy)
388            .with_blank_space(true)
389            .with_top_comment(top_comment)
390            .with_partition(self.partition)
391            .codegen()
392    }
393}
394
395#[derive(Debug)]
396struct ParsedOnnxGraph(OnnxGraph);
397
398impl ParsedOnnxGraph {
399    /// Converts ONNX graph to Burn graph.
400    pub fn into_burn(self) -> BurnGraph {
401        let mut graph = BurnGraph::default();
402
403        for node in self.0.nodes {
404            // Register node directly (control flow nodes will fail at codegen time)
405            graph.register(node);
406        }
407
408        // Extract input and output names
409        let input_names: Vec<_> = self
410            .0
411            .inputs
412            .iter()
413            .map(|input| input.name.clone())
414            .collect();
415        let output_names: Vec<_> = self
416            .0
417            .outputs
418            .iter()
419            .map(|output| output.name.clone())
420            .collect();
421
422        // Register inputs and outputs with the graph (pass Arguments directly)
423        graph.register_input_output(input_names, output_names, &self.0.inputs, &self.0.outputs);
424
425        graph
426    }
427}