1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use crate::gradients::{CanUpdateWithGradients, GradientProvider, UnusedTensors};
use crate::prelude::*;
#[derive(Debug, Clone, Default)]
pub struct GeneralizedResidual<F, R> {
pub f: F,
pub r: R,
}
impl<F: CanUpdateWithGradients, R: CanUpdateWithGradients> CanUpdateWithGradients
for GeneralizedResidual<F, R>
{
fn update<G: GradientProvider>(&mut self, grads: &mut G, unused: &mut UnusedTensors) {
self.f.update(grads, unused);
self.r.update(grads, unused);
}
}
impl<F: ResetParams, R: ResetParams> ResetParams for GeneralizedResidual<F, R> {
fn reset_params<RNG: rand::Rng>(&mut self, rng: &mut RNG) {
self.f.reset_params(rng);
self.r.reset_params(rng);
}
}
impl<F, R, T> Module<T> for GeneralizedResidual<F, R>
where
T: Tensor<Dtype = f32>,
F: Module<T>,
R: Module<T, Output = F::Output>,
F::Output: Tensor<Dtype = f32, Tape = T::Tape>,
{
type Output = F::Output;
fn forward(&self, x: T) -> Self::Output {
add(self.f.forward(x.with_empty_tape()), self.r.forward(x))
}
}
impl<F, R, T> ModuleMut<T> for GeneralizedResidual<F, R>
where
T: Tensor<Dtype = f32>,
F: ModuleMut<T>,
R: ModuleMut<T, Output = F::Output>,
F::Output: Tensor<Dtype = f32, Tape = T::Tape>,
{
type Output = F::Output;
fn forward_mut(&mut self, x: T) -> Self::Output {
add(
self.f.forward_mut(x.with_empty_tape()),
self.r.forward_mut(x),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::assert_close;
use rand::{prelude::StdRng, SeedableRng};
#[test]
fn test_reset_generalized_residual() {
let mut rng = StdRng::seed_from_u64(0);
let mut model: GeneralizedResidual<Linear<2, 5>, Linear<2, 5>> = Default::default();
assert_eq!(model.f.weight.data(), &[[0.0; 2]; 5]);
assert_eq!(model.f.bias.data(), &[0.0; 5]);
assert_eq!(model.r.weight.data(), &[[0.0; 2]; 5]);
assert_eq!(model.r.bias.data(), &[0.0; 5]);
model.reset_params(&mut rng);
assert_ne!(model.f.weight.data(), &[[0.0; 2]; 5]);
assert_ne!(model.f.bias.data(), &[0.0; 5]);
assert_ne!(model.r.weight.data(), &[[0.0; 2]; 5]);
assert_ne!(model.r.bias.data(), &[0.0; 5]);
}
#[test]
fn test_generalized_residual_gradients() {
let mut rng = StdRng::seed_from_u64(0);
let mut model: GeneralizedResidual<Linear<2, 2>, Linear<2, 2>> = Default::default();
model.reset_params(&mut rng);
let x: Tensor2D<4, 2> = TensorCreator::randn(&mut rng);
let y = model.forward_mut(x.trace());
#[rustfmt::skip]
assert_close(y.data(), &[[-0.81360567, -1.1473482], [1.0925694, 0.17383915], [-0.32519114, 0.49806428], [0.08259219, -0.7277866]]);
let g = backward(y.mean());
assert_close(g.ref_gradient(&x), &[[0.15889636, 0.062031522]; 4]);
assert_close(g.ref_gradient(&model.f.weight), &[[-0.025407, 0.155879]; 2]);
assert_close(g.ref_gradient(&model.f.bias), &[0.5; 2]);
assert_close(g.ref_gradient(&model.r.weight), &[[-0.025407, 0.155879]; 2]);
assert_close(g.ref_gradient(&model.r.bias), &[0.5; 2]);
}
}