neuronika 0.2.0

Tensors and dynamic neural networks.
use super::{Constant, PaddingMode, Reflective, Replicative, Zero};

#[test]
fn constant_pad() {
    let padding = Constant::new(8.);
    let arr = ndarray::Array::range(0., 25., 1.)
        .into_shape((5, 5))
        .unwrap();
    let padded = padding.pad(&arr, [1, 2]);
    assert_eq!(
        padded,
        ndarray::array![
            [8., 8., 8., 8., 8., 8., 8., 8., 8.],
            [8., 8., 0., 1., 2., 3., 4., 8., 8.],
            [8., 8., 5., 6., 7., 8., 9., 8., 8.],
            [8., 8., 10., 11., 12., 13., 14., 8., 8.],
            [8., 8., 15., 16., 17., 18., 19., 8., 8.],
            [8., 8., 20., 21., 22., 23., 24., 8., 8.],
            [8., 8., 8., 8., 8., 8., 8., 8., 8.],
        ]
    );
}

#[test]
fn zero_pad() {
    let padding = Zero;
    let arr = ndarray::Array::range(0., 25., 1.)
        .into_shape((5, 5))
        .unwrap();
    let padded = padding.pad(&arr, [1, 2]);
    assert_eq!(
        padded,
        ndarray::array![
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 1., 2., 3., 4., 0., 0.],
            [0., 0., 5., 6., 7., 8., 9., 0., 0.],
            [0., 0., 10., 11., 12., 13., 14., 0., 0.],
            [0., 0., 15., 16., 17., 18., 19., 0., 0.],
            [0., 0., 20., 21., 22., 23., 24., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        ]
    );
}

#[test]
fn replication_pad_1d() {
    let padding = Replicative;
    let arr = ndarray::Array::range(0., 5., 1.);
    let padded = padding.pad(&arr, [2]);
    assert_eq!(padded, ndarray::array![0., 0., 0., 1., 2., 3., 4., 4., 4.],);
}

#[test]
fn replication_pad_2d() {
    let padding = Replicative;
    let arr = ndarray::Array::range(0., 25., 1.)
        .into_shape((5, 5))
        .unwrap();
    let padded = padding.pad(&arr, [1, 2]);
    assert_eq!(
        padded,
        ndarray::array![
            [0., 0., 0., 1., 2., 3., 4., 4., 4.],
            [0., 0., 0., 1., 2., 3., 4., 4., 4.],
            [5., 5., 5., 6., 7., 8., 9., 9., 9.],
            [10., 10., 10., 11., 12., 13., 14., 14., 14.],
            [15., 15., 15., 16., 17., 18., 19., 19., 19.],
            [20., 20., 20., 21., 22., 23., 24., 24., 24.],
            [20., 20., 20., 21., 22., 23., 24., 24., 24.],
        ]
    );
}

