1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use crate::{
    shapes::*,
    tensor::*,
    tensor_ops::{Device, TryAdd},
};

use super::*;

/// A residual connection around `F`: `F(x) + x`,
/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
///
/// # Generics
/// - `F`: The underlying module to do a skip connection around.
///
/// # Examples
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// type Model = Residual<ReLU>;
/// let model = dev.build_module::<Model, f32>();
/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = model.forward(x);
/// assert_eq!(y.array(), [-2.0, -1.0, 0.0, 2.0, 4.0]);
/// ```
#[derive(Debug, Clone, Default)]
pub struct Residual<F>(pub F);

impl<D: Device<E>, E: Dtype, F: BuildOnDevice<D, E>> BuildOnDevice<D, E> for Residual<F> {
    type Built = Residual<F::Built>;
}

impl<E: Dtype, D: Device<E>, F: TensorCollection<E, D>> TensorCollection<E, D> for Residual<F> {
    type To<E2: Dtype, D2: Device<E2>> = Residual<F::To<E2, D2>>;

    fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
        visitor: &mut V,
    ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
        visitor.visit_fields(Self::module("0", |s| &s.0, |s| &mut s.0), Residual)
    }
}

impl<T: WithEmptyTape + TryAdd<T>, F: Module<T, Output = T, Error = T::Err>> Module<T>
    for Residual<F>
{
    type Output = T;
    type Error = F::Error;

    fn try_forward(&self, x: T) -> Result<Self::Output, F::Error> {
        self.0.try_forward(x.with_empty_tape())?.try_add(x)
    }
}

impl<T: WithEmptyTape + TryAdd<T>, F: ModuleMut<T, Output = T, Error = T::Err>> ModuleMut<T>
    for Residual<F>
{
    type Output = T;
    type Error = F::Error;

    fn try_forward_mut(&mut self, x: T) -> Result<Self::Output, F::Error> {
        self.0.try_forward_mut(x.with_empty_tape())?.try_add(x)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tests::*;
    use crate::{nn::builders::Linear, tensor_ops::*};

    #[test]
    fn test_residual_reset() {
        let dev: TestDevice = Default::default();
        let model = dev.build_module::<Residual<Linear<2, 5>>, TestDtype>();
        assert_ne!(model.0.weight.array(), [[TestDtype::default(); 2]; 5]);
        assert_ne!(model.0.bias.array(), [TestDtype::default(); 5]);
    }

    #[test]
    fn test_residual_gradients() {
        let dev: TestDevice = Default::default();

        let model = dev
            .build_module::<Residual<Linear<2, 2>>, f32>()
            .to_dtype::<TestDtype>();

        let x: Tensor<Rank2<4, 2>, f32, _> = dev.sample_normal();
        let x = x.to_dtype::<TestDtype>();
        let y = model.forward(x.leaky_trace());

        #[rustfmt::skip]
        assert_close_to_literal!(y, [[0.25372928, -2.4258814],[1.7892148, -2.6242268],[1.5131638, 0.23407778],[3.4201493, 1.597525]]);

        let g = y.mean().backward();
        assert_close_to_literal!(g.get(&model.0.weight), [[0.475242, -0.075136]; 2]);
        assert_close_to_literal!(g.get(&model.0.bias), [0.5; 2]);
        assert_close_to_literal!(g.get(&x), [[0.18806472, 0.21419683]; 4]);
    }
}