1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
#[burn_tensor_testgen::testgen(ad_conv_transpose2d)]
mod tests {
    use super::*;
    use burn_tensor::{module::conv_transpose2d, ops::ConvTransposeOptions, Data, Shape};

    #[test]
    fn test_conv_transpose2d_basic() {
        let test = ConvTranspose2dTestCase {
            batch_size: 2,
            channels: [2, 2],
            kernel_size: [3, 3],
            padding: [0, 0],
            padding_out: [0, 0],
            stride: [1, 1],
            dilation: [1, 1],
            groups: 1,
            size: [4, 4],
        };
        let grads = Grads {
            x: TestTensor::from_floats([
                [
                    [
                        [153., 153., 153., 153.],
                        [153., 153., 153., 153.],
                        [153., 153., 153., 153.],
                        [153., 153., 153., 153.],
                    ],
                    [
                        [477., 477., 477., 477.],
                        [477., 477., 477., 477.],
                        [477., 477., 477., 477.],
                        [477., 477., 477., 477.],
                    ],
                ],
                [
                    [
                        [153., 153., 153., 153.],
                        [153., 153., 153., 153.],
                        [153., 153., 153., 153.],
                        [153., 153., 153., 153.],
                    ],
                    [
                        [477., 477., 477., 477.],
                        [477., 477., 477., 477.],
                        [477., 477., 477., 477.],
                        [477., 477., 477., 477.],
                    ],
                ],
            ]),
            weight: TestTensor::from_floats([
                [
                    [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
                    [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
                ],
                [
                    [
                        [1264., 1264., 1264.],
                        [1264., 1264., 1264.],
                        [1264., 1264., 1264.],
                    ],
                    [
                        [1264., 1264., 1264.],
                        [1264., 1264., 1264.],
                        [1264., 1264., 1264.],
                    ],
                ],
            ]),
            bias: TestTensor::from_floats([72., 72.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_padding() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [1, 1],
            kernel_size: [3, 3],
            padding: [1, 2],
            padding_out: [0, 0],
            stride: [1, 1],
            dilation: [1, 1],
            groups: 1,
            size: [4, 4],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[[
                [13., 24., 20., 9.],
                [15., 27., 21., 9.],
                [15., 27., 21., 9.],
                [7., 12., 8., 3.],
            ]]]),
            weight: TestTensor::from_floats([[[
                [63., 57., 51.],
                [68., 60., 52.],
                [39., 33., 27.],
            ]]]),
            bias: TestTensor::from_floats([8.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_stride() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [1, 1],
            kernel_size: [3, 3],
            padding: [0, 0],
            padding_out: [0, 0],
            stride: [2, 3],
            dilation: [1, 1],
            groups: 1,
            size: [4, 4],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[[
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
            ]]]),
            weight: TestTensor::from_floats([[[
                [120., 120., 120.],
                [120., 120., 120.],
                [120., 120., 120.],
            ]]]),
            bias: TestTensor::from_floats([108.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_stride_padding_out() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [1, 1],
            kernel_size: [3, 3],
            padding: [0, 0],
            padding_out: [1, 2],
            stride: [2, 3],
            dilation: [1, 1],
            groups: 1,
            size: [4, 4],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[[
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
            ]]]),
            weight: TestTensor::from_floats([[[
                [120., 120., 120.],
                [120., 120., 120.],
                [120., 120., 120.],
            ]]]),
            bias: TestTensor::from_floats([140.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_dilation() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [1, 1],
            kernel_size: [3, 3],
            padding: [0, 0],
            padding_out: [0, 0],
            stride: [1, 1],
            dilation: [2, 3],
            groups: 1,
            size: [4, 4],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[[
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
                [36., 36., 36., 36.],
            ]]]),
            weight: TestTensor::from_floats([[[
                [120., 120., 120.],
                [120., 120., 120.],
                [120., 120., 120.],
            ]]]),
            bias: TestTensor::from_floats([80.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_channels() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [2, 3],
            kernel_size: [3, 3],
            padding: [0, 0],
            padding_out: [0, 0],
            stride: [1, 1],
            dilation: [1, 1],
            groups: 1,
            size: [4, 4],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[
                [
                    [351., 351., 351., 351.],
                    [351., 351., 351., 351.],
                    [351., 351., 351., 351.],
                    [351., 351., 351., 351.],
                ],
                [
                    [1080., 1080., 1080., 1080.],
                    [1080., 1080., 1080., 1080.],
                    [1080., 1080., 1080., 1080.],
                    [1080., 1080., 1080., 1080.],
                ],
            ]]),
            weight: TestTensor::from_floats([
                [
                    [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
                    [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
                    [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
                ],
                [
                    [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
                    [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
                    [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
                ],
            ]),
            bias: TestTensor::from_floats([36., 36., 36.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_kernel_size() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [1, 1],
            kernel_size: [3, 5],
            padding: [0, 0],
            padding_out: [0, 0],
            stride: [1, 1],
            dilation: [1, 1],
            groups: 1,
            size: [6, 6],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[[
                [105., 105., 105., 105., 105., 105.],
                [105., 105., 105., 105., 105., 105.],
                [105., 105., 105., 105., 105., 105.],
                [105., 105., 105., 105., 105., 105.],
                [105., 105., 105., 105., 105., 105.],
                [105., 105., 105., 105., 105., 105.],
            ]]]),
            weight: TestTensor::from_floats([[[
                [630., 630., 630., 630., 630.],
                [630., 630., 630., 630., 630.],
                [630., 630., 630., 630., 630.],
            ]]]),
            bias: TestTensor::from_floats([80.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_groups() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [2, 2],
            kernel_size: [3, 3],
            padding: [0, 0],
            padding_out: [0, 0],
            stride: [1, 1],
            dilation: [1, 1],
            groups: 2,
            size: [4, 4],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[
                [
                    [36., 36., 36., 36.],
                    [36., 36., 36., 36.],
                    [36., 36., 36., 36.],
                    [36., 36., 36., 36.],
                ],
                [
                    [117., 117., 117., 117.],
                    [117., 117., 117., 117.],
                    [117., 117., 117., 117.],
                    [117., 117., 117., 117.],
                ],
            ]]),
            weight: TestTensor::from_floats([
                [[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]],
                [[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]],
            ]),
            bias: TestTensor::from_floats([36., 36.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_complex_no_groups() {
        let test = ConvTranspose2dTestCase {
            batch_size: 2,
            channels: [2, 3],
            kernel_size: [3, 5],
            padding: [1, 2],
            padding_out: [1, 2],
            stride: [2, 3],
            dilation: [2, 3],
            groups: 1,
            size: [6, 8],
        };
        let grads = Grads {
            x: TestTensor::from_floats([
                [
                    [
                        [600., 735., 735., 735., 735., 735., 735., 735.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                    ],
                    [
                        [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                    ],
                ],
                [
                    [
                        [600., 735., 735., 735., 735., 735., 735., 735.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                        [810., 990., 990., 990., 990., 990., 990., 990.],
                    ],
                    [
                        [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
                    ],
                ],
            ]),
            weight: TestTensor::from_floats([
                [
                    [
                        [5320., 6040., 6040., 6040., 6040.],
                        [6048., 6864., 6864., 6864., 6864.],
                        [6048., 6864., 6864., 6864., 6864.],
                    ],
                    [
                        [5320., 6040., 6040., 6040., 6040.],
                        [6048., 6864., 6864., 6864., 6864.],
                        [6048., 6864., 6864., 6864., 6864.],
                    ],
                    [
                        [5320., 6040., 6040., 6040., 6040.],
                        [6048., 6864., 6864., 6864., 6864.],
                        [6048., 6864., 6864., 6864., 6864.],
                    ],
                ],
                [
                    [
                        [8680., 9880., 9880., 9880., 9880.],
                        [10080., 11472., 11472., 11472., 11472.],
                        [10080., 11472., 11472., 11472., 11472.],
                    ],
                    [
                        [8680., 9880., 9880., 9880., 9880.],
                        [10080., 11472., 11472., 11472., 11472.],
                        [10080., 11472., 11472., 11472., 11472.],
                    ],
                    [
                        [8680., 9880., 9880., 9880., 9880.],
                        [10080., 11472., 11472., 11472., 11472.],
                        [10080., 11472., 11472., 11472., 11472.],
                    ],
                ],
            ]),
            bias: TestTensor::from_floats([896., 896., 896.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_complex_no_groups_2() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [4, 2],
            kernel_size: [2, 3],
            padding: [1, 2],
            padding_out: [1, 2],
            stride: [2, 3],
            dilation: [1, 2],
            groups: 1,
            size: [10, 10],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[
                [
                    [30., 42., 42., 42., 42., 42., 42., 42., 42., 42.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                ],
                [
                    [78., 114., 114., 114., 114., 114., 114., 114., 114., 114.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
                ],
                [
                    [126., 186., 186., 186., 186., 186., 186., 186., 186., 186.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
                ],
                [
                    [174., 258., 258., 258., 258., 258., 258., 258., 258., 258.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
                ],
            ]]),
            weight: TestTensor::from_floats([
                [
                    [[4455., 4905., 4905.], [4500., 4950., 4950.]],
                    [[4455., 4905., 4905.], [4500., 4950., 4950.]],
                ],
                [
                    [[12555., 13905., 13905.], [13500., 14950., 14950.]],
                    [[12555., 13905., 13905.], [13500., 14950., 14950.]],
                ],
                [
                    [[20655., 22905., 22905.], [22500., 24950., 24950.]],
                    [[20655., 22905., 22905.], [22500., 24950., 24950.]],
                ],
                [
                    [[28755., 31905., 31905.], [31500., 34950., 34950.]],
                    [[28755., 31905., 31905.], [31500., 34950., 34950.]],
                ],
            ]),
            bias: TestTensor::from_floats([570., 570.]),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose2d_complex_groups() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels: [4, 2],
            kernel_size: [2, 3],
            padding: [1, 2],
            padding_out: [1, 2],
            stride: [2, 3],
            dilation: [1, 2],
            groups: 2,
            size: [10, 10],
        };
        let grads = Grads {
            x: TestTensor::from_floats([[
                [
                    [9., 12., 12., 12., 12., 12., 12., 12., 12., 12.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                ],
                [
                    [21., 30., 30., 30., 30., 30., 30., 30., 30., 30.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
                ],
                [
                    [33., 48., 48., 48., 48., 48., 48., 48., 48., 48.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
                ],
                [
                    [45., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
                ],
            ]]),
            weight: TestTensor::from_floats([
                [[[4455., 4905., 4905.], [4500., 4950., 4950.]]],
                [[[12555., 13905., 13905.], [13500., 14950., 14950.]]],
                [[[20655., 22905., 22905.], [22500., 24950., 24950.]]],
                [[[28755., 31905., 31905.], [31500., 34950., 34950.]]],
            ]),
            bias: TestTensor::from_floats([570., 570.]),
        };
        test.assert_grads(grads);
    }

    struct ConvTranspose2dTestCase {
        batch_size: usize,
        channels: [usize; 2],
        kernel_size: [usize; 2],
        padding: [usize; 2],
        padding_out: [usize; 2],
        stride: [usize; 2],
        dilation: [usize; 2],
        groups: usize,
        size: [usize; 2],
    }

    struct Grads {
        x: TestTensor<4>,
        weight: TestTensor<4>,
        bias: TestTensor<1>,
    }

    impl ConvTranspose2dTestCase {
        fn assert_grads(self, expected_grads: Grads) {
            let shape_x = Shape::new([
                self.batch_size,
                self.channels[0],
                self.size[0],
                self.size[1],
            ]);
            let shape_weight = Shape::new([
                self.channels[0],
                self.channels[1] / self.groups,
                self.kernel_size[0],
                self.kernel_size[1],
            ]);
            let weight = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_weight.num_elements())
                    .reshape(shape_weight)
                    .into_data()
                    .convert(),
            )
            .require_grad();
            let bias = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..self.channels[1])
                    .into_data()
                    .convert(),
            )
            .require_grad();
            let x = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_x.num_elements())
                    .reshape(shape_x)
                    .into_data()
                    .convert(),
            )
            .require_grad();
            let output = conv_transpose2d(
                x.clone(),
                weight.clone(),
                Some(bias.clone()),
                ConvTransposeOptions::new(
                    self.stride,
                    self.padding,
                    self.padding_out,
                    self.dilation,
                    self.groups,
                ),
            );
            let grads = output.backward();

            // Assert
            let x_grad_actual = x.grad(&grads).unwrap();
            let weight_grad_actual = weight.grad(&grads).unwrap();
            let bias_grad_actual = bias.grad(&grads).unwrap();

            expected_grads
                .bias
                .to_data()
                .assert_approx_eq(&bias_grad_actual.to_data(), 3);
            expected_grads
                .x
                .to_data()
                .assert_approx_eq(&x_grad_actual.to_data(), 3);
            expected_grads
                .weight
                .to_data()
                .assert_approx_eq(&weight_grad_actual.to_data(), 3);
        }
    }
}