1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use crate::{
shapes::*,
tensor::*,
tensor_ops::{Device, TryAdd},
};
use super::*;
#[derive(Debug, Clone, Default)]
pub struct GeneralizedResidual<F, R> {
pub f: F,
pub r: R,
}
impl<D: Device<E>, E: Dtype, F: BuildOnDevice<D, E>, R: BuildOnDevice<D, E>> BuildOnDevice<D, E>
for GeneralizedResidual<F, R>
{
type Built = GeneralizedResidual<F::Built, R::Built>;
}
impl<E: Dtype, D: Device<E>, F: TensorCollection<E, D>, R: TensorCollection<E, D>>
TensorCollection<E, D> for GeneralizedResidual<F, R>
{
type To<E2: Dtype, D2: Device<E2>> = GeneralizedResidual<F::To<E2, D2>, R::To<E2, D2>>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
(
Self::module("f", |s| &s.f, |s| &mut s.f),
Self::module("r", |s| &s.r, |s| &mut s.r),
),
|(f, r)| GeneralizedResidual { f, r },
)
}
}
impl<T: WithEmptyTape, F: Module<T>, R: Module<T, Output = F::Output, Error = F::Error>> Module<T>
for GeneralizedResidual<F, R>
where
F::Output: TryAdd<F::Output> + HasErr<Err = F::Error>,
{
type Output = F::Output;
type Error = F::Error;
fn try_forward(&self, x: T) -> Result<Self::Output, F::Error> {
self.f
.try_forward(x.with_empty_tape())?
.try_add(self.r.try_forward(x)?)
}
}
impl<T: WithEmptyTape, F: ModuleMut<T>, R: ModuleMut<T, Output = F::Output, Error = F::Error>>
ModuleMut<T> for GeneralizedResidual<F, R>
where
F::Output: TryAdd<F::Output> + HasErr<Err = F::Error>,
{
type Output = F::Output;
type Error = F::Error;
fn try_forward_mut(&mut self, x: T) -> Result<Self::Output, F::Error> {
self.f
.try_forward_mut(x.with_empty_tape())?
.try_add(self.r.try_forward_mut(x)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nn::builders::{DeviceBuildExt, Linear};
use crate::{tensor_ops::*, tests::*};
#[test]
fn test_reset_generalized_residual() {
let dev: TestDevice = Default::default();
type Model = GeneralizedResidual<Linear<2, 5>, Linear<2, 5>>;
let model = dev.build_module::<Model, f32>();
assert_ne!(model.f.weight.array(), [[0.0; 2]; 5]);
assert_ne!(model.f.bias.array(), [0.0; 5]);
assert_ne!(model.r.weight.array(), [[0.0; 2]; 5]);
assert_ne!(model.r.bias.array(), [0.0; 5]);
}
#[test]
fn test_generalized_residual_gradients() {
let dev: TestDevice = Default::default();
type Model = GeneralizedResidual<Linear<2, 2>, Linear<2, 2>>;
let model = dev.build_module::<Model, f32>();
let x = dev.sample_normal::<Rank2<4, 2>>();
let y = model.forward(x.leaky_trace());
#[rustfmt::skip]
assert_close(&y.array(), &[[-0.81360567, -1.1473482], [1.0925694, 0.17383915], [-0.32519114, 0.49806428], [0.08259219, -0.7277866]]);
let g = y.mean().backward();
assert_close(&g.get(&x).array(), &[[0.15889636, 0.062031522]; 4]);
assert_close(&g.get(&model.f.weight).array(), &[[-0.025407, 0.155879]; 2]);
assert_close(&g.get(&model.f.bias).array(), &[0.5; 2]);
assert_close(&g.get(&model.r.weight).array(), &[[-0.025407, 0.155879]; 2]);
assert_close(&g.get(&model.r.bias).array(), &[0.5; 2]);
}
}