use super::*;
mod auxiliary_functions {
use super::*;
#[test]
fn im2col() {
let array = ndarray::array![
[
[0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.],
[12., 13., 14., 15.]
],
[
[0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.],
[12., 13., 14., 15.]
],
[
[0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.],
[12., 13., 14., 15.]
],
];
let im2col = ndarray::array![
[0.0, 1.0, 4.0, 5.0],
[1.0, 2.0, 5.0, 6.0],
[2.0, 3.0, 6.0, 7.0],
[4.0, 5.0, 8.0, 9.0],
[5.0, 6.0, 9.0, 10.0],
[6.0, 7.0, 10.0, 11.0],
[8.0, 9.0, 12.0, 13.0],
[9.0, 10.0, 13.0, 14.0],
[10.0, 11.0, 14.0, 15.0],
[0.0, 1.0, 4.0, 5.0],
[1.0, 2.0, 5.0, 6.0],
[2.0, 3.0, 6.0, 7.0],
[4.0, 5.0, 8.0, 9.0],
[5.0, 6.0, 9.0, 10.0],
[6.0, 7.0, 10.0, 11.0],
[8.0, 9.0, 12.0, 13.0],
[9.0, 10.0, 13.0, 14.0],
[10.0, 11.0, 14.0, 15.0],
[0.0, 1.0, 4.0, 5.0],
[1.0, 2.0, 5.0, 6.0],
[2.0, 3.0, 6.0, 7.0],
[4.0, 5.0, 8.0, 9.0],
[5.0, 6.0, 9.0, 10.0],
[6.0, 7.0, 10.0, 11.0],
[8.0, 9.0, 12.0, 13.0],
[9.0, 10.0, 13.0, 14.0],
[10.0, 11.0, 14.0, 15.0]
];
let input_batch = ndarray::stack(ndarray::Axis(0), &[array.view(), array.view()]).unwrap();
let array_as_image = input_batch.into_shape((2, 3, 4, 4)).unwrap();
assert_eq!(
ndarray::stack(ndarray::Axis(0), &[im2col.t(), im2col.t()]).unwrap(),
as_windows(&array_as_image, &[1, 3, 3, 3], &[1, 1], &[1, 1])
.to_shape(columns_shape(
&array_as_image,
&[1, 3, 3, 3],
&[1, 1],
&[1, 1]
))
.unwrap()
);
}
#[test]
fn flatten() {
use ndarray::stack;
let kernel1 = (0..9)
.map(|el| el as f32)
.collect::<Array<f32, _>>()
.into_shape((3, 3))
.unwrap();
let kernel2 = (9..18)
.map(|el| el as f32)
.collect::<Array<f32, _>>()
.into_shape((3, 3))
.unwrap();
let kernel3 = (18..27)
.map(|el| el as f32)
.collect::<Array<f32, _>>()
.into_shape((3, 3))
.unwrap();
let flattened = stack(Axis(0), &[kernel1.view(), kernel2.view(), kernel3.view()]).unwrap();
let flat_shape = super::super::flat_shape(&flattened);
assert_eq!(
flattened.into_shape(flat_shape).unwrap(),
(0..27)
.map(|el| el as f32)
.collect::<Array<f32, _>>()
.into_shape((3, 9))
.unwrap()
);
}
#[test]
fn apply_and_remove_padding() {
let to_pad = (0..625)
.map(|el| el as f32)
.collect::<Array<f32, _>>()
.into_shape((5, 5, 5, 5))
.unwrap();
let padding = &[4, 2];
let padded_shape = (5, 5, 13, 9);
let padded = pad(&to_pad, padding, &crate::variable::Reflective);
let real_padded_elems = vec![
22., 21., 20., 21., 22., 23., 24., 23., 22., 17., 16., 15., 16., 17., 18., 19., 18.,
17., 12., 11., 10., 11., 12., 13., 14., 13., 12., 7., 6., 5., 6., 7., 8., 9., 8., 7.,
2., 1., 0., 1., 2., 3., 4., 3., 2., 7., 6., 5., 6., 7., 8., 9., 8., 7., 12., 11., 10.,
11., 12., 13., 14., 13., 12., 17., 16., 15., 16., 17., 18., 19., 18., 17., 22., 21.,
20., 21., 22., 23., 24., 23., 22., 17., 16., 15., 16., 17., 18., 19., 18., 17., 12.,
11., 10., 11., 12., 13., 14., 13., 12., 7., 6., 5., 6., 7., 8., 9., 8., 7., 2., 1., 0.,
1., 2., 3., 4., 3., 2., 47., 46., 45., 46., 47., 48., 49., 48., 47., 42., 41., 40.,
41., 42., 43., 44., 43., 42., 37., 36., 35., 36., 37., 38., 39., 38., 37., 32., 31.,
30., 31., 32., 33., 34., 33., 32., 27., 26., 25., 26., 27., 28., 29., 28., 27., 32.,
31., 30., 31., 32., 33., 34., 33., 32., 37., 36., 35., 36., 37., 38., 39., 38., 37.,
42., 41., 40., 41., 42., 43., 44., 43., 42., 47., 46., 45., 46., 47., 48., 49., 48.,
47., 42., 41., 40., 41., 42., 43., 44., 43., 42., 37., 36., 35., 36., 37., 38., 39.,
38., 37., 32., 31., 30., 31., 32., 33., 34., 33., 32., 27., 26., 25., 26., 27., 28.,
29., 28., 27., 72., 71., 70., 71., 72., 73., 74., 73., 72., 67., 66., 65., 66., 67.,
68., 69., 68., 67., 62., 61., 60., 61., 62., 63., 64., 63., 62., 57., 56., 55., 56.,
57., 58., 59., 58., 57., 52., 51., 50., 51., 52., 53., 54., 53., 52., 57., 56., 55.,
56., 57., 58., 59., 58., 57., 62., 61., 60., 61., 62., 63., 64., 63., 62., 67., 66.,
65., 66., 67., 68., 69., 68., 67., 72., 71., 70., 71., 72., 73., 74., 73., 72., 67.,
66., 65., 66., 67., 68., 69., 68., 67., 62., 61., 60., 61., 62., 63., 64., 63., 62.,
57., 56., 55., 56., 57., 58., 59., 58., 57., 52., 51., 50., 51., 52., 53., 54., 53.,
52., 97., 96., 95., 96., 97., 98., 99., 98., 97., 92., 91., 90., 91., 92., 93., 94.,
93., 92., 87., 86., 85., 86., 87., 88., 89., 88., 87., 82., 81., 80., 81., 82., 83.,
84., 83., 82., 77., 76., 75., 76., 77., 78., 79., 78., 77., 82., 81., 80., 81., 82.,
83., 84., 83., 82., 87., 86., 85., 86., 87., 88., 89., 88., 87., 92., 91., 90., 91.,
92., 93., 94., 93., 92., 97., 96., 95., 96., 97., 98., 99., 98., 97., 92., 91., 90.,
91., 92., 93., 94., 93., 92., 87., 86., 85., 86., 87., 88., 89., 88., 87., 82., 81.,
80., 81., 82., 83., 84., 83., 82., 77., 76., 75., 76., 77., 78., 79., 78., 77., 122.,
121., 120., 121., 122., 123., 124., 123., 122., 117., 116., 115., 116., 117., 118.,
119., 118., 117., 112., 111., 110., 111., 112., 113., 114., 113., 112., 107., 106.,
105., 106., 107., 108., 109., 108., 107., 102., 101., 100., 101., 102., 103., 104.,
103., 102., 107., 106., 105., 106., 107., 108., 109., 108., 107., 112., 111., 110.,
111., 112., 113., 114., 113., 112., 117., 116., 115., 116., 117., 118., 119., 118.,
117., 122., 121., 120., 121., 122., 123., 124., 123., 122., 117., 116., 115., 116.,
117., 118., 119., 118., 117., 112., 111., 110., 111., 112., 113., 114., 113., 112.,
107., 106., 105., 106., 107., 108., 109., 108., 107., 102., 101., 100., 101., 102.,
103., 104., 103., 102., 147., 146., 145., 146., 147., 148., 149., 148., 147., 142.,
141., 140., 141., 142., 143., 144., 143., 142., 137., 136., 135., 136., 137., 138.,
139., 138., 137., 132., 131., 130., 131., 132., 133., 134., 133., 132., 127., 126.,
125., 126., 127., 128., 129., 128., 127., 132., 131., 130., 131., 132., 133., 134.,
133., 132., 137., 136., 135., 136., 137., 138., 139., 138., 137., 142., 141., 140.,
141., 142., 143., 144., 143., 142., 147., 146., 145., 146., 147., 148., 149., 148.,
147., 142., 141., 140., 141., 142., 143., 144., 143., 142., 137., 136., 135., 136.,
137., 138., 139., 138., 137., 132., 131., 130., 131., 132., 133., 134., 133., 132.,
127., 126., 125., 126., 127., 128., 129., 128., 127., 172., 171., 170., 171., 172.,
173., 174., 173., 172., 167., 166., 165., 166., 167., 168., 169., 168., 167., 162.,
161., 160., 161., 162., 163., 164., 163., 162., 157., 156., 155., 156., 157., 158.,
159., 158., 157., 152., 151., 150., 151., 152., 153., 154., 153., 152., 157., 156.,
155., 156., 157., 158., 159., 158., 157., 162., 161., 160., 161., 162., 163., 164.,
163., 162., 167., 166., 165., 166., 167., 168., 169., 168., 167., 172., 171., 170.,
171., 172., 173., 174., 173., 172., 167., 166., 165., 166., 167., 168., 169., 168.,
167., 162., 161., 160., 161., 162., 163., 164., 163., 162., 157., 156., 155., 156.,
157., 158., 159., 158., 157., 152., 151., 150., 151., 152., 153., 154., 153., 152.,
197., 196., 195., 196., 197., 198., 199., 198., 197., 192., 191., 190., 191., 192.,
193., 194., 193., 192., 187., 186., 185., 186., 187., 188., 189., 188., 187., 182.,
181., 180., 181., 182., 183., 184., 183., 182., 177., 176., 175., 176., 177., 178.,
179., 178., 177., 182., 181., 180., 181., 182., 183., 184., 183., 182., 187., 186.,
185., 186., 187., 188., 189., 188., 187., 192., 191., 190., 191., 192., 193., 194.,
193., 192., 197., 196., 195., 196., 197., 198., 199., 198., 197., 192., 191., 190.,
191., 192., 193., 194., 193., 192., 187., 186., 185., 186., 187., 188., 189., 188.,
187., 182., 181., 180., 181., 182., 183., 184., 183., 182., 177., 176., 175., 176.,
177., 178., 179., 178., 177., 222., 221., 220., 221., 222., 223., 224., 223., 222.,
217., 216., 215., 216., 217., 218., 219., 218., 217., 212., 211., 210., 211., 212.,
213., 214., 213., 212., 207., 206., 205., 206., 207., 208., 209., 208., 207., 202.,
201., 200., 201., 202., 203., 204., 203., 202., 207., 206., 205., 206., 207., 208.,
209., 208., 207., 212., 211., 210., 211., 212., 213., 214., 213., 212., 217., 216.,
215., 216., 217., 218., 219., 218., 217., 222., 221., 220., 221., 222., 223., 224.,
223., 222., 217., 216., 215., 216., 217., 218., 219., 218., 217., 212., 211., 210.,
211., 212., 213., 214., 213., 212., 207., 206., 205., 206., 207., 208., 209., 208.,
207., 202., 201., 200., 201., 202., 203., 204., 203., 202., 247., 246., 245., 246.,
247., 248., 249., 248., 247., 242., 241., 240., 241., 242., 243., 244., 243., 242.,
237., 236., 235., 236., 237., 238., 239., 238., 237., 232., 231., 230., 231., 232.,
233., 234., 233., 232., 227., 226., 225., 226., 227., 228., 229., 228., 227., 232.,
231., 230., 231., 232., 233., 234., 233., 232., 237., 236., 235., 236., 237., 238.,
239., 238., 237., 242., 241., 240., 241., 242., 243., 244., 243., 242., 247., 246.,
245., 246., 247., 248., 249., 248., 247., 242., 241., 240., 241., 242., 243., 244.,
243., 242., 237., 236., 235., 236., 237., 238., 239., 238., 237., 232., 231., 230.,
231., 232., 233., 234., 233., 232., 227., 226., 225., 226., 227., 228., 229., 228.,
227., 272., 271., 270., 271., 272., 273., 274., 273., 272., 267., 266., 265., 266.,
267., 268., 269., 268., 267., 262., 261., 260., 261., 262., 263., 264., 263., 262.,
257., 256., 255., 256., 257., 258., 259., 258., 257., 252., 251., 250., 251., 252.,
253., 254., 253., 252., 257., 256., 255., 256., 257., 258., 259., 258., 257., 262.,
261., 260., 261., 262., 263., 264., 263., 262., 267., 266., 265., 266., 267., 268.,
269., 268., 267., 272., 271., 270., 271., 272., 273., 274., 273., 272., 267., 266.,
265., 266., 267., 268., 269., 268., 267., 262., 261., 260., 261., 262., 263., 264.,
263., 262., 257., 256., 255., 256., 257., 258., 259., 258., 257., 252., 251., 250.,
251., 252., 253., 254., 253., 252., 297., 296., 295., 296., 297., 298., 299., 298.,
297., 292., 291., 290., 291., 292., 293., 294., 293., 292., 287., 286., 285., 286.,
287., 288., 289., 288., 287., 282., 281., 280., 281., 282., 283., 284., 283., 282.,
277., 276., 275., 276., 277., 278., 279., 278., 277., 282., 281., 280., 281., 282.,
283., 284., 283., 282., 287., 286., 285., 286., 287., 288., 289., 288., 287., 292.,
291., 290., 291., 292., 293., 294., 293., 292., 297., 296., 295., 296., 297., 298.,
299., 298., 297., 292., 291., 290., 291., 292., 293., 294., 293., 292., 287., 286.,
285., 286., 287., 288., 289., 288., 287., 282., 281., 280., 281., 282., 283., 284.,
283., 282., 277., 276., 275., 276., 277., 278., 279., 278., 277., 322., 321., 320.,
321., 322., 323., 324., 323., 322., 317., 316., 315., 316., 317., 318., 319., 318.,
317., 312., 311., 310., 311., 312., 313., 314., 313., 312., 307., 306., 305., 306.,
307., 308., 309., 308., 307., 302., 301., 300., 301., 302., 303., 304., 303., 302.,
307., 306., 305., 306., 307., 308., 309., 308., 307., 312., 311., 310., 311., 312.,
313., 314., 313., 312., 317., 316., 315., 316., 317., 318., 319., 318., 317., 322.,
321., 320., 321., 322., 323., 324., 323., 322., 317., 316., 315., 316., 317., 318.,
319., 318., 317., 312., 311., 310., 311., 312., 313., 314., 313., 312., 307., 306.,
305., 306., 307., 308., 309., 308., 307., 302., 301., 300., 301., 302., 303., 304.,
303., 302., 347., 346., 345., 346., 347., 348., 349., 348., 347., 342., 341., 340.,
341., 342., 343., 344., 343., 342., 337., 336., 335., 336., 337., 338., 339., 338.,
337., 332., 331., 330., 331., 332., 333., 334., 333., 332., 327., 326., 325., 326.,
327., 328., 329., 328., 327., 332., 331., 330., 331., 332., 333., 334., 333., 332.,
337., 336., 335., 336., 337., 338., 339., 338., 337., 342., 341., 340., 341., 342.,
343., 344., 343., 342., 347., 346., 345., 346., 347., 348., 349., 348., 347., 342.,
341., 340., 341., 342., 343., 344., 343., 342., 337., 336., 335., 336., 337., 338.,
339., 338., 337., 332., 331., 330., 331., 332., 333., 334., 333., 332., 327., 326.,
325., 326., 327., 328., 329., 328., 327., 372., 371., 370., 371., 372., 373., 374.,
373., 372., 367., 366., 365., 366., 367., 368., 369., 368., 367., 362., 361., 360.,
361., 362., 363., 364., 363., 362., 357., 356., 355., 356., 357., 358., 359., 358.,
357., 352., 351., 350., 351., 352., 353., 354., 353., 352., 357., 356., 355., 356.,
357., 358., 359., 358., 357., 362., 361., 360., 361., 362., 363., 364., 363., 362.,
367., 366., 365., 366., 367., 368., 369., 368., 367., 372., 371., 370., 371., 372.,
373., 374., 373., 372., 367., 366., 365., 366., 367., 368., 369., 368., 367., 362.,
361., 360., 361., 362., 363., 364., 363., 362., 357., 356., 355., 356., 357., 358.,
359., 358., 357., 352., 351., 350., 351., 352., 353., 354., 353., 352., 397., 396.,
395., 396., 397., 398., 399., 398., 397., 392., 391., 390., 391., 392., 393., 394.,
393., 392., 387., 386., 385., 386., 387., 388., 389., 388., 387., 382., 381., 380.,
381., 382., 383., 384., 383., 382., 377., 376., 375., 376., 377., 378., 379., 378.,
377., 382., 381., 380., 381., 382., 383., 384., 383., 382., 387., 386., 385., 386.,
387., 388., 389., 388., 387., 392., 391., 390., 391., 392., 393., 394., 393., 392.,
397., 396., 395., 396., 397., 398., 399., 398., 397., 392., 391., 390., 391., 392.,
393., 394., 393., 392., 387., 386., 385., 386., 387., 388., 389., 388., 387., 382.,
381., 380., 381., 382., 383., 384., 383., 382., 377., 376., 375., 376., 377., 378.,
379., 378., 377., 422., 421., 420., 421., 422., 423., 424., 423., 422., 417., 416.,
415., 416., 417., 418., 419., 418., 417., 412., 411., 410., 411., 412., 413., 414.,
413., 412., 407., 406., 405., 406., 407., 408., 409., 408., 407., 402., 401., 400.,
401., 402., 403., 404., 403., 402., 407., 406., 405., 406., 407., 408., 409., 408.,
407., 412., 411., 410., 411., 412., 413., 414., 413., 412., 417., 416., 415., 416.,
417., 418., 419., 418., 417., 422., 421., 420., 421., 422., 423., 424., 423., 422.,
417., 416., 415., 416., 417., 418., 419., 418., 417., 412., 411., 410., 411., 412.,
413., 414., 413., 412., 407., 406., 405., 406., 407., 408., 409., 408., 407., 402.,
401., 400., 401., 402., 403., 404., 403., 402., 447., 446., 445., 446., 447., 448.,
449., 448., 447., 442., 441., 440., 441., 442., 443., 444., 443., 442., 437., 436.,
435., 436., 437., 438., 439., 438., 437., 432., 431., 430., 431., 432., 433., 434.,
433., 432., 427., 426., 425., 426., 427., 428., 429., 428., 427., 432., 431., 430.,
431., 432., 433., 434., 433., 432., 437., 436., 435., 436., 437., 438., 439., 438.,
437., 442., 441., 440., 441., 442., 443., 444., 443., 442., 447., 446., 445., 446.,
447., 448., 449., 448., 447., 442., 441., 440., 441., 442., 443., 444., 443., 442.,
437., 436., 435., 436., 437., 438., 439., 438., 437., 432., 431., 430., 431., 432.,
433., 434., 433., 432., 427., 426., 425., 426., 427., 428., 429., 428., 427., 472.,
471., 470., 471., 472., 473., 474., 473., 472., 467., 466., 465., 466., 467., 468.,
469., 468., 467., 462., 461., 460., 461., 462., 463., 464., 463., 462., 457., 456.,
455., 456., 457., 458., 459., 458., 457., 452., 451., 450., 451., 452., 453., 454.,
453., 452., 457., 456., 455., 456., 457., 458., 459., 458., 457., 462., 461., 460.,
461., 462., 463., 464., 463., 462., 467., 466., 465., 466., 467., 468., 469., 468.,
467., 472., 471., 470., 471., 472., 473., 474., 473., 472., 467., 466., 465., 466.,
467., 468., 469., 468., 467., 462., 461., 460., 461., 462., 463., 464., 463., 462.,
457., 456., 455., 456., 457., 458., 459., 458., 457., 452., 451., 450., 451., 452.,
453., 454., 453., 452., 497., 496., 495., 496., 497., 498., 499., 498., 497., 492.,
491., 490., 491., 492., 493., 494., 493., 492., 487., 486., 485., 486., 487., 488.,
489., 488., 487., 482., 481., 480., 481., 482., 483., 484., 483., 482., 477., 476.,
475., 476., 477., 478., 479., 478., 477., 482., 481., 480., 481., 482., 483., 484.,
483., 482., 487., 486., 485., 486., 487., 488., 489., 488., 487., 492., 491., 490.,
491., 492., 493., 494., 493., 492., 497., 496., 495., 496., 497., 498., 499., 498.,
497., 492., 491., 490., 491., 492., 493., 494., 493., 492., 487., 486., 485., 486.,
487., 488., 489., 488., 487., 482., 481., 480., 481., 482., 483., 484., 483., 482.,
477., 476., 475., 476., 477., 478., 479., 478., 477., 522., 521., 520., 521., 522.,
523., 524., 523., 522., 517., 516., 515., 516., 517., 518., 519., 518., 517., 512.,
511., 510., 511., 512., 513., 514., 513., 512., 507., 506., 505., 506., 507., 508.,
509., 508., 507., 502., 501., 500., 501., 502., 503., 504., 503., 502., 507., 506.,
505., 506., 507., 508., 509., 508., 507., 512., 511., 510., 511., 512., 513., 514.,
513., 512., 517., 516., 515., 516., 517., 518., 519., 518., 517., 522., 521., 520.,
521., 522., 523., 524., 523., 522., 517., 516., 515., 516., 517., 518., 519., 518.,
517., 512., 511., 510., 511., 512., 513., 514., 513., 512., 507., 506., 505., 506.,
507., 508., 509., 508., 507., 502., 501., 500., 501., 502., 503., 504., 503., 502.,
547., 546., 545., 546., 547., 548., 549., 548., 547., 542., 541., 540., 541., 542.,
543., 544., 543., 542., 537., 536., 535., 536., 537., 538., 539., 538., 537., 532.,
531., 530., 531., 532., 533., 534., 533., 532., 527., 526., 525., 526., 527., 528.,
529., 528., 527., 532., 531., 530., 531., 532., 533., 534., 533., 532., 537., 536.,
535., 536., 537., 538., 539., 538., 537., 542., 541., 540., 541., 542., 543., 544.,
543., 542., 547., 546., 545., 546., 547., 548., 549., 548., 547., 542., 541., 540.,
541., 542., 543., 544., 543., 542., 537., 536., 535., 536., 537., 538., 539., 538.,
537., 532., 531., 530., 531., 532., 533., 534., 533., 532., 527., 526., 525., 526.,
527., 528., 529., 528., 527., 572., 571., 570., 571., 572., 573., 574., 573., 572.,
567., 566., 565., 566., 567., 568., 569., 568., 567., 562., 561., 560., 561., 562.,
563., 564., 563., 562., 557., 556., 555., 556., 557., 558., 559., 558., 557., 552.,
551., 550., 551., 552., 553., 554., 553., 552., 557., 556., 555., 556., 557., 558.,
559., 558., 557., 562., 561., 560., 561., 562., 563., 564., 563., 562., 567., 566.,
565., 566., 567., 568., 569., 568., 567., 572., 571., 570., 571., 572., 573., 574.,
573., 572., 567., 566., 565., 566., 567., 568., 569., 568., 567., 562., 561., 560.,
561., 562., 563., 564., 563., 562., 557., 556., 555., 556., 557., 558., 559., 558.,
557., 552., 551., 550., 551., 552., 553., 554., 553., 552., 597., 596., 595., 596.,
597., 598., 599., 598., 597., 592., 591., 590., 591., 592., 593., 594., 593., 592.,
587., 586., 585., 586., 587., 588., 589., 588., 587., 582., 581., 580., 581., 582.,
583., 584., 583., 582., 577., 576., 575., 576., 577., 578., 579., 578., 577., 582.,
581., 580., 581., 582., 583., 584., 583., 582., 587., 586., 585., 586., 587., 588.,
589., 588., 587., 592., 591., 590., 591., 592., 593., 594., 593., 592., 597., 596.,
595., 596., 597., 598., 599., 598., 597., 592., 591., 590., 591., 592., 593., 594.,
593., 592., 587., 586., 585., 586., 587., 588., 589., 588., 587., 582., 581., 580.,
581., 582., 583., 584., 583., 582., 577., 576., 575., 576., 577., 578., 579., 578.,
577., 622., 621., 620., 621., 622., 623., 624., 623., 622., 617., 616., 615., 616.,
617., 618., 619., 618., 617., 612., 611., 610., 611., 612., 613., 614., 613., 612.,
607., 606., 605., 606., 607., 608., 609., 608., 607., 602., 601., 600., 601., 602.,
603., 604., 603., 602., 607., 606., 605., 606., 607., 608., 609., 608., 607., 612.,
611., 610., 611., 612., 613., 614., 613., 612., 617., 616., 615., 616., 617., 618.,
619., 618., 617., 622., 621., 620., 621., 622., 623., 624., 623., 622., 617., 616.,
615., 616., 617., 618., 619., 618., 617., 612., 611., 610., 611., 612., 613., 614.,
613., 612., 607., 606., 605., 606., 607., 608., 609., 608., 607., 602., 601., 600.,
601., 602., 603., 604., 603., 602.,
];
assert_eq!(
padded,
Array::from_shape_vec(padded_shape, real_padded_elems).unwrap()
);
assert_eq!(unpad(&padded, padding), to_pad);
}
#[test]
fn conv_args_ok() {
let conv_input = ndarray::Array::<f32, _>::zeros((1, 2, 4, 4));
check_conv_args(conv_input.shape(), &[1, 2, 2, 2], &[0, 0], &[1, 1], &[1, 1]);
}
#[test]
#[should_panic(expected = "error: invalid kernel's shape [1, 2, 2] for 2d conv")]
fn conv_args_invalid_kernel() {
let conv_input = ndarray::Array::<f32, _>::zeros((1, 2, 4, 4));
check_conv_args(conv_input.shape(), &[1, 2, 2], &[0, 0], &[1, 1], &[1, 1]);
}
#[test]
fn conv_groups_args_ok() {
check_groups_args(&[3, 3, 10, 10], &[3, 3, 3, 3], 3);
}
#[test]
#[should_panic]
fn conv_groups_args_panic() {
check_groups_args(&[3, 3, 10, 10], &[3, 3, 3, 3], 5);
}
}
mod convolution_numeric {
use super::super::*;
#[test]
fn conv1d() {
use ndarray::prelude::*;
use ndarray::Ix3;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((5, 3, 10)).unwrap();
let kernel = Array::<f32, _>::ones((6, 3, 5));
let stride = &[1];
let padding = &[0];
let dilation = &[1];
let conv_out_shape =
conv_out_shape::<Ix3>(input.shape(), kernel.shape(), padding, stride, dilation);
let true_output_elems = vec![
180., 195., 210., 225., 240., 255., 180., 195., 210., 225., 240., 255., 180., 195.,
210., 225., 240., 255., 180., 195., 210., 225., 240., 255., 180., 195., 210., 225.,
240., 255., 180., 195., 210., 225., 240., 255., 630., 645., 660., 675., 690., 705.,
630., 645., 660., 675., 690., 705., 630., 645., 660., 675., 690., 705., 630., 645.,
660., 675., 690., 705., 630., 645., 660., 675., 690., 705., 630., 645., 660., 675.,
690., 705., 1080., 1095., 1110., 1125., 1140., 1155., 1080., 1095., 1110., 1125.,
1140., 1155., 1080., 1095., 1110., 1125., 1140., 1155., 1080., 1095., 1110., 1125.,
1140., 1155., 1080., 1095., 1110., 1125., 1140., 1155., 1080., 1095., 1110., 1125.,
1140., 1155., 1530., 1545., 1560., 1575., 1590., 1605., 1530., 1545., 1560., 1575.,
1590., 1605., 1530., 1545., 1560., 1575., 1590., 1605., 1530., 1545., 1560., 1575.,
1590., 1605., 1530., 1545., 1560., 1575., 1590., 1605., 1530., 1545., 1560., 1575.,
1590., 1605., 1980., 1995., 2010., 2025., 2040., 2055., 1980., 1995., 2010., 2025.,
2040., 2055., 1980., 1995., 2010., 2025., 2040., 2055., 1980., 1995., 2010., 2025.,
2040., 2055., 1980., 1995., 2010., 2025., 2040., 2055., 1980., 1995., 2010., 2025.,
2040., 2055.,
];
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems = vec![
6., 12., 18., 24., 30., 30., 24., 18., 12., 6., 6., 12., 18., 24., 30., 30., 24., 18.,
12., 6., 6., 12., 18., 24., 30., 30., 24., 18., 12., 6., 6., 12., 18., 24., 30., 30.,
24., 18., 12., 6., 6., 12., 18., 24., 30., 30., 24., 18., 12., 6., 6., 12., 18., 24.,
30., 30., 24., 18., 12., 6., 6., 12., 18., 24., 30., 30., 24., 18., 12., 6., 6., 12.,
18., 24., 30., 30., 24., 18., 12., 6., 6., 12., 18., 24., 30., 30., 24., 18., 12., 6.,
6., 12., 18., 24., 30., 30., 24., 18., 12., 6., 6., 12., 18., 24., 30., 30., 24., 18.,
12., 6., 6., 12., 18., 24., 30., 30., 24., 18., 12., 6., 6., 12., 18., 24., 30., 30.,
24., 18., 12., 6., 6., 12., 18., 24., 30., 30., 24., 18., 12., 6., 6., 12., 18., 24.,
30., 30., 24., 18., 12., 6.,
];
let true_kernel_grad_elems = array![
[
[1875., 1905., 1935., 1965., 1995.],
[2175., 2205., 2235., 2265., 2295.],
[2475., 2505., 2535., 2565., 2595.],
],
[
[1875., 1905., 1935., 1965., 1995.],
[2175., 2205., 2235., 2265., 2295.],
[2475., 2505., 2535., 2565., 2595.],
],
[
[1875., 1905., 1935., 1965., 1995.],
[2175., 2205., 2235., 2265., 2295.],
[2475., 2505., 2535., 2565., 2595.],
],
[
[1875., 1905., 1935., 1965., 1995.],
[2175., 2205., 2235., 2265., 2295.],
[2475., 2505., 2535., 2565., 2595.],
],
[
[1875., 1905., 1935., 1965., 1995.],
[2175., 2205., 2235., 2265., 2295.],
[2475., 2505., 2535., 2565., 2595.],
],
[
[1875., 1905., 1935., 1965., 1995.],
[2175., 2205., 2235., 2265., 2295.],
[2475., 2505., 2535., 2565., 2595.],
],
];
assert_eq!(
input_grad,
Array::from_shape_vec(input.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(kernel_grad, true_kernel_grad_elems);
}
#[test]
fn conv2d() {
use ndarray::Ix4;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((3, 2, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((3, 2, 2, 2));
let stride = &[1, 1];
let padding = &[0, 0];
let dilation = &[1, 1];
let conv_out_shape =
conv_out_shape::<Ix4>(input.shape(), kernel.shape(), padding, stride, dilation);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
let true_output_elems: Vec<f32> = vec![
124., 132., 140., 148., 164., 172., 180., 188., 204., 212., 220., 228., 244., 252.,
260., 268., 124., 132., 140., 148., 164., 172., 180., 188., 204., 212., 220., 228.,
244., 252., 260., 268., 124., 132., 140., 148., 164., 172., 180., 188., 204., 212.,
220., 228., 244., 252., 260., 268., 524., 532., 540., 548., 564., 572., 580., 588.,
604., 612., 620., 628., 644., 652., 660., 668., 524., 532., 540., 548., 564., 572.,
580., 588., 604., 612., 620., 628., 644., 652., 660., 668., 524., 532., 540., 548.,
564., 572., 580., 588., 604., 612., 620., 628., 644., 652., 660., 668., 924., 932.,
940., 948., 964., 972., 980., 988., 1004., 1012., 1020., 1028., 1044., 1052., 1060.,
1068., 924., 932., 940., 948., 964., 972., 980., 988., 1004., 1012., 1020., 1028.,
1044., 1052., 1060., 1068., 924., 932., 940., 948., 964., 972., 980., 988., 1004.,
1012., 1020., 1028., 1044., 1052., 1060., 1068.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros((3, 2, 5, 5));
let mut kernel_grad = Array::<f32, _>::zeros((3, 2, 2, 2));
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems: Vec<f32> = vec![
3., 6., 6., 6., 3., 6., 12., 12., 12., 6., 6., 12., 12., 12., 6., 6., 12., 12., 12.,
6., 3., 6., 6., 6., 3., 3., 6., 6., 6., 3., 6., 12., 12., 12., 6., 6., 12., 12., 12.,
6., 6., 12., 12., 12., 6., 3., 6., 6., 6., 3., 3., 6., 6., 6., 3., 6., 12., 12., 12.,
6., 6., 12., 12., 12., 6., 6., 12., 12., 12., 6., 3., 6., 6., 6., 3., 3., 6., 6., 6.,
3., 6., 12., 12., 12., 6., 6., 12., 12., 12., 6., 6., 12., 12., 12., 6., 3., 6., 6.,
6., 3., 3., 6., 6., 6., 3., 6., 12., 12., 12., 6., 6., 12., 12., 12., 6., 6., 12., 12.,
12., 6., 3., 6., 6., 6., 3., 3., 6., 6., 6., 3., 6., 12., 12., 12., 6., 6., 12., 12.,
12., 6., 6., 12., 12., 12., 6., 3., 6., 6., 6., 3.,
];
let true_kernel_grad_elems: Vec<f32> = vec![
2832., 2880., 3072., 3120., 4032., 4080., 4272., 4320., 2832., 2880., 3072., 3120.,
4032., 4080., 4272., 4320., 2832., 2880., 3072., 3120., 4032., 4080., 4272., 4320.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
#[test]
fn conv3d() {
let input_elems = (0..750).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((2, 3, 5, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((4, 3, 2, 2, 2));
let stride = &[1, 1, 1];
let padding = &[0, 0, 0];
let dilation = &[1, 1, 1];
let conv_out_shape = conv_out_shape::<ndarray::Ix5>(
input.shape(),
kernel.shape(),
padding,
stride,
dilation,
);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
let true_output_elems = vec![
3372., 3396., 3420., 3444., 3492., 3516., 3540., 3564., 3612., 3636., 3660., 3684.,
3732., 3756., 3780., 3804., 3972., 3996., 4020., 4044., 4092., 4116., 4140., 4164.,
4212., 4236., 4260., 4284., 4332., 4356., 4380., 4404., 4572., 4596., 4620., 4644.,
4692., 4716., 4740., 4764., 4812., 4836., 4860., 4884., 4932., 4956., 4980., 5004.,
5172., 5196., 5220., 5244., 5292., 5316., 5340., 5364., 5412., 5436., 5460., 5484.,
5532., 5556., 5580., 5604., 3372., 3396., 3420., 3444., 3492., 3516., 3540., 3564.,
3612., 3636., 3660., 3684., 3732., 3756., 3780., 3804., 3972., 3996., 4020., 4044.,
4092., 4116., 4140., 4164., 4212., 4236., 4260., 4284., 4332., 4356., 4380., 4404.,
4572., 4596., 4620., 4644., 4692., 4716., 4740., 4764., 4812., 4836., 4860., 4884.,
4932., 4956., 4980., 5004., 5172., 5196., 5220., 5244., 5292., 5316., 5340., 5364.,
5412., 5436., 5460., 5484., 5532., 5556., 5580., 5604., 3372., 3396., 3420., 3444.,
3492., 3516., 3540., 3564., 3612., 3636., 3660., 3684., 3732., 3756., 3780., 3804.,
3972., 3996., 4020., 4044., 4092., 4116., 4140., 4164., 4212., 4236., 4260., 4284.,
4332., 4356., 4380., 4404., 4572., 4596., 4620., 4644., 4692., 4716., 4740., 4764.,
4812., 4836., 4860., 4884., 4932., 4956., 4980., 5004., 5172., 5196., 5220., 5244.,
5292., 5316., 5340., 5364., 5412., 5436., 5460., 5484., 5532., 5556., 5580., 5604.,
3372., 3396., 3420., 3444., 3492., 3516., 3540., 3564., 3612., 3636., 3660., 3684.,
3732., 3756., 3780., 3804., 3972., 3996., 4020., 4044., 4092., 4116., 4140., 4164.,
4212., 4236., 4260., 4284., 4332., 4356., 4380., 4404., 4572., 4596., 4620., 4644.,
4692., 4716., 4740., 4764., 4812., 4836., 4860., 4884., 4932., 4956., 4980., 5004.,
5172., 5196., 5220., 5244., 5292., 5316., 5340., 5364., 5412., 5436., 5460., 5484.,
5532., 5556., 5580., 5604., 12372., 12396., 12420., 12444., 12492., 12516., 12540.,
12564., 12612., 12636., 12660., 12684., 12732., 12756., 12780., 12804., 12972., 12996.,
13020., 13044., 13092., 13116., 13140., 13164., 13212., 13236., 13260., 13284., 13332.,
13356., 13380., 13404., 13572., 13596., 13620., 13644., 13692., 13716., 13740., 13764.,
13812., 13836., 13860., 13884., 13932., 13956., 13980., 14004., 14172., 14196., 14220.,
14244., 14292., 14316., 14340., 14364., 14412., 14436., 14460., 14484., 14532., 14556.,
14580., 14604., 12372., 12396., 12420., 12444., 12492., 12516., 12540., 12564., 12612.,
12636., 12660., 12684., 12732., 12756., 12780., 12804., 12972., 12996., 13020., 13044.,
13092., 13116., 13140., 13164., 13212., 13236., 13260., 13284., 13332., 13356., 13380.,
13404., 13572., 13596., 13620., 13644., 13692., 13716., 13740., 13764., 13812., 13836.,
13860., 13884., 13932., 13956., 13980., 14004., 14172., 14196., 14220., 14244., 14292.,
14316., 14340., 14364., 14412., 14436., 14460., 14484., 14532., 14556., 14580., 14604.,
12372., 12396., 12420., 12444., 12492., 12516., 12540., 12564., 12612., 12636., 12660.,
12684., 12732., 12756., 12780., 12804., 12972., 12996., 13020., 13044., 13092., 13116.,
13140., 13164., 13212., 13236., 13260., 13284., 13332., 13356., 13380., 13404., 13572.,
13596., 13620., 13644., 13692., 13716., 13740., 13764., 13812., 13836., 13860., 13884.,
13932., 13956., 13980., 14004., 14172., 14196., 14220., 14244., 14292., 14316., 14340.,
14364., 14412., 14436., 14460., 14484., 14532., 14556., 14580., 14604., 12372., 12396.,
12420., 12444., 12492., 12516., 12540., 12564., 12612., 12636., 12660., 12684., 12732.,
12756., 12780., 12804., 12972., 12996., 13020., 13044., 13092., 13116., 13140., 13164.,
13212., 13236., 13260., 13284., 13332., 13356., 13380., 13404., 13572., 13596., 13620.,
13644., 13692., 13716., 13740., 13764., 13812., 13836., 13860., 13884., 13932., 13956.,
13980., 14004., 14172., 14196., 14220., 14244., 14292., 14316., 14340., 14364., 14412.,
14436., 14460., 14484., 14532., 14556., 14580., 14604.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems = vec![
4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32.,
32., 16., 16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 16.,
32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8., 16., 16.,
16., 8., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16.,
32., 32., 32., 16., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8.,
8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4.,
8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8.,
4., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32.,
32., 32., 16., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16.,
16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 8., 16., 16.,
16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8.,
16., 16., 16., 8., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8.,
8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8.,
8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 8., 16., 16., 16.,
8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8., 16.,
16., 16., 8., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16.,
16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 16., 32., 32.,
32., 16., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 4.,
8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8.,
4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8.,
8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 16., 32., 32., 32.,
16., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 8., 16.,
16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32., 32., 32.,
16., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32.,
32., 32., 16., 16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 8.,
16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4.,
4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32.,
32., 16., 16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 16.,
32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8., 16., 16.,
16., 8., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16.,
32., 32., 32., 16., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8.,
8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4.,
8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8.,
4., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32.,
32., 32., 16., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 16., 32., 32., 32., 16.,
16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8., 16., 16., 16., 8., 8., 16., 16.,
16., 8., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 16., 32., 32., 32., 16., 8.,
16., 16., 16., 8., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8.,
8., 16., 16., 16., 8., 4., 8., 8., 8., 4.,
];
let true_kernel_grad_elems = vec![
29952., 30080., 30592., 30720., 33152., 33280., 33792., 33920., 45952., 46080., 46592.,
46720., 49152., 49280., 49792., 49920., 61952., 62080., 62592., 62720., 65152., 65280.,
65792., 65920., 29952., 30080., 30592., 30720., 33152., 33280., 33792., 33920., 45952.,
46080., 46592., 46720., 49152., 49280., 49792., 49920., 61952., 62080., 62592., 62720.,
65152., 65280., 65792., 65920., 29952., 30080., 30592., 30720., 33152., 33280., 33792.,
33920., 45952., 46080., 46592., 46720., 49152., 49280., 49792., 49920., 61952., 62080.,
62592., 62720., 65152., 65280., 65792., 65920., 29952., 30080., 30592., 30720., 33152.,
33280., 33792., 33920., 45952., 46080., 46592., 46720., 49152., 49280., 49792., 49920.,
61952., 62080., 62592., 62720., 65152., 65280., 65792., 65920.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input_grad.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel_grad.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
#[test]
fn conv1d_strided() {
use ndarray::prelude::*;
use ndarray::Ix3;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((5, 3, 10)).unwrap();
let kernel = Array::<f32, _>::ones((6, 3, 5));
let stride = &[2];
let padding = &[0];
let dilation = &[1];
let conv_out_shape =
conv_out_shape::<Ix3>(input.shape(), kernel.shape(), padding, stride, dilation);
let true_output_elems = vec![
180., 210., 240., 180., 210., 240., 180., 210., 240., 180., 210., 240., 180., 210.,
240., 180., 210., 240., 630., 660., 690., 630., 660., 690., 630., 660., 690., 630.,
660., 690., 630., 660., 690., 630., 660., 690., 1080., 1110., 1140., 1080., 1110.,
1140., 1080., 1110., 1140., 1080., 1110., 1140., 1080., 1110., 1140., 1080., 1110.,
1140., 1530., 1560., 1590., 1530., 1560., 1590., 1530., 1560., 1590., 1530., 1560.,
1590., 1530., 1560., 1590., 1530., 1560., 1590., 1980., 2010., 2040., 1980., 2010.,
2040., 1980., 2010., 2040., 1980., 2010., 2040., 1980., 2010., 2040., 1980., 2010.,
2040.,
];
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems = vec![
6., 6., 12., 12., 18., 12., 12., 6., 6., 0., 6., 6., 12., 12., 18., 12., 12., 6., 6.,
0., 6., 6., 12., 12., 18., 12., 12., 6., 6., 0., 6., 6., 12., 12., 18., 12., 12., 6.,
6., 0., 6., 6., 12., 12., 18., 12., 12., 6., 6., 0., 6., 6., 12., 12., 18., 12., 12.,
6., 6., 0., 6., 6., 12., 12., 18., 12., 12., 6., 6., 0., 6., 6., 12., 12., 18., 12.,
12., 6., 6., 0., 6., 6., 12., 12., 18., 12., 12., 6., 6., 0., 6., 6., 12., 12., 18.,
12., 12., 6., 6., 0., 6., 6., 12., 12., 18., 12., 12., 6., 6., 0., 6., 6., 12., 12.,
18., 12., 12., 6., 6., 0., 6., 6., 12., 12., 18., 12., 12., 6., 6., 0., 6., 6., 12.,
12., 18., 12., 12., 6., 6., 0., 6., 6., 12., 12., 18., 12., 12., 6., 6., 0.,
];
let true_kernel_grad_elems = array![
[
[930., 945., 960., 975., 990.],
[1080., 1095., 1110., 1125., 1140.],
[1230., 1245., 1260., 1275., 1290.]
],
[
[930., 945., 960., 975., 990.],
[1080., 1095., 1110., 1125., 1140.],
[1230., 1245., 1260., 1275., 1290.]
],
[
[930., 945., 960., 975., 990.],
[1080., 1095., 1110., 1125., 1140.],
[1230., 1245., 1260., 1275., 1290.]
],
[
[930., 945., 960., 975., 990.],
[1080., 1095., 1110., 1125., 1140.],
[1230., 1245., 1260., 1275., 1290.]
],
[
[930., 945., 960., 975., 990.],
[1080., 1095., 1110., 1125., 1140.],
[1230., 1245., 1260., 1275., 1290.]
],
[
[930., 945., 960., 975., 990.],
[1080., 1095., 1110., 1125., 1140.],
[1230., 1245., 1260., 1275., 1290.]
]
];
assert_eq!(
input_grad,
Array::from_shape_vec(input.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(kernel_grad, true_kernel_grad_elems);
}
#[test]
fn conv2d_strided() {
use ndarray::Ix4;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((3, 2, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((3, 2, 2, 2));
let stride = &[2, 2];
let padding = &[0, 0];
let dilation = &[1, 1];
let conv_out_shape =
conv_out_shape::<Ix4>(input.shape(), kernel.shape(), padding, stride, dilation);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
let true_output_elems: Vec<f32> = vec![
124., 140., 204., 220., 124., 140., 204., 220., 124., 140., 204., 220., 524., 540.,
604., 620., 524., 540., 604., 620., 524., 540., 604., 620., 924., 940., 1004., 1020.,
924., 940., 1004., 1020., 924., 940., 1004., 1020.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros((3, 2, 5, 5));
let mut kernel_grad = Array::<f32, _>::zeros((3, 2, 2, 2));
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems: Vec<f32> = vec![
3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 0., 0.,
0., 0., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3.,
0., 0., 0., 0., 0., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 3.,
3., 3., 3., 0., 0., 0., 0., 0., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 3., 3., 3.,
3., 0., 3., 3., 3., 3., 0., 0., 0., 0., 0., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0.,
3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 0., 0., 0., 0., 0., 3., 3., 3., 3., 0., 3., 3.,
3., 3., 0., 3., 3., 3., 3., 0., 3., 3., 3., 3., 0., 0., 0., 0., 0., 0.,
];
let true_kernel_grad_elems: Vec<f32> = vec![
672., 684., 732., 744., 972., 984., 1032., 1044., 672., 684., 732., 744., 972., 984.,
1032., 1044., 672., 684., 732., 744., 972., 984., 1032., 1044.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
#[test]
fn conv3d_strided() {
let input_elems = (0..750).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((2, 3, 5, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((4, 3, 2, 2, 2));
let stride = &[1, 2, 3];
let padding = &[0, 0, 0];
let dilation = &[1, 1, 1];
let conv_out_shape = conv_out_shape::<ndarray::Ix5>(
input.shape(),
kernel.shape(),
padding,
stride,
dilation,
);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
let true_output_elems = vec![
3372., 3444., 3612., 3684., 3972., 4044., 4212., 4284., 4572., 4644., 4812., 4884.,
5172., 5244., 5412., 5484., 3372., 3444., 3612., 3684., 3972., 4044., 4212., 4284.,
4572., 4644., 4812., 4884., 5172., 5244., 5412., 5484., 3372., 3444., 3612., 3684.,
3972., 4044., 4212., 4284., 4572., 4644., 4812., 4884., 5172., 5244., 5412., 5484.,
3372., 3444., 3612., 3684., 3972., 4044., 4212., 4284., 4572., 4644., 4812., 4884.,
5172., 5244., 5412., 5484., 12372., 12444., 12612., 12684., 12972., 13044., 13212.,
13284., 13572., 13644., 13812., 13884., 14172., 14244., 14412., 14484., 12372., 12444.,
12612., 12684., 12972., 13044., 13212., 13284., 13572., 13644., 13812., 13884., 14172.,
14244., 14412., 14484., 12372., 12444., 12612., 12684., 12972., 13044., 13212., 13284.,
13572., 13644., 13812., 13884., 14172., 14244., 14412., 14484., 12372., 12444., 12612.,
12684., 12972., 13044., 13212., 13284., 13572., 13644., 13812., 13884., 14172., 14244.,
14412., 14484.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems = vec![
4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0., 0.,
0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8.,
8., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8.,
8., 0., 8., 8., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0.,
8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4.,
4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4., 4.,
0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0., 0., 0., 0., 0., 8., 8., 0., 8.,
8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 8.,
8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0.,
0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8.,
0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4.,
0., 4., 4., 0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4.,
4., 4., 4., 0., 4., 4., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8.,
8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0.,
8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8.,
8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 4., 4.,
0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0., 0., 0., 0.,
0., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0.,
0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0.,
8., 8., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8.,
8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8.,
0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4., 4., 0., 4.,
4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4.,
4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0., 0., 0., 0., 0., 8., 8., 0.,
8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0.,
8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0.,
0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8.,
8., 0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4.,
4., 0., 4., 4., 0., 0., 0., 0., 0., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0.,
4., 4., 4., 4., 0., 4., 4., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8.,
8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 8., 8., 0., 8., 8., 8., 8.,
0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 8., 8., 0., 8.,
8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 8., 8., 0., 8., 8., 0., 0., 0., 0., 0., 4.,
4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 4., 4., 0., 4., 4., 0., 0., 0.,
0., 0.,
];
let true_kernel_grad_elems = vec![
7408., 7440., 7568., 7600., 8208., 8240., 8368., 8400., 11408., 11440., 11568., 11600.,
12208., 12240., 12368., 12400., 15408., 15440., 15568., 15600., 16208., 16240., 16368.,
16400., 7408., 7440., 7568., 7600., 8208., 8240., 8368., 8400., 11408., 11440., 11568.,
11600., 12208., 12240., 12368., 12400., 15408., 15440., 15568., 15600., 16208., 16240.,
16368., 16400., 7408., 7440., 7568., 7600., 8208., 8240., 8368., 8400., 11408., 11440.,
11568., 11600., 12208., 12240., 12368., 12400., 15408., 15440., 15568., 15600., 16208.,
16240., 16368., 16400., 7408., 7440., 7568., 7600., 8208., 8240., 8368., 8400., 11408.,
11440., 11568., 11600., 12208., 12240., 12368., 12400., 15408., 15440., 15568., 15600.,
16208., 16240., 16368., 16400.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input_grad.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel_grad.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
#[test]
fn conv1d_dilated() {
use ndarray::prelude::*;
use ndarray::Ix3;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((5, 3, 10)).unwrap();
let kernel = Array::<f32, _>::ones((6, 3, 5));
let stride = &[2];
let padding = &[0];
let dilation = &[2];
let conv_out_shape =
conv_out_shape::<Ix3>(input.shape(), kernel.shape(), padding, stride, dilation);
let true_output_elems = vec![
210., 210., 210., 210., 210., 210., 660., 660., 660., 660., 660., 660., 1110., 1110.,
1110., 1110., 1110., 1110., 1560., 1560., 1560., 1560., 1560., 1560., 2010., 2010.,
2010., 2010., 2010., 2010.,
];
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems = vec![
6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0.,
6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0.,
6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0.,
6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0.,
6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0.,
6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0.,
6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0., 6., 0.,
];
let true_kernel_grad_elems = array![
[
[300., 310., 320., 330., 340.],
[350., 360., 370., 380., 390.],
[400., 410., 420., 430., 440.]
],
[
[300., 310., 320., 330., 340.],
[350., 360., 370., 380., 390.],
[400., 410., 420., 430., 440.]
],
[
[300., 310., 320., 330., 340.],
[350., 360., 370., 380., 390.],
[400., 410., 420., 430., 440.]
],
[
[300., 310., 320., 330., 340.],
[350., 360., 370., 380., 390.],
[400., 410., 420., 430., 440.]
],
[
[300., 310., 320., 330., 340.],
[350., 360., 370., 380., 390.],
[400., 410., 420., 430., 440.]
],
[
[300., 310., 320., 330., 340.],
[350., 360., 370., 380., 390.],
[400., 410., 420., 430., 440.]
]
];
assert_eq!(
input_grad,
Array::from_shape_vec(input.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(kernel_grad, true_kernel_grad_elems);
}
#[test]
fn conv2d_dilated() {
use ndarray::Ix4;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((3, 2, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((3, 2, 2, 2));
let stride = &[2, 2];
let padding = &[0, 0];
let dilation = &[2, 2];
let conv_out_shape =
conv_out_shape::<Ix4>(input.shape(), kernel.shape(), padding, stride, dilation);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
let true_output_elems: Vec<f32> = vec![
148., 164., 228., 244., 148., 164., 228., 244., 148., 164., 228., 244., 548., 564.,
628., 644., 548., 564., 628., 644., 548., 564., 628., 644., 948., 964., 1028., 1044.,
948., 964., 1028., 1044., 948., 964., 1028., 1044.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros((3, 2, 5, 5));
let mut kernel_grad = Array::<f32, _>::zeros((3, 2, 2, 2));
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems: Vec<f32> = vec![
3., 0., 6., 0., 3., 0., 0., 0., 0., 0., 6., 0., 12., 0., 6., 0., 0., 0., 0., 0., 3.,
0., 6., 0., 3., 3., 0., 6., 0., 3., 0., 0., 0., 0., 0., 6., 0., 12., 0., 6., 0., 0.,
0., 0., 0., 3., 0., 6., 0., 3., 3., 0., 6., 0., 3., 0., 0., 0., 0., 0., 6., 0., 12.,
0., 6., 0., 0., 0., 0., 0., 3., 0., 6., 0., 3., 3., 0., 6., 0., 3., 0., 0., 0., 0., 0.,
6., 0., 12., 0., 6., 0., 0., 0., 0., 0., 3., 0., 6., 0., 3., 3., 0., 6., 0., 3., 0.,
0., 0., 0., 0., 6., 0., 12., 0., 6., 0., 0., 0., 0., 0., 3., 0., 6., 0., 3., 3., 0.,
6., 0., 3., 0., 0., 0., 0., 0., 6., 0., 12., 0., 6., 0., 0., 0., 0., 0., 3., 0., 6.,
0., 3.,
];
let true_kernel_grad_elems: Vec<f32> = vec![
672., 696., 792., 816., 972., 996., 1092., 1116., 672., 696., 792., 816., 972., 996.,
1092., 1116., 672., 696., 792., 816., 972., 996., 1092., 1116.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
#[test]
fn conv3d_dilated() {
let input_elems = (0..750).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((2, 3, 5, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((4, 3, 2, 2, 2));
let stride = &[1, 2, 3];
let padding = &[0, 0, 0];
let dilation = &[1, 2, 2];
let conv_out_shape = conv_out_shape::<ndarray::Ix5>(
input.shape(),
kernel.shape(),
padding,
stride,
dilation,
);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&input, &kernel, &mut conv_out, stride, dilation);
let true_output_elems = vec![
3444., 3684., 4044., 4284., 4644., 4884., 5244., 5484., 3444., 3684., 4044., 4284.,
4644., 4884., 5244., 5484., 3444., 3684., 4044., 4284., 4644., 4884., 5244., 5484.,
3444., 3684., 4044., 4284., 4644., 4884., 5244., 5484., 12444., 12684., 13044., 13284.,
13644., 13884., 14244., 14484., 12444., 12684., 13044., 13284., 13644., 13884., 14244.,
14484., 12444., 12684., 13044., 13284., 13644., 13884., 14244., 14484., 12444., 12684.,
13044., 13284., 13644., 13884., 14244., 14484.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&input,
stride,
dilation,
true,
);
let true_input_grad_elems = vec![
4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0.,
4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0.,
0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0.,
0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0.,
16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0.,
0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16.,
0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0.,
0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8.,
0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0.,
0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4.,
0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0.,
0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0.,
16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0.,
0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0.,
4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0.,
0.,
];
let true_kernel_grad_elems = vec![
3680., 3712., 3840., 3872., 4080., 4112., 4240., 4272., 5680., 5712., 5840., 5872.,
6080., 6112., 6240., 6272., 7680., 7712., 7840., 7872., 8080., 8112., 8240., 8272.,
3680., 3712., 3840., 3872., 4080., 4112., 4240., 4272., 5680., 5712., 5840., 5872.,
6080., 6112., 6240., 6272., 7680., 7712., 7840., 7872., 8080., 8112., 8240., 8272.,
3680., 3712., 3840., 3872., 4080., 4112., 4240., 4272., 5680., 5712., 5840., 5872.,
6080., 6112., 6240., 6272., 7680., 7712., 7840., 7872., 8080., 8112., 8240., 8272.,
3680., 3712., 3840., 3872., 4080., 4112., 4240., 4272., 5680., 5712., 5840., 5872.,
6080., 6112., 6240., 6272., 7680., 7712., 7840., 7872., 8080., 8112., 8240., 8272.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input_grad.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel_grad.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
#[test]
fn conv2d_padded() {
use ndarray::Ix4;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((3, 2, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((3, 2, 2, 2));
let stride = &[1, 1];
let padding = &[3, 1];
let dilation = &[1, 1];
let padded_input = pad(&input, padding, &crate::variable::Zero);
assert_eq!(padded_input.shape(), &[3, 2, 11, 7]);
let conv_out_shape =
conv_out_shape_padded::<Ix4>(padded_input.shape(), kernel.shape(), stride, dilation);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution(&padded_input, &kernel, &mut conv_out, stride, dilation);
let true_output_elems = vec![
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 52., 56., 60., 64., 33., 60.,
124., 132., 140., 148., 76., 80., 164., 172., 180., 188., 96., 100., 204., 212., 220.,
228., 116., 120., 244., 252., 260., 268., 136., 65., 132., 136., 140., 144., 73., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 25., 52., 56., 60., 64., 33., 60., 124., 132., 140., 148., 76., 80., 164., 172.,
180., 188., 96., 100., 204., 212., 220., 228., 116., 120., 244., 252., 260., 268.,
136., 65., 132., 136., 140., 144., 73., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 52., 56., 60., 64., 33., 60.,
124., 132., 140., 148., 76., 80., 164., 172., 180., 188., 96., 100., 204., 212., 220.,
228., 116., 120., 244., 252., 260., 268., 136., 65., 132., 136., 140., 144., 73., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 125., 252., 256., 260., 264., 133., 260., 524., 532., 540., 548., 276., 280., 564.,
572., 580., 588., 296., 300., 604., 612., 620., 628., 316., 320., 644., 652., 660.,
668., 336., 165., 332., 336., 340., 344., 173., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 125., 252., 256., 260., 264.,
133., 260., 524., 532., 540., 548., 276., 280., 564., 572., 580., 588., 296., 300.,
604., 612., 620., 628., 316., 320., 644., 652., 660., 668., 336., 165., 332., 336.,
340., 344., 173., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 125., 252., 256., 260., 264., 133., 260., 524., 532., 540.,
548., 276., 280., 564., 572., 580., 588., 296., 300., 604., 612., 620., 628., 316.,
320., 644., 652., 660., 668., 336., 165., 332., 336., 340., 344., 173., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 225.,
452., 456., 460., 464., 233., 460., 924., 932., 940., 948., 476., 480., 964., 972.,
980., 988., 496., 500., 1004., 1012., 1020., 1028., 516., 520., 1044., 1052., 1060.,
1068., 536., 265., 532., 536., 540., 544., 273., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 225., 452., 456., 460.,
464., 233., 460., 924., 932., 940., 948., 476., 480., 964., 972., 980., 988., 496.,
500., 1004., 1012., 1020., 1028., 516., 520., 1044., 1052., 1060., 1068., 536., 265.,
532., 536., 540., 544., 273., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 225., 452., 456., 460., 464., 233., 460., 924.,
932., 940., 948., 476., 480., 964., 972., 980., 988., 496., 500., 1004., 1012., 1020.,
1028., 516., 520., 1044., 1052., 1060., 1068., 536., 265., 532., 536., 540., 544.,
273., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_backward_input(
&mut input_grad,
&conv_out_grad,
&kernel,
padding,
stride,
dilation,
true,
);
convolution_backward_kernel(
&mut kernel_grad,
&conv_out_grad,
&padded_input,
stride,
dilation,
true,
);
let true_kernel_grad_elems = vec![
4650., 4650., 4650., 4650., 6525., 6525., 6525., 6525., 4650., 4650., 4650., 4650.,
6525., 6525., 6525., 6525., 4650., 4650., 4650., 4650., 6525., 6525., 6525., 6525.,
];
let true_input_grad_elems = vec![
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input_grad.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel_grad.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
#[test]
fn grouped_conv1d() {
use ndarray::prelude::*;
use ndarray::Ix3;
let input_elems = (0..150).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((5, 3, 10)).unwrap();
let kernel = Array::<f32, _>::ones((6, 1, 5));
let stride = &[2];
let padding = &[0];
let dilation = &[2];
let groups = 3;
let conv_out_shape =
conv_out_shape::<Ix3>(input.shape(), kernel.shape(), padding, stride, dilation);
let true_output_elems = vec![
20., 20., 70., 70., 120., 120., 170., 170., 220., 220., 270., 270., 320., 320., 370.,
370., 420., 420., 470., 470., 520., 520., 570., 570., 620., 620., 670., 670., 720.,
720.,
];
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution_with_groups(&input, &kernel, &mut conv_out, stride, dilation, groups);
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_with_groups_backward(
&mut input_grad,
&mut kernel_grad,
&conv_out_grad,
&input,
&kernel,
padding,
stride,
dilation,
groups,
true,
true,
);
let true_input_grad_elems = vec![
2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0.,
2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0.,
2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0.,
2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0.,
2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0.,
2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0.,
2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0., 2., 0.,
];
let true_kernel_grad_elems = array![
[[300., 310., 320., 330., 340.]],
[[300., 310., 320., 330., 340.]],
[[350., 360., 370., 380., 390.]],
[[350., 360., 370., 380., 390.]],
[[400., 410., 420., 430., 440.]],
[[400., 410., 420., 430., 440.]]
];
assert_eq!(
input_grad,
Array::from_shape_vec(input.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(kernel_grad, true_kernel_grad_elems);
}
#[test]
fn grouped_conv2d() {
use ndarray::Ix4;
let input: Array<f32, Ix4> = (0..800)
.map(|el| el as f32)
.collect::<Array<f32, _>>()
.into_shape((4, 8, 5, 5))
.unwrap();
let kernel = Array::<f32, _>::ones((8, 4, 2, 2));
let stride = &[1, 1];
let padding = &[0, 0];
let dilation = &[1, 1];
let groups = 2;
let conv_out_shape =
conv_out_shape::<Ix4>(input.shape(), kernel.shape(), padding, stride, dilation);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution_with_groups(&input, &kernel, &mut conv_out, stride, dilation, groups);
let true_output_elems = vec![
648., 664., 680., 696., 728., 744., 760., 776., 808., 824., 840., 856., 888., 904.,
920., 936., 648., 664., 680., 696., 728., 744., 760., 776., 808., 824., 840., 856.,
888., 904., 920., 936., 648., 664., 680., 696., 728., 744., 760., 776., 808., 824.,
840., 856., 888., 904., 920., 936., 648., 664., 680., 696., 728., 744., 760., 776.,
808., 824., 840., 856., 888., 904., 920., 936., 2248., 2264., 2280., 2296., 2328.,
2344., 2360., 2376., 2408., 2424., 2440., 2456., 2488., 2504., 2520., 2536., 2248.,
2264., 2280., 2296., 2328., 2344., 2360., 2376., 2408., 2424., 2440., 2456., 2488.,
2504., 2520., 2536., 2248., 2264., 2280., 2296., 2328., 2344., 2360., 2376., 2408.,
2424., 2440., 2456., 2488., 2504., 2520., 2536., 2248., 2264., 2280., 2296., 2328.,
2344., 2360., 2376., 2408., 2424., 2440., 2456., 2488., 2504., 2520., 2536., 3848.,
3864., 3880., 3896., 3928., 3944., 3960., 3976., 4008., 4024., 4040., 4056., 4088.,
4104., 4120., 4136., 3848., 3864., 3880., 3896., 3928., 3944., 3960., 3976., 4008.,
4024., 4040., 4056., 4088., 4104., 4120., 4136., 3848., 3864., 3880., 3896., 3928.,
3944., 3960., 3976., 4008., 4024., 4040., 4056., 4088., 4104., 4120., 4136., 3848.,
3864., 3880., 3896., 3928., 3944., 3960., 3976., 4008., 4024., 4040., 4056., 4088.,
4104., 4120., 4136., 5448., 5464., 5480., 5496., 5528., 5544., 5560., 5576., 5608.,
5624., 5640., 5656., 5688., 5704., 5720., 5736., 5448., 5464., 5480., 5496., 5528.,
5544., 5560., 5576., 5608., 5624., 5640., 5656., 5688., 5704., 5720., 5736., 5448.,
5464., 5480., 5496., 5528., 5544., 5560., 5576., 5608., 5624., 5640., 5656., 5688.,
5704., 5720., 5736., 5448., 5464., 5480., 5496., 5528., 5544., 5560., 5576., 5608.,
5624., 5640., 5656., 5688., 5704., 5720., 5736., 7048., 7064., 7080., 7096., 7128.,
7144., 7160., 7176., 7208., 7224., 7240., 7256., 7288., 7304., 7320., 7336., 7048.,
7064., 7080., 7096., 7128., 7144., 7160., 7176., 7208., 7224., 7240., 7256., 7288.,
7304., 7320., 7336., 7048., 7064., 7080., 7096., 7128., 7144., 7160., 7176., 7208.,
7224., 7240., 7256., 7288., 7304., 7320., 7336., 7048., 7064., 7080., 7096., 7128.,
7144., 7160., 7176., 7208., 7224., 7240., 7256., 7288., 7304., 7320., 7336., 8648.,
8664., 8680., 8696., 8728., 8744., 8760., 8776., 8808., 8824., 8840., 8856., 8888.,
8904., 8920., 8936., 8648., 8664., 8680., 8696., 8728., 8744., 8760., 8776., 8808.,
8824., 8840., 8856., 8888., 8904., 8920., 8936., 8648., 8664., 8680., 8696., 8728.,
8744., 8760., 8776., 8808., 8824., 8840., 8856., 8888., 8904., 8920., 8936., 8648.,
8664., 8680., 8696., 8728., 8744., 8760., 8776., 8808., 8824., 8840., 8856., 8888.,
8904., 8920., 8936., 10248., 10264., 10280., 10296., 10328., 10344., 10360., 10376.,
10408., 10424., 10440., 10456., 10488., 10504., 10520., 10536., 10248., 10264., 10280.,
10296., 10328., 10344., 10360., 10376., 10408., 10424., 10440., 10456., 10488., 10504.,
10520., 10536., 10248., 10264., 10280., 10296., 10328., 10344., 10360., 10376., 10408.,
10424., 10440., 10456., 10488., 10504., 10520., 10536., 10248., 10264., 10280., 10296.,
10328., 10344., 10360., 10376., 10408., 10424., 10440., 10456., 10488., 10504., 10520.,
10536., 11848., 11864., 11880., 11896., 11928., 11944., 11960., 11976., 12008., 12024.,
12040., 12056., 12088., 12104., 12120., 12136., 11848., 11864., 11880., 11896., 11928.,
11944., 11960., 11976., 12008., 12024., 12040., 12056., 12088., 12104., 12120., 12136.,
11848., 11864., 11880., 11896., 11928., 11944., 11960., 11976., 12008., 12024., 12040.,
12056., 12088., 12104., 12120., 12136., 11848., 11864., 11880., 11896., 11928., 11944.,
11960., 11976., 12008., 12024., 12040., 12056., 12088., 12104., 12120., 12136.,
];
assert_eq!(
conv_out,
Array::from_shape_vec(conv_out.raw_dim(), true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros((4, 8, 5, 5));
let mut kernel_grad = Array::<f32, _>::zeros((8, 4, 2, 2));
let d_out = Array::<f32, _>::ones(conv_out.raw_dim());
convolution_with_groups_backward(
&mut input_grad,
&mut kernel_grad,
&d_out,
&input,
&kernel,
padding,
stride,
dilation,
groups,
true,
true,
);
let true_kernel_grad_elems: Vec<f32> = vec![
19776., 19840., 20096., 20160., 21376., 21440., 21696., 21760., 22976., 23040., 23296.,
23360., 24576., 24640., 24896., 24960., 19776., 19840., 20096., 20160., 21376., 21440.,
21696., 21760., 22976., 23040., 23296., 23360., 24576., 24640., 24896., 24960., 19776.,
19840., 20096., 20160., 21376., 21440., 21696., 21760., 22976., 23040., 23296., 23360.,
24576., 24640., 24896., 24960., 19776., 19840., 20096., 20160., 21376., 21440., 21696.,
21760., 22976., 23040., 23296., 23360., 24576., 24640., 24896., 24960., 26176., 26240.,
26496., 26560., 27776., 27840., 28096., 28160., 29376., 29440., 29696., 29760., 30976.,
31040., 31296., 31360., 26176., 26240., 26496., 26560., 27776., 27840., 28096., 28160.,
29376., 29440., 29696., 29760., 30976., 31040., 31296., 31360., 26176., 26240., 26496.,
26560., 27776., 27840., 28096., 28160., 29376., 29440., 29696., 29760., 30976., 31040.,
31296., 31360., 26176., 26240., 26496., 26560., 27776., 27840., 28096., 28160., 29376.,
29440., 29696., 29760., 30976., 31040., 31296., 31360.,
];
let true_input_grad_elems: Vec<f32> = vec![
4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16.,
8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8.,
4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8.,
8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16.,
16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16.,
16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16.,
16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8.,
8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8.,
8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16.,
16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16.,
16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16.,
16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8.,
8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4.,
8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8.,
16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8.,
16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8.,
16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4.,
4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16.,
8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8.,
4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8.,
8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16.,
16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16.,
16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16.,
16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8.,
8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8.,
8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16.,
16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16.,
16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16.,
16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8.,
8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4.,
8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8.,
16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8.,
16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8.,
16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4.,
4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 4., 8., 8., 8., 4., 4., 8., 8., 8., 4., 8., 16., 16., 16., 8., 8., 16., 16., 16.,
8., 8., 16., 16., 16., 8., 4., 8., 8., 8., 4.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input_grad.raw_dim(), true_input_grad_elems).unwrap()
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel_grad.raw_dim(), true_kernel_grad_elems).unwrap()
);
}
#[test]
fn grouped_conv3d() {
let input_elems = (0..2_000).map(|el| el as f32).collect::<Array<f32, _>>();
let input = input_elems.into_shape((2, 8, 5, 5, 5)).unwrap();
let kernel = Array::<f32, _>::ones((16, 2, 2, 2, 2));
let stride = &[1, 2, 3];
let padding = &[0, 0, 0];
let dilation = &[1, 2, 2];
let groups = 4;
let conv_out_shape = conv_out_shape::<ndarray::Ix5>(
input.shape(),
kernel.shape(),
padding,
stride,
dilation,
);
let mut conv_out = Array::<f32, _>::zeros(conv_out_shape);
convolution_with_groups(&input, &kernel, &mut conv_out, stride, dilation, groups);
let true_output_elems = vec![
1296., 1456., 1696., 1856., 2096., 2256., 2496., 2656., 1296., 1456., 1696., 1856.,
2096., 2256., 2496., 2656., 1296., 1456., 1696., 1856., 2096., 2256., 2496., 2656.,
1296., 1456., 1696., 1856., 2096., 2256., 2496., 2656., 5296., 5456., 5696., 5856.,
6096., 6256., 6496., 6656., 5296., 5456., 5696., 5856., 6096., 6256., 6496., 6656.,
5296., 5456., 5696., 5856., 6096., 6256., 6496., 6656., 5296., 5456., 5696., 5856.,
6096., 6256., 6496., 6656., 9296., 9456., 9696., 9856., 10096., 10256., 10496., 10656.,
9296., 9456., 9696., 9856., 10096., 10256., 10496., 10656., 9296., 9456., 9696., 9856.,
10096., 10256., 10496., 10656., 9296., 9456., 9696., 9856., 10096., 10256., 10496.,
10656., 13296., 13456., 13696., 13856., 14096., 14256., 14496., 14656., 13296., 13456.,
13696., 13856., 14096., 14256., 14496., 14656., 13296., 13456., 13696., 13856., 14096.,
14256., 14496., 14656., 13296., 13456., 13696., 13856., 14096., 14256., 14496., 14656.,
17296., 17456., 17696., 17856., 18096., 18256., 18496., 18656., 17296., 17456., 17696.,
17856., 18096., 18256., 18496., 18656., 17296., 17456., 17696., 17856., 18096., 18256.,
18496., 18656., 17296., 17456., 17696., 17856., 18096., 18256., 18496., 18656., 21296.,
21456., 21696., 21856., 22096., 22256., 22496., 22656., 21296., 21456., 21696., 21856.,
22096., 22256., 22496., 22656., 21296., 21456., 21696., 21856., 22096., 22256., 22496.,
22656., 21296., 21456., 21696., 21856., 22096., 22256., 22496., 22656., 25296., 25456.,
25696., 25856., 26096., 26256., 26496., 26656., 25296., 25456., 25696., 25856., 26096.,
26256., 26496., 26656., 25296., 25456., 25696., 25856., 26096., 26256., 26496., 26656.,
25296., 25456., 25696., 25856., 26096., 26256., 26496., 26656., 29296., 29456., 29696.,
29856., 30096., 30256., 30496., 30656., 29296., 29456., 29696., 29856., 30096., 30256.,
30496., 30656., 29296., 29456., 29696., 29856., 30096., 30256., 30496., 30656., 29296.,
29456., 29696., 29856., 30096., 30256., 30496., 30656.,
];
assert_eq!(
conv_out,
Array::<f32, _>::from_shape_vec(conv_out_shape, true_output_elems).unwrap()
);
let mut input_grad = Array::<f32, _>::zeros(input.raw_dim());
let mut kernel_grad = Array::<f32, _>::zeros(kernel.raw_dim());
let conv_out_grad = Array::<f32, _>::ones(conv_out_shape);
convolution_with_groups_backward(
&mut input_grad,
&mut kernel_grad,
&conv_out_grad,
&input,
&kernel,
padding,
stride,
dilation,
groups,
true,
true,
);
let true_input_grad_elems = vec![
4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0.,
4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0.,
0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0.,
0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0.,
16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0.,
0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16.,
0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0.,
0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8.,
0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0.,
0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4.,
0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0.,
0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0.,
16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0.,
0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0.,
4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0.,
0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4.,
0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0.,
16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0.,
0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16.,
0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0.,
0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8.,
0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0.,
0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4.,
0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0.,
0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0.,
16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0.,
0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0.,
4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0.,
0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4.,
0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0.,
16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0.,
0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16.,
0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0.,
0., 0., 0., 8., 0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 0., 0., 0., 0., 0., 4., 0., 4., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 8., 0., 8., 0.,
0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0.,
8., 0., 8., 0., 0., 0., 0., 0., 0., 0., 16., 0., 16., 0., 0., 0., 0., 0., 0., 0., 8.,
0., 8., 0., 0., 4., 0., 4., 0., 0., 0., 0., 0., 0., 0., 8., 0., 8., 0., 0., 0., 0., 0.,
0., 0., 4., 0., 4., 0., 0.,
];
let true_kernel_grad_elems = vec![
8680., 8712., 8840., 8872., 9080., 9112., 9240., 9272., 10680., 10712., 10840., 10872.,
11080., 11112., 11240., 11272., 8680., 8712., 8840., 8872., 9080., 9112., 9240., 9272.,
10680., 10712., 10840., 10872., 11080., 11112., 11240., 11272., 8680., 8712., 8840.,
8872., 9080., 9112., 9240., 9272., 10680., 10712., 10840., 10872., 11080., 11112.,
11240., 11272., 8680., 8712., 8840., 8872., 9080., 9112., 9240., 9272., 10680., 10712.,
10840., 10872., 11080., 11112., 11240., 11272., 12680., 12712., 12840., 12872., 13080.,
13112., 13240., 13272., 14680., 14712., 14840., 14872., 15080., 15112., 15240., 15272.,
12680., 12712., 12840., 12872., 13080., 13112., 13240., 13272., 14680., 14712., 14840.,
14872., 15080., 15112., 15240., 15272., 12680., 12712., 12840., 12872., 13080., 13112.,
13240., 13272., 14680., 14712., 14840., 14872., 15080., 15112., 15240., 15272., 12680.,
12712., 12840., 12872., 13080., 13112., 13240., 13272., 14680., 14712., 14840., 14872.,
15080., 15112., 15240., 15272., 16680., 16712., 16840., 16872., 17080., 17112., 17240.,
17272., 18680., 18712., 18840., 18872., 19080., 19112., 19240., 19272., 16680., 16712.,
16840., 16872., 17080., 17112., 17240., 17272., 18680., 18712., 18840., 18872., 19080.,
19112., 19240., 19272., 16680., 16712., 16840., 16872., 17080., 17112., 17240., 17272.,
18680., 18712., 18840., 18872., 19080., 19112., 19240., 19272., 16680., 16712., 16840.,
16872., 17080., 17112., 17240., 17272., 18680., 18712., 18840., 18872., 19080., 19112.,
19240., 19272., 20680., 20712., 20840., 20872., 21080., 21112., 21240., 21272., 22680.,
22712., 22840., 22872., 23080., 23112., 23240., 23272., 20680., 20712., 20840., 20872.,
21080., 21112., 21240., 21272., 22680., 22712., 22840., 22872., 23080., 23112., 23240.,
23272., 20680., 20712., 20840., 20872., 21080., 21112., 21240., 21272., 22680., 22712.,
22840., 22872., 23080., 23112., 23240., 23272., 20680., 20712., 20840., 20872., 21080.,
21112., 21240., 21272., 22680., 22712., 22840., 22872., 23080., 23112., 23240., 23272.,
];
assert_eq!(
input_grad,
Array::from_shape_vec(input_grad.raw_dim(), true_input_grad_elems).unwrap(),
);
assert_eq!(
kernel_grad,
Array::from_shape_vec(kernel_grad.raw_dim(), true_kernel_grad_elems).unwrap(),
);
}
}