burn-onnx 0.21.0-pre.3

Library for importing ONNX models into the Burn framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
use std::{
    env,
    fs::{self, create_dir_all},
    path::{Path, PathBuf},
};

use crate::{burn::graph::BurnGraph, format_tokens, logger::init_log};

use onnx_ir::{OnnxGraphBuilder, ir::OnnxGraph};

/// Controls how model weights are loaded at runtime.
///
/// This determines which constructors are generated on the `Model` struct.
/// `from_bytes()` is generated for all variants except `None`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LoadStrategy {
    /// Keep weights in a separate `.bpk` file. Generates `from_file()`, `from_bytes()`,
    /// and a `Default` impl that calls `from_file()`.
    #[default]
    File,

    /// Embed weights in the binary via `include_bytes!`. Generates `from_embedded()`,
    /// `from_bytes()`, and a `Default` impl that calls `from_embedded()`.
    Embedded,

    /// No built-in file or embedded loading. Generates only `from_bytes()`.
    /// The caller must provide weight bytes at runtime.
    Bytes,

    /// No weight-loading constructors at all.
    None,
}

/// Builder for generating Burn model code from ONNX files.
///
/// `ModelGen` converts ONNX models into Burn-compatible Rust source code and model weights.
/// It can be used from both build scripts and CLI applications.
///
/// # Conversion Process
///
/// 1. Parses ONNX model file(s)
/// 2. Converts ONNX operations to Burn nodes using the node registry
/// 3. Generates Rust source code with type-safe tensor operations
/// 4. Saves model weights in BurnPack (.bpk) format
///
/// # Examples
///
/// ## Using in a build script (`build.rs`)
///
/// ```no_run
/// use burn_onnx::ModelGen;
///
/// ModelGen::new()
///     .input("path/to/model.onnx")
///     .out_dir("model/")
///     .run_from_script();
/// ```
///
/// This generates code in `$OUT_DIR/model/model.rs` which can be included in your crate:
///
/// ```ignore
/// include!(concat!(env!("OUT_DIR"), "/model/model.rs"));
/// ```
///
/// ## Using from CLI
///
/// ```no_run
/// use burn_onnx::ModelGen;
///
/// ModelGen::new()
///     .input("path/to/model.onnx")
///     .out_dir("src/model/")
///     .run_from_cli();
/// ```
///
/// ## Development mode for debugging
///
/// ```no_run
/// use burn_onnx::ModelGen;
///
/// ModelGen::new()
///     .input("path/to/model.onnx")
///     .out_dir("model/")
///     .development(true)  // Generates .onnx.txt and .graph.txt debug files
///     .run_from_cli();
/// ```
#[derive(Debug)]
pub struct ModelGen {
    out_dir: Option<PathBuf>,
    /// List of onnx files to generate source code from.
    inputs: Vec<PathBuf>,
    development: bool,
    load_strategy: LoadStrategy,
    /// Whether to run graph simplification passes (default: true)
    simplify: bool,
    /// Whether to partition large models into submodules (default: true)
    partition: bool,
}

impl Default for ModelGen {
    fn default() -> Self {
        Self {
            out_dir: None,
            inputs: Vec::new(),
            development: false,
            load_strategy: LoadStrategy::default(),
            simplify: true,
            partition: true,
        }
    }
}

impl ModelGen {
    /// Creates a new `ModelGen` builder with default settings.
    ///
    /// Default configuration:
    /// - Development mode: off
    /// - Load strategy: [`LoadStrategy::File`]
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use burn_onnx::ModelGen;
    ///
    /// ModelGen::new()
    ///     .input("model.onnx")
    ///     .out_dir("./out")
    ///     .run_from_cli();
    /// ```
    pub fn new() -> Self {
        init_log().ok(); // Error when init multiple times are ignored.
        Self::default()
    }

    /// Sets the output directory for generated files.
    ///
    /// When used with [`run_from_script`](Self::run_from_script), this path is appended to
    /// `$OUT_DIR`. When used with [`run_from_cli`](Self::run_from_cli), this is the absolute
    /// or relative path where files will be written.
    ///
    /// # Arguments
    ///
    /// * `out_dir` - Directory path where generated `.rs` and record files will be saved
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use burn_onnx::ModelGen;
    ///
    /// ModelGen::new()
    ///     .out_dir("model/")  // In build.rs: $OUT_DIR/model/
    ///     .input("model.onnx")
    ///     .run_from_script();
    /// ```
    pub fn out_dir(&mut self, out_dir: &str) -> &mut Self {
        self.out_dir = Some(Path::new(out_dir).into());
        self
    }