#[test]
fn replication_pad_3d() {
    let padding = Replicative;
    let arr = ndarray::Array::range(0., 125., 1.)
        .into_shape((5, 5, 5))
        .unwrap();
    let padded = padding.pad(&arr, [1, 2, 3]);
    assert_eq!(
        padded,
        ndarray::array![
            [
                [0., 0., 0., 0., 1., 2., 3., 4., 4., 4., 4.],
                [0., 0., 0., 0., 1., 2., 3., 4., 4., 4., 4.],
                [0., 0., 0., 0., 1., 2., 3., 4., 4., 4., 4.],
                [5., 5., 5., 5., 6., 7., 8., 9., 9., 9., 9.],
                [10., 10., 10., 10., 11., 12., 13., 14., 14., 14., 14.],
                [15., 15., 15., 15., 16., 17., 18., 19., 19., 19., 19.],
                [20., 20., 20., 20., 21., 22., 23., 24., 24., 24., 24.],
                [20., 20., 20., 20., 21., 22., 23., 24., 24., 24., 24.],
                [20., 20., 20., 20., 21., 22., 23., 24., 24., 24., 24.]
            ],
            [
                [0., 0., 0., 0., 1., 2., 3., 4., 4., 4., 4.],
                [0., 0., 0., 0., 1., 2., 3., 4., 4., 4., 4.],
                [0., 0., 0., 0., 1., 2., 3., 4., 4., 4., 4.],
                [5., 5., 5., 5., 6., 7., 8., 9., 9., 9., 9.],
                [10., 10., 10., 10., 11., 12., 13., 14., 14., 14., 14.],
                [15., 15., 15., 15., 16., 17., 18., 19., 19., 19., 19.],
                [20., 20., 20., 20., 21., 22., 23., 24., 24., 24., 24.],
                [20., 20., 20., 20., 21., 22., 23., 24., 24., 24., 24.],
                [20., 20., 20., 20., 21., 22., 23., 24., 24., 24., 24.]
            ],
            [
                [25., 25., 25., 25., 26., 27., 28., 29., 29., 29., 29.],
                [25., 25., 25., 25., 26., 27., 28., 29., 29., 29., 29.],
                [25., 25., 25., 25., 26., 27., 28., 29., 29., 29., 29.],
                [30., 30., 30., 30., 31., 32., 33., 34., 34., 34., 34.],
                [35., 35., 35., 35., 36., 37., 38., 39., 39., 39., 39.],
                [40., 40., 40., 40., 41., 42., 43., 44., 44., 44., 44.],
                [45., 45., 45., 45., 46., 47., 48., 49., 49., 49., 49.],
                [45., 45., 45., 45., 46., 47., 48., 49., 49., 49., 49.],
                [45., 45., 45., 45., 46., 47., 48., 49., 49., 49., 49.]
            ],
            [
                [50., 50., 50., 50., 51., 52., 53., 54., 54., 54., 54.],
                [50., 50., 50., 50., 51., 52., 53., 54., 54., 54., 54.],
                [50., 50., 50., 50., 51., 52., 53., 54., 54., 54., 54.],
                [55., 55., 55., 55., 56., 57., 58., 59., 59., 59., 59.],
                [60., 60., 60., 60., 61., 62., 63., 64., 64., 64., 64.],
                [65., 65., 65., 65., 66., 67., 68., 69., 69., 69., 69.],
                [70., 70., 70., 70., 71., 72., 73., 74., 74., 74., 74.],
                [70., 70., 70., 70., 71., 72., 73., 74., 74., 74., 74.],
                [70., 70., 70., 70., 71., 72., 73., 74., 74., 74., 74.]
            ],
            [
                [75., 75., 75., 75., 76., 77., 78., 79., 79., 79., 79.],
                [75., 75., 75., 75., 76., 77., 78., 79., 79., 79., 79.],
                [75., 75., 75., 75., 76., 77., 78., 79., 79., 79., 79.],
                [80., 80., 80., 80., 81., 82., 83., 84., 84., 84., 84.],
                [85., 85., 85., 85., 86., 87., 88., 89., 89., 89., 89.],
                [90., 90., 90., 90., 91., 92., 93., 94., 94., 94., 94.],
                [95., 95., 95., 95., 96., 97., 98., 99., 99., 99., 99.],
                [95., 95., 95., 95., 96., 97., 98., 99., 99., 99., 99.],
                [95., 95., 95., 95., 96., 97., 98., 99., 99., 99., 99.]
            ],
            [
                [100., 100., 100., 100., 101., 102., 103., 104., 104., 104., 104.],
                [100., 100., 100., 100., 101., 102., 103., 104., 104., 104., 104.],
                [100., 100., 100., 100., 101., 102., 103., 104., 104., 104., 104.],
                [105., 105., 105., 105., 106., 107., 108., 109., 109., 109., 109.],
                [110., 110., 110., 110., 111., 112., 113., 114., 114., 114., 114.],
                [115., 115., 115., 115., 116., 117., 118., 119., 119., 119., 119.],
                [120., 120., 120., 120., 121., 122., 123., 124., 124., 124., 124.],
                [120., 120., 120., 120., 121., 122., 123., 124., 124., 124., 124.],
                [120., 120., 120., 120., 121., 122., 123., 124., 124., 124., 124.]
            ],
            [
                [100., 100., 100., 100., 101., 102., 103., 104., 104., 104., 104.],
                [100., 100., 100., 100., 101., 102., 103., 104., 104., 104., 104.],
                [100., 100., 100., 100., 101., 102., 103., 104., 104., 104., 104.],
                [105., 105., 105., 105., 106., 107., 108., 109., 109., 109., 109.],
                [110., 110., 110., 110., 111., 112., 113., 114., 114., 114., 114.],
                [115., 115., 115., 115., 116., 117., 118., 119., 119., 119., 119.],
                [120., 120., 120., 120., 121., 122., 123., 124., 124., 124., 124.],
                [120., 120., 120., 120., 121., 122., 123., 124., 124., 124., 124.],
                [120., 120., 120., 120., 121., 122., 123., 124., 124., 124., 124.]
            ]
        ]
    );
}

