1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
use crate::prelude::*;
pub fn mean_axis<T: Reduce1<I>, const I: isize>(t: T) -> T::Reduced
where
T::Array: HasAxis<I>,
{
div_scalar(sum_axis::<T, I>(t), <T::Array as HasAxis<I>>::SIZE as f32)
}
macro_rules! mean_axis_impl {
($typename:ident, [$($Vs:tt),*]) => {
impl<$(const $Vs: usize, )* H: Tape> $typename<$($Vs, )* H> {
pub fn mean_axis<const I: isize>(self) -> <Self as Reduce1<I>>::Reduced
where
Self: Reduce1<I>,
<Self as HasArrayType>::Array: HasAxis<I>,
{
mean_axis::<Self, I>(self)
}
}
};
}
mean_axis_impl!(Tensor0D, []);
mean_axis_impl!(Tensor1D, [M]);
mean_axis_impl!(Tensor2D, [M, N]);
mean_axis_impl!(Tensor3D, [M, N, O]);
mean_axis_impl!(Tensor4D, [M, N, O, P]);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valids_mean_axis() {
let _: Tensor0D = Tensor1D::<5>::zeros().mean_axis::<-1>();
let _: Tensor1D<3> = Tensor2D::<5, 3>::zeros().mean_axis::<0>();
let _: Tensor1D<5> = Tensor2D::<5, 3>::zeros().mean_axis::<-1>();
let _: Tensor2D<5, 3> = Tensor3D::<7, 5, 3>::zeros().mean_axis::<0>();
let _: Tensor2D<7, 3> = Tensor3D::<7, 5, 3>::zeros().mean_axis::<1>();
let _: Tensor2D<7, 5> = Tensor3D::<7, 5, 3>::zeros().mean_axis::<-1>();
let _: Tensor3D<7, 5, 3> = Tensor4D::<9, 7, 5, 3>::zeros().mean_axis::<0>();
let _: Tensor3D<9, 5, 3> = Tensor4D::<9, 7, 5, 3>::zeros().mean_axis::<1>();
let _: Tensor3D<9, 7, 3> = Tensor4D::<9, 7, 5, 3>::zeros().mean_axis::<2>();
let _: Tensor3D<9, 7, 5> = Tensor4D::<9, 7, 5, 3>::zeros().mean_axis::<-1>();
}
#[test]
fn test_mean_axis_0_2d() {
let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 2.0, 3.0], [-2.0, 4.0, -6.0]]);
let r = t.trace().mean_axis::<0>();
assert_eq!(r.data(), &[-0.5, 3.0, -1.5]);
let gradients = r.exp().mean().backward();
assert_eq!(
gradients.ref_gradient(&t),
&[[0.10108845, 3.3475895, 0.037188362]; 2]
);
}
#[test]
fn test_mean_axis_1_2d() {
let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 2.0, 3.0], [-2.0, 4.0, -6.0]]);
let r = t.trace().mean_axis::<-1>();
assert_eq!(r.data(), &[2.0, -4.0 / 3.0]);
let gradients = r.exp().mean().backward();
assert_eq!(
gradients.ref_gradient(&t),
&[[1.2315093; 3], [0.043932855; 3]]
);
}
}