burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
#[burn_tensor_testgen::testgen(module_conv_transpose3d)]
mod tests {
    use super::*;
    use burn_tensor::module::conv_transpose3d;
    use burn_tensor::ops::ConvTransposeOptions;
    use burn_tensor::{Shape, Tensor};

    #[test]
    fn test_conv_transpose3d_simple_1() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels_in: 1,
            channels_out: 1,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 1,
            padding_2: 1,
            padding_3: 1,
            padding_out_1: 0,
            padding_out_2: 0,
            padding_out_3: 0,
            stride_1: 1,
            stride_2: 1,
            stride_3: 1,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 1,
            depth: 2,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[[
            [[96., 124.], [180., 208.]],
            [[348., 376.], [432., 460.]],
        ]]]));
    }
    #[test]
    fn test_conv_transpose3d_simple_2() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels_in: 3,
            channels_out: 3,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 1,
            padding_2: 1,
            padding_3: 1,
            padding_out_1: 0,
            padding_out_2: 0,
            padding_out_3: 0,
            stride_1: 1,
            stride_2: 1,
            stride_3: 1,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 1,
            depth: 4,
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [
                [
                    [238452., 360588., 363756., 244488.],
                    [367929., 556353., 561186., 377163.],
                    [380745., 575685., 580518., 390123.],
                    [261192., 394896., 398172., 267564.],
                ],
                [
                    [394083., 595827., 600822., 403749.],
                    [607635., 918648., 926262., 622404.],
                    [627831., 949104., 956718., 642816.],
                    [430353., 650529., 655686., 440523.],
                ],
                [
                    [447075., 675747., 680742., 457317.],
                    [688419., 1040472., 1048086., 704052.],
                    [708615., 1070928., 1078542., 724464.],
                    [485073., 733041., 738198., 495819.],
                ],
                [
                    [328656., 496632., 500124., 335892.],
                    [505611., 763983., 769302., 516645.],
                    [519723., 785259., 790578., 530901.],
                    [355428., 536988., 540588., 363000.],
                ],
            ],
            [
                [
                    [286729., 433489., 437629., 294061.],
                    [442288., 668620., 674911., 453466.],
                    [458992., 693784., 700075., 470314.],
                    [314653., 475573., 479821., 322321.],
                ],
                [
                    [474274., 716842., 723295., 485884.],
                    [730837., 1104544., 1114345., 748522.],
                    [756865., 1143748., 1153549., 774766.],
                    [518320., 783208., 789823., 530434.],
                ],
                [
                    [542818., 820090., 826543., 555004.],
                    [834949., 1261360., 1271161., 853498.],
                    [860977., 1300564., 1310365., 879742.],
                    [588592., 889048., 895663., 601282.],
                ],
                [
                    [397669., 600637., 605101., 406201.],
                    [611074., 922906., 929683., 624052.],
                    [629074., 950014., 956791., 642196.],
                    [429625., 648769., 653341., 438493.],
                ],
            ],
            [
                [
                    [335006., 506390., 511502., 343634.],
                    [516647., 780887., 788636., 529769.],
                    [537239., 811883., 819632., 550505.],
                    [368114., 556250., 561470., 377078.],
                ],
                [
                    [554465., 837857., 845768., 568019.],
                    [854039., 1290440., 1302428., 874640.],
                    [885899., 1338392., 1350380., 906716.],
                    [606287., 915887., 923960., 620345.],
                ],
                [
                    [638561., 964433., 972344., 652691.],
                    [981479., 1482248., 1494236., 1002944.],
                    [1013339., 1530200., 1542188., 1035020.],
                    [692111., 1045055., 1053128., 706745.],
                ],
                [
                    [466682., 704642., 710078., 476510.],
                    [716537., 1081829., 1090064., 731459.],
                    [738425., 1114769., 1123004., 753491.],
                    [503822., 760550., 766094., 513986.],
                ],
            ],
        ]]));
    }

    #[test]
    fn test_conv_transpose3d_stride_2() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels_in: 1,
            channels_out: 1,
            kernel_size_1: 2,
            kernel_size_2: 2,
            kernel_size_3: 2,
            padding_1: 0,
            padding_2: 0,
            padding_3: 0,
            padding_out_1: 0,
            padding_out_2: 0,
            padding_out_3: 0,
            stride_1: 2,
            stride_2: 2,
            stride_3: 2,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 1,
            depth: 2,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[[
            [
                [0., 0., 0., 1.],
                [0., 0., 2., 3.],
                [0., 2., 0., 3.],
                [4., 6., 6., 9.],
            ],
            [
                [0., 0., 4., 5.],
                [0., 0., 6., 7.],
                [8., 10., 12., 15.],
                [12., 14., 18., 21.],
            ],
            [
                [0., 4., 0., 5.],
                [8., 12., 10., 15.],
                [0., 6., 0., 7.],
                [12., 18., 14., 21.],
            ],
            [
                [16., 20., 20., 25.],
                [24., 28., 30., 35.],
                [24., 30., 28., 35.],
                [36., 42., 42., 49.],
            ],
        ]]]));
    }

    #[test]
    fn test_conv_transpose3d_dilation_2() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 1,
            padding_2: 1,
            padding_3: 1,
            padding_out_1: 1,
            padding_out_2: 1,
            padding_out_3: 1,
            stride_1: 1,
            stride_2: 1,
            stride_3: 1,
            dilation_1: 2,
            dilation_2: 2,
            dilation_3: 2,
            groups: 1,
            depth: 2,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[
            [
                [
                    [810., 776., 832., 796., 854.],
                    [756., 712., 774., 728., 792.],
                    [876., 836., 898., 856., 920.],
                    [810., 760., 828., 776., 846.],
                    [942., 896., 964., 916., 986.],
                ],
                [
                    [720., 660., 734., 672., 748.],
                    [606., 536., 616., 544., 626.],
                    [762., 696., 776., 708., 790.],
                    [636., 560., 646., 568., 656.],
                    [804., 732., 818., 744., 832.],
                ],
                [
                    [1008., 956., 1030., 976., 1052.],
                    [918., 856., 936., 872., 954.],
                    [1074., 1016., 1096., 1036., 1118.],
                    [972., 904., 990., 920., 1008.],
                    [1140., 1076., 1162., 1096., 1184.],
                ],
                [
                    [846., 768., 860., 780., 874.],
                    [696., 608., 706., 616., 716.],
                    [888., 804., 902., 816., 916.],
                    [726., 632., 736., 640., 746.],
                    [930., 840., 944., 852., 958.],
                ],
                [
                    [1206., 1136., 1228., 1156., 1250.],
                    [1080., 1000., 1098., 1016., 1116.],
                    [1272., 1196., 1294., 1216., 1316.],
                    [1134., 1048., 1152., 1064., 1170.],
                    [1338., 1256., 1360., 1276., 1382.],
                ],
            ],
            [
                [
                    [1405., 1317., 1427., 1337., 1449.],
                    [1243., 1145., 1261., 1161., 1279.],
                    [1471., 1377., 1493., 1397., 1515.],
                    [1297., 1193., 1315., 1209., 1333.],
                    [1537., 1437., 1559., 1457., 1581.],
                ],
                [
                    [1099., 985., 1113., 997., 1127.],
                    [877., 753., 887., 761., 897.],
                    [1141., 1021., 1155., 1033., 1169.],
                    [907., 777., 917., 785., 927.],
                    [1183., 1057., 1197., 1069., 1211.],
                ],
                [
                    [1603., 1497., 1625., 1517., 1647.],
                    [1405., 1289., 1423., 1305., 1441.],
                    [1669., 1557., 1691., 1577., 1713.],
                    [1459., 1337., 1477., 1353., 1495.],
                    [1735., 1617., 1757., 1637., 1779.],
                ],
                [
                    [1225., 1093., 1239., 1105., 1253.],
                    [967., 825., 977., 833., 987.],
                    [1267., 1129., 1281., 1141., 1295.],
                    [997., 849., 1007., 857., 1017.],
                    [1309., 1165., 1323., 1177., 1337.],
                ],
                [
                    [1801., 1677., 1823., 1697., 1845.],
                    [1567., 1433., 1585., 1449., 1603.],
                    [1867., 1737., 1889., 1757., 1911.],
                    [1621., 1481., 1639., 1497., 1657.],
                    [1933., 1797., 1955., 1817., 1977.],
                ],
            ],
        ]]));
    }

    #[test]
    fn test_conv_transpose3d_stride2_out_padding() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 1,
            padding_2: 1,
            padding_3: 1,
            padding_out_1: 1,
            padding_out_2: 1,
            padding_out_3: 1,
            stride_1: 2,
            stride_2: 2,
            stride_3: 2,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 1,
            depth: 2,
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [
                [
                    [2144., 4366., 2224., 4526., 2304., 4686., 2384., 2422.],
                    [4584., 9324., 4744., 9644., 4904., 9964., 5064., 5148.],
                    [2464., 5006., 2544., 5166., 2624., 5326., 2704., 2750.],
                    [5224., 10604., 5384., 10924., 5544., 11244., 5704., 5804.],
                    [2784., 5646., 2864., 5806., 2944., 5966., 3024., 3078.],
                    [5864., 11884., 6024., 12204., 6184., 12524., 6344., 6460.],
                    [3104., 6286., 3184., 6446., 3264., 6606., 3344., 3406.],
                    [3272., 6628., 3358., 6800., 3444., 6972., 3530., 3592.],
                ],
                [
                    [5280., 10716., 5440., 11036., 5600., 11356., 5760., 5868.],
                    [
                        11152., 22616., 11472., 23256., 11792., 23896., 12112., 12344.,
                    ],
                    [5920., 11996., 6080., 12316., 6240., 12636., 6400., 6524.],
                    [
                        12432., 25176., 12752., 25816., 13072., 26456., 13392., 13656.,
                    ],
                    [6560., 13276., 6720., 13596., 6880., 13916., 7040., 7180.],
                    [
                        13712., 27736., 14032., 28376., 14352., 29016., 14672., 14968.,
                    ],
                    [7200., 14556., 7360., 14876., 7520., 15196., 7680., 7836.],
                    [7632., 15432., 7804., 15776., 7976., 16120., 8148., 8304.],
                ],
                [
                    [3424., 6926., 3504., 7086., 3584., 7246., 3664., 3734.],
                    [7144., 14444., 7304., 14764., 7464., 15084., 7624., 7772.],
                    [3744., 7566., 3824., 7726., 3904., 7886., 3984., 4062.],
                    [7784., 15724., 7944., 16044., 8104., 16364., 8264., 8428.],
                    [4064., 8206., 4144., 8366., 4224., 8526., 4304., 4390.],
                    [8424., 17004., 8584., 17324., 8744., 17644., 8904., 9084.],
                    [4384., 8846., 4464., 9006., 4544., 9166., 4624., 4718.],
                    [4648., 9380., 4734., 9552., 4820., 9724., 4906., 5000.],
                ],
                [
                    [4000., 8096., 4098., 8292., 4196., 8488., 4294., 4364.],
                    [8368., 16928., 8564., 17320., 8760., 17712., 8956., 9104.],
                    [4392., 8880., 4490., 9076., 4588., 9272., 4686., 4764.],
                    [9152., 18496., 9348., 18888., 9544., 19280., 9740., 9904.],
                    [4784., 9664., 4882., 9860., 4980., 10056., 5078., 5164.],
                    [
                        9936., 20064., 10132., 20456., 10328., 20848., 10524., 10704.,
                    ],
                    [5176., 10448., 5274., 10644., 5372., 10840., 5470., 5564.],
                    [5440., 10982., 5544., 11190., 5648., 11398., 5752., 5846.],
                ],
            ],
            [
                [
                    [3009., 6149., 3143., 6417., 3277., 6685., 3411., 3449.],
                    [6529., 13321., 6797., 13857., 7065., 14393., 7333., 7417.],
                    [3545., 7221., 3679., 7489., 3813., 7757., 3947., 3993.],
                    [7601., 15465., 7869., 16001., 8137., 16537., 8405., 8505.],
                    [4081., 8293., 4215., 8561., 4349., 8829., 4483., 4537.],
                    [8673., 17609., 8941., 18145., 9209., 18681., 9477., 9593.],
                    [4617., 9365., 4751., 9633., 4885., 9901., 5019., 5081.],
                    [4785., 9707., 4925., 9987., 5065., 10267., 5205., 5267.],
                ],
                [
                    [7873., 16009., 8141., 16545., 8409., 17081., 8677., 8785.],
                    [
                        16769., 34065., 17305., 35137., 17841., 36209., 18377., 18609.,
                    ],
                    [8945., 18153., 9213., 18689., 9481., 19225., 9749., 9873.],
                    [
                        18913., 38353., 19449., 39425., 19985., 40497., 20521., 20785.,
                    ],
                    [
                        10017., 20297., 10285., 20833., 10553., 21369., 10821., 10961.,
                    ],
                    [
                        21057., 42641., 21593., 43713., 22129., 44785., 22665., 22961.,
                    ],
                    [
                        11089., 22441., 11357., 22977., 11625., 23513., 11893., 12049.,
                    ],
                    [
                        11521., 23317., 11801., 23877., 12081., 24437., 12361., 12517.,
                    ],
                ],
                [
                    [5153., 10437., 5287., 10705., 5421., 10973., 5555., 5625.],
                    [
                        10817., 21897., 11085., 22433., 11353., 22969., 11621., 11769.,
                    ],
                    [5689., 11509., 5823., 11777., 5957., 12045., 6091., 6169.],
                    [
                        11889., 24041., 12157., 24577., 12425., 25113., 12693., 12857.,
                    ],
                    [6225., 12581., 6359., 12849., 6493., 13117., 6627., 6713.],
                    [
                        12961., 26185., 13229., 26721., 13497., 27257., 13765., 13945.,
                    ],
                    [6761., 13653., 6895., 13921., 7029., 14189., 7163., 7257.],
                    [7025., 14187., 7165., 14467., 7305., 14747., 7445., 7539.],
                ],
                [
                    [5729., 11607., 5881., 11911., 6033., 12215., 6185., 6255.],
                    [
                        12041., 24381., 12345., 24989., 12649., 25597., 12953., 13101.,
                    ],
                    [6337., 12823., 6489., 13127., 6641., 13431., 6793., 6871.],
                    [
                        13257., 26813., 13561., 27421., 13865., 28029., 14169., 14333.,
                    ],
                    [6945., 14039., 7097., 14343., 7249., 14647., 7401., 7487.],
                    [
                        14473., 29245., 14777., 29853., 15081., 30461., 15385., 15565.,
                    ],
                    [7553., 15255., 7705., 15559., 7857., 15863., 8009., 8103.],
                    [7817., 15789., 7975., 16105., 8133., 16421., 8291., 8385.],
                ],
            ],
        ]]));
    }

    #[test]
    fn test_conv_transpose3d_groups_2() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 1,
            padding_2: 1,
            padding_3: 1,
            padding_out_1: 0,
            padding_out_2: 0,
            padding_out_3: 0,
            stride_1: 1,
            stride_2: 1,
            stride_3: 1,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 2,
            depth: 2,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[
            [[[96., 124.], [180., 208.]], [[348., 376.], [432., 460.]]],
            [
                [[2997., 3089.], [3273., 3365.]],
                [[3825., 3917.], [4101., 4193.]],
            ],
        ]]));
    }

    #[test]
    fn test_conv_transpose3d_groups_different_channels() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 6,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 0,
            padding_2: 0,
            padding_3: 0,
            padding_out_1: 0,
            padding_out_2: 0,
            padding_out_3: 0,
            stride_1: 1,
            stride_2: 1,
            stride_3: 1,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 2,
            depth: 2,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[
            [
                [
                    [0., 0., 1., 2.],
                    [0., 5., 11., 11.],
                    [6., 23., 29., 23.],
                    [12., 32., 37., 24.],
                ],
                [
                    [0., 13., 23., 21.],
                    [30., 96., 124., 86.],
                    [66., 180., 208., 134.],
                    [66., 161., 179., 107.],
                ],
                [
                    [36., 103., 113., 75.],
                    [138., 348., 376., 230.],
                    [174., 432., 460., 278.],
                    [138., 323., 341., 197.],
                ],
                [
                    [72., 166., 175., 100.],
                    [192., 433., 455., 255.],
                    [222., 499., 521., 291.],
                    [144., 318., 331., 182.],
                ],
            ],
            [
                [
                    [1., 28., 29., 30.],
                    [55., 168., 174., 120.],
                    [61., 186., 192., 132.],
                    [67., 168., 173., 106.],
                ],
                [
                    [109., 284., 294., 184.],
                    [355., 853., 881., 519.],
                    [391., 937., 965., 567.],
                    [283., 648., 666., 378.],
                ],
                [
                    [145., 374., 384., 238.],
                    [463., 1105., 1133., 663.],
                    [499., 1189., 1217., 711.],
                    [355., 810., 828., 468.],
                ],
                [
                    [181., 410., 419., 236.],
                    [463., 1028., 1050., 580.],
                    [493., 1094., 1116., 616.],
                    [307., 670., 683., 372.],
                ],
            ],
            [
                [
                    [2., 56., 57., 58.],
                    [110., 331., 337., 229.],
                    [116., 349., 355., 241.],
                    [122., 304., 309., 188.],
                ],
                [
                    [218., 555., 565., 347.],
                    [680., 1610., 1638., 952.],
                    [716., 1694., 1722., 1000.],
                    [500., 1135., 1153., 649.],
                ],
                [
                    [254., 645., 655., 401.],
                    [788., 1862., 1890., 1096.],
                    [824., 1946., 1974., 1144.],
                    [572., 1297., 1315., 739.],
                ],
                [
                    [290., 654., 663., 372.],
                    [734., 1623., 1645., 905.],
                    [764., 1689., 1711., 941.],
                    [470., 1022., 1035., 562.],
                ],
            ],
            [
                [
                    [651., 1388., 1405., 750.],
                    [1485., 3150., 3188., 1690.],
                    [1539., 3264., 3302., 1750.],
                    [873., 1840., 1861., 982.],
                ],
                [
                    [1695., 3578., 3620., 1910.],
                    [3789., 7967., 8059., 4233.],
                    [3921., 8243., 8335., 4377.],
                    [2181., 4566., 4616., 2416.],
                ],
                [
                    [1875., 3956., 3998., 2108.],
                    [4185., 8795., 8887., 4665.],
                    [4317., 9071., 9163., 4809.],
                    [2397., 5016., 5066., 2650.],
                ],
                [
                    [1191., 2490., 2515., 1316.],
                    [2613., 5450., 5504., 2870.],
                    [2691., 5612., 5666., 2954.],
                    [1473., 3062., 3091., 1608.],
                ],
            ],
            [
                [
                    [868., 1848., 1865., 994.],
                    [1972., 4177., 4215., 2231.],
                    [2026., 4291., 4329., 2291.],
                    [1144., 2408., 2429., 1280.],
                ],
                [
                    [2236., 4713., 4755., 2505.],
                    [4978., 10452., 10544., 5530.],
                    [5110., 10728., 10820., 5674.],
                    [2830., 5917., 5967., 3119.],
                ],
                [
                    [2416., 5091., 5133., 2703.],
                    [5374., 11280., 11372., 5962.],
                    [5506., 11556., 11648., 6106.],
                    [3046., 6367., 6417., 3353.],
                ],
                [
                    [1516., 3166., 3191., 1668.],
                    [3316., 6909., 6963., 3627.],
                    [3394., 7071., 7125., 3711.],
                    [1852., 3846., 3875., 2014.],
                ],
            ],
            [
                [
                    [1085., 2308., 2325., 1238.],
                    [2459., 5204., 5242., 2772.],
                    [2513., 5318., 5356., 2832.],
                    [1415., 2976., 2997., 1578.],
                ],
                [
                    [2777., 5848., 5890., 3100.],
                    [6167., 12937., 13029., 6827.],
                    [6299., 13213., 13305., 6971.],
                    [3479., 7268., 7318., 3822.],
                ],
                [
                    [2957., 6226., 6268., 3298.],
                    [6563., 13765., 13857., 7259.],
                    [6695., 14041., 14133., 7403.],
                    [3695., 7718., 7768., 4056.],
                ],
                [
                    [1841., 3842., 3867., 2020.],
                    [4019., 8368., 8422., 4384.],
                    [4097., 8530., 8584., 4468.],
                    [2231., 4630., 4659., 2420.],
                ],
            ],
        ]]));
    }

    struct ConvTranspose3dTestCase {
        batch_size: usize,
        channels_in: usize,
        channels_out: usize,
        kernel_size_1: usize,
        kernel_size_2: usize,
        kernel_size_3: usize,
        padding_1: usize,
        padding_2: usize,
        padding_3: usize,
        padding_out_1: usize,
        padding_out_2: usize,
        padding_out_3: usize,
        stride_1: usize,
        stride_2: usize,
        stride_3: usize,
        dilation_1: usize,
        dilation_2: usize,
        dilation_3: usize,
        groups: usize,
        depth: usize,
        height: usize,
        width: usize,
    }

    impl ConvTranspose3dTestCase {
        fn assert_output(self, y: TestTensor<5>) {
            let shape_x = Shape::new([
                self.batch_size,
                self.channels_in,
                self.depth,
                self.height,
                self.width,
            ]);
            let shape_weights = Shape::new([
                self.channels_in,
                self.channels_out / self.groups,
                self.kernel_size_1,
                self.kernel_size_2,
                self.kernel_size_3,
            ]);
            let device = Default::default();
            let weights = TestTensor::from(
                TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
                    .reshape::<5, _>(shape_weights)
                    .into_data(),
            );
            let bias = TestTensor::from(
                TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
            );
            let x = TestTensor::from(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<5, _>(shape_x)
                    .into_data(),
            );
            let output = conv_transpose3d(
                x,
                weights,
                Some(bias),
                ConvTransposeOptions::new(
                    [self.stride_1, self.stride_2, self.stride_3],
                    [self.padding_1, self.padding_2, self.padding_3],
                    [self.padding_out_1, self.padding_out_2, self.padding_out_3],
                    [self.dilation_1, self.dilation_2, self.dilation_3],
                    self.groups,
                ),
            );

            y.to_data().assert_approx_eq(&output.into_data(), 3);
        }
    }
}