burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(module_conv_transpose1d)]
mod tests {
    use super::*;
    use burn_tensor::module::conv_transpose1d;
    use burn_tensor::ops::ConvTransposeOptions;
    use burn_tensor::{Shape, Tensor};

    #[test]
    fn test_conv_transpose1d_diff_channels() {
        let test = ConvTranspose1dTestCase {
            batch_size: 1,
            channels_in: 3,
            channels_out: 2,
            kernel_size: 3,
            padding: 1,
            padding_out: 0,
            stride: 1,
            dilation: 1,
            groups: 1,
            length: 4,
        };

        test.assert_output(TestTensor::from([[
            [270., 453., 516., 387.],
            [352., 589., 679., 505.],
        ]]));
    }

    #[test]
    fn test_conv_transpose1d_stride() {
        let test = ConvTranspose1dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size: 3,
            padding: 1,
            padding_out: 1,
            stride: 2,
            dilation: 1,
            groups: 1,
            length: 4,
        };

        test.assert_output(TestTensor::from([[
            [28., 62., 36., 78., 44., 94., 52., 62.],
            [41., 93., 55., 121., 69., 149., 83., 93.],
        ]]));
    }

    #[test]
    fn test_conv_transpose1d_dilation() {
        let test = ConvTranspose1dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size: 3,
            padding: 1,
            padding_out: 0,
            stride: 1,
            dilation: 2,
            groups: 1,
            length: 4,
        };

        test.assert_output(TestTensor::from([[
            [30., 64., 78., 76., 94., 52.],
            [49., 101., 127., 113., 143., 77.],
        ]]));
    }

    #[test]
    fn test_conv_transpose1d_groups() {
        let test = ConvTranspose1dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size: 3,
            padding: 1,
            padding_out: 0,
            stride: 1,
            dilation: 1,
            groups: 2,
            length: 4,
        };

        test.assert_output(TestTensor::from_floats(
            [[[0., 1., 4., 7.], [32., 59., 71., 59.]]],
            &Default::default(),
        ));
    }

    struct ConvTranspose1dTestCase {
        batch_size: usize,
        channels_in: usize,
        channels_out: usize,
        kernel_size: usize,
        padding: usize,
        padding_out: usize,
        stride: usize,
        dilation: usize,
        groups: usize,
        length: usize,
    }

    impl ConvTranspose1dTestCase {
        fn assert_output(self, y: TestTensor<3>) {
            let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
            let shape_weights = Shape::new([
                self.channels_in,
                self.channels_out / self.groups,
                self.kernel_size,
            ]);
            let device = Default::default();
            let weights = TestTensor::from_data(
                TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
                    .reshape::<3, _>(shape_weights)
                    .into_data(),
                &device,
            );
            let bias = TestTensor::from_data(
                TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
                &device,
            );
            let x = TestTensor::from_data(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<3, _>(shape_x)
                    .into_data(),
                &device,
            );
            let output = conv_transpose1d(
                x,
                weights,
                Some(bias),
                ConvTransposeOptions::new(
                    [self.stride],
                    [self.padding],
                    [self.padding_out],
                    [self.dilation],
                    self.groups,
                ),
            );

            y.to_data().assert_approx_eq(&output.into_data(), 3);
        }
    }
}