1use super::{BurnImports, Scope, ToTokens};
2use crate::LoadStrategy;
3use crate::burn::node::NodeCodegen;
4use crate::burn::partition::{
5 MIN_GRAPH_SIZE, Partition, reorder_constants_to_consumers, try_partition,
6};
7use burn_store::{BurnpackWriter, TensorSnapshot};
8use onnx_ir::{Node, ir::ArgType};
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11use std::{collections::HashMap, path::PathBuf};
12
13#[derive(Debug)]
15pub struct BurnGraph {
16 nodes: Vec<Node>,
17 scope: Scope,
18 imports: BurnImports,
19 top_comment: Option<String>,
20 default: Option<TokenStream>,
21 blank_spaces: bool,
22 graph_input_args: Vec<onnx_ir::Argument>,
23 graph_output_args: Vec<onnx_ir::Argument>,
24 partition: bool,
26 cached_partition: Option<Option<Partition>>,
28 boundary_output_conversions: HashMap<String, onnx_ir::ir::DType>,
33 boundary_input_conversions: HashMap<String, onnx_ir::ir::DType>,
34}
35
36impl Default for BurnGraph {
37 fn default() -> Self {
38 Self {
39 nodes: Vec::new(),
40 scope: Scope::default(),
41 imports: BurnImports::default(),
42 top_comment: None,
43 default: None,
44 blank_spaces: false,
45 graph_input_args: Vec::new(),
46 graph_output_args: Vec::new(),
47 partition: true,
48 cached_partition: None,
49 boundary_output_conversions: HashMap::new(),
50 boundary_input_conversions: HashMap::new(),
51 }
52 }
53}
54
55impl BurnGraph {
56 pub fn register(&mut self, node: Node) {
62 log::debug!("Registering node => '{}'", node.name());
63 self.nodes.push(node);
64 }
65
66 pub fn with_burnpack(mut self, out_file: PathBuf, strategy: LoadStrategy) -> Self {
70 let snapshots = self.collect_all_snapshots();
72
73 let burnpack_file = out_file.with_extension("bpk");
75 BurnpackWriter::new(snapshots)
76 .with_metadata("producer", "burn-onnx")
77 .write_to_file(&burnpack_file)
78 .expect("Failed to write burnpack file");
79
80 if strategy != LoadStrategy::None {
82 self.register_burnpack_loaders(burnpack_file, strategy);
83 }
84
85 self
86 }
87
88 fn collect_all_snapshots(&mut self) -> Vec<TensorSnapshot> {
94 let partition = self.compute_partition();
95
96 if let Some(partition) = partition {
97 self.collect_snapshots_partitioned(&partition)
98 } else {
99 self.collect_snapshots_flat()
100 }
101 }
102
103 fn compute_partition(&mut self) -> Option<Partition> {
106 if let Some(ref cached) = self.cached_partition {
107 return cached.clone();
108 }
109 let result = if self.partition {
110 if self.nodes.len() >= MIN_GRAPH_SIZE {
114 reorder_constants_to_consumers(&mut self.nodes);
115 }
116 try_partition(&self.nodes, &self.graph_input_args, &self.graph_output_args)
117 } else {
118 None
119 };
120 self.cached_partition = Some(result.clone());
121 result
122 }
123
124 fn collect_snapshots_flat(&self) -> Vec<TensorSnapshot> {
125 let mut snapshots = Vec::new();
126 let mut field_name_counts: HashMap<String, usize> = HashMap::new();
127 collect_snapshots_from_nodes(&self.nodes, "", &mut field_name_counts, &mut snapshots);
128 snapshots
129 }
130
131 fn collect_snapshots_partitioned(&self, partition: &Partition) -> Vec<TensorSnapshot> {
132 let mut snapshots = Vec::new();
133
134 for (chunk_idx, range) in partition.chunks.iter().enumerate() {
135 let prefix = format!("submodule{}", chunk_idx + 1);
136 let chunk_nodes = &self.nodes[range.clone()];
137 let mut field_name_counts: HashMap<String, usize> = HashMap::new();
139 collect_snapshots_from_nodes(
140 chunk_nodes,
141 &prefix,
142 &mut field_name_counts,
143 &mut snapshots,
144 );
145 }
146 snapshots
147 }
148
149 pub fn with_blank_space(mut self, blank_spaces: bool) -> Self {
155 self.blank_spaces = blank_spaces;
156 self
157 }
158
159 pub fn with_top_comment(mut self, top_comment: Option<String>) -> Self {
161 self.top_comment = top_comment;
162 self
163 }
164
165 pub fn with_partition(mut self, partition: bool) -> Self {
167 self.partition = partition;
168 self
169 }
170
171 pub fn codegen(mut self) -> TokenStream {
173 self.register_imports();
174
175 let partition = self.compute_partition();
176
177 if let Some(partition) = partition {
178 self.codegen_partitioned(partition)
179 } else {
180 self.codegen_flat()
181 }
182 }
183
184 fn codegen_flat(mut self) -> TokenStream {
186 self.build_scope();
187
188 let codegen_imports = self.imports.codegen();
189 let codegen_struct = self.codegen_struct();
190 let codegen_new = self.codegen_new();
191 let codegen_forward = self.codegen_forward();
192
193 let maybe_blank = match self.blank_spaces {
194 true => quote! {
195 _blank_!();
196 },
197 false => quote! {},
198 };
199 let codegen_default = match self.default {
200 Some(default) => quote! {
201 #default
202 #maybe_blank
203 },
204 None => quote! {},
205 };
206
207 let maybe_top_file_comment = match self.top_comment {
208 Some(comment) => quote! {
209 _comment_!(#comment);
210 },
211 None => quote! {},
212 };
213
214 quote! {
215 #maybe_top_file_comment
219 #codegen_imports
220 #maybe_blank
221 #maybe_blank
222
223 #codegen_struct
224 #maybe_blank
225
226 #codegen_default
227
228 impl<B: Backend> Model<B> {
229 #codegen_new
230
231 #maybe_blank
232
233 #codegen_forward
234 }
235 }
236 }
237
238 fn codegen_partitioned(self, partition: Partition) -> TokenStream {
240 let maybe_blank = match self.blank_spaces {
241 true => quote! { _blank_!(); },
242 false => quote! {},
243 };
244
245 let codegen_imports = self.imports.codegen();
246 let maybe_top_file_comment = match &self.top_comment {
247 Some(comment) => {
248 let c = comment.clone();
249 quote! { _comment_!(#c); }
250 }
251 None => quote! {},
252 };
253
254 let num_chunks = partition.chunks.len();
255 let mut submodule_defs = Vec::with_capacity(num_chunks);
256 let mut submodule_field_decls = Vec::with_capacity(num_chunks);
257 let mut submodule_field_inits = Vec::with_capacity(num_chunks);
258 let mut submodule_field_names = Vec::with_capacity(num_chunks);
259 let mut forward_calls = Vec::with_capacity(num_chunks);
260
261 let mut remaining_uses: HashMap<String, usize> = HashMap::new();
264 for inputs in &partition.chunk_inputs {
265 for arg in inputs {
266 *remaining_uses.entry(arg.name.clone()).or_insert(0) += 1;
267 }
268 }
269
270 for (chunk_idx, range) in partition.chunks.iter().enumerate() {
271 let struct_name = format_ident!("Submodule{}", chunk_idx + 1);
272 let field_name = format_ident!("submodule{}", chunk_idx + 1);
273 let chunk_nodes = &self.nodes[range.clone()];
274 let chunk_inputs = &partition.chunk_inputs[chunk_idx];
275 let chunk_outputs = &partition.chunk_outputs[chunk_idx];
276
277 let mut scope = Scope::default();
279
280 for arg in chunk_inputs {
285 if matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_))
286 || self.boundary_input_conversions.contains_key(&arg.name)
287 {
288 scope.tensor_register_variable(arg, 0);
289 }
290 }
291
292 for (local_pos, node) in chunk_nodes.iter().enumerate() {
294 for arg in node.outputs() {
295 if matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)) {
296 scope.tensor_register_variable(arg, local_pos + 1);
297 }
298 }
299 for arg in node.inputs() {
300 if (arg.is_dynamic() || arg.is_constant())
301 && matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_))
302 {
303 scope.tensor_register_future_use(arg, local_pos);
304 }
305 }
306 }
307
308 let chunk_len = chunk_nodes.len();
310 for arg in chunk_outputs {
311 if matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)) {
312 scope.tensor_register_future_use(arg, chunk_len);
313 }
314 }
315
316 let chunk_fields = collect_fields_for_nodes(chunk_nodes);
318
319 let struct_fields: Vec<_> = chunk_fields
321 .iter()
322 .map(|(name, ty, _)| quote! { #name: #ty, })
323 .collect();
324
325 let field_init_code: TokenStream = chunk_fields
327 .iter()
328 .filter_map(|(_, _, init)| init.clone())
329 .collect();
330 let field_names_for_init: Vec<_> = chunk_fields
331 .iter()
332 .map(|(name, _, _)| name.clone())
333 .collect();
334
335 let input_params = crate::burn::codegen_fn_params(chunk_inputs);
337 let output_type = crate::burn::codegen_return_type(chunk_outputs);
338 let output_return = crate::burn::codegen_return_expr(chunk_outputs);
339
340 let mut forward_body = quote! {};
341 for (local_pos, node) in chunk_nodes.iter().enumerate() {
342 let mut scope_at_pos = scope.at_position(local_pos);
343 let code = NodeCodegen::forward(node, &mut scope_at_pos);
344 forward_body.extend(code);
345 }
346
347 let submodule_def = quote! {
348 #[derive(Module, Debug)]
349 pub struct #struct_name<B: Backend> {
350 #(#struct_fields)*
351 phantom: core::marker::PhantomData<B>,
352 #[module(skip)]
353 device: B::Device,
354 }
355
356 impl<B: Backend> #struct_name<B> {
357 #[allow(unused_variables)]
358 pub fn new(device: &B::Device) -> Self {
359 #field_init_code
360 Self {
361 #(#field_names_for_init,)*
362 phantom: core::marker::PhantomData,
363 device: device.clone(),
364 }
365 }
366
367 #[allow(clippy::let_and_return, clippy::approx_constant)]
368 pub fn forward(&self, #input_params) -> #output_type {
369 #forward_body
370 #output_return
371 }
372 }
373 };
374 submodule_defs.push(submodule_def);
375
376 submodule_field_decls.push(quote! { #field_name: #struct_name<B>, });
378 submodule_field_inits.push(quote! { let #field_name = #struct_name::new(device); });
379 submodule_field_names.push(field_name.clone());
380
381 let input_args: Vec<_> = chunk_inputs
384 .iter()
385 .map(|arg| {
386 let name = crate::burn::arg_ident(arg);
387 let remaining = remaining_uses.get(&arg.name).copied().unwrap_or(0);
388 if remaining > 1 {
389 remaining_uses.insert(arg.name.clone(), remaining - 1);
391 quote! { #name.clone() }
392 } else {
393 remaining_uses.remove(&arg.name);
394 quote! { #name }
395 }
396 })
397 .collect();
398
399 if chunk_outputs.len() == 1 {
400 let out_name = crate::burn::arg_ident(&chunk_outputs[0]);
401 forward_calls.push(quote! {
402 let #out_name = self.#field_name.forward(#(#input_args),*);
403 });
404 } else {
405 let out_names: Vec<_> = chunk_outputs.iter().map(crate::burn::arg_ident).collect();
406 forward_calls.push(quote! {
407 let (#(#out_names),*) = self.#field_name.forward(#(#input_args),*);
408 });
409 }
410 }
411
412 let input_def = crate::burn::codegen_fn_params(&self.graph_input_args);
414 let output_type_def = crate::burn::codegen_return_type(&self.graph_output_args);
415 let output_return_def = crate::burn::codegen_return_expr(&self.graph_output_args);
416
417 let input_conversions = self.codegen_boundary_input_conversions();
418 let boundary_conversions = self.codegen_boundary_output_conversions();
419
420 let codegen_default = match &self.default {
421 Some(default) => {
422 let d = default.clone();
423 quote! { #d #maybe_blank }
424 }
425 None => quote! {},
426 };
427
428 quote! {
429 #maybe_top_file_comment
433 #codegen_imports
434 #maybe_blank
435 #maybe_blank
436
437 #(#submodule_defs)*
438 #maybe_blank
439
440 #[derive(Module, Debug)]
441 pub struct Model<B: Backend> {
442 #(#submodule_field_decls)*
443 phantom: core::marker::PhantomData<B>,
444 #[module(skip)]
445 device: B::Device,
446 }
447 #maybe_blank
448
449 #codegen_default
450
451 impl<B: Backend> Model<B> {
452 #[allow(unused_variables)]
453 pub fn new(device: &B::Device) -> Self {
454 #(#submodule_field_inits)*
455 Self {
456 #(#submodule_field_names,)*
457 phantom: core::marker::PhantomData,
458 device: device.clone(),
459 }
460 }
461
462 #maybe_blank
463
464 #[allow(clippy::let_and_return, clippy::approx_constant)]
465 pub fn forward(&self, #input_def) -> #output_type_def {
466 #input_conversions
467 #(#forward_calls)*
468 #boundary_conversions
469 #output_return_def
470 }
471 }
472 }
473 }
474
475 fn register_imports(&mut self) {
476 self.nodes
478 .iter()
479 .for_each(|node| NodeCodegen::register_imports(node, &mut self.imports));
480 }
481
482 fn build_scope(&mut self) {
484 log::debug!("Building the scope nodes len => '{}'", self.nodes.len());
485
486 self.graph_input_args
488 .iter()
489 .filter(|arg| {
490 matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_))
491 || self.boundary_input_conversions.contains_key(&arg.name)
492 })
493 .for_each(|arg| {
494 self.scope.tensor_register_variable(arg, 0);
495 });
496
497 self.nodes
498 .iter()
499 .enumerate()
500 .for_each(|(node_position, node)| {
501 node.outputs()
503 .iter()
504 .filter(|arg| matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)))
505 .for_each(|arg| {
506 self.scope.tensor_register_variable(arg, node_position + 1);
507 });
508 node.inputs()
512 .iter()
513 .filter(|arg| arg.is_dynamic() || arg.is_constant())
514 .filter(|arg| matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)))
515 .for_each(|arg| self.scope.tensor_register_future_use(arg, node_position));
516 });
517
518 self.graph_output_args
520 .iter()
521 .filter(|arg| matches!(arg.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)))
522 .for_each(|arg| {
523 self.scope.tensor_register_future_use(arg, self.nodes.len());
524 });
525 }
526
527 fn register_burnpack_loaders(&mut self, file: PathBuf, strategy: LoadStrategy) {
528 self.imports.register("burn_store::BurnpackStore");
529 self.imports.register("burn_store::ModuleSnapshot");
530 self.imports.register("burn::tensor::Bytes");
531
532 let mut statics = quote! {};
533 let mut default_impl = quote! {};
534 let mut extra_loaders = quote! {};
535
536 match strategy {
537 LoadStrategy::File => {
538 let file = file.to_str().unwrap();
539 statics = quote! {
540 extern crate std;
543 _blank_!();
544 };
545 default_impl = quote! {
546 impl<B: Backend> Default for Model<B> {
547 fn default() -> Self {
548 Self::from_file(#file, &Default::default())
549 }
550 }
551 _blank_!();
552 };
553 extra_loaders = quote! {
554 pub fn from_file<P: AsRef<std::path::Path>>(file: P, device: &B::Device) -> Self {
556 let mut model = Self::new(device);
557 let mut store = BurnpackStore::from_file(file);
558 model.load_from(&mut store).expect("Failed to load burnpack file");
559 model
560 }
561 _blank_!();
562 };
563 }
564 LoadStrategy::Embedded => {
565 let file_size = std::fs::metadata(&file)
566 .expect("Failed to read burnpack file metadata")
567 .len() as usize;
568 let file = file.to_str().unwrap();
569 statics = quote! {
570 #[repr(C, align(256))]
574 struct Aligned256([u8; #file_size]);
575 static ALIGNED_DATA: Aligned256 = Aligned256(*include_bytes!(#file));
576 static EMBEDDED_STATES: &[u8] = &ALIGNED_DATA.0;
577 _blank_!();
578 };
579 default_impl = quote! {
580 impl<B: Backend> Default for Model<B> {
581 fn default() -> Self {
582 Self::from_embedded(&Default::default())
583 }
584 }
585 _blank_!();
586 };
587 extra_loaders = quote! {
588 pub fn from_embedded(device: &B::Device) -> Self {
598 let mut model = Self::new(device);
599 let mut store = BurnpackStore::from_static(EMBEDDED_STATES);
600 model.load_from(&mut store).expect("Failed to load embedded burnpack");
601 model
602 }
603 _blank_!();
604 };
605 }
606 LoadStrategy::Bytes | LoadStrategy::None => {}
607 }
608
609 self.default = Some(quote! {
610 _blank_!();
611 #statics
612 #default_impl
613 impl<B: Backend> Model<B> {
614 #extra_loaders
615 pub fn from_bytes(bytes: Bytes, device: &B::Device) -> Self {
619 let mut model = Self::new(device);
620 let mut store = BurnpackStore::from_bytes(Some(bytes));
621 model.load_from(&mut store).expect("Failed to load burnpack bytes");
622 model
623 }
624 }
625 });
626 }
627
628 fn collect_all_fields(&self) -> Vec<FieldTuple> {
630 collect_fields_for_nodes(&self.nodes)
631 }
632
633 fn codegen_struct(&self) -> TokenStream {
634 let mut body = quote! {};
635 self.collect_all_fields()
636 .iter()
637 .map(|(name, ty, _)| {
638 quote! {
639 #name: #ty,
640 }
641 })
642 .for_each(|code| body.extend(code));
643
644 body.extend(quote! {
646 phantom: core::marker::PhantomData<B>,
647 #[module(skip)]
648 device: B::Device,
649 });
650
651 quote! {
652 #[derive(Module, Debug)]
653 pub struct Model<B: Backend> {
654 #body
655 }
656 }
657 }
658
659 fn codegen_new(&self) -> TokenStream {
660 let mut body = quote! {};
661 let all_fields = self.collect_all_fields();
662
663 for (_, _, field_init) in &all_fields {
665 body.extend(field_init.clone());
666 }
667
668 let field_names: Vec<_> = all_fields.iter().map(|(name, _, _)| name.clone()).collect();
670
671 quote! {
672 #[allow(unused_variables)]
673 pub fn new(device: &B::Device) -> Self {
674 #body
675
676 Self {
677 #(#field_names,)*
678 phantom: core::marker::PhantomData,
679 device: device.clone(),
680 }
681 }
682 }
683 }
684
685 fn codegen_forward(&mut self) -> TokenStream {
686 let input_def = crate::burn::codegen_fn_params(&self.graph_input_args);
687 let output_type_def = crate::burn::codegen_return_type(&self.graph_output_args);
688 let output_return_def = crate::burn::codegen_return_expr(&self.graph_output_args);
689
690 let input_conversions = self.codegen_boundary_input_conversions();
691
692 let mut body = quote! {};
693 for (index, node) in self.nodes.iter().enumerate() {
694 let mut scope_at_pos = self.scope.at_position(index);
695 let code = NodeCodegen::forward(node, &mut scope_at_pos);
696 body.extend(code);
697 }
698
699 let boundary_conversions = self.codegen_boundary_output_conversions();
700
701 quote! {
705 #[allow(clippy::let_and_return, clippy::approx_constant)]
706 pub fn forward(&self, #input_def) -> #output_type_def {
707 #input_conversions
708 #body
709 #boundary_conversions
710 #output_return_def
711 }
712 }
713 }
714
715 pub fn register_input_output(
726 &mut self,
727 input_names: Vec<String>,
728 output_names: Vec<String>,
729 input_args: &[onnx_ir::Argument],
730 output_args: &[onnx_ir::Argument],
731 ) {
732 if self.nodes.is_empty() {
734 self.graph_input_args.extend_from_slice(input_args);
736 self.graph_output_args.extend_from_slice(output_args);
737 self.convert_graph_boundary_scalars();
738 return;
739 }
740
741 let mut inputs = HashMap::new();
743 let mut outputs = HashMap::new();
744 for node in self.nodes.iter() {
745 for input_arg in NodeCodegen::inputs(node) {
746 inputs.insert(input_arg.name.clone(), input_arg.clone());
747 }
748 for output_arg in NodeCodegen::outputs(node) {
749 outputs.insert(output_arg.name.clone(), output_arg.clone());
750 }
751 }
752
753 input_names.iter().enumerate().for_each(|(idx, input)| {
756 let input_arg = inputs
757 .get(input)
758 .cloned()
759 .or_else(|| {
760 if idx < input_args.len() {
762 Some(input_args[idx].clone())
763 } else {
764 None
765 }
766 })
767 .unwrap_or_else(|| panic!("Input argument not found for {input}"));
768
769 self.graph_input_args.push(input_arg);
770 });
771
772 if !output_args.is_empty() {
775 output_names
776 .iter()
777 .zip(output_args.iter())
778 .for_each(|(name, arg)| {
779 let mut renamed_arg = arg.clone();
781 renamed_arg.name = name.clone();
782 self.graph_output_args.push(renamed_arg);
783 });
784 } else {
785 output_names.iter().for_each(|output| {
787 self.graph_output_args.push(
788 outputs
789 .get(output)
790 .unwrap_or_else(|| panic!("Output argument not found for {output}"))
791 .clone(),
792 );
793 });
794 }
795
796 self.convert_graph_boundary_scalars();
799 }
800
801 fn codegen_boundary_input_conversions(&self) -> TokenStream {
803 let mut tokens = quote! {};
804 for arg in &self.graph_input_args {
805 if let Some(dtype) = self.boundary_input_conversions.get(&arg.name) {
806 let name = crate::burn::arg_ident(arg);
807 let dtype_tokens = dtype.to_tokens();
808 if dtype.is_float() {
809 tokens.extend(quote! {
810 let #name = Tensor::<B, 1>::from_data(
811 burn::tensor::TensorData::from([#name]),
812 (&self.device, #dtype_tokens)
813 );
814 });
815 } else if dtype.is_int() || dtype.is_uint() {
816 tokens.extend(quote! {
817 let #name = Tensor::<B, 1, Int>::from_data(
818 burn::tensor::TensorData::from([#name]),
819 (&self.device, #dtype_tokens)
820 );
821 });
822 } else if dtype.is_bool() {
823 tokens.extend(quote! {
824 let #name = Tensor::<B, 1, Bool>::from_data(
825 burn::tensor::TensorData::from([#name]),
826 (&self.device, #dtype_tokens)
827 );
828 });
829 } else {
830 panic!(
831 "Unsupported dtype {:?} for graph boundary ScalarNative -> ScalarTensor conversion",
832 dtype
833 );
834 }
835 }
836 }
837 tokens
838 }
839
840 fn codegen_boundary_output_conversions(&self) -> TokenStream {
842 let mut tokens = quote! {};
843 for arg in &self.graph_output_args {
844 if let Some(dtype) = self.boundary_output_conversions.get(&arg.name) {
845 let name = crate::burn::arg_ident(arg);
846 let convert = crate::burn::on_device_to_native(quote! { #name }, dtype);
847 tokens.extend(quote! {
848 let #name = #convert;
849 });
850 }
851 }
852 tokens
853 }
854
855 fn convert_graph_boundary_scalars(&mut self) {
858 for arg in &mut self.graph_input_args {
859 if let ArgType::ScalarTensor(dtype) = arg.ty {
860 self.boundary_input_conversions
861 .insert(arg.name.clone(), dtype);
862 arg.ty = ArgType::ScalarNative(dtype);
863 }
864 }
865 for arg in &mut self.graph_output_args {
866 if let ArgType::ScalarTensor(dtype) = arg.ty {
867 self.boundary_output_conversions
868 .insert(arg.name.clone(), dtype);
869 arg.ty = ArgType::ScalarNative(dtype);
870 }
871 }
872 }
873}
874
875type FieldTuple = (proc_macro2::Ident, TokenStream, Option<TokenStream>);
880
881fn collect_fields_for_nodes(nodes: &[Node]) -> Vec<FieldTuple> {
883 let mut field_name_counts: HashMap<String, usize> = HashMap::new();
884 let mut all_fields: Vec<FieldTuple> = Vec::new();
885
886 fn collect_subgraph_fields_recursive(
887 subgraph: &onnx_ir::OnnxGraph,
888 field_name_counts: &mut HashMap<String, usize>,
889 all_fields: &mut Vec<FieldTuple>,
890 ) {
891 for node in &subgraph.nodes {
892 if let Some(mut field) = NodeCodegen::field(node) {
893 let base_name = field.name.to_string();
894 let count = field_name_counts.entry(base_name.clone()).or_insert(0);
895 *count += 1;
896
897 if *count > 1 {
898 let new_name_str = format!("{}_{}", base_name, count);
899 let new_name = syn::Ident::new(&new_name_str, proc_macro2::Span::call_site());
900 field.name = new_name;
901
902 let init_str = field.init.to_string();
903 let updated = init_str
904 .replace(
905 &format!("let {} :", base_name),
906 &format!("let {} :", new_name_str),
907 )
908 .replace(
909 &format!("let {} =", base_name),
910 &format!("let {} =", new_name_str),
911 );
912 field.init = updated.parse().unwrap_or_else(|e| {
913 log::warn!(
914 "Failed to parse renamed field init for '{}': {e}",
915 new_name_str
916 );
917 field.init.clone()
918 });
919 }
920 all_fields.push((field.name.clone(), field.ty.clone(), Some(field.init)));
921 }
922
923 if let Node::If(nested) = node {
924 collect_subgraph_fields_recursive(
925 &nested.config.then_branch,
926 field_name_counts,
927 all_fields,
928 );
929 collect_subgraph_fields_recursive(
930 &nested.config.else_branch,
931 field_name_counts,
932 all_fields,
933 );
934 } else if let Node::Loop(nested) = node {
935 collect_subgraph_fields_recursive(
936 &nested.config.body,
937 field_name_counts,
938 all_fields,
939 );
940 }
941 }
942 }
943
944 for node in nodes {
945 if let Some(field) = NodeCodegen::field(node) {
946 all_fields.push((field.name, field.ty, Some(field.init)));
947 }
948
949 if let Node::If(if_node) = node {
950 collect_subgraph_fields_recursive(
951 &if_node.config.then_branch,
952 &mut field_name_counts,
953 &mut all_fields,
954 );
955 collect_subgraph_fields_recursive(
956 &if_node.config.else_branch,
957 &mut field_name_counts,
958 &mut all_fields,
959 );
960 } else if let Node::Loop(loop_node) = node {
961 collect_subgraph_fields_recursive(
962 &loop_node.config.body,
963 &mut field_name_counts,
964 &mut all_fields,
965 );
966 }
967 }
968
969 all_fields
970}
971
972fn collect_snapshots_from_nodes(
976 nodes: &[Node],
977 prefix: &str,
978 field_name_counts: &mut HashMap<String, usize>,
979 snapshots: &mut Vec<TensorSnapshot>,
980) {
981 fn collect_subgraph_snapshots_recursive(
982 subgraph: &onnx_ir::OnnxGraph,
983 prefix: &str,
984 field_name_counts: &mut HashMap<String, usize>,
985 snapshots: &mut Vec<TensorSnapshot>,
986 ) {
987 for node in &subgraph.nodes {
988 if let Some(field) = NodeCodegen::field(node) {
989 let base_name = field.name.to_string();
990 let count = field_name_counts.entry(base_name.clone()).or_insert(0);
991 *count += 1;
992
993 let unique_name = if *count > 1 {
994 format!("{}_{}", base_name, count)
995 } else {
996 base_name
997 };
998
999 let full_name = if prefix.is_empty() {
1000 unique_name
1001 } else {
1002 format!("{}.{}", prefix, unique_name)
1003 };
1004 let node_snapshots = NodeCodegen::collect_snapshots(node, &full_name);
1005 snapshots.extend(node_snapshots);
1006 }
1007
1008 if let Node::If(nested) = node {
1009 collect_subgraph_snapshots_recursive(
1010 &nested.config.then_branch,
1011 prefix,
1012 field_name_counts,
1013 snapshots,
1014 );
1015 collect_subgraph_snapshots_recursive(
1016 &nested.config.else_branch,
1017 prefix,
1018 field_name_counts,
1019 snapshots,
1020 );
1021 } else if let Node::Loop(nested) = node {
1022 collect_subgraph_snapshots_recursive(
1023 &nested.config.body,
1024 prefix,
1025 field_name_counts,
1026 snapshots,
1027 );
1028 }
1029 }
1030 }
1031
1032 for node in nodes {
1033 if let Some(field) = NodeCodegen::field(node) {
1034 let base_name = field.name.to_string();
1035 let count = field_name_counts.entry(base_name.clone()).or_insert(0);
1036 *count += 1;
1037
1038 let unique_name = if *count > 1 {
1039 format!("{}_{}", base_name, count)
1040 } else {
1041 base_name
1042 };
1043
1044 let full_name = if prefix.is_empty() {
1045 unique_name
1046 } else {
1047 format!("{}.{}", prefix, unique_name)
1048 };
1049 let node_snapshots = NodeCodegen::collect_snapshots(node, &full_name);
1050 snapshots.extend(node_snapshots);
1051 }
1052
1053 if let Node::If(if_node) = node {
1054 collect_subgraph_snapshots_recursive(
1055 &if_node.config.then_branch,
1056 prefix,
1057 field_name_counts,
1058 snapshots,
1059 );
1060 collect_subgraph_snapshots_recursive(
1061 &if_node.config.else_branch,
1062 prefix,
1063 field_name_counts,
1064 snapshots,
1065 );
1066 } else if let Node::Loop(loop_node) = node {
1067 collect_subgraph_snapshots_recursive(
1068 &loop_node.config.body,
1069 prefix,
1070 field_name_counts,
1071 snapshots,
1072 );
1073 }
1074 }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079 use super::*;
1080 use burn::tensor::DType;
1081 use onnx_ir::node::abs::AbsNodeBuilder;
1082 use rust_format::{Config, Formatter, PostProcess, PrettyPlease};
1083
1084 fn format_tokens(tokens: TokenStream) -> String {
1085 let config = Config::new_str().post_proc(PostProcess::ReplaceMarkersAndDocBlocks);
1086 let formatter = PrettyPlease::from_config(config);
1087 formatter
1088 .format_tokens(tokens)
1089 .unwrap_or_else(|_| "FORMATTING FAILED".to_string())
1090 }
1091
1092 fn build_abs_chain(n: usize) -> BurnGraph {
1094 let mut graph = BurnGraph::default();
1095
1096 for i in 0..n {
1097 let in_name = if i == 0 {
1098 "input".to_string()
1099 } else {
1100 format!("t{}", i - 1)
1101 };
1102 let out_name = format!("t{}", i);
1103
1104 let node = AbsNodeBuilder::new(&format!("abs{}", i))
1105 .input_tensor(&in_name, 2, DType::F32)
1106 .output_tensor(&out_name, 2, DType::F32)
1107 .build();
1108
1109 graph.register(Node::Abs(node));
1110 }
1111
1112 let last_out = format!("t{}", n - 1);
1113 graph.register_input_output(vec!["input".to_string()], vec![last_out], &[], &[]);
1114
1115 graph
1116 }
1117
1118 fn build_two_clip_chain() -> BurnGraph {
1124 use onnx_ir::clip::{ClipConfig, ClipNodeBuilder};
1125 use onnx_ir::node::clip::ClipInput;
1126
1127 let mut graph = BurnGraph::default();
1128
1129 let mk = |name: &str, in_tensor: &str, min_name: &str, max_name: &str, out: &str| {
1130 ClipNodeBuilder::new(name)
1131 .input_tensor(in_tensor, 2, DType::F32)
1132 .input_scalar(min_name, DType::F32)
1133 .input_scalar(max_name, DType::F32)
1134 .output_tensor(out, 2, DType::F32)
1135 .config(ClipConfig {
1136 min: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
1137 min_name.to_string(),
1138 1,
1139 ))),
1140 max: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
1141 max_name.to_string(),
1142 2,
1143 ))),
1144 })
1145 .build()
1146 };
1147
1148 graph.register(Node::Clip(mk("clip0", "input", "min0", "max0", "t0")));
1149 graph.register(Node::Clip(mk("clip1", "t0", "min1", "max1", "t1")));
1150
1151 graph.register_input_output(
1152 vec![
1153 "input".to_string(),
1154 "min0".to_string(),
1155 "max0".to_string(),
1156 "min1".to_string(),
1157 "max1".to_string(),
1158 ],
1159 vec!["t1".to_string()],
1160 &[],
1161 &[],
1162 );
1163
1164 graph
1165 }
1166
1167 fn innermost_blocks(code: &str) -> Vec<&str> {
1180 let bytes = code.as_bytes();
1181 let mut stack: Vec<(usize, bool)> = Vec::new();
1182 let mut innermost: Vec<(usize, usize)> = Vec::new();
1183 for (i, &b) in bytes.iter().enumerate() {
1184 match b {
1185 b'{' => {
1186 if let Some(last) = stack.last_mut() {
1187 last.1 = true;
1188 }
1189 stack.push((i, false));
1190 }
1191 b'}' => {
1192 if let Some((open, has_inner)) = stack.pop()
1193 && !has_inner
1194 {
1195 innermost.push((open + 1, i));
1196 }
1197 }
1198 _ => {}
1199 }
1200 }
1201 innermost.into_iter().map(|(s, e)| &code[s..e]).collect()
1202 }
1203
1204 #[test]
1220 fn multi_instance_clip_scoping() {
1221 let graph = build_two_clip_chain();
1222 let code = format_tokens(graph.codegen());
1223
1224 let scoped_blocks: Vec<&str> = innermost_blocks(&code)
1225 .into_iter()
1226 .filter(|b| b.contains("let __clip_min = ") && b.contains("let __clip_max = "))
1227 .collect();
1228
1229 assert_eq!(
1230 scoped_blocks.len(),
1231 2,
1232 "expected exactly two innermost blocks each containing \
1233 both `let __clip_min =` and `let __clip_max =`, got \
1234 {} such blocks. Full generated code:\n{code}",
1235 scoped_blocks.len()
1236 );
1237
1238 for (idx, block) in scoped_blocks.iter().enumerate() {
1243 assert_eq!(
1244 block.matches("let __clip_min = ").count(),
1245 1,
1246 "block {idx} should declare exactly one __clip_min, got:\n{block}"
1247 );
1248 assert_eq!(
1249 block.matches("let __clip_max = ").count(),
1250 1,
1251 "block {idx} should declare exactly one __clip_max, got:\n{block}"
1252 );
1253 }
1254 }
1255
1256 #[test]
1257 fn small_graph_uses_flat_codegen() {
1258 let graph = build_abs_chain(5);
1259 let code = format_tokens(graph.codegen());
1260
1261 assert!(code.contains("pub struct Model<B: Backend>"));
1263 assert!(!code.contains("Submodule"));
1264 }
1265
1266 #[test]
1267 fn large_graph_uses_partitioned_codegen() {
1268 let graph = build_abs_chain(250);
1269 let code = format_tokens(graph.codegen());
1270
1271 assert!(code.contains("pub struct Submodule1<B: Backend>"));
1273 assert!(code.contains("pub struct Model<B: Backend>"));
1274 assert!(code.contains("submodule1: Submodule1<B>"));
1275
1276 assert!(code.contains("self.submodule1.forward("));
1278
1279 assert!(code.contains("pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2>"));
1281 }
1282
1283 #[test]
1284 fn large_graph_with_partition_disabled_uses_flat_codegen() {
1285 let graph = build_abs_chain(250);
1286 let code = format_tokens(graph.with_partition(false).codegen());
1287
1288 assert!(code.contains("pub struct Model<B: Backend>"));
1290 assert!(
1291 !code.contains("Submodule"),
1292 "partition(false) should prevent submodules"
1293 );
1294
1295 assert!(code.contains("pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2>"));
1297 }
1298
1299 #[test]
1300 fn partitioned_graph_snapshot() {
1301 let graph = build_abs_chain(200);
1303 let code = format_tokens(graph.codegen());
1304
1305 assert!(code.contains("Submodule1"));
1310 assert!(code.contains("Submodule2"));
1311
1312 let module_derive_count = code.matches("#[derive(Module, Debug)]").count();
1314 assert!(
1316 module_derive_count >= 3,
1317 "Expected at least 3 #[derive(Module, Debug)], got {}",
1318 module_derive_count
1319 );
1320
1321 assert!(code.contains("Submodule1::new(device)"));
1323 assert!(code.contains("Submodule2::new(device)"));
1324
1325 let submodule1_count = code.matches("pub struct Submodule1").count();
1327 assert_eq!(submodule1_count, 1, "Submodule1 defined more than once");
1328 }
1329
1330 fn temp_bpk() -> std::path::PathBuf {
1332 use std::sync::atomic::{AtomicU64, Ordering};
1333 static COUNTER: AtomicU64 = AtomicU64::new(0);
1334 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1335 let path =
1336 std::env::temp_dir().join(format!("burn-onnx-test-{}-{}.bpk", std::process::id(), id));
1337 std::fs::write(&path, [0u8; 4]).unwrap();
1338 path
1339 }
1340
1341 #[test]
1342 fn load_strategy_file_generates_from_file_and_from_bytes() {
1343 let bpk = temp_bpk();
1344 let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::File);
1345 let code = format_tokens(graph.codegen());
1346 let _ = std::fs::remove_file(bpk);
1347
1348 assert!(
1349 code.contains(
1350 "pub fn from_file<P: AsRef<std::path::Path>>(file: P, device: &B::Device)"
1351 )
1352 );
1353 assert!(code.contains("pub fn from_bytes(bytes: Bytes"));
1354 assert!(code.contains("impl<B: Backend> Default for Model<B>"));
1355 assert!(code.contains("Self::from_file("));
1356 assert!(!code.contains("from_embedded"));
1357 assert!(code.contains("extern crate std;"));
1360 }
1361
1362 #[test]
1363 fn load_strategy_embedded_generates_from_embedded_and_from_bytes() {
1364 let bpk = temp_bpk();
1365 let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::Embedded);
1366 let code = format_tokens(graph.codegen());
1367 let _ = std::fs::remove_file(bpk);
1368
1369 assert!(code.contains("pub fn from_embedded("));
1370 assert!(code.contains("pub fn from_bytes(bytes: Bytes"));
1371 assert!(code.contains("impl<B: Backend> Default for Model<B>"));
1372 assert!(code.contains("Self::from_embedded("));
1373 assert!(code.contains("include_bytes!"));
1374 assert!(!code.contains("from_file"));
1375 assert!(!code.contains("extern crate std"));
1376 }
1377
1378 #[test]
1379 fn load_strategy_bytes_generates_only_from_bytes() {
1380 let bpk = temp_bpk();
1381 let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::Bytes);
1382 let code = format_tokens(graph.codegen());
1383 let _ = std::fs::remove_file(bpk);
1384
1385 assert!(code.contains("pub fn from_bytes(bytes: Bytes"));
1386 assert!(!code.contains("from_file"));
1387 assert!(!code.contains("from_embedded"));
1388 assert!(!code.contains("impl<B: Backend> Default for Model<B>"));
1389 assert!(!code.contains("extern crate std"));
1390 }
1391
1392 #[test]
1393 fn load_strategy_none_generates_no_loaders() {
1394 let bpk = temp_bpk();
1395 let graph = build_abs_chain(1).with_burnpack(bpk.clone(), LoadStrategy::None);
1396 let code = format_tokens(graph.codegen());
1397 let _ = std::fs::remove_file(bpk);
1398
1399 assert!(!code.contains("from_file"));
1400 assert!(!code.contains("from_bytes"));
1401 assert!(!code.contains("from_embedded"));
1402 assert!(!code.contains("impl<B: Backend> Default for Model<B>"));
1403 assert!(!code.contains("extern crate std"));
1404 }
1405}