    /// Adds an ONNX model file to convert.
    ///
    /// Multiple input files can be added by calling this method multiple times.
    /// Each input file will generate a separate `.rs` file with the same base name.
    ///
    /// # Arguments
    ///
    /// * `input` - Path to the ONNX model file (`.onnx`)
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use burn_onnx::ModelGen;
    ///
    /// ModelGen::new()
    ///     .input("encoder.onnx")
    ///     .input("decoder.onnx")  // Generate multiple models
    ///     .out_dir("models/")
    ///     .run_from_cli();
    /// ```
    pub fn input(&mut self, input: &str) -> &mut Self {
        self.inputs.push(input.into());
        self
    }

    /// Enables development mode for debugging.
    ///
    /// When enabled, generates additional debug files alongside the Rust source:
    /// - `<model>.onnx.txt` - Debug representation of the parsed ONNX graph
    /// - `<model>.graph.txt` - Debug representation of the converted Burn graph
    ///
    /// # Arguments
    ///
    /// * `development` - If `true`, generate debug files
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use burn_onnx::ModelGen;
    ///
    /// ModelGen::new()
    ///     .input("model.onnx")
    ///     .out_dir("debug/")
    ///     .development(true)  // Generates model.onnx.txt and model.graph.txt
    ///     .run_from_cli();
    /// ```
    pub fn development(&mut self, development: bool) -> &mut Self {
        self.development = development;
        self
    }

    /// Sets the weight loading strategy for the generated model.
    ///
    /// See [`LoadStrategy`] for available options.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use burn_onnx::{ModelGen, LoadStrategy};
    ///
    /// // WASM or embedded: load weights from bytes at runtime
    /// ModelGen::new()
    ///     .input("model.onnx")
    ///     .out_dir("model/")
    ///     .load_strategy(LoadStrategy::Bytes)
    ///     .run_from_script();
    /// ```
    pub fn load_strategy(&mut self, strategy: LoadStrategy) -> &mut Self {
        self.load_strategy = strategy;
        self
    }

    /// Enable or disable graph simplification passes (default: true).
    ///
    /// When enabled, optimization passes like dead node elimination, common
    /// subexpression elimination (CSE), and pattern-based simplifications
    /// are applied to the ONNX IR before code generation.
    pub fn simplify(&mut self, simplify: bool) -> &mut Self {
        self.simplify = simplify;
        self
    }

    /// Enable or disable submodule partitioning for large models (default: true).
    ///
    /// When enabled, models with more than 200 nodes are automatically split into
    /// smaller submodule structs to keep generated code compilable. Each submodule
    /// gets its own `forward()` method, and the top-level `Model` delegates to them.
    pub fn partition(&mut self, partition: bool) -> &mut Self {
        self.partition = partition;
        self
    }

    /// Runs code generation from a build script context.
    ///
    /// Use this method when calling from `build.rs`. The output directory will be
    /// `$OUT_DIR/<out_dir>`, allowing the generated code to be included with:
    ///
    /// ```ignore
    /// include!(concat!(env!("OUT_DIR"), "/<out_dir>/<model>.rs"));
    /// ```
    ///
    /// # Panics
    ///
    /// Panics if `OUT_DIR` environment variable is not set (should be set by Cargo).
    ///
    /// # Examples
    ///
    /// In `build.rs`:
    ///
    /// ```no_run
    /// use burn_onnx::ModelGen;
    ///
    /// ModelGen::new()
    ///     .input("path/to/model.onnx")
    ///     .out_dir("model/")
    ///     .run_from_script();
    /// ```
    pub fn run_from_script(&self) {
        self.run(true);
    }

