Skip to main content

burn_onnx/burn/
graph.rs

1use super::{BurnImports, Scope, ToTokens};
2use crate::LoadStrategy;
3use crate::burn::node::NodeCodegen;
4use crate::burn::partition::{
5    MIN_GRAPH_SIZE, Partition, reorder_constants_to_consumers, try_partition,
6};
7use burn_store::{BurnpackWriter, TensorSnapshot};
8use onnx_ir::{Node, ir::ArgType};
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11use std::{collections::HashMap, path::PathBuf};
12
13/// Burn graph intermediate representation of modules and tensor operations.
14#[derive(Debug)]
15pub struct BurnGraph {
16    nodes: Vec<Node>,
17    scope: Scope,
18    imports: BurnImports,
19    top_comment: Option<String>,
20    default: Option<TokenStream>,
21    blank_spaces: bool,
22    graph_input_args: Vec<onnx_ir::Argument>,
23    graph_output_args: Vec<onnx_ir::Argument>,
24    /// Whether to partition large graphs into submodules (default: true)
25    partition: bool,
26    /// Cached partition result (computed once, reused by snapshot collection and codegen)
27    cached_partition: Option<Option<Partition>>,
28    /// Graph I/O args that were converted from ScalarTensor to ScalarNative at the
29    /// boundary. Maps arg name -> DType. Used to insert conversion code:
30    /// - Outputs: `.into_scalar().elem::<T>()` before the return
31    /// - Inputs: `Tensor::from_data([name as T], &self.device)` after the params
32    boundary_output_conversions: HashMap<String, onnx_ir::ir::DType>,
33    boundary_input_conversions: HashMap<String, onnx_ir::ir::DType>,
34}
35
36impl Default for BurnGraph {
37    fn default() -> Self {
38        Self {
39            nodes: Vec::new(),
40            scope: Scope::default(),
41            imports: BurnImports::default(),
42            top_comment: None,
43            default: None,
44            blank_spaces: false,
45            graph_input_args: Vec::new(),
46            graph_output_args: Vec::new(),
47            partition: true,
48            cached_partition: None,
49            boundary_output_conversions: HashMap::new(),
50            boundary_input_conversions: HashMap::new(),
51        }
52    }
53}
54
55impl BurnGraph {
56    /// Register a new operation node into the graph.
57    ///
58    /// # Notes
59    ///
60    /// The node must be registered in the same order they will be executed in the forward pass.
61    pub fn register(&mut self, node: Node) {
62        log::debug!("Registering node => '{}'", node.name());
63        self.nodes.push(node);
64    }
65
66    /// Save the state of each node in a burnpack file and generate weight-loading constructors.
67    ///
68    /// The [`LoadStrategy`] controls which constructors are generated on the `Model` struct.
69    pub fn with_burnpack(mut self, out_file: PathBuf, strategy: LoadStrategy) -> Self {
70        // Collect all tensor snapshots from nodes
71        let snapshots = self.collect_all_snapshots();
72
73        // Write burnpack file
74        let burnpack_file = out_file.with_extension("bpk");
75        BurnpackWriter::new(snapshots)
76            .with_metadata("producer", "burn-onnx")
77            .write_to_file(&burnpack_file)
78            .expect("Failed to write burnpack file");
79
80        // Register the loading code based on strategy
81        if strategy != LoadStrategy::None {
82            self.register_burnpack_loaders(burnpack_file, strategy);
83        }
84
85        self
86    }
87
88    /// Collect all tensor snapshots from nodes recursively.
89    ///
90    /// When partitioned into submodules, snapshot paths are prefixed with the submodule
91    /// field name (e.g. "submodule1.linear1.weight") so that `load_from` routes weights
92    /// to the correct nested module.
93    fn collect_all_snapshots(&mut self) -> Vec<TensorSnapshot> {
94        let partition = self.compute_partition();
95
96        if let Some(partition) = partition {
97            self.collect_snapshots_partitioned(&partition)
98        } else {
99            self.collect_snapshots_flat()
100        }
101    }
102
103    /// Compute the partition once and cache it for reuse by both snapshot
104    /// collection and codegen, avoiding redundant work and ensuring consistency.
105    fn compute_partition(&mut self) -> Option<Partition> {
106        if let Some(ref cached) = self.cached_partition {
107            return cached.clone();
108        }
109        let result = if self.partition {
110            // Move constants to just before their first consumer so they land
111            // in the same chunk, avoiding wide forward() interfaces.
112            // Only reorder for graphs large enough to actually partition.
113            if self.nodes.len() >= MIN_GRAPH_SIZE {
114                reorder_constants_to_consumers(&mut self.nodes);
115            }
116            try_partition(&self.nodes, &self.graph_input_args, &self.graph_output_args)
117        } else {
118            None
119        };
120        self.cached_partition = Some(result.clone());
121        result
122    }
123
124    fn collect_snapshots_flat(&self) -> Vec<TensorSnapshot> {
125        let mut snapshots = Vec::new();
126        let mut field_name_counts: HashMap<String, usize> = HashMap::new();
127        collect_snapshots_from_nodes(&self.nodes, "", &mut field_name_counts, &mut snapshots);
128        snapshots
129    }
130
131    fn collect_snapshots_partitioned(&self, partition: &Partition) -> Vec<TensorSnapshot> {
132        let mut snapshots = Vec::new();
133
134        for (chunk_idx, range) in partition.chunks.iter().enumerate() {
135            let prefix = format!("submodule{}", chunk_idx + 1);
136            let chunk_nodes = &self.nodes[range.clone()];
137            // Each chunk gets its own counter to match collect_fields_for_nodes (per-chunk)
138            let mut field_name_counts: HashMap<String, usize> = HashMap::new();
139            collect_snapshots_from_nodes(
140                chunk_nodes,
141                &prefix,
142                &mut field_name_counts,
143                &mut snapshots,
144            );
145        }
146        snapshots
147    }
148
149    /// Add blank spaces in some places
150    ///
151    /// # Notes
152    ///
153    /// It can be problematic when testing.
154    pub fn with_blank_space(mut self, blank_spaces: bool) -> Self {
155        self.blank_spaces = blank_spaces;
156        self
157    }
158
159    /// Add a comment at the top of the generated file.
160    pub fn with_top_comment(mut self, top_comment: Option<String>) -> Self {
161        self.top_comment = top_comment;
162        self
163    }
164
165    /// Enable or disable submodule partitioning for large models.
166    pub fn with_partition(mut self, partition: bool) -> Self {
167        self.partition = partition;
168        self
169    }
170
171    /// Generate tokens representing the graph with Burn modules and tensor operations.
172    pub fn codegen(mut self) -> TokenStream {
173        self.register_imports();
174
175        let partition = self.compute_partition();
176
177        if let Some(partition) = partition {
178            self.codegen_partitioned(partition)
179        } else {
180            self.codegen_flat()
181        }
182    }
183
184    /// Generate flat code (no submodules) for small graphs.
185    fn codegen_flat(mut self) -> TokenStream {
186        self.build_scope();
187
188        let codegen_imports = self.imports.codegen();
189        let codegen_struct = self.codegen_struct();
190        let codegen_new = self.codegen_new();
191        let codegen_forward = self.codegen_forward();
192
193        let maybe_blank = match self.blank_spaces {
194            true => quote! {
195                _blank_!();
196            },
197            false => quote! {},
198        };
199        let codegen_default = match self.default {
200            Some(default) => quote! {
201                #default
202                #maybe_blank
203            },
204            None => quote! {},
205        };
206
207        let maybe_top_file_comment = match self.top_comment {
208            Some(comment) => quote! {
209                _comment_!(#comment);
210            },
211            None => quote! {},
212        };
213
214        quote! {
215            // @generated
216            // This file is automatically generated by burn-onnx
217
218            #maybe_top_file_comment
219            #codegen_imports
220            #maybe_blank
221            #maybe_blank
222
223            #codegen_struct
224            #maybe_blank
225
226            #codegen_default
227
228            impl<B: Backend> Model<B> {
229                #codegen_new
230
231                #maybe_blank
232
233                #codegen_forward
234            }
235        }
236    }
237
238    /// Generate partitioned code with submodule structs.
239    fn codegen_partitioned(self, partition: Partition) -> TokenStream {
240        let maybe_blank = match self.blank_spaces {
241            true => quote! { _blank_!(); },
242            false => quote! {},
243        };
244
245        let codegen_imports = self.imports.codegen();
246        let maybe_top_file_comment = match &self.top_comment {
247            Some(comment) => {
248                let c = comment.clone();
249                quote! { _comment_!(#c); }
250            }
251            None => quote! {},
252        };
253
254        let num_chunks = partition.chunks.len();
255        let mut submodule_defs = Vec::with_capacity(num_chunks);
256        let mut submodule_field_decls = Vec::with_capacity(num_chunks);
257        let mut submodule_field_inits = Vec::with_capacity(num_chunks);
258        let mut submodule_field_names = Vec::with_capacity(num_chunks);
259        let mut forward_calls = Vec::with_capacity(num_chunks);
260
261        // Count how many times each tensor is consumed across all chunk inputs.
262        // This tells us when we need .clone() in the top-level forward.
263        let mut remaining_uses: HashMap<String, usize> = HashMap::new();
264        for inputs in &partition.chunk_inputs {
265            for arg in inputs {
266                *remaining_uses.entry(arg.name.clone()).or_insert(0) += 1;
267            }
268        }
269
270        for (chunk_idx, range) in partition.chunks.iter().enumerate() {
271            let struct_name = format_ident!("Submodule{}", chunk_idx + 1);
272            let field_name = format_ident!("submodule{}", chunk_idx + 1);
273            let chunk_nodes = &self.nodes[range.clone()];
274            let chunk_inputs = &partition.chunk_inputs[chunk_idx];
275            let chunk_outputs = &partition.chunk_outputs[chunk_idx];
276
277            // Build scope for this chunk
278            let mut scope = Scope::default();
279
280            // Register chunk inputs as variables at position 0.
281            // Mirror build_scope: also register boundary-converted inputs (ScalarNative
282            // that were originally ScalarTensor) as tensor variables, since the top-level
283            // forward converts them to Tensor<B, 1> before calling submodule.forward().
284            for arg in chunk_inputs {
285                if matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_))
286                    || self.boundary_input_conversions.contains_key(&arg.name)
287                {
288                    scope.tensor_register_variable(arg, 0);
289                }
290            }
291
292            // Register node outputs and future uses (positions are local to this chunk)
293            for (local_pos, node) in chunk_nodes.iter().enumerate() {
294                for arg in node.outputs() {
295                    if matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)) {
296                        scope.tensor_register_variable(arg, local_pos + 1);
297                    }
298                }
299                for arg in node.inputs() {
300                    if (arg.is_dynamic() || arg.is_constant())
301                        && matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_))
302                    {
303                        scope.tensor_register_future_use(arg, local_pos);
304                    }
305                }
306            }
307
308            // Register chunk outputs as future uses at the end
309            let chunk_len = chunk_nodes.len();
310            for arg in chunk_outputs {
311                if matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)) {
312                    scope.tensor_register_future_use(arg, chunk_len);
313                }
314            }
315
316            // Collect fields from this chunk's nodes
317            let chunk_fields = collect_fields_for_nodes(chunk_nodes);
318
319            // Generate the submodule struct body
320            let struct_fields: Vec<_> = chunk_fields
321                .iter()
322                .map(|(name, ty, _)| quote! { #name: #ty, })
323                .collect();
324
325            // Generate new() body
326            let field_init_code: TokenStream = chunk_fields
327                .iter()
328                .filter_map(|(_, _, init)| init.clone())
329                .collect();
330            let field_names_for_init: Vec<_> = chunk_fields
331                .iter()
332                .map(|(name, _, _)| name.clone())
333                .collect();
334
335            // Generate forward() body
336            let input_params = crate::burn::codegen_fn_params(chunk_inputs);
337            let output_type = crate::burn::codegen_return_type(chunk_outputs);
338            let output_return = crate::burn::codegen_return_expr(chunk_outputs);
339
340            let mut forward_body = quote! {};
341            for (local_pos, node) in chunk_nodes.iter().enumerate() {
342                let mut scope_at_pos = scope.at_position(local_pos);
343                let code = NodeCodegen::forward(node, &mut scope_at_pos);
344                forward_body.extend(code);
345            }
346
347            let submodule_def = quote! {
348                #[derive(Module, Debug)]
349                pub struct #struct_name<B: Backend> {
350                    #(#struct_fields)*
351                    phantom: core::marker::PhantomData<B>,
352                    #[module(skip)]
353                    device: B::Device,
354                }
355
356                impl<B: Backend> #struct_name<B> {
357                    #[allow(unused_variables)]
358                    pub fn new(device: &B::Device) -> Self {
359                        #field_init_code
360                        Self {
361                            #(#field_names_for_init,)*
362                            phantom: core::marker::PhantomData,
363                            device: device.clone(),
364                        }
365                    }
366
367                    #[allow(clippy::let_and_return, clippy::approx_constant)]
368                    pub fn forward(&self, #input_params) -> #output_type {
369                        #forward_body
370                        #output_return
371                    }
372                }
373            };
374            submodule_defs.push(submodule_def);
375
376            // Top-level Model field for this submodule
377            submodule_field_decls.push(quote! { #field_name: #struct_name<B>, });
378            submodule_field_inits.push(quote! { let #field_name = #struct_name::new(device); });
379            submodule_field_names.push(field_name.clone());
380
381            // Generate the forward call in the top-level forward().
382            // Clone tensors that are consumed by later chunks.
383            let input_args: Vec<_> = chunk_inputs
384                .iter()
385                .map(|arg| {
386                    let name = crate::burn::arg_ident(arg);
387                    let remaining = remaining_uses.get(&arg.name).copied().unwrap_or(0);
388                    if remaining > 1 {
389                        // Will be used again by a later chunk
390                        remaining_uses.insert(arg.name.clone(), remaining - 1);
391                        quote! { #name.clone() }
392                    } else {
393                        remaining_uses.remove(&arg.name);
394                        quote! { #name }
395                    }
396                })
397                .collect();
398
399            if chunk_outputs.len() == 1 {
400                let out_name = crate::burn::arg_ident(&chunk_outputs[0]);
401                forward_calls.push(quote! {
402                    let #out_name = self.#field_name.forward(#(#input_args),*);
403                });
404            } else {
405                let out_names: Vec<_> = chunk_outputs.iter().map(crate::burn::arg_ident).collect();
406                forward_calls.push(quote! {
407                    let (#(#out_names),*) = self.#field_name.forward(#(#input_args),*);
408                });
409            }
410        }
411
412        // Top-level Model forward signature
413        let input_def = crate::burn::codegen_fn_params(&self.graph_input_args);
414        let output_type_def = crate::burn::codegen_return_type(&self.graph_output_args);
415        let output_return_def = crate::burn::codegen_return_expr(&self.graph_output_args);
416
417        let input_conversions = self.codegen_boundary_input_conversions();
418        let boundary_conversions = self.codegen_boundary_output_conversions();
419
420        let codegen_default = match &self.default {
421            Some(default) => {
422                let d = default.clone();
423                quote! { #d #maybe_blank }
424            }
425            None => quote! {},
426        };
427
428        quote! {
429            // @generated
430            // This file is automatically generated by burn-onnx
431
432            #maybe_top_file_comment
433            #codegen_imports
434            #maybe_blank
435            #maybe_blank
436
437            #(#submodule_defs)*
438            #maybe_blank
439
440            #[derive(Module, Debug)]
441            pub struct Model<B: Backend> {
442                #(#submodule_field_decls)*
443                phantom: core::marker::PhantomData<B>,
444                #[module(skip)]
445                device: B::Device,
446            }
447            #maybe_blank
448
449            #codegen_default
450
451            impl<B: Backend> Model<B> {
452                #[allow(unused_variables)]
453                pub fn new(device: &B::Device) -> Self {
454                    #(#submodule_field_inits)*
455                    Self {
456                        #(#submodule_field_names,)*
457                        phantom: core::marker::PhantomData,
458                        device: device.clone(),
459                    }
460                }
461
462                #maybe_blank
463
464                #[allow(clippy::let_and_return, clippy::approx_constant)]
465                pub fn forward(&self, #input_def) -> #output_type_def {
466                    #input_conversions
467                    #(#forward_calls)*
468                    #boundary_conversions
469                    #output_return_def
470                }
471            }
472        }
473    }
474
475    fn register_imports(&mut self) {
476        // Register imports from nodes
477        self.nodes
478            .iter()
479            .for_each(|node| NodeCodegen::register_imports(node, &mut self.imports));
480    }
481
482    /// Build the scope state to make sure tensor clones are added where needed.
483    fn build_scope(&mut self) {
484        log::debug!("Building the scope nodes len => '{}'", self.nodes.len());
485
486        // Register graph tensor inputs with 0 as node position
487        self.graph_input_args
488            .iter()
489            .filter(|arg| {
490                matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_))
491                    || self.boundary_input_conversions.contains_key(&arg.name)
492            })
493            .for_each(|arg| {
494                self.scope.tensor_register_variable(arg, 0);
495            });
496
497        self.nodes
498            .iter()
499            .enumerate()
500            .for_each(|(node_position, node)| {
501                // Register tensor outputs as variables
502                node.outputs()
503                    .iter()
504                    .filter(|arg| matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)))
505                    .for_each(|arg| {
506                        self.scope.tensor_register_variable(arg, node_position + 1);
507                    });
508                // Since the graph is guaranteed to be a DAG, we can safely register future uses
509                // of the inputs (which are the previous nodes' outputs)
510                // Filter to only dynamic/constant inputs (exclude static-only initializers)
511                node.inputs()
512                    .iter()
513                    .filter(|arg| arg.is_dynamic() || arg.is_constant())
514                    .filter(|arg| matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)))
515                    .for_each(|arg| self.scope.tensor_register_future_use(arg, node_position));
516            });
517
518        // Register graph tensor output with the last node position
519        self.graph_output_args
520            .iter()
521            .filter(|arg| matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)))
522            .for_each(|arg| {
523                self.scope.tensor_register_future_use(arg, self.nodes.len());
524            });
525    }
526
527    fn register_burnpack_loaders(&mut self, file: PathBuf, strategy: LoadStrategy) {
528        self.imports.register("burn_store::BurnpackStore");
529        self.imports.register("burn_store::ModuleSnapshot");
530        self.imports.register("burn::tensor::Bytes");
531
532        let mut statics = quote! {};
533        let mut default_impl = quote! {};
534        let mut extra_loaders = quote! {};
535
536        match strategy {
537            LoadStrategy::File => {
538                let file = file.to_str().unwrap();
539                statics = quote! {
540                    // `from_file` requires `std::path::Path`; opt into std so this
541                    // also works when included from `#![no_std]` crates.
542                    extern crate std;
543                    _blank_!();
544                };
545                default_impl = quote! {
546                    impl<B: Backend> Default for Model<B> {
547                        fn default() -> Self {
548                            Self::from_file(#file, &Default::default())
549                        }
550                    }
551                    _blank_!();
552                };
553                extra_loaders = quote! {
554                    /// Load model weights from a burnpack file.
555                    pub fn from_file<P: AsRef<std::path::Path>>(file: P, device: &B::Device) -> Self {
556                        let mut model = Self::new(device);
557                        let mut store = BurnpackStore::from_file(file);
558                        model.load_from(&mut store).expect("Failed to load burnpack file");
559                        model
560                    }
561                    _blank_!();
562                };
563            }
564            LoadStrategy::Embedded => {
565                let file_size = std::fs::metadata(&file)
566                    .expect("Failed to read burnpack file metadata")
567                    .len() as usize;
568                let file = file.to_str().unwrap();
569                statics = quote! {
570                    // Align embedded data to 256-byte boundary to match burnpack's internal alignment.
571                    // This ensures tensor data remains properly aligned for zero-copy loading,
572                    // regardless of where the linker places the static data in the binary.
573                    #[repr(C, align(256))]
574                    struct Aligned256([u8; #file_size]);
575                    static ALIGNED_DATA: Aligned256 = Aligned256(*include_bytes!(#file));
576                    static EMBEDDED_STATES: &[u8] = &ALIGNED_DATA.0;
577                    _blank_!();
578                };
579                default_impl = quote! {
580                    impl<B: Backend> Default for Model<B> {
581                        fn default() -> Self {
582                            Self::from_embedded(&Default::default())
583                        }
584                    }
585                    _blank_!();
586                };
587                extra_loaders = quote! {
588                    /// Load model weights from embedded burnpack data (zero-copy at store level).
589                    ///
590                    /// The embedded data stays in the binary's .rodata section without heap allocation.
591                    /// Tensor data is sliced directly from the static bytes.
592                    ///
593                    /// Note: Some backends may still copy data internally.
594                    /// See <https://github.com/tracel-ai/burn/issues/4153> for true backend zero-copy.
595                    ///
596                    /// See <https://github.com/tracel-ai/burn/issues/4123>
597                    pub fn from_embedded(device: &B::Device) -> Self {
598                        let mut model = Self::new(device);
599                        let mut store = BurnpackStore::from_static(EMBEDDED_STATES);
600                        model.load_from(&mut store).expect("Failed to load embedded burnpack");
601                        model
602                    }
603                    _blank_!();
604                };
605            }
606            LoadStrategy::Bytes | LoadStrategy::None => {}
607        }
608
609        self.default = Some(quote! {
610            _blank_!();
611            #statics
612            #default_impl
613            impl<B: Backend> Model<B> {
614                #extra_loaders
615                /// Load model weights from in-memory bytes.
616                ///
617                /// The bytes must be the contents of a `.bpk` file.
618                pub fn from_bytes(bytes: Bytes, device: &B::Device) -> Self {
619                    let mut model = Self::new(device);
620                    let mut store = BurnpackStore::from_bytes(Some(bytes));
621                    model.load_from(&mut store).expect("Failed to load burnpack bytes");
622                    model
623                }
624            }
625        });
626    }
627
628    /// Recursively collect all fields from nodes, including subgraph nodes in If/Loop/Scan
629    fn collect_all_fields(&self) -> Vec<FieldTuple> {
630        collect_fields_for_nodes(&self.nodes)
631    }
632
633    fn codegen_struct(&self) -> TokenStream {
634        let mut body = quote! {};
635        self.collect_all_fields()
636            .iter()
637            .map(|(name, ty, _)| {
638                quote! {
639                    #name: #ty,
640                }
641            })
642            .for_each(|code| body.extend(code));
643
644        // Extend with phantom data to avoid unused generic type.
645        body.extend(quote! {
646            phantom: core::marker::PhantomData<B>,
647            #[module(skip)]
648            device: B::Device,
649        });
650
651        quote! {
652            #[derive(Module, Debug)]
653            pub struct Model<B: Backend> {
654                #body
655            }
656        }
657    }
658
659    fn codegen_new(&self) -> TokenStream {
660        let mut body = quote! {};
661        let all_fields = self.collect_all_fields();
662
663        // Generate field initialization code
664        for (_, _, field_init) in &all_fields {
665            body.extend(field_init.clone());
666        }
667
668        // Collect field names for struct initialization
669        let field_names: Vec<_> = all_fields.iter().map(|(name, _, _)| name.clone()).collect();
670
671        quote! {
672            #[allow(unused_variables)]
673            pub fn new(device: &B::Device) -> Self {
674                #body
675
676                Self {
677                    #(#field_names,)*
678                    phantom: core::marker::PhantomData,
679                    device: device.clone(),
680                }
681            }
682        }
683    }
684
685    fn codegen_forward(&mut self) -> TokenStream {
686        let input_def = crate::burn::codegen_fn_params(&self.graph_input_args);
687        let output_type_def = crate::burn::codegen_return_type(&self.graph_output_args);
688        let output_return_def = crate::burn::codegen_return_expr(&self.graph_output_args);
689
690        let input_conversions = self.codegen_boundary_input_conversions();
691
692        let mut body = quote! {};
693        for (index, node) in self.nodes.iter().enumerate() {
694            let mut scope_at_pos = self.scope.at_position(index);
695            let code = NodeCodegen::forward(node, &mut scope_at_pos);
696            body.extend(code);
697        }
698
699        let boundary_conversions = self.codegen_boundary_output_conversions();
700
701        // TODO Return the result without a `let` binding from a block,
702        // otherwise let_and_return error will be triggered by clippy.
703        // For now, we just disable the warning.
704        quote! {
705            #[allow(clippy::let_and_return, clippy::approx_constant)]
706            pub fn forward(&self, #input_def) -> #output_type_def {
707                #input_conversions
708                #body
709                #boundary_conversions
710                #output_return_def
711            }
712        }
713    }
714
715    /// Register the input and output types of the graph using the passed in names.
716    /// The names must be unique and match the names of the inputs and outputs of the nodes.
717    /// The order will be preserved.
718    ///
719    /// # Arguments
720    ///
721    /// * `input_names` - The names of the inputs of the graph.
722    /// * `output_names` - The names of the outputs of the graph.
723    /// * `input_args` - The input arguments (from ONNX graph, used for empty graphs).
724    /// * `output_args` - The output arguments (from ONNX graph, used for empty graphs).
725    pub fn register_input_output(
726        &mut self,
727        input_names: Vec<String>,
728        output_names: Vec<String>,
729        input_args: &[onnx_ir::Argument],
730        output_args: &[onnx_ir::Argument],
731    ) {
732        // Handle empty graphs: use provided arguments directly
733        if self.nodes.is_empty() {
734            // For empty graphs, inputs pass through directly to outputs
735            self.graph_input_args.extend_from_slice(input_args);
736            self.graph_output_args.extend_from_slice(output_args);
737            self.convert_graph_boundary_scalars();
738            return;
739        }
740
741        // Get the unique names of each input/output of the nodes
742        let mut inputs = HashMap::new();
743        let mut outputs = HashMap::new();
744        for node in self.nodes.iter() {
745            for input_arg in NodeCodegen::inputs(node) {
746                inputs.insert(input_arg.name.clone(), input_arg.clone());
747            }
748            for output_arg in NodeCodegen::outputs(node) {
749                outputs.insert(output_arg.name.clone(), output_arg.clone());
750            }
751        }
752
753        // Get the input arguments of the graph using passed in names
754        // For outer scope variables, fall back to the provided input_args
755        input_names.iter().enumerate().for_each(|(idx, input)| {
756            let input_arg = inputs
757                .get(input)
758                .cloned()
759                .or_else(|| {
760                    // Fall back to provided input_args for outer scope variables
761                    if idx < input_args.len() {
762                        Some(input_args[idx].clone())
763                    } else {
764                        None
765                    }
766                })
767                .unwrap_or_else(|| panic!("Input argument not found for {input}"));
768
769            self.graph_input_args.push(input_arg);
770        });
771
772        // Handle outputs - if output_args is provided (from ONNX), use it with renaming
773        // Otherwise, look up arguments from node outputs (for tests)
774        if !output_args.is_empty() {
775            output_names
776                .iter()
777                .zip(output_args.iter())
778                .for_each(|(name, arg)| {
779                    // Rename argument to the graph output name
780                    let mut renamed_arg = arg.clone();
781                    renamed_arg.name = name.clone();
782                    self.graph_output_args.push(renamed_arg);
783                });
784        } else {
785            // For tests and non-ONNX usage: look up arguments from node outputs
786            output_names.iter().for_each(|output| {
787                self.graph_output_args.push(
788                    outputs
789                        .get(output)
790                        .unwrap_or_else(|| panic!("Output argument not found for {output}"))
791                        .clone(),
792                );
793            });
794        }
795
796        // Convert ScalarTensor to ScalarNative at graph boundary so user-facing
797        // forward() signatures use native types (f32, i64, etc.) not Tensor<B, 1>
798        self.convert_graph_boundary_scalars();
799    }
800
801    /// Generate ScalarNative -> ScalarTensor input conversion code for graph boundary.
802    fn codegen_boundary_input_conversions(&self) -> TokenStream {
803        let mut tokens = quote! {};
804        for arg in &self.graph_input_args {
805            if let Some(dtype) = self.boundary_input_conversions.get(&arg.name) {
806                let name = crate::burn::arg_ident(arg);
807                let dtype_tokens = dtype.to_tokens();
808                if dtype.is_float() {
809                    tokens.extend(quote! {
810                        let #name = Tensor::<B, 1>::from_data(
811                            burn::tensor::TensorData::from([#name]),
812                            (&self.device, #dtype_tokens)
813                        );
814                    });
815                } else if dtype.is_int() || dtype.is_uint() {
816                    tokens.extend(quote! {
817                        let #name = Tensor::<B, 1, Int>::from_data(
818                            burn::tensor::TensorData::from([#name]),
819                            (&self.device, #dtype_tokens)
820                        );
821                    });
822                } else if dtype.is_bool() {
823                    tokens.extend(quote! {
824                        let #name = Tensor::<B, 1, Bool>::from_data(
825                            burn::tensor::TensorData::from([#name]),
826                            (&self.device, #dtype_tokens)
827                        );
828                    });
829                } else {
830                    panic!(
831                        "Unsupported dtype {:?} for graph boundary ScalarNative -> ScalarTensor conversion",
832                        dtype
833                    );
834                }
835            }
836        }
837        tokens
838    }
839
840    /// Generate ScalarTensor -> ScalarNative output conversion code for graph boundary.
841    fn codegen_boundary_output_conversions(&self) -> TokenStream {
842        let mut tokens = quote! {};
843        for arg in &self.graph_output_args {
844            if let Some(dtype) = self.boundary_output_conversions.get(&arg.name) {
845                let name = crate::burn::arg_ident(arg);
846                let convert = crate::burn::on_device_to_native(quote! { #name }, dtype);
847                tokens.extend(quote! {
848                    let #name = #convert;
849                });
850            }
851        }
852        tokens
853    }
854
855    /// Convert ScalarTensor to ScalarNative at graph I/O boundary.
856    /// Users pass/receive native scalars; internal representation is on-device.
857    fn convert_graph_boundary_scalars(&mut self) {
858        for arg in &mut self.graph_input_args {
859            if let ArgType::ScalarTensor(dtype) = arg.ty {
860                self.boundary_input_conversions
861                    .insert(arg.name.clone(), dtype);
862                arg.ty = ArgType::ScalarNative(dtype);
863            }
864        }
865        for arg in &mut self.graph_output_args {
866            if let ArgType::ScalarTensor(dtype) = arg.ty {
867                self.boundary_output_conversions
868                    .insert(arg.name.clone(), dtype);
869                arg.ty = ArgType::ScalarNative(dtype);
870            }
871        }
872    }
873}
874
875// ============================================================================
876// Free functions shared by flat and partitioned codegen paths
877// ============================================================================
878
879type FieldTuple = (proc_macro2::Ident, TokenStream, Option<TokenStream>);
880
881/// Collect fields from a slice of nodes (including If/Loop subgraph fields).
882fn collect_fields_for_nodes(nodes: &[Node]) -> Vec<FieldTuple> {
883    let mut field_name_counts: HashMap<String, usize> = HashMap::new();
884    let mut all_fields: Vec<FieldTuple> = Vec::new();
885
886    fn collect_subgraph_fields_recursive(
887        subgraph: &onnx_ir::OnnxGraph,
888        field_name_counts: &mut HashMap<String, usize>,
889        all_fields: &mut Vec<FieldTuple>,
890    ) {
891        for node in &subgraph.nodes {
892            if let Some(mut field) = NodeCodegen::field(node) {
893                let base_name = field.name.to_string();
894                let count = field_name_counts.entry(base_name.clone()).or_insert(0);
895                *count += 1;
896
897                if *count > 1 {
898                    let new_name_str = format!("{}_{}", base_name, count);
899                    let new_name = syn::Ident::new(&new_name_str, proc_macro2::Span::call_site());
900                    field.name = new_name;
901
902                    let init_str = field.init.to_string();
903                    let updated = init_str
904                        .replace(
905                            &format!("let {} :", base_name),
906                            &format!("let {} :", new_name_str),
907                        )
908                        .replace(
909                            &format!("let {} =", base_name),
910                            &format!("let {} =", new_name_str),
911                        );
912                    field.init = updated.parse().unwrap_or_else(|e| {
913                        log::warn!(
914                            "Failed to parse renamed field init for '{}': {e}",
915                            new_name_str
916                        );
917                        field.init.clone()
918                    });
919                }
920                all_fields.push((field.name.clone(), field.ty.clone(), Some(field.init)));
921            }
922
923            if let Node::If(nested) = node {
924                collect_subgraph_fields_recursive(
925                    &nested.config.then_branch,
926                    field_name_counts,
927                    all_fields,
928                );
929                collect_subgraph_fields_recursive(
930                    &nested.config.else_branch,
931                    field_name_counts,
932                    all_fields,
933                );
934            } else if let Node::Loop(nested) = node {
935                collect_subgraph_fields_recursive(
936                    &nested.config.body,
937                    field_name_counts,
938                    all_fields,
939                );
940            }
941        }
942    }
943
944    for node in nodes {
945        if let Some(field) = NodeCodegen::field(node) {
946            all_fields.push((field.name, field.ty, Some(field.init)));
947        }
948
949        if let Node::If(if_node) = node {
950            collect_subgraph_fields_recursive(
951                &if_node.config.then_branch,
952                &mut field_name_counts,
953                &mut all_fields,
954            );
955            collect_subgraph_fields_recursive(
956                &if_node.config.else_branch,
957                &mut field_name_counts,
958                &mut all_fields,
959            );
960        } else if let Node::Loop(loop_node) = node {
961            collect_subgraph_fields_recursive(
962                &loop_node.config.body,
963                &mut field_name_counts,
964                &mut all_fields,
965            );
966        }
967    }
968
969    all_fields
970}
971
972/// Collect tensor snapshots from a slice of nodes, optionally prefixing paths.
973///
974/// When `prefix` is non-empty, snapshot paths become "prefix.field.weight" etc.
975fn collect_snapshots_from_nodes(
976    nodes: &[Node],
977    prefix: &str,
978    field_name_counts: &mut HashMap<String, usize>,
979    snapshots: &mut Vec<TensorSnapshot>,
980) {
981    fn collect_subgraph_snapshots_recursive(
982        subgraph: &onnx_ir::OnnxGraph,
983        prefix: &str,
984        field_name_counts: &mut HashMap<String, usize>,
985        snapshots: &mut Vec<TensorSnapshot>,
986    ) {
987        for node in &subgraph.nodes {
988            if let Some(field) = NodeCodegen::field(node) {
989                let base_name = field.name.to_string();
990                let count = field_name_counts.entry(base_name.clone()).or_insert(0);
991                *count += 1;
992
993                let unique_name = if *count > 1 {
994                    format!("{}_{}", base_name, count)
995                } else {
996                    base_name
997                };
998
999                let full_name = if prefix.is_empty() {
1000                    unique_name
1001                } else {
1002                    format!("{}.{}", prefix, unique_name)
1003                };
1004                let node_snapshots = NodeCodegen::collect_snapshots(node, &full_name);
1005                snapshots.extend(node_snapshots);
1006            }
1007
1008            if let Node::If(nested) = node {
1009                collect_subgraph_snapshots_recursive(
1010                    &nested.config.then_branch,
1011                    prefix,
1012                    field_name_counts,
1013                    snapshots,
1014                );
1015                collect_subgraph_snapshots_recursive(
1016                    &nested.config.else_branch,
1017                    prefix,
1018                    field_name_counts,
1019                    snapshots,
1020                );
1021            } else if let Node::Loop(nested) = node {
1022                collect_subgraph_snapshots_recursive(
1023                    &nested.config.body,
1024                    prefix,
1025                    field_name_counts,
1026                    snapshots,
1027                );
1028            }
1029        }
1030    }
1031
1032    for node in nodes {
1033        if let Some(field) = NodeCodegen::field(node) {
1034            let base_name = field.name.to_string();
1035            let count = field_name_counts.entry(base_name.clone()).or_insert(0);
1036            *count += 1;
1037
1038            let unique_name = if *count > 1 {
1039                format!("{}_{}", base_name, count)
1040            } else {
1041                base_name
1042            };
1043
1044            let full_name = if prefix.is_empty() {
1045                unique_name
1046            } else {
1047                format!("{}.{}", prefix, unique_name)
1048            };
1049            let node_snapshots = NodeCodegen::collect_snapshots(node, &full_name);
1050            snapshots.extend(node_snapshots);
1051        }
1052
1053        if let Node::If(if_node) = node {
1054            collect_subgraph_snapshots_recursive(
1055                &if_node.config.then_branch,
1056                prefix,
1057                field_name_counts,
1058                snapshots,
1059            );
1060            collect_subgraph_snapshots_recursive(
1061                &if_node.config.else_branch,
1062                prefix,
1063                field_name_counts,
1064                snapshots,
1065            );
1066        } else if let Node::Loop(loop_node) = node {
1067            collect_subgraph_snapshots_recursive(
1068                &loop_node.config.body,
1069                prefix,
1070                field_name_counts,
1071                snapshots,
1072            );
1073        }
1074    }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079    use super::*;
1080    use burn::tensor::DType;
1081    use onnx_ir::node::abs::AbsNodeBuilder;
1082    use rust_format::{Config, Formatter, PostProcess, PrettyPlease};
1083
1084    fn format_tokens(tokens: TokenStream) -> String {
1085        let config = Config::new_str().post_proc(PostProcess::ReplaceMarkersAndDocBlocks);
1086        let formatter = PrettyPlease::from_config(config);
1087        formatter
1088            .format_tokens(tokens)
1089            .unwrap_or_else(|_| "FORMATTING FAILED".to_string())
1090    }
1091
1092    /// Build a chain of N abs nodes: input -> t0 -> t1 -> ... -> t{N-1}
1093    fn build_abs_chain(n: usize) -> BurnGraph {
1094        let mut graph = BurnGraph::default();
1095
1096        for i in 0..n {
1097            let in_name = if i == 0 {
1098                "input".to_string()
1099            } else {
1100                format!("t{}", i - 1)
1101            };
1102            let out_name = format!("t{}", i);
1103
1104            let node = AbsNodeBuilder::new(&format!("abs{}", i))
1105                .input_tensor(&in_name, 2, DType::F32)
1106                .output_tensor(&out_name, 2, DType::F32)
1107                .build();
1108
1109            graph.register(Node::Abs(node));
1110        }
1111
1112        let last_out = format!("t{}", n - 1);
1113        graph.register_input_output(vec!["input".to_string()], vec![last_out], &[], &[]);
1114
1115        graph
1116    }
1117
1118    /// Two Clip nodes chained through a single intermediate tensor,
1119    /// each with its own independent runtime scalar bounds. The
1120    /// generated `__clip_min` / `__clip_max` temporaries must each
1121    /// live inside their own per-node block so clone-tracking and
1122    /// name resolution don't interleave across the two instances.
1123    fn build_two_clip_chain() -> BurnGraph {
1124        use onnx_ir::clip::{ClipConfig, ClipNodeBuilder};
1125        use onnx_ir::node::clip::ClipInput;
1126
1127        let mut graph = BurnGraph::default();
1128
1129        let mk = |name: &str, in_tensor: &str, min_name: &str, max_name: &str, out: &str| {
1130            ClipNodeBuilder::new(name)
1131                .input_tensor(in_tensor, 2, DType::F32)
1132                .input_scalar(min_name, DType::F32)
1133                .input_scalar(max_name, DType::F32)
1134                .output_tensor(out, 2, DType::F32)
1135                .config(ClipConfig {
1136                    min: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
1137                        min_name.to_string(),
1138                        1,
1139                    ))),
1140                    max: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
1141                        max_name.to_string(),
1142                        2,
1143                    ))),
1144                })
1145                .build()
1146        };
1147
1148        graph.register(Node::Clip(mk("clip0", "input", "min0", "max0", "t0")));
1149        graph.register(Node::Clip(mk("clip1", "t0", "min1", "max1", "t1")));
1150
1151        graph.register_input_output(
1152            vec![
1153                "input".to_string(),
1154                "min0".to_string(),
1155                "max0".to_string(),
1156                "min1".to_string(),
1157                "max1".to_string(),
1158            ],
1159            vec!["t1".to_string()],
1160            &[],
1161            &[],
1162        );
1163
1164        graph
1165    }
1166
1167    /// Walk the generated Rust text and return the list of innermost
1168    /// `{ ... }` blocks (as substrings of `code`, without the braces).
1169    /// Used by scoping regression tests: counting raw occurrences of a
1170    /// `let __foo` binding is not enough because the same bindings at
1171    /// the outer `forward` scope would still pass. An innermost-block
1172    /// scan lets us assert that the bindings sit inside a per-node
1173    /// subscope, not at the function top level.
1174    ///
1175    /// "Innermost" means the block contains no nested `{...}` children.
1176    /// Tracked per-block on the stack so siblings don't pollute each
1177    /// other (a parent with one inner child is still a parent — not
1178    /// innermost — but its other children can still qualify).
1179    fn innermost_blocks(code: &str) -> Vec<&str> {
1180        let bytes = code.as_bytes();
1181        let mut stack: Vec<(usize, bool)> = Vec::new();
1182        let mut innermost: Vec<(usize, usize)> = Vec::new();
1183        for (i, &b) in bytes.iter().enumerate() {
1184            match b {
1185                b'{' => {
1186                    if let Some(last) = stack.last_mut() {
1187                        last.1 = true;
1188                    }
1189                    stack.push((i, false));
1190                }
1191                b'}' => {
1192                    if let Some((open, has_inner)) = stack.pop()
1193                        && !has_inner
1194                    {
1195                        innermost.push((open + 1, i));
1196                    }
1197                }
1198                _ => {}
1199            }
1200        }
1201        innermost.into_iter().map(|(s, e)| &code[s..e]).collect()
1202    }
1203
1204    /// Regression test for #317, issue 6: verifies that runtime-bound Clip
1205    /// nodes emit their `__clip_min` / `__clip_max` temporaries inside
1206    /// per-node block scopes rather than at the outer `forward` scope.
1207    /// Without the wrapper block, both `let __clip_min = ...;` bindings
1208    /// would land at the outer scope — still legal Rust, but
1209    /// clone-tracking for the runtime-bound inputs and variable
1210    /// resolution for downstream consumers would interleave across nodes
1211    /// in hard-to-debug ways.
1212    ///
1213    /// The test walks the generated code to find every innermost `{ ... }`
1214    /// block and counts the ones that contain both a `let __clip_min = `
1215    /// and a `let __clip_max = `. That count must be exactly two (one per
1216    /// Clip node). A raw `code.matches(...).count() == 2` would also pass
1217    /// if both bindings were at the outer scope, which is exactly the
1218    /// regression we are trying to rule out.
1219    #[test]
1220    fn multi_instance_clip_scoping() {
1221        let graph = build_two_clip_chain();
1222        let code = format_tokens(graph.codegen());
1223
1224        let scoped_blocks: Vec<&str> = innermost_blocks(&code)
1225            .into_iter()
1226            .filter(|b| b.contains("let __clip_min = ") && b.contains("let __clip_max = "))
1227            .collect();
1228
1229        assert_eq!(
1230            scoped_blocks.len(),
1231            2,
1232            "expected exactly two innermost blocks each containing \
1233             both `let __clip_min =` and `let __clip_max =`, got \
1234             {} such blocks. Full generated code:\n{code}",
1235            scoped_blocks.len()
1236        );
1237
1238        // Belt-and-braces: each scoped block must declare exactly one
1239        // `__clip_min` and one `__clip_max`. A block containing two
1240        // `__clip_min` bindings would mean two clip nodes collapsed into
1241        // a single scope, which is the bug we are guarding against.
1242        for (idx, block) in scoped_blocks.iter().enumerate() {
1243            assert_eq!(
1244                block.matches("let __clip_min = ").count(),
1245                1,
1246                "block {idx} should declare exactly one __clip_min, got:\n{block}"
1247            );
1248            assert_eq!(
1249                block.matches("let __clip_max = ").count(),
1250                1,
1251                "block {idx} should declare exactly one __clip_max, got:\n{block}"
1252            );
1253        }
1254    }
1255
1256    #[test]
1257    fn small_graph_uses_flat_codegen() {
1258        let graph = build_abs_chain(5);
1259        let code = format_tokens(graph.codegen());
1260
1261        // Should have a single Model struct, no Submodule structs
1262        assert!(code.contains("pub struct Model<B: Backend>"));
1263        assert!(!code.contains("Submodule"));
1264    }
1265
1266    #[test]
1267    fn large_graph_uses_partitioned_codegen() {
1268        let graph = build_abs_chain(250);
1269        let code = format_tokens(graph.codegen());
1270
1271        // Should have Submodule structs and a Model that delegates
1272        assert!(code.contains("pub struct Submodule1<B: Backend>"));
1273        assert!(code.contains("pub struct Model<B: Backend>"));
1274        assert!(code.contains("submodule1: Submodule1<B>"));
1275
1276        // Submodules should have their own forward methods
1277        assert!(code.contains("self.submodule1.forward("));
1278
1279        // The Model forward should still take `input` and return the final tensor
1280        assert!(code.contains("pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2>"));
1281    }
1282
1283    #[test]
1284    fn large_graph_with_partition_disabled_uses_flat_codegen() {
1285        let graph = build_abs_chain(250);
1286        let code = format_tokens(graph.with_partition(false).codegen());
1287
1288        // Should use flat codegen despite exceeding MIN_GRAPH_SIZE
1289        assert!(code.contains("pub struct Model<B: Backend>"));
1290        assert!(
1291            !code.contains("Submodule"),
1292            "partition(false) should prevent submodules"
1293        );
1294
1295        // Forward should be directly on Model, not delegated
1296        assert!(code.contains("pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2>"));
1297    }
1298
1299    #[test]
1300    fn partitioned_graph_snapshot() {
1301        // Use a graph just above the threshold (200 nodes) for a manageable snapshot
1302        let graph = build_abs_chain(200);
1303        let code = format_tokens(graph.codegen());
1304
1305        // Verify the overall structure by checking key patterns
1306        // (Full snapshot would be too long; check structural invariants instead)
1307
1308        // Must have at least 2 submodules
1309        assert!(code.contains("Submodule1"));
1310        assert!(code.contains("Submodule2"));
1311
1312        // Each submodule must have #[derive(Module, Debug)]
1313        let module_derive_count = code.matches("#[derive(Module, Debug)]").count();
1314        // At least 3: one per submodule + one for Model
1315        assert!(
1316            module_derive_count >= 3,
1317            "Expected at least 3 #[derive(Module, Debug)], got {}",
1318            module_derive_count
1319        );
1320
1321        // Model::new should create submodules
1322        assert!(code.contains("Submodule1::new(device)"));
1323        assert!(code.contains("Submodule2::new(device)"));
1324
1325        // No duplicate struct definitions
1326        let submodule1_count = code.matches("pub struct Submodule1").count();
1327        assert_eq!(submodule1_count, 1, "Submodule1 defined more than once");
1328    }
1329
1330    /// Create a temporary .bpk file for tests that need `with_burnpack`.
1331    fn temp_bpk() -> std::path::PathBuf {
1332        use std::sync::atomic::{AtomicU64, Ordering};
1333        static COUNTER: AtomicU64 = AtomicU64::new(0);
1334        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1335        let path =
1336            std::env::temp_dir().join(format!("burn-onnx-test-{}-{}.bpk", std::process::id(), id));
1337        std::fs::write(&path, [0u8; 4]).unwrap();
1338        path
1339    }
1340
1341    #[test]
1342    fn load_strategy_file_generates_from_file_and_from_bytes() {
1343        let bpk = temp_bpk();
1344        let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::File);
1345        let code = format_tokens(graph.codegen());
1346        let _ = std::fs::remove_file(bpk);
1347
1348        assert!(
1349            code.contains(
1350                "pub fn from_file<P: AsRef<std::path::Path>>(file: P, device: &B::Device)"
1351            )
1352        );
1353        assert!(code.contains("pub fn from_bytes(bytes: Bytes"));
1354        assert!(code.contains("impl<B: Backend> Default for Model<B>"));
1355        assert!(code.contains("Self::from_file("));
1356        assert!(!code.contains("from_embedded"));
1357        // `from_file` references `std::path::Path`, which is not resolvable from
1358        // `#![no_std]` consumers unless std is explicitly linked. Pin the opt-in.
1359        assert!(code.contains("extern crate std;"));
1360    }
1361
1362    #[test]
1363    fn load_strategy_embedded_generates_from_embedded_and_from_bytes() {
1364        let bpk = temp_bpk();
1365        let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::Embedded);
1366        let code = format_tokens(graph.codegen());
1367        let _ = std::fs::remove_file(bpk);
1368
1369        assert!(code.contains("pub fn from_embedded("));
1370        assert!(code.contains("pub fn from_bytes(bytes: Bytes"));
1371        assert!(code.contains("impl<B: Backend> Default for Model<B>"));
1372        assert!(code.contains("Self::from_embedded("));
1373        assert!(code.contains("include_bytes!"));
1374        assert!(!code.contains("from_file"));
1375        assert!(!code.contains("extern crate std"));
1376    }
1377
1378    #[test]
1379    fn load_strategy_bytes_generates_only_from_bytes() {
1380        let bpk = temp_bpk();
1381        let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::Bytes);
1382        let code = format_tokens(graph.codegen());
1383        let _ = std::fs::remove_file(bpk);
1384
1385        assert!(code.contains("pub fn from_bytes(bytes: Bytes"));
1386        assert!(!code.contains("from_file"));
1387        assert!(!code.contains("from_embedded"));
1388        assert!(!code.contains("impl<B: Backend> Default for Model<B>"));
1389        assert!(!code.contains("extern crate std"));
1390    }
1391
1392    #[test]
1393    fn load_strategy_none_generates_no_loaders() {
1394        let bpk = temp_bpk();
1395        let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::None);
1396        let code = format_tokens(graph.codegen());
1397        let _ = std::fs::remove_file(bpk);
1398
1399        assert!(!code.contains("from_file"));
1400        assert!(!code.contains("from_bytes"));
1401        assert!(!code.contains("from_embedded"));
1402        assert!(!code.contains("impl<B: Backend> Default for Model<B>"));
1403        assert!(!code.contains("extern crate std"));
1404    }
1405}