1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
//! Shared utilities for subgraph code generation in control flow nodes (If, Loop, Scan)
use super::prelude::*;
use std::collections::HashSet;
/// Collect all names referenced within a subgraph (node inputs).
///
/// This is used to filter outer-scope bindings so we only generate bindings
/// for variables that are actually used within a specific subgraph branch.
pub(super) fn collect_subgraph_referenced_names(subgraph: &onnx_ir::OnnxGraph) -> HashSet<String> {
let mut names = HashSet::new();
// Collect all node input names
for node in &subgraph.nodes {
for input in node.inputs() {
if !input.name.is_empty() {
names.insert(input.name.clone());
}
}
}
// Also include output names that reference node outputs
// (these may reference outer-scope variables in pass-through cases)
for output in &subgraph.outputs {
if !output.name.is_empty() {
names.insert(output.name.clone());
}
}
names
}
/// Generate outer-scope reference bindings for a subgraph.
///
/// Creates `let` bindings that map outer-scope values (from the parent graph)
/// to the names used within the subgraph.
///
/// # Parameters
/// - `outer_scope_inputs`: The node inputs that provide values for outer-scope references
/// - `scope_ref_names`: The original sanitized ONNX names that the subgraph uses
/// - `exclude_names`: Names to exclude from binding generation (e.g., loop-provided variables)
/// - `used_names`: Optional set of names actually used in this subgraph. If provided, only
/// bindings for names in this set will be generated (avoids unused variable warnings).
/// - `scope`: The parent scope for accessing outer values
/// - `node_position`: The position of the control flow node in the graph
pub(super) fn generate_outer_scope_bindings(
outer_scope_inputs: &[Argument],
scope_ref_names: &[String],
exclude_names: &HashSet<String>,
used_names: Option<&HashSet<String>>,
scope: &mut Scope,
node_position: usize,
) -> TokenStream {
let mut bindings = quote! {};
for (idx, scope_ref_name) in scope_ref_names.iter().enumerate() {
// Skip names that should be excluded (e.g., loop-provided variables)
if exclude_names.contains(scope_ref_name) {
continue;
}
// Skip names not actually used in this subgraph (if used_names filter is provided)
if let Some(used) = used_names
&& !used.contains(scope_ref_name)
{
continue;
}
if let Some(outer_input) = outer_scope_inputs.get(idx) {
let var_name = quote::format_ident!("{}", scope_ref_name);
let outer_var = scope.at_position(node_position).arg(outer_input);
match &outer_input.ty {
ArgType::Tensor(_) | ArgType::ScalarTensor(_) => {
bindings.extend(quote! {
let #var_name = #outer_var.clone();
});
}
ArgType::ScalarNative(_) => {
bindings.extend(quote! {
let #var_name = #outer_var;
});
}
_ => {}
}
}
}
bindings
}
/// Register subgraph inputs and build scope for generating node forward code.
///
/// This registers all subgraph tensors in the scope so they can be properly
/// referenced and cloned during code generation.
pub(super) fn register_subgraph_scope(
subgraph: &onnx_ir::OnnxGraph,
scope: &mut Scope,
node_position: usize,
) {
// Register subgraph inputs in scope
for input in &subgraph.inputs {
if matches!(&input.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)) {
scope.tensor_register_variable(input, node_position);
}
}
// Build scope for subgraph nodes: register outputs and future uses
for (idx, node) in subgraph.nodes.iter().enumerate() {
let subgraph_node_pos = node_position + idx + 1;
// Register node outputs
for output in NodeCodegen::outputs(node) {
if matches!(&output.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)) {
scope.tensor_register_variable(output, subgraph_node_pos);
}
}
// Register future uses of node inputs.
// We only track dynamic and constant arguments because:
// - Dynamic: runtime values that need clone tracking for ownership
// - Constant: values embedded in the model that may be referenced multiple times
// - Static initializers are excluded because they're baked into the model at
// compile time and don't need runtime clone management
for input in NodeCodegen::inputs(node)
.iter()
.filter(|arg| arg.is_dynamic() || arg.is_constant())
{
if matches!(&input.ty, ArgType::Tensor(_) | ArgType::ScalarTensor(_)) {
scope.tensor_register_future_use(input, subgraph_node_pos - 1);
}
}
}
// Register future uses for subgraph outputs
for output in &subgraph.outputs {
if let ArgType::Tensor(_) = &output.ty {
scope.tensor_register_future_use(output, node_position + subgraph.nodes.len());
}
}
}
/// Generate forward code for all nodes in a subgraph.
pub(super) fn generate_subgraph_forward_code(
subgraph: &onnx_ir::OnnxGraph,
scope: &mut Scope,
node_position: usize,
) -> TokenStream {
let mut code = quote! {};
for (idx, node) in subgraph.nodes.iter().enumerate() {
let mut scope_at_pos = scope.at_position(node_position + idx + 1);
let node_code = NodeCodegen::forward(node, &mut scope_at_pos);
code.extend(node_code);
}
code
}