burn-onnx 0.21.0-pre.3

Library for importing ONNX models into the Burn framework
Documentation
use super::prelude::*;

impl NodeCodegen for onnx_ir::shape::ShapeNode {
    fn inputs(&self) -> &[Argument] {
        &self.inputs
    }

    fn outputs(&self) -> &[Argument] {
        &self.outputs
    }

    fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
        use onnx_ir::ir::ArgType;

        let input_arg = self.inputs.first().unwrap();
        let output_arg = self.outputs.first().unwrap();
        let output = arg_to_ident(output_arg);

        let dim = match &output_arg.ty {
            ArgType::Shape(rank) => rank.to_tokens(),
            ArgType::Tensor(_) | ArgType::ScalarNative(_) | ArgType::ScalarTensor(_) => {
                panic!("Shape operation expects Shape output")
            }
        };

        let start_dim_tok = self.config.start.to_tokens();
        let end_dim_tok = self.config.end.to_tokens();
        let output_rank = (self.config.end - self.config.start).to_tokens();

        let function = match &input_arg.ty {
            ArgType::Tensor(_) => {
                let input = scope.arg(input_arg);
                quote! {
                    {
                        let axes = &#input.dims()[#start_dim_tok..#end_dim_tok];
                        let mut output = [0i64; #output_rank];
                        for i in 0..#output_rank {
                            output[i] = axes[i] as i64;
                        }
                        output
                    }
                }
            }
            ArgType::Shape(shape_rank) => {
                // If input is already a shape array [i64; N], the Shape operation
                // returns the dimensionality of the shape (which is N) as a Shape(1) array
                // This matches the ONNX semantics where Shape of a shape gives you the rank
                let rank_value = *shape_rank as i64;
                quote! { [#rank_value] }
            }
            ArgType::ScalarTensor(_) => {
                // ScalarTensor is rank 1, so Shape returns [1]
                let input = scope.arg(input_arg);
                quote! {
                    {
                        let axes = &#input.dims()[#start_dim_tok..#end_dim_tok];
                        let mut output = [0i64; #output_rank];
                        for i in 0..#output_rank {
                            output[i] = axes[i] as i64;
                        }
                        output
                    }
                }
            }
            ArgType::ScalarNative(_) => {
                panic!("Shape operation does not support ScalarNative inputs")
            }
        };

        quote! {
            let #output: [i64;#dim] = #function;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::super::test_helpers::*;
    use burn::tensor::DType;
    use insta::assert_snapshot;
    use onnx_ir::ir::{ArgType, Argument, TensorType};
    use onnx_ir::shape::{ShapeConfig, ShapeNode};

    #[test]
    fn test_shape_full() {
        let config = ShapeConfig { start: 0, end: 3 };
        let input = Argument::new(
            "input",
            ArgType::Tensor(TensorType::new(DType::F32, 3, None)),
        );

        let node = ShapeNode {
            name: "shape1".to_string(),
            inputs: vec![input],
            outputs: vec![Argument::new("output", ArgType::Shape(3))],
            config,
        };
        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, input: Tensor<B, 3>) -> [i64; 3] {
            let output: [i64; 3] = {
                let axes = &input.dims()[0..3];
                let mut output = [0i64; 3];
                for i in 0..3 {
                    output[i] = axes[i] as i64;
                }
                output
            };
            output
        }
        ");
    }

    #[test]
    fn test_shape_partial() {
        let config = ShapeConfig { start: 1, end: 3 };
        let input = Argument::new(
            "input",
            ArgType::Tensor(TensorType::new(DType::F32, 4, None)),
        );

        let node = ShapeNode {
            name: "shape2".to_string(),
            inputs: vec![input],
            outputs: vec![Argument::new("output", ArgType::Shape(2))],
            config,
        };
        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, input: Tensor<B, 4>) -> [i64; 2] {
            let output: [i64; 2] = {
                let axes = &input.dims()[1..3];
                let mut output = [0i64; 2];
                for i in 0..2 {
                    output[i] = axes[i] as i64;
                }
                output
            };
            output
        }
        ");
    }

    #[test]
    fn test_shape_of_shape() {
        let config = ShapeConfig { start: 0, end: 1 };
        let input = Argument::new("input", ArgType::Shape(3));

        let node = ShapeNode {
            name: "shape3".to_string(),
            inputs: vec![input],
            outputs: vec![Argument::new("output", ArgType::Shape(1))],
            config,
        };
        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, input: [i64; 3]) -> [i64; 1] {
            let output: [i64; 1] = [3i64];
            output
        }
        ");
    }

    #[test]
    fn test_shape_of_scalar_tensor() {
        let config = ShapeConfig { start: 0, end: 1 };
        let input = Argument::new("input", ArgType::ScalarTensor(DType::I64));

        let node = ShapeNode {
            name: "shape4".to_string(),
            inputs: vec![input],
            outputs: vec![Argument::new("output", ArgType::Shape(1))],
            config,
        };
        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, input: Tensor<B, 1, Int>) -> [i64; 1] {
            let output: [i64; 1] = {
                let axes = &input.dims()[0..1];
                let mut output = [0i64; 1];
                for i in 0..1 {
                    output[i] = axes[i] as i64;
                }
                output
            };
            output
        }
        ");
    }
}