burn_onnx/model_gen.rs
1use std::{
2 env,
3 fs::{self, create_dir_all},
4 path::{Path, PathBuf},
5};
6
7use crate::{burn::graph::BurnGraph, format_tokens, logger::init_log};
8
9use onnx_ir::{OnnxGraphBuilder, ir::OnnxGraph};
10
11/// Controls how model weights are loaded at runtime.
12///
13/// This determines which constructors are generated on the `Model` struct.
14/// `from_bytes()` is generated for all variants except `None`.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum LoadStrategy {
17 /// Keep weights in a separate `.bpk` file. Generates `from_file()`, `from_bytes()`,
18 /// and a `Default` impl that calls `from_file()`.
19 #[default]
20 File,
21
22 /// Embed weights in the binary via `include_bytes!`. Generates `from_embedded()`,
23 /// `from_bytes()`, and a `Default` impl that calls `from_embedded()`.
24 Embedded,
25
26 /// No built-in file or embedded loading. Generates only `from_bytes()`.
27 /// The caller must provide weight bytes at runtime.
28 Bytes,
29
30 /// No weight-loading constructors at all.
31 None,
32}
33
34/// Builder for generating Burn model code from ONNX files.
35///
36/// `ModelGen` converts ONNX models into Burn-compatible Rust source code and model weights.
37/// It can be used from both build scripts and CLI applications.
38///
39/// # Conversion Process
40///
41/// 1. Parses ONNX model file(s)
42/// 2. Converts ONNX operations to Burn nodes using the node registry
43/// 3. Generates Rust source code with type-safe tensor operations
44/// 4. Saves model weights in BurnPack (.bpk) format
45///
46/// # Examples
47///
48/// ## Using in a build script (`build.rs`)
49///
50/// ```no_run
51/// use burn_onnx::ModelGen;
52///
53/// ModelGen::new()
54/// .input("path/to/model.onnx")
55/// .out_dir("model/")
56/// .run_from_script();
57/// ```
58///
59/// This generates code in `$OUT_DIR/model/model.rs` which can be included in your crate:
60///
61/// ```ignore
62/// include!(concat!(env!("OUT_DIR"), "/model/model.rs"));
63/// ```
64///
65/// ## Using from CLI
66///
67/// ```no_run
68/// use burn_onnx::ModelGen;
69///
70/// ModelGen::new()
71/// .input("path/to/model.onnx")
72/// .out_dir("src/model/")
73/// .run_from_cli();
74/// ```
75///
76/// ## Development mode for debugging
77///
78/// ```no_run
79/// use burn_onnx::ModelGen;
80///
81/// ModelGen::new()
82/// .input("path/to/model.onnx")
83/// .out_dir("model/")
84/// .development(true) // Generates .onnx.txt and .graph.txt debug files
85/// .run_from_cli();
86/// ```
87#[derive(Debug)]
88pub struct ModelGen {
89 out_dir: Option<PathBuf>,
90 /// List of onnx files to generate source code from.
91 inputs: Vec<PathBuf>,
92 development: bool,
93 load_strategy: LoadStrategy,
94 /// Whether to run graph simplification passes (default: true)
95 simplify: bool,
96 /// Whether to partition large models into submodules (default: true)
97 partition: bool,
98}
99
100impl Default for ModelGen {
101 fn default() -> Self {
102 Self {
103 out_dir: None,
104 inputs: Vec::new(),
105 development: false,
106 load_strategy: LoadStrategy::default(),
107 simplify: true,
108 partition: true,
109 }
110 }
111}
112
113impl ModelGen {
114 /// Creates a new `ModelGen` builder with default settings.
115 ///
116 /// Default configuration:
117 /// - Development mode: off
118 /// - Load strategy: [`LoadStrategy::File`]
119 ///
120 /// # Examples
121 ///
122 /// ```no_run
123 /// use burn_onnx::ModelGen;
124 ///
125 /// ModelGen::new()
126 /// .input("model.onnx")
127 /// .out_dir("./out")
128 /// .run_from_cli();
129 /// ```
130 pub fn new() -> Self {
131 init_log().ok(); // Error when init multiple times are ignored.
132 Self::default()
133 }
134
135 /// Sets the output directory for generated files.
136 ///
137 /// When used with [`run_from_script`](Self::run_from_script), this path is appended to
138 /// `$OUT_DIR`. When used with [`run_from_cli`](Self::run_from_cli), this is the absolute
139 /// or relative path where files will be written.
140 ///
141 /// # Arguments
142 ///
143 /// * `out_dir` - Directory path where generated `.rs` and record files will be saved
144 ///
145 /// # Examples
146 ///
147 /// ```no_run
148 /// use burn_onnx::ModelGen;
149 ///
150 /// ModelGen::new()
151 /// .out_dir("model/") // In build.rs: $OUT_DIR/model/
152 /// .input("model.onnx")
153 /// .run_from_script();
154 /// ```
155 pub fn out_dir(&mut self, out_dir: &str) -> &mut Self {
156 self.out_dir = Some(Path::new(out_dir).into());
157 self
158 }
159
160 /// Adds an ONNX model file to convert.
161 ///
162 /// Multiple input files can be added by calling this method multiple times.
163 /// Each input file will generate a separate `.rs` file with the same base name.
164 ///
165 /// # Arguments
166 ///
167 /// * `input` - Path to the ONNX model file (`.onnx`)
168 ///
169 /// # Examples
170 ///
171 /// ```no_run
172 /// use burn_onnx::ModelGen;
173 ///
174 /// ModelGen::new()
175 /// .input("encoder.onnx")
176 /// .input("decoder.onnx") // Generate multiple models
177 /// .out_dir("models/")
178 /// .run_from_cli();
179 /// ```
180 pub fn input(&mut self, input: &str) -> &mut Self {
181 self.inputs.push(input.into());
182 self
183 }
184
185 /// Enables development mode for debugging.
186 ///
187 /// When enabled, generates additional debug files alongside the Rust source:
188 /// - `<model>.onnx.txt` - Debug representation of the parsed ONNX graph
189 /// - `<model>.graph.txt` - Debug representation of the converted Burn graph
190 ///
191 /// # Arguments
192 ///
193 /// * `development` - If `true`, generate debug files
194 ///
195 /// # Examples
196 ///
197 /// ```no_run
198 /// use burn_onnx::ModelGen;
199 ///
200 /// ModelGen::new()
201 /// .input("model.onnx")
202 /// .out_dir("debug/")
203 /// .development(true) // Generates model.onnx.txt and model.graph.txt
204 /// .run_from_cli();
205 /// ```
206 pub fn development(&mut self, development: bool) -> &mut Self {
207 self.development = development;
208 self
209 }
210
211 /// Sets the weight loading strategy for the generated model.
212 ///
213 /// See [`LoadStrategy`] for available options.
214 ///
215 /// # Examples
216 ///
217 /// ```no_run
218 /// use burn_onnx::{ModelGen, LoadStrategy};
219 ///
220 /// // WASM or embedded: load weights from bytes at runtime
221 /// ModelGen::new()
222 /// .input("model.onnx")
223 /// .out_dir("model/")
224 /// .load_strategy(LoadStrategy::Bytes)
225 /// .run_from_script();
226 /// ```
227 pub fn load_strategy(&mut self, strategy: LoadStrategy) -> &mut Self {
228 self.load_strategy = strategy;
229 self
230 }
231
232 /// Enable or disable graph simplification passes (default: true).
233 ///
234 /// When enabled, optimization passes like dead node elimination, common
235 /// subexpression elimination (CSE), and pattern-based simplifications
236 /// are applied to the ONNX IR before code generation.
237 pub fn simplify(&mut self, simplify: bool) -> &mut Self {
238 self.simplify = simplify;
239 self
240 }
241
242 /// Enable or disable submodule partitioning for large models (default: true).
243 ///
244 /// When enabled, models with more than 200 nodes are automatically split into
245 /// smaller submodule structs to keep generated code compilable. Each submodule
246 /// gets its own `forward()` method, and the top-level `Model` delegates to them.
247 pub fn partition(&mut self, partition: bool) -> &mut Self {
248 self.partition = partition;
249 self
250 }
251
252 /// Runs code generation from a build script context.
253 ///
254 /// Use this method when calling from `build.rs`. The output directory will be
255 /// `$OUT_DIR/<out_dir>`, allowing the generated code to be included with:
256 ///
257 /// ```ignore
258 /// include!(concat!(env!("OUT_DIR"), "/<out_dir>/<model>.rs"));
259 /// ```
260 ///
261 /// # Panics
262 ///
263 /// Panics if `OUT_DIR` environment variable is not set (should be set by Cargo).
264 ///
265 /// # Examples
266 ///
267 /// In `build.rs`:
268 ///
269 /// ```no_run
270 /// use burn_onnx::ModelGen;
271 ///
272 /// ModelGen::new()
273 /// .input("path/to/model.onnx")
274 /// .out_dir("model/")
275 /// .run_from_script();
276 /// ```
277 pub fn run_from_script(&self) {
278 self.run(true);
279 }
280
281 /// Runs code generation from a CLI or application context.
282 ///
283 /// Use this method when calling from a CLI tool or regular application.
284 /// The output directory is used as-is (relative or absolute path).
285 ///
286 /// # Panics
287 ///
288 /// Panics if `out_dir` was not set via [`out_dir`](Self::out_dir).
289 ///
290 /// # Examples
291 ///
292 /// ```no_run
293 /// use burn_onnx::ModelGen;
294 ///
295 /// ModelGen::new()
296 /// .input("model.onnx")
297 /// .out_dir("./generated/")
298 /// .run_from_cli();
299 /// ```
300 pub fn run_from_cli(&self) {
301 self.run(false);
302 }
303
304 /// Run code generation.
305 fn run(&self, is_build_script: bool) {
306 log::info!("Starting to convert ONNX to Burn");
307
308 let out_dir = self.get_output_directory(is_build_script);
309 log::debug!("Output directory: {out_dir:?}");
310
311 create_dir_all(&out_dir).unwrap();
312
313 for input in self.inputs.iter() {
314 let file_name = input.file_stem().unwrap();
315 let out_file: PathBuf = out_dir.join(file_name);
316
317 log::info!("Converting {input:?}");
318 log::debug!("Input file name: {file_name:?}");
319 log::debug!("Output file: {out_file:?}");
320
321 self.generate_model(input, out_file);
322 }
323
324 log::info!("Finished converting ONNX to Burn");
325 }
326
327 /// Get the output directory path based on whether this is a build script or CLI invocation.
328 fn get_output_directory(&self, is_build_script: bool) -> PathBuf {
329 if is_build_script {
330 let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
331 let mut path = PathBuf::from(cargo_out_dir);
332 // Append the out_dir to the cargo_out_dir
333 path.push(self.out_dir.as_ref().unwrap());
334 path
335 } else {
336 self.out_dir.as_ref().expect("out_dir is not set").clone()
337 }
338 }
339
340 /// Generate model source code and model state.
341 fn generate_model(&self, input: &PathBuf, out_file: PathBuf) {
342 log::info!("Generating model from {input:?}");
343 log::debug!("Development mode: {:?}", self.development);
344 log::debug!("Output file: {out_file:?}");
345
346 let graph = OnnxGraphBuilder::new()
347 .simplify(self.simplify)
348 .parse_file(input)
349 .unwrap_or_else(|e| panic!("Failed to parse ONNX file '{}': {}", input.display(), e));
350
351 if self.development {
352 self.write_debug_file(&out_file, "onnx.txt", &graph);
353 }
354
355 let graph = ParsedOnnxGraph(graph);
356
357 let top_comment = Some(format!("Generated from ONNX {input:?} by burn-onnx"));
358
359 let code = self.generate_burn_graph(graph, &out_file, top_comment);
360
361 let code_str = format_tokens(code);
362 let source_code_file = out_file.with_extension("rs");
363 log::info!("Writing source code to {}", source_code_file.display());
364 fs::write(source_code_file, code_str).unwrap();
365
366 log::info!("Model generated");
367 }
368
369 /// Write debug file in development mode.
370 fn write_debug_file<T: std::fmt::Debug>(&self, out_file: &Path, extension: &str, content: &T) {
371 let debug_content = format!("{content:#?}");
372 let debug_file = out_file.with_extension(extension);
373 log::debug!("Writing debug file: {debug_file:?}");
374 fs::write(debug_file, debug_content).unwrap();
375 }
376
377 /// Generate BurnGraph and codegen.
378 fn generate_burn_graph(
379 &self,
380 graph: ParsedOnnxGraph,
381 out_file: &Path,
382 top_comment: Option<String>,
383 ) -> proc_macro2::TokenStream {
384 let bpk_file = out_file.with_extension("bpk");
385 graph
386 .into_burn()
387 .with_burnpack(bpk_file, self.load_strategy)
388 .with_blank_space(true)
389 .with_top_comment(top_comment)
390 .with_partition(self.partition)
391 .codegen()
392 }
393}
394
395#[derive(Debug)]
396struct ParsedOnnxGraph(OnnxGraph);
397
398impl ParsedOnnxGraph {
399 /// Converts ONNX graph to Burn graph.
400 pub fn into_burn(self) -> BurnGraph {
401 let mut graph = BurnGraph::default();
402
403 for node in self.0.nodes {
404 // Register node directly (control flow nodes will fail at codegen time)
405 graph.register(node);
406 }
407
408 // Extract input and output names
409 let input_names: Vec<_> = self
410 .0
411 .inputs
412 .iter()
413 .map(|input| input.name.clone())
414 .collect();
415 let output_names: Vec<_> = self
416 .0
417 .outputs
418 .iter()
419 .map(|output| output.name.clone())
420 .collect();
421
422 // Register inputs and outputs with the graph (pass Arguments directly)
423 graph.register_input_output(input_names, output_names, &self.0.inputs, &self.0.outputs);
424
425 graph
426 }
427}