1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use crate::{
    compute::StaticKernel,
    element::WgpuElement,
    kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
    kernel_wgsl,
    ops::numeric::empty_device,
    tensor::WgpuTensor,
};
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};

kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl");

pub(crate) fn conv_transpose2d<E: WgpuElement + Element>(
    input: WgpuTensor<E, 4>,
    weight: WgpuTensor<E, 4>,
    bias: Option<WgpuTensor<E, 1>>,
    options: ConvTransposeOptions<2>,
) -> WgpuTensor<E, 4> {
    let input = kernel::into_contiguous(input);
    let weight = kernel::into_contiguous(weight);
    let [batch_size, _, in_height, in_width] = input.shape.dims;
    let [_, out_channels, kernel_0, kernel_1] = weight.shape.dims;

    let out_0 = (in_height - 1) * options.stride[0]
        + options.dilation[0] * (kernel_0 - 1)
        + options.padding_out[0]
        - 2 * options.padding[0]
        + 1;
    let out_1 = (in_width - 1) * options.stride[1]
        + options.dilation[1] * (kernel_1 - 1)
        + options.padding_out[1]
        - 2 * options.padding[1]
        + 1;

    let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
    let num_elems = shape_out.num_elements();

    let output = empty_device(
        input.client.clone(),
        input.device.clone(),
        shape_out.clone(),
    );
    let mut info = build_info(&[&input, &output, &weight]);

    info.push(options.stride[0] as u32);
    info.push(options.stride[1] as u32);
    info.push(options.padding[0] as u32);
    info.push(options.padding[1] as u32);
    info.push(options.dilation[0] as u32);
    info.push(options.dilation[1] as u32);
    info.push(options.groups as u32);

    let bias_handle = bias
        .map(|bias| bias.handle)
        .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));

    let info_handle = input.client.create(bytemuck::cast_slice(&info));

    let kernel = StaticKernel::<
        KernelSettings<ConvTranspose2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
    >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
    input.client.execute(
        Box::new(kernel),
        &[
            &input.handle,
            &weight.handle,
            &bias_handle,
            &output.handle,
            &info_handle,
        ],
    );

    output
}

#[cfg(test)]
mod tests {
    use crate::tests::{ReferenceBackend, TestBackend};
    use burn_tensor::{backend::Backend, module, Distribution, Tensor};

    #[test]
    fn conv_transpose2d_should_work_with_multiple_invocations() {
        TestBackend::seed(0);

        let height = 8;
        let width = 8;
        let in_channels = 8;
        let out_channels = 8;
        let batch_size = 32;
        let kernel_size_0 = 3;
        let kernel_size_1 = 3;
        let options =
            burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1);

        let input = Tensor::<TestBackend, 4>::random(
            [batch_size, in_channels, height, width],
            Distribution::Default,
        );
        let weight = Tensor::<TestBackend, 4>::random(
            [
                in_channels,
                out_channels / options.groups,
                kernel_size_0,
                kernel_size_1,
            ],
            Distribution::Default,
        );
        let bias = Tensor::<TestBackend, 1>::random([out_channels], Distribution::Default);
        let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data());
        let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data());
        let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data());

        let output = module::conv_transpose2d(input, weight, Some(bias), options.clone());
        let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options);

        output
            .into_data()
            .assert_approx_eq(&output_ref.into_data(), 3);
    }
}