1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
use super::prelude::*;
use super::subgraph_helper;
use std::collections::HashSet;
/// Generate inline code for a loop body subgraph.
///
/// Loop body inputs (iter_num, cond, loop-carried vars) are excluded from
/// outer-scope bindings since they're provided by the loop construct.
fn generate_loop_body_code(
subgraph: &onnx_ir::OnnxGraph,
outer_scope_inputs: &[Argument],
scope_ref_names: &[String],
body_input_names: &HashSet<String>,
scope: &mut Scope,
node_position: usize,
) -> TokenStream {
// Collect names actually used in this body to avoid unused variable warnings
let used_names = subgraph_helper::collect_subgraph_referenced_names(subgraph);
// Generate outer-scope bindings (excluding loop-provided body inputs, only for used names)
let bindings = subgraph_helper::generate_outer_scope_bindings(
outer_scope_inputs,
scope_ref_names,
body_input_names,
Some(&used_names),
scope,
node_position,
);
// Register subgraph scope
subgraph_helper::register_subgraph_scope(subgraph, scope, node_position);
// Generate forward code
let forward_code =
subgraph_helper::generate_subgraph_forward_code(subgraph, scope, node_position);
quote! {
#bindings
#forward_code
}
}
impl NodeCodegen for onnx_ir::node::loop_node::LoopNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
// Inputs: [M (max_trip_count), cond (initial condition), v_initial..., outer_scope_refs...]
// Per ONNX spec, M and cond can be empty strings (optional)
// We added outer-scope references as additional inputs during ONNX conversion
let max_trip_count_arg = &self.inputs[0];
let init_cond_arg = &self.inputs[1];
// Calculate how many v_initial args we have (excluding outer-scope refs)
// scope_ref_names tells us how many outer-scope refs were added
let num_outer_scope_refs = self.config.scope_ref_names.len();
let num_onnx_inputs = self.inputs.len() - num_outer_scope_refs;
let v_initial_args: Vec<_> = self
.inputs
.iter()
.skip(2)
.take(num_onnx_inputs - 2)
.collect();
// Outer-scope references (values from parent scope that subgraph needs)
let outer_scope_inputs: Vec<_> =
self.inputs.iter().skip(num_onnx_inputs).cloned().collect();
// Extract max trip count
let max_count = if max_trip_count_arg.is_optional() {
quote! { i64::MAX } // No limit if not provided
} else {
match &max_trip_count_arg.ty {
ArgType::ScalarNative(_) => {
let name = arg_to_ident(max_trip_count_arg);
quote! { #name }
}
ArgType::ScalarTensor(dtype) => {
let tensor = scope.arg(max_trip_count_arg);
on_device_to_native(tensor, dtype)
}
ArgType::Tensor(_) => {
let tensor = scope.arg(max_trip_count_arg);
quote! { #tensor.into_scalar().elem::<i64>() }
}
_ => panic!("Loop max_trip_count must be scalar i64"),
}
};
// Extract initial condition
let init_cond = if init_cond_arg.is_optional() {
quote! { true } // Run if not provided
} else {
match &init_cond_arg.ty {
ArgType::ScalarNative(_) => {
let name = arg_to_ident(init_cond_arg);
quote! { #name }
}
ArgType::ScalarTensor(dtype) => {
let tensor = scope.arg(init_cond_arg);
on_device_to_native(tensor, dtype)
}
ArgType::Tensor(_) => {
let tensor = scope.arg(init_cond_arg);
quote! { #tensor.into_scalar().elem::<bool>() }
}
_ => panic!("Loop condition must be scalar bool"),
}
};
// Body inputs: [iter_num, cond_in, v_in...]
// Body outputs: [cond_out, v_out..., scan_outputs...]
// Calculate number of loop-carried dependencies
let num_loop_vars = v_initial_args.len();
// Per ONNX spec, first N body outputs (after cond_out) are loop-carried deps
// where N = number of v_initial inputs. Rest are scan outputs.
let num_loop_carried_outputs = num_loop_vars;
let num_scan_outputs = self.outputs.len() - num_loop_carried_outputs;
// Get body input and output variable names
// Body inputs: [iter_num, cond_in, v_in...]
let iter_name = arg_to_ident(&self.config.body.inputs[0]);
let cond_in_name = arg_to_ident(&self.config.body.inputs[1]);
let loop_var_names: Vec<_> = self
.config
.body
.inputs
.iter()
.skip(2) // Skip iter and cond_in
.map(arg_to_ident)
.collect();
// Body outputs: [cond_out, v_out..., scan_outputs...]
let cond_out_name = arg_to_ident(&self.config.body.outputs[0]);
let loop_out_names: Vec<_> = self
.config
.body
.outputs
.iter()
.skip(1)
.take(num_loop_carried_outputs)
.map(arg_to_ident)
.collect();
// Initialize loop-carried dependency variables
// Only mark as mutable if the variable is actually updated (different name from output)
let mut init_stmts = quote! {};
for (idx, initial_arg) in v_initial_args.iter().enumerate() {
let var_name = &loop_var_names[idx];
let init_value = arg_to_ident(initial_arg);
// Check if this variable will be updated (different name means it gets assigned)
let needs_mut = idx < num_loop_carried_outputs
&& loop_out_names.get(idx).is_some_and(|out| var_name != out);
if needs_mut {
init_stmts.extend(quote! {
let mut #var_name = #init_value;
});
} else {
init_stmts.extend(quote! {
let #var_name = #init_value;
});
}
}
// Initialize scan output collectors if any
let mut scan_init = quote! {};
let mut scan_collectors = vec![];
let scan_out_args: Vec<_> = self
.config
.body
.outputs
.iter()
.skip(1 + num_loop_vars)
.collect();
if num_scan_outputs > 0 {
for i in 0..num_scan_outputs {
let collector_name = syn::Ident::new(
&format!("scan_collector_{}", i),
proc_macro2::Span::call_site(),
);
scan_collectors.push(collector_name.clone());
scan_init.extend(quote! {
let mut #collector_name = alloc::vec::Vec::new();
});
}
}
// Collect body input names (iter_num, cond, loop-carried vars)
// These should NOT be treated as outer-scope references even though
// they're declared as subgraph inputs without initializers
let body_input_names: HashSet<String> = self
.config
.body
.inputs
.iter()
.map(|arg| arg.name.clone())
.collect();
// Generate loop body code
let node_position = scope.node_position();
let body_code = generate_loop_body_code(
&self.config.body,
&outer_scope_inputs,
&self.config.scope_ref_names,
&body_input_names,
scope.scope(),
node_position,
);
// Update loop-carried variables after iteration
// Skip self-assignments when body passes through a value unchanged (same name)
let mut update_vars = quote! {};
for (idx, out_name) in loop_out_names.iter().enumerate() {
let var_name = &loop_var_names[idx];
if var_name != out_name {
update_vars.extend(quote! {
#var_name = #out_name;
});
}
}
// Update condition from body output (skip if same name)
let update_cond = if cond_in_name != cond_out_name {
let cond_out_ty = &self.config.body.outputs[0].ty;
if let ArgType::ScalarTensor(dtype) = cond_out_ty {
// ScalarTensor -> native bool
let convert = on_device_to_native(quote! { #cond_out_name }, dtype);
quote! { #cond_in_name = #convert; }
} else {
quote! { #cond_in_name = #cond_out_name; }
}
} else {
quote! {}
};
// Collect scan outputs - handle scalar vs tensor
let mut collect_scans = quote! {};
for (idx, scan_arg) in scan_out_args.iter().enumerate() {
let out_name = arg_to_ident(scan_arg);
let collector = &scan_collectors[idx];
// Tensors need to be cloned before collecting, scalars can be copied
match &scan_arg.ty {
ArgType::ScalarNative(_) => {
collect_scans.extend(quote! {
#collector.push(#out_name);
});
}
ArgType::Tensor(_) | ArgType::ScalarTensor(_) => {
collect_scans.extend(quote! {
#collector.push(#out_name.clone());
});
}
_ => panic!("Scan output must be scalar or tensor"),
}
}
// Build output tuple: (loop-carried values, concatenated scan outputs)
let output_names: Vec<_> = self.outputs.iter().map(arg_to_ident).collect();
// Collect final values for outputs
let mut output_values = vec![];
// First outputs are final loop-carried dependencies
for (idx, _) in output_names
.iter()
.take(num_loop_carried_outputs)
.enumerate()
{
let var_name = &loop_var_names[idx];
output_values.push(quote! { #var_name });
}
// Remaining outputs are concatenated scan outputs
for (idx, scan_arg) in scan_out_args.iter().enumerate() {
let collector = &scan_collectors[idx];
// Handle scalar vs tensor scan outputs
match &scan_arg.ty {
ArgType::ScalarNative(dtype) => {
// Convert Vec<scalar> to 2D tensor with shape [N, 1]
// ONNX spec: scan outputs from scalars get an added dimension
// Use from_data with correct tensor kind to preserve dtype
let dtype_tokens = dtype.to_tokens();
let tensor_creation = if dtype.is_float() {
quote! {
Tensor::<B, 1>::from_data(data, (&self.device, #dtype_tokens))
}
} else if dtype.is_int() || dtype.is_uint() {
quote! {
Tensor::<B, 1, Int>::from_data(data, (&self.device, #dtype_tokens))
}
} else {
// Bool
quote! {
Tensor::<B, 1, Bool>::from_data(data, (&self.device, #dtype_tokens))
}
};
output_values.push(quote! {
{
let data = TensorData::from(#collector.as_slice());
let len = #collector.len();
let tensor1d = #tensor_creation;
tensor1d.reshape([len, 1])
}
});
}
ArgType::Tensor(_) | ArgType::ScalarTensor(_) => {
// Concatenate tensors
output_values.push(quote! { Tensor::cat(#collector, 0) });
}
_ => panic!("Scan output must be scalar or tensor"),
}
}
// Generate output declaration (let outputs = ...)
let output_decls = if self.outputs.len() == 1 {
let out = &output_names[0];
quote! { let #out }
} else {
quote! { let (#(#output_names),*) }
};
// Generate output tuple
let output_tuple = if self.outputs.len() == 1 {
let val = &output_values[0];
quote! { #val }
} else {
quote! { (#(#output_values),*) }
};
// cond_in only needs mut if it's actually updated (different name from output)
let cond_needs_mut = cond_in_name != cond_out_name;
let cond_init = if cond_needs_mut {
quote! { let mut #cond_in_name = #init_cond; }
} else {
quote! { let #cond_in_name = #init_cond; }
};
quote! {
#[allow(unused_variables, unused_assignments)]
#output_decls = {
#init_stmts
#scan_init
let mut #iter_name = 0_i64;
#cond_init
while #cond_in_name && #iter_name < #max_count {
#body_code
// Collect scan outputs from body outputs (before updating variables)
#collect_scans
// Update loop-carried variables for next iteration
#update_vars
// Update condition from body output for next iteration
#update_cond
#iter_name += 1;
}
#output_tuple
};
}
}
fn register_imports(&self, imports: &mut BurnImports) {
// Register imports from subgraph nodes
for node in &self.config.body.nodes {
NodeCodegen::register_imports(node, imports);
}
// Calculate number of loop-carried vars, accounting for outer-scope refs
let num_outer_scope_refs = self.config.scope_ref_names.len();
let num_onnx_inputs = self.inputs.len() - num_outer_scope_refs;
let num_loop_vars = num_onnx_inputs - 2; // Subtract M and cond
let num_scan_outputs = self.outputs.len() - num_loop_vars;
// Register Tensor for scan outputs if needed
if num_scan_outputs > 0 {
imports.register("burn::tensor::Tensor");
// Check if any scan outputs are scalars - need TensorData import
let scan_out_args: Vec<_> = self
.config
.body
.outputs
.iter()
.skip(1 + num_loop_vars)
.collect();
for scan_arg in scan_out_args {
if matches!(&scan_arg.ty, ArgType::ScalarNative(_)) {
imports.register("burn::tensor::TensorData");
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
// Loop node tests require complex OnnxGraph construction which is better tested
// through integration tests
}