use super::prelude::*;
use super::subgraph_helper;
use std::collections::HashSet;
fn generate_scan_body_code(
subgraph: &onnx_ir::OnnxGraph,
outer_scope_inputs: &[Argument],
scope_ref_names: &[String],
body_input_names: &HashSet<String>,
scope: &mut Scope,
node_position: usize,
) -> TokenStream {
let used_names = subgraph_helper::collect_subgraph_referenced_names(subgraph);
let bindings = subgraph_helper::generate_outer_scope_bindings(
outer_scope_inputs,
scope_ref_names,
body_input_names,
Some(&used_names),
scope,
node_position,
);
subgraph_helper::register_subgraph_scope(subgraph, scope, node_position);
let forward_code =
subgraph_helper::generate_subgraph_forward_code(subgraph, scope, node_position);
quote! {
#bindings
#forward_code
}
}
impl NodeCodegen for onnx_ir::node::scan_node::ScanNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
let num_scan_inputs = self.config.num_scan_inputs as usize;
let num_outer_scope_refs = self.config.scope_ref_names.len();
let num_onnx_inputs = self.inputs.len() - num_outer_scope_refs;
let num_state_vars = num_onnx_inputs - num_scan_inputs;
let outer_scope_inputs: Vec<_> =
self.inputs.iter().skip(num_onnx_inputs).cloned().collect();
let initial_state_vars: Vec<_> = self.inputs.iter().take(num_state_vars).collect();
let scan_input_sequences: Vec<_> = self
.inputs
.iter()
.skip(num_state_vars)
.take(num_scan_inputs)
.collect();
let num_scan_outputs = self.config.body.outputs.len() - num_state_vars;
let first_scan_input = scan_input_sequences[0];
let scan_axis = self.config.scan_input_axes.first().copied().unwrap_or(0) as usize;
let first_scan_name = arg_to_ident(first_scan_input);
let seq_len_expr = quote! { #first_scan_name.shape()[#scan_axis] };
let mut init_stmts = quote! {};
let state_var_names: Vec<_> = self
.config
.body
.inputs
.iter()
.take(num_state_vars)
.map(arg_to_ident)
.collect();
for (idx, initial_arg) in initial_state_vars.iter().enumerate() {
let var_name = &state_var_names[idx];
let init_value = arg_to_ident(initial_arg);
init_stmts.extend(quote! {
let mut #var_name = #init_value;
});
}
let scan_output_collectors: Vec<_> = (0..num_scan_outputs)
.map(|i| {
syn::Ident::new(
&format!("scan_output_collector_{}", i),
proc_macro2::Span::call_site(),
)
})
.collect();
for collector in &scan_output_collectors {
init_stmts.extend(quote! {
let mut #collector = alloc::vec::Vec::new();
});
}
init_stmts.extend(quote! {
let seq_len = #seq_len_expr;
});
let scan_input_names: Vec<_> = self
.config
.body
.inputs
.iter()
.skip(num_state_vars)
.map(arg_to_ident)
.collect();
let mut slice_stmts = quote! {};
for (idx, scan_input_arg) in scan_input_sequences.iter().enumerate() {
let var_name = &scan_input_names[idx];
let tensor_name = arg_to_ident(scan_input_arg);
let reverse = self
.config
.scan_input_directions
.get(idx)
.copied()
.unwrap_or(0)
== 1;
let scan_axis = self.config.scan_input_axes.get(idx).copied().unwrap_or(0) as usize;
let iter_expr = if reverse {
quote! { seq_len - i - 1 }
} else {
quote! { i }
};
let ArgType::Tensor(tensor_ty) = &scan_input_arg.ty else {
panic!("Scan input must be tensor");
};
let input_rank = tensor_ty.rank;
let output_rank = input_rank - 1;
slice_stmts.extend(quote! {
let #var_name = #tensor_name
.clone()
.slice_dim(#scan_axis, #iter_expr..#iter_expr + 1)
.squeeze_dim::<#output_rank>(#scan_axis);
});
}
let body_input_names: HashSet<String> = self
.config
.body
.inputs
.iter()
.map(|arg| arg.name.clone())
.collect();
let node_position = scope.node_position();
let body_code = generate_scan_body_code(
&self.config.body,
&outer_scope_inputs,
&self.config.scope_ref_names,
&body_input_names,
scope.scope(),
node_position,
);
let mut update_stmts = quote! {};
let should_clone = num_scan_outputs > 0;
for (idx, var_name) in state_var_names.iter().enumerate() {
let out_name = arg_to_ident(&self.config.body.outputs[idx]);
if should_clone {
update_stmts.extend(quote! {
#var_name = #out_name.clone();
});
} else {
update_stmts.extend(quote! {
#var_name = #out_name;
});
}
}
for (idx, collector) in scan_output_collectors.iter().enumerate() {
let body_out_name = arg_to_ident(&self.config.body.outputs[num_state_vars + idx]);
update_stmts.extend(quote! {
#collector.push(#body_out_name);
});
}
let mut finalize_stmts = quote! {};
let final_scan_output_names: Vec<_> = (0..num_scan_outputs)
.map(|i| {
syn::Ident::new(
&format!("scan_output_{}", i),
proc_macro2::Span::call_site(),
)
})
.collect();
for (idx, (collector, output_name)) in scan_output_collectors
.iter()
.zip(final_scan_output_names.iter())
.enumerate()
{
let reverse = self
.config
.scan_output_directions
.get(idx)
.copied()
.unwrap_or(0)
== 1;
if reverse {
finalize_stmts.extend(quote! {
#collector.reverse();
});
}
let stack_axis = self.config.scan_output_axes.get(idx).copied().unwrap_or(0) as usize;
finalize_stmts.extend(quote! {
let #output_name = Tensor::stack(#collector.clone(), #stack_axis);
});
}
let all_output_names: Vec<_> = self.outputs.iter().map(arg_to_ident).collect();
let output_vars: Vec<_> = (0..self.outputs.len())
.map(|idx| {
if idx < num_state_vars {
state_var_names[idx].clone()
} else {
final_scan_output_names[idx - num_state_vars].clone()
}
})
.collect();
if self.outputs.len() == 1 {
let output_name = &all_output_names[0];
let var_name = &output_vars[0];
quote! {
#[allow(unused_variables, unused_assignments)]
let #output_name = {
#init_stmts
for i in 0..seq_len {
#slice_stmts
#body_code
#update_stmts
}
#finalize_stmts
#var_name
};
}
} else {
quote! {
#[allow(unused_variables, unused_assignments)]
let (#(#all_output_names),*) = {
#init_stmts
for i in 0..seq_len {
#slice_stmts
#body_code
#update_stmts
}
#finalize_stmts
(#(#output_vars),*)
};
}
}
}
fn register_imports(&self, imports: &mut BurnImports) {
for node in &self.config.body.nodes {
NodeCodegen::register_imports(node, imports);
}
}
}
#[cfg(test)]
mod tests {
}