#[test]
fn reflection_pad_1d() {
    let padding = Reflective;
    let arr = ndarray::Array::range(0., 5., 1.);
    let padded = padding.pad(&arr, [2]);
    assert_eq!(padded, ndarray::array![2., 1., 0., 1., 2., 3., 4., 3., 2.],);
}

#[test]
fn reflection_pad_2d() {
    let padding = Reflective;
    let arr = ndarray::Array::range(0., 25., 1.)
        .into_shape((5, 5))
        .unwrap();
    let padded = padding.pad(&arr, [1, 2]);
    assert_eq!(
        padded,
        ndarray::array![
            [7., 6., 5., 6., 7., 8., 9., 8., 7.],
            [2., 1., 0., 1., 2., 3., 4., 3., 2.],
            [7., 6., 5., 6., 7., 8., 9., 8., 7.],
            [12., 11., 10., 11., 12., 13., 14., 13., 12.],
            [17., 16., 15., 16., 17., 18., 19., 18., 17.],
            [22., 21., 20., 21., 22., 23., 24., 23., 22.],
            [17., 16., 15., 16., 17., 18., 19., 18., 17.]
        ]
    );
}

#[test]
fn reflection_pad_3d() {
    let padding = Reflective;
    let arr = ndarray::Array::range(0., 125., 1.)
        .into_shape((5, 5, 5))
        .unwrap();
    let padded = padding.pad(&arr, [1, 2, 3]);
    assert_eq!(
        padded,
        ndarray::array![
            [
                [38., 37., 36., 35., 36., 37., 38., 39., 38., 37., 36.],
                [33., 32., 31., 30., 31., 32., 33., 34., 33., 32., 31.],
                [28., 27., 26., 25., 26., 27., 28., 29., 28., 27., 26.],
                [33., 32., 31., 30., 31., 32., 33., 34., 33., 32., 31.],
                [38., 37., 36., 35., 36., 37., 38., 39., 38., 37., 36.],
                [43., 42., 41., 40., 41., 42., 43., 44., 43., 42., 41.],
                [48., 47., 46., 45., 46., 47., 48., 49., 48., 47., 46.],
                [43., 42., 41., 40., 41., 42., 43., 44., 43., 42., 41.],
                [38., 37., 36., 35., 36., 37., 38., 39., 38., 37., 36.]
            ],
            [
                [13., 12., 11., 10., 11., 12., 13., 14., 13., 12., 11.],
                [8., 7., 6., 5., 6., 7., 8., 9., 8., 7., 6.],
                [3., 2., 1., 0., 1., 2., 3., 4., 3., 2., 1.],
                [8., 7., 6., 5., 6., 7., 8., 9., 8., 7., 6.],
                [13., 12., 11., 10., 11., 12., 13., 14., 13., 12., 11.],
                [18., 17., 16., 15., 16., 17., 18., 19., 18., 17., 16.],
                [23., 22., 21., 20., 21., 22., 23., 24., 23., 22., 21.],
                [18., 17., 16., 15., 16., 17., 18., 19., 18., 17., 16.],
                [13., 12., 11., 10., 11., 12., 13., 14., 13., 12., 11.]
            ],
            [
                [38., 37., 36., 35., 36., 37., 38., 39., 38., 37., 36.],
                [33., 32., 31., 30., 31., 32., 33., 34., 33., 32., 31.],
                [28., 27., 26., 25., 26., 27., 28., 29., 28., 27., 26.],
                [33., 32., 31., 30., 31., 32., 33., 34., 33., 32., 31.],
                [38., 37., 36., 35., 36., 37., 38., 39., 38., 37., 36.],
                [43., 42., 41., 40., 41., 42., 43., 44., 43., 42., 41.],
                [48., 47., 46., 45., 46., 47., 48., 49., 48., 47., 46.],
                [43., 42., 41., 40., 41., 42., 43., 44., 43., 42., 41.],
                [38., 37., 36., 35., 36., 37., 38., 39., 38., 37., 36.]
            ],
            [
                [63., 62., 61., 60., 61., 62., 63., 64., 63., 62., 61.],
                [58., 57., 56., 55., 56., 57., 58., 59., 58., 57., 56.],
                [53., 52., 51., 50., 51., 52., 53., 54., 53., 52., 51.],
                [58., 57., 56., 55., 56., 57., 58., 59., 58., 57., 56.],
                [63., 62., 61., 60., 61., 62., 63., 64., 63., 62., 61.],
                [68., 67., 66., 65., 66., 67., 68., 69., 68., 67., 66.],
                [73., 72., 71., 70., 71., 72., 73., 74., 73., 72., 71.],
                [68., 67., 66., 65., 66., 67., 68., 69., 68., 67., 66.],
                [63., 62., 61., 60., 61., 62., 63., 64., 63., 62., 61.]
            ],
            [
                [88., 87., 86., 85., 86., 87., 88., 89., 88., 87., 86.],
                [83., 82., 81., 80., 81., 82., 83., 84., 83., 82., 81.],
                [78., 77., 76., 75., 76., 77., 78., 79., 78., 77., 76.],
                [83., 82., 81., 80., 81., 82., 83., 84., 83., 82., 81.],
                [88., 87., 86., 85., 86., 87., 88., 89., 88., 87., 86.],
                [93., 92., 91., 90., 91., 92., 93., 94., 93., 92., 91.],
                [98., 97., 96., 95., 96., 97., 98., 99., 98., 97., 96.],
                [93., 92., 91., 90., 91., 92., 93., 94., 93., 92., 91.],
                [88., 87., 86., 85., 86., 87., 88., 89., 88., 87., 86.]
            ],
            [
                [113., 112., 111., 110., 111., 112., 113., 114., 113., 112., 111.],
                [108., 107., 106., 105., 106., 107., 108., 109., 108., 107., 106.],
                [103., 102., 101., 100., 101., 102., 103., 104., 103., 102., 101.],
                [108., 107., 106., 105., 106., 107., 108., 109., 108., 107., 106.],
                [113., 112., 111., 110., 111., 112., 113., 114., 113., 112., 111.],
                [118., 117., 116., 115., 116., 117., 118., 119., 118., 117., 116.],
                [123., 122., 121., 120., 121., 122., 123., 124., 123., 122., 121.],
                [118., 117., 116., 115., 116., 117., 118., 119., 118., 117., 116.],
                [113., 112., 111., 110., 111., 112., 113., 114., 113., 112., 111.]
            ],
            [
                [88., 87., 86., 85., 86., 87., 88., 89., 88., 87., 86.],
                [83., 82., 81., 80., 81., 82., 83., 84., 83., 82., 81.],
                [78., 77., 76., 75., 76., 77., 78., 79., 78., 77., 76.],
                [83., 82., 81., 80., 81., 82., 83., 84., 83., 82., 81.],
                [88., 87., 86., 85., 86., 87., 88., 89., 88., 87., 86.],
                [93., 92., 91., 90., 91., 92., 93., 94., 93., 92., 91.],
                [98., 97., 96., 95., 96., 97., 98., 99., 98., 97., 96.],
                [93., 92., 91., 90., 91., 92., 93., 94., 93., 92., 91.],
                [88., 87., 86., 85., 86., 87., 88., 89., 88., 87., 86.]
            ]
        ]
    )
}