burn-onnx 0.21.0-pre.3

Library for importing ONNX models into the Burn framework
Documentation
use super::prelude::*;

impl NodeCodegen for onnx_ir::unsqueeze::UnsqueezeNode {
    fn inputs(&self) -> &[Argument] {
        &self.inputs
    }

    fn outputs(&self) -> &[Argument] {
        &self.outputs
    }

    fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
        use onnx_ir::ir::ArgType;

        let input_arg = self.inputs.first().unwrap();
        let output_arg = self.outputs.first().unwrap();
        let output = arg_to_ident(output_arg);

        // Generate axes token stream
        let axes = match &self.config {
            onnx_ir::unsqueeze::UnsqueezeConfig::Static(static_axes) => static_axes.to_tokens(),
            onnx_ir::unsqueeze::UnsqueezeConfig::Runtime(axes_ref) => {
                let axes_arg = &self.inputs[axes_ref.input_index];
                match &axes_arg.ty {
                    ArgType::Tensor(_) => {
                        let tensor_name = arg_to_ident(axes_arg);
                        quote! {
                            #tensor_name.to_data().convert::<i64>().into_vec::<i64>().unwrap()
                        }
                    }
                    _ => panic!(
                        "UnsqueezeNode received invalid axes type: expected tensor but got {:?}",
                        axes_arg.ty
                    ),
                }
            }
        };

        match (&input_arg.ty, &output_arg.ty) {
            (input_ty, ArgType::Tensor(output_tensor)) if input_ty.is_on_device() => {
                let input = scope.arg(input_arg);
                let output_rank = output_tensor.rank.to_tokens();

                // Generate the correct output type based on the tensor kind
                let output_type = match &output_tensor.dtype {
                    dtype if dtype.is_int() || dtype.is_uint() => {
                        quote! { Tensor<B, #output_rank, Int> }
                    }
                    dtype if dtype.is_float() => {
                        quote! { Tensor<B, #output_rank> }
                    }
                    dtype if dtype.is_bool() => {
                        quote! { Tensor<B, #output_rank, Bool> }
                    }
                    _ => panic!("Unsupported tensor dtype: {:?}", output_tensor.dtype),
                };

                quote! {
                    let #output: #output_type = #input.unsqueeze_dims::<#output_rank>(&#axes);
                }
            }
            (ArgType::ScalarNative(_scalar_type), ArgType::Tensor(output_tensor)) => {
                let scalar_name = arg_to_ident(input_arg);
                let output_rank = output_tensor.rank.to_tokens();
                let dtype_tokens = output_tensor.dtype.to_tokens();

                // Create tensor from scalar with explicit dtype
                let tensor_creation = match &output_tensor.dtype {
                    dtype if dtype.is_int() || dtype.is_uint() => {
                        // Cast to i64 for TensorData, then from_data converts to target dtype
                        quote! {
                            Tensor::<B, #output_rank, Int>::from_data(
                                burn::tensor::TensorData::from([#scalar_name as i64]),
                                (&self.device, #dtype_tokens)
                            ).unsqueeze()
                        }
                    }
                    dtype if dtype.is_float() => {
                        // Cast to f64 for TensorData, then from_data converts to target dtype
                        quote! {
                            Tensor::<B, #output_rank>::from_data(
                                burn::tensor::TensorData::from([#scalar_name as f64]),
                                (&self.device, #dtype_tokens)
                            ).unsqueeze()
                        }
                    }
                    dtype if dtype.is_bool() => {
                        quote! {
                            Tensor::<B, #output_rank, Bool>::from_data(
                                burn::tensor::TensorData::from([#scalar_name != 0]),
                                (&self.device, #dtype_tokens)
                            ).unsqueeze()
                        }
                    }
                    _ => panic!("Unsupported tensor dtype: {:?}", output_tensor.dtype),
                };

                quote! {
                    let #output = #tensor_creation;
                }
            }
            (ArgType::ScalarNative(_), ArgType::Shape(_)) => {
                let input_name = arg_to_ident(input_arg);
                let value_expr = scalar_native_to_shape(quote! { #input_name });
                quote! {
                    let #output = #value_expr;
                }
            }
            (ArgType::ScalarTensor(dtype), ArgType::Shape(_)) => {
                let input = scope.arg(input_arg);
                let value_expr = scalar_tensor_to_shape(input, dtype);
                quote! {
                    let #output = #value_expr;
                }
            }
            (ArgType::Shape(_), ArgType::Tensor(output_tensor)) => {
                let input_name = arg_to_ident(input_arg);
                let output_rank = output_tensor.rank.to_tokens();
                let dtype_tokens = output_tensor.dtype.to_tokens();
                quote! {
                    let #output: Tensor<B, #output_rank, Int> = Tensor::<B, 1, Int>::from_data(
                        burn::tensor::TensorData::from(#input_name.as_slice()),
                        (&self.device, #dtype_tokens)
                    ).unsqueeze_dims::<#output_rank>(&#axes);
                }
            }
            _ => panic!(
                "UnsqueezeNode received unsupported input/output combination: {:?} -> {:?}",
                input_arg.ty, output_arg.ty
            ),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::super::test_helpers::*;
    use burn::tensor::DType;
    use insta::assert_snapshot;
    use onnx_ir::unsqueeze::{UnsqueezeConfig, UnsqueezeNode, UnsqueezeNodeBuilder};

    fn create_unsqueeze_node(name: &str, axes: Vec<i64>) -> UnsqueezeNode {
        let config = UnsqueezeConfig::Static(axes);

        UnsqueezeNodeBuilder::new(name)
            .input_tensor("input", 2, DType::F32)
            .output_tensor("output", 3, DType::F32)
            .config(config)
            .build()
    }

    #[test]
    fn test_unsqueeze_forward_single_axis() {
        let node = create_unsqueeze_node("unsqueeze1", vec![0]);
        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 3> {
            let output: Tensor<B, 3> = input.unsqueeze_dims::<3>(&[0]);
            output
        }
        ");
    }

    #[test]
    fn test_unsqueeze_forward_axis_1() {
        let node = create_unsqueeze_node("unsqueeze1", vec![1]);
        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 3> {
            let output: Tensor<B, 3> = input.unsqueeze_dims::<3>(&[1]);
            output
        }
        ");
    }

    #[test]
    fn test_unsqueeze_forward_axis_2() {
        let node = create_unsqueeze_node("unsqueeze1", vec![2]);
        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 3> {
            let output: Tensor<B, 3> = input.unsqueeze_dims::<3>(&[2]);
            output
        }
        ");
    }

    #[test]
    fn test_unsqueeze_shape_input() {
        let config = UnsqueezeConfig::Static(vec![0]);

        let node = UnsqueezeNodeBuilder::new("unsqueeze_shape")
            .input_shape("shape_val", 4)
            .output_tensor("output", 2, DType::I64)
            .config(config)
            .build();

        let code = codegen_forward_default(&node);
        assert_snapshot!(code, @r"
        pub fn forward(&self, shape_val: [i64; 4]) -> Tensor<B, 2, Int> {
            let output: Tensor<B, 2, Int> = Tensor::<
                B,
                1,
                Int,
            >::from_data(
                    burn::tensor::TensorData::from(shape_val.as_slice()),
                    (&self.device, burn::tensor::DType::I64),
                )
                .unsqueeze_dims::<2>(&[0]);
            output
        }
        ");
    }
}