use super::prelude::*;
impl NodeCodegen for onnx_ir::shape::ShapeNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
use onnx_ir::ir::ArgType;
let input_arg = self.inputs.first().unwrap();
let output_arg = self.outputs.first().unwrap();
let output = arg_to_ident(output_arg);
let dim = match &output_arg.ty {
ArgType::Shape(rank) => rank.to_tokens(),
ArgType::Tensor(_) | ArgType::ScalarNative(_) | ArgType::ScalarTensor(_) => {
panic!("Shape operation expects Shape output")
}
};
let start_dim_tok = self.config.start.to_tokens();
let end_dim_tok = self.config.end.to_tokens();
let output_rank = (self.config.end - self.config.start).to_tokens();
let function = match &input_arg.ty {
ArgType::Tensor(_) => {
let input = scope.arg(input_arg);
quote! {
{
let axes = &#input.dims()[#start_dim_tok..#end_dim_tok];
let mut output = [0i64; #output_rank];
for i in 0..#output_rank {
output[i] = axes[i] as i64;
}
output
}
}
}
ArgType::Shape(shape_rank) => {
let rank_value = *shape_rank as i64;
quote! { [#rank_value] }
}
ArgType::ScalarTensor(_) => {
let input = scope.arg(input_arg);
quote! {
{
let axes = &#input.dims()[#start_dim_tok..#end_dim_tok];
let mut output = [0i64; #output_rank];
for i in 0..#output_rank {
output[i] = axes[i] as i64;
}
output
}
}
}
ArgType::ScalarNative(_) => {
panic!("Shape operation does not support ScalarNative inputs")
}
};
quote! {
let #output: [i64;#dim] = #function;
}
}
}
#[cfg(test)]
mod tests {
use super::super::test_helpers::*;
use burn::tensor::DType;
use insta::assert_snapshot;
use onnx_ir::ir::{ArgType, Argument, TensorType};
use onnx_ir::shape::{ShapeConfig, ShapeNode};
#[test]
fn test_shape_full() {
let config = ShapeConfig { start: 0, end: 3 };
let input = Argument::new(
"input",
ArgType::Tensor(TensorType::new(DType::F32, 3, None)),
);
let node = ShapeNode {
name: "shape1".to_string(),
inputs: vec![input],
outputs: vec![Argument::new("output", ArgType::Shape(3))],
config,
};
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 3>) -> [i64; 3] {
let output: [i64; 3] = {
let axes = &input.dims()[0..3];
let mut output = [0i64; 3];
for i in 0..3 {
output[i] = axes[i] as i64;
}
output
};
output
}
");
}
#[test]
fn test_shape_partial() {
let config = ShapeConfig { start: 1, end: 3 };
let input = Argument::new(
"input",
ArgType::Tensor(TensorType::new(DType::F32, 4, None)),
);
let node = ShapeNode {
name: "shape2".to_string(),
inputs: vec![input],
outputs: vec![Argument::new("output", ArgType::Shape(2))],
config,
};
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 4>) -> [i64; 2] {
let output: [i64; 2] = {
let axes = &input.dims()[1..3];
let mut output = [0i64; 2];
for i in 0..2 {
output[i] = axes[i] as i64;
}
output
};
output
}
");
}
#[test]
fn test_shape_of_shape() {
let config = ShapeConfig { start: 0, end: 1 };
let input = Argument::new("input", ArgType::Shape(3));
let node = ShapeNode {
name: "shape3".to_string(),
inputs: vec![input],
outputs: vec![Argument::new("output", ArgType::Shape(1))],
config,
};
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: [i64; 3]) -> [i64; 1] {
let output: [i64; 1] = [3i64];
output
}
");
}
#[test]
fn test_shape_of_scalar_tensor() {
let config = ShapeConfig { start: 0, end: 1 };
let input = Argument::new("input", ArgType::ScalarTensor(DType::I64));
let node = ShapeNode {
name: "shape4".to_string(),
inputs: vec![input],
outputs: vec![Argument::new("output", ArgType::Shape(1))],
config,
};
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 1, Int>) -> [i64; 1] {
let output: [i64; 1] = {
let axes = &input.dims()[0..1];
let mut output = [0i64; 1];
for i in 0..1 {
output[i] = axes[i] as i64;
}
output
};
output
}
");
}
}