use super::prelude::*;
use super::subgraph_helper;
use std::collections::HashSet;
fn generate_branch_code(
subgraph: &onnx_ir::OnnxGraph,
outer_scope_inputs: &[Argument],
scope_ref_names: &[String],
scope: &mut Scope,
node_position: usize,
) -> (TokenStream, TokenStream) {
let exclude_names = HashSet::new();
let used_names = subgraph_helper::collect_subgraph_referenced_names(subgraph);
let bindings = subgraph_helper::generate_outer_scope_bindings(
outer_scope_inputs,
scope_ref_names,
&exclude_names,
Some(&used_names),
scope,
node_position,
);
subgraph_helper::register_subgraph_scope(subgraph, scope, node_position);
let forward_code =
subgraph_helper::generate_subgraph_forward_code(subgraph, scope, node_position);
let output_names: Vec<_> = subgraph.outputs.iter().map(arg_to_ident).collect();
let output_tuple = if output_names.len() == 1 {
let out = &output_names[0];
quote! { #out }
} else {
quote! { (#(#output_names),*) }
};
let body = quote! {
#bindings
#forward_code
};
(body, output_tuple)
}
fn find_unsqueeze_dims(
branch_shape: Option<&[Option<usize>]>,
target_shape: Option<&[Option<usize>]>,
branch_rank: usize,
target_rank: usize,
) -> Vec<isize> {
if let (Some(b_shape), Some(t_shape)) = (branch_shape, target_shape) {
let mut dims = Vec::new();
let mut b_idx = 0;
let mut misaligned = false;
for (t_idx, t_dim) in t_shape.iter().enumerate() {
if b_idx < b_shape.len() && b_shape[b_idx] == *t_dim {
b_idx += 1;
} else if *t_dim == Some(1) {
dims.push(t_idx as isize);
} else {
misaligned = true;
log::warn!(
"If branch shapes don't align cleanly (branch {:?} vs target {:?}), \
falling back to trailing unsqueeze dims",
b_shape,
t_shape
);
break;
}
}
if !misaligned && dims.len() == target_rank - branch_rank {
return dims;
}
if !misaligned {
log::warn!(
"If branch shape alignment produced {} dims but expected {} \
(branch {:?} vs target {:?}), falling back to trailing unsqueeze dims",
dims.len(),
target_rank - branch_rank,
b_shape,
t_shape,
);
}
}
(branch_rank as isize..target_rank as isize).collect()
}
impl NodeCodegen for onnx_ir::node::if_node::IfNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
let cond_arg = self
.inputs
.first()
.expect("If node requires condition input");
let cond = match &cond_arg.ty {
ArgType::ScalarNative(_) => {
let name = arg_to_ident(cond_arg);
quote! { #name }
}
ArgType::ScalarTensor(dtype) => {
let cond_tensor = scope.arg(cond_arg);
on_device_to_native(quote! { #cond_tensor }, dtype)
}
ArgType::Tensor(_) => {
let cond_tensor = scope.arg(cond_arg);
quote! { #cond_tensor.into_scalar().elem::<bool>() }
}
ArgType::Shape(rank) => {
let name = arg_to_ident(cond_arg);
if *rank == 0 {
quote! { false }
} else {
quote! { #name[0] != 0 }
}
}
};
let outer_scope_inputs: Vec<_> = self.inputs.iter().skip(1).cloned().collect();
let node_position = scope.node_position();
let (then_body, then_output) = generate_branch_code(
&self.config.then_branch,
&outer_scope_inputs,
&self.config.scope_ref_names,
scope.scope(),
node_position,
);
let (else_body, else_output) = generate_branch_code(
&self.config.else_branch,
&outer_scope_inputs,
&self.config.scope_ref_names,
scope.scope(),
node_position,
);
let then_output =
align_branch_output(then_output, &self.config.then_branch.outputs, &self.outputs);
let else_output =
align_branch_output(else_output, &self.config.else_branch.outputs, &self.outputs);
let output_names: Vec<_> = self.outputs.iter().map(arg_to_ident).collect();
let output_decls = if self.outputs.len() == 1 {
let out = &output_names[0];
quote! { let #out }
} else {
quote! { let (#(#output_names),*) }
};
quote! {
#output_decls = if #cond {
#then_body
#then_output
} else {
#else_body
#else_output
};
}
}
fn register_imports(&self, imports: &mut BurnImports) {
let mut register_subgraph_imports = |subgraph: &onnx_ir::OnnxGraph| {
for node in &subgraph.nodes {
NodeCodegen::register_imports(node, imports);
}
};
register_subgraph_imports(&self.config.then_branch);
register_subgraph_imports(&self.config.else_branch);
}
}
fn align_branch_output(
branch_output: TokenStream,
branch_outputs: &[Argument],
if_outputs: &[Argument],
) -> TokenStream {
if branch_outputs.len() == 1 && if_outputs.len() == 1 {
let branch_rank = branch_outputs[0].ty.rank();
let target_rank = if_outputs[0].ty.rank();
if branch_rank < target_rank {
let dims = find_unsqueeze_dims(
branch_outputs[0].ty.static_shape().map(|v| v.as_slice()),
if_outputs[0].ty.static_shape().map(|v| v.as_slice()),
branch_rank,
target_rank,
);
let target_rank_lit = target_rank;
return quote! {
#branch_output.unsqueeze_dims::<#target_rank_lit>(&[#(#dims),*])
};
}
}
branch_output
}
#[cfg(test)]
mod tests {
}