    /// Runs code generation from a CLI or application context.
    ///
    /// Use this method when calling from a CLI tool or regular application.
    /// The output directory is used as-is (relative or absolute path).
    ///
    /// # Panics
    ///
    /// Panics if `out_dir` was not set via [`out_dir`](Self::out_dir).
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use burn_onnx::ModelGen;
    ///
    /// ModelGen::new()
    ///     .input("model.onnx")
    ///     .out_dir("./generated/")
    ///     .run_from_cli();
    /// ```
    pub fn run_from_cli(&self) {
        self.run(false);
    }

    /// Run code generation.
    fn run(&self, is_build_script: bool) {
        log::info!("Starting to convert ONNX to Burn");

        let out_dir = self.get_output_directory(is_build_script);
        log::debug!("Output directory: {out_dir:?}");

        create_dir_all(&out_dir).unwrap();

        for input in self.inputs.iter() {
            let file_name = input.file_stem().unwrap();
            let out_file: PathBuf = out_dir.join(file_name);

            log::info!("Converting {input:?}");
            log::debug!("Input file name: {file_name:?}");
            log::debug!("Output file: {out_file:?}");

            self.generate_model(input, out_file);
        }

        log::info!("Finished converting ONNX to Burn");
    }

    /// Get the output directory path based on whether this is a build script or CLI invocation.
    fn get_output_directory(&self, is_build_script: bool) -> PathBuf {
        if is_build_script {
            let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
            let mut path = PathBuf::from(cargo_out_dir);
            // Append the out_dir to the cargo_out_dir
            path.push(self.out_dir.as_ref().unwrap());
            path
        } else {
            self.out_dir.as_ref().expect("out_dir is not set").clone()
        }
    }

    /// Generate model source code and model state.
    fn generate_model(&self, input: &PathBuf, out_file: PathBuf) {
        log::info!("Generating model from {input:?}");
        log::debug!("Development mode: {:?}", self.development);
        log::debug!("Output file: {out_file:?}");

        let graph = OnnxGraphBuilder::new()
            .simplify(self.simplify)
            .parse_file(input)
            .unwrap_or_else(|e| panic!("Failed to parse ONNX file '{}': {}", input.display(), e));

        if self.development {
            self.write_debug_file(&out_file, "onnx.txt", &graph);
        }

        let graph = ParsedOnnxGraph(graph);

        let top_comment = Some(format!("Generated from ONNX {input:?} by burn-onnx"));

        let code = self.generate_burn_graph(graph, &out_file, top_comment);

        let code_str = format_tokens(code);
        let source_code_file = out_file.with_extension("rs");
        log::info!("Writing source code to {}", source_code_file.display());
        fs::write(source_code_file, code_str).unwrap();

        log::info!("Model generated");
    }

    /// Write debug file in development mode.
    fn write_debug_file<T: std::fmt::Debug>(&self, out_file: &Path, extension: &str, content: &T) {
        let debug_content = format!("{content:#?}");
        let debug_file = out_file.with_extension(extension);
        log::debug!("Writing debug file: {debug_file:?}");
        fs::write(debug_file, debug_content).unwrap();
    }

    /// Generate BurnGraph and codegen.
    fn generate_burn_graph(
        &self,
        graph: ParsedOnnxGraph,
        out_file: &Path,
        top_comment: Option<String>,
    ) -> proc_macro2::TokenStream {
        let bpk_file = out_file.with_extension("bpk");
        graph
            .into_burn()
            .with_burnpack(bpk_file, self.load_strategy)
            .with_blank_space(true)
            .with_top_comment(top_comment)
            .with_partition(self.partition)
            .codegen()
    }
}

#[derive(Debug)]
struct ParsedOnnxGraph(OnnxGraph);

impl ParsedOnnxGraph {
    /// Converts ONNX graph to Burn graph.
    pub fn into_burn(self) -> BurnGraph {
        let mut graph = BurnGraph::default();

        for node in self.0.nodes {
            // Register node directly (control flow nodes will fail at codegen time)
            graph.register(node);
        }

        // Extract input and output names
        let input_names: Vec<_> = self
            .0
            .inputs
            .iter()
            .map(|input| input.name.clone())
            .collect();
        let output_names: Vec<_> = self
            .0
            .outputs
            .iter()
            .map(|output| output.name.clone())
            .collect();

        // Register inputs and outputs with the graph (pass Arguments directly)
        graph.register_input_output(input_names, output_names, &self.0.inputs, &self.0.outputs);

        graph
    }
}