1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use crate::{
    shapes::{Axes, Dtype, ReduceShape, Shape},
    tensor::{HasErr, Tape, Tensor},
};

use super::{BroadcastTo, Device, MeanTo, TryAdd, TryDiv, TrySub};

/// Normalizes `t` to have mean `0.0` and stddev `1.0` along `Ax`. `epsilon` is used during stddev.
/// Computes `(t - t.mean(Ax)) / t.std(Ax, epsilon)`.
///
/// Normalizing a single axis:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
/// let _ = t.normalize::<Axis<1>>(1e-5);
/// ```
pub fn normalize<Ax: Axes, S: Shape + ReduceShape<Ax>, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    t: Tensor<S, E, D, T>,
    epsilon: E,
) -> Tensor<S, E, D, T> {
    t.normalize::<Ax>(epsilon)
}

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Tensor<S, E, D, T> {
    /// See [normalize]
    pub fn normalize<Ax: Axes>(self, epsilon: E) -> Self
    where
        S: ReduceShape<Ax>,
    {
        self.try_normalize::<Ax>(epsilon).unwrap()
    }

    /// See [normalize]
    pub fn try_normalize<Ax: Axes>(self, epsilon: E) -> Result<Self, <Self as HasErr>::Err>
    where
        S: ReduceShape<Ax>,
    {
        let shape = self.shape;
        let mean = self.retaped::<T>().try_mean::<_, Ax>()?;
        let centered = self.try_sub(mean.try_broadcast_like(&shape)?)?;
        let std = centered
            .retaped::<T>()
            .try_square()?
            .try_mean::<_, Ax>()?
            .try_add(epsilon)?
            .try_sqrt()?;
        centered.try_div(std.try_broadcast_like(&shape)?)
    }
}

#[cfg(test)]
mod tests {
    use crate::tests::*;
    use crate::{shapes::*, tensor::*, tensor_ops::*};

    #[test]
    fn test_1d_normalize_axis_last() {
        let dev: TestDevice = Default::default();
        let a: Tensor<_, TestDtype, _> = dev.tensor([-2.0, 0.0, 5.0]);
        let r = a.leaky_trace().normalize(1e-5);
        assert_close(&r.array(), &[-1.0190487, -0.3396829, 1.3587316]);
        // NOTE: .exp() so we can make sure normalize is using result grad properly
        let g = r.exp().mean().backward();
        assert_close(&g.get(&a).array(), &[0.033410847, -0.04677555, 0.013364702]);
    }

    #[test]
    fn test_2d_normalize_axis_last() {
        let dev: TestDevice = Default::default();
        let a: Tensor<_, TestDtype, _> = dev.tensor([[-2.0, 0.0, 5.0], [1.0, 2.0, 3.0]]);
        let r = a.leaky_trace().normalize::<Axis<1>>(1e-5);
        assert_close(
            &r.array(),
            &[
                [-1.0190487, -0.3396829, 1.3587316],
                [-1.2247356, 0.0, 1.2247356],
            ],
        );
        let g = r.exp().mean().backward();
        assert_close(
            &g.get(&a).array(),
            &[
                [0.016705424, -0.023387775, 0.006682351],
                [0.05773133, -0.11547226, 0.057740927],
            ],
        );
    }

    #[test]
    fn test_2d_normalize_axis_first() {
        let dev: TestDevice = Default::default();
        let a: Tensor<_, TestDtype, _> = dev.tensor([[-2.0, 0.0], [1.0, 2.0], [4.0, 5.0]]);
        let r = a.leaky_trace().normalize::<Axis<0>>(1e-5);
        assert_close(
            &r.array(),
            &[
                [-1.2247438, -1.1355485],
                [0.0, -0.16222118],
                [1.2247438, 1.2977698],
            ],
        );
        let g = r.exp().mean().backward();
        assert_close(
            &g.get(&a).array(),
            &[
                [0.019245632, 0.025835907],
                [-0.038491584, -0.043060362],
                [0.019245982, 0.01722446],
            ],
        );
    }

    #[test]
    fn test_3d_normalize_axis_last() {
        let dev: TestDevice = Default::default();
        let a: Tensor<Rank3<4, 2, 3>, TestDtype, _> = dev.ones();
        let r = a.leaky_trace().normalize::<Axis<2>>(1e-5);
        assert_eq!(r.array(), [[[0.0; 3]; 2]; 4]);
        let g = r.exp().mean().backward();
        assert_eq!(g.get(&a).array(), [[[0.0; 3]; 2]; 4]);
    }
}