1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::{string::String, vec::Vec};

use crate::{shapes::Dtype, tensor_ops::Device};

use super::*;

/// Repeats `T` `N` times. This requires that `T`'s input is the same as it's output.
///
/// # Generics
/// - `T` the [Module] to repeat
/// - `N` the number of times to repeat `T`.
///
/// # Examples
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// type Model = Repeated<(Linear<10, 10>, ReLU), 5>;
/// let model = dev.build_module::<Model, f32>();
/// let out: Tensor<Rank1<10>, f32, _> = model.forward(dev.zeros());
/// ```
#[derive(Debug, Clone)]
pub struct Repeated<T, const N: usize> {
    pub modules: Vec<T>,
}

impl<D: Device<E>, E: Dtype, T: BuildOnDevice<D, E>, const N: usize> BuildOnDevice<D, E>
    for Repeated<T, N>
{
    type Built = Repeated<T::Built, N>;
}

impl<E: Dtype, D: Device<E>, T: TensorCollection<E, D>, const N: usize> TensorCollection<E, D>
    for Repeated<T, N>
{
    type To<E2: Dtype, D2: Device<E2>> = Repeated<T::To<E2, D2>, N>;

    fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
        visitor: &mut V,
    ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
        let names: Vec<String> = (0..N).map(|i| format!("{i}")).collect();

        visitor.visit_fields(
            (0..N)
                .zip(names.iter())
                .map(|(i, name)| {
                    Self::module(name, move |s| &s.modules[i], move |s| &mut s.modules[i])
                })
                .collect::<Vec<_>>(),
            |modules| Repeated { modules },
        )
    }
}

impl<T, const N: usize> std::ops::Index<usize> for Repeated<T, N> {
    type Output = T;
    fn index(&self, index: usize) -> &Self::Output {
        &self.modules[index]
    }
}

impl<Input, T: Module<Input, Output = Input>, const N: usize> Module<Input> for Repeated<T, N> {
    type Output = T::Output;
    type Error = T::Error;

    fn try_forward(&self, mut x: Input) -> Result<Self::Output, T::Error> {
        for i in 0..N {
            x = self.modules[i].try_forward(x)?;
        }
        Ok(x)
    }
}

impl<Input, T: ModuleMut<Input, Output = Input>, const N: usize> ModuleMut<Input>
    for Repeated<T, N>
{
    type Output = T::Output;
    type Error = T::Error;

    fn try_forward_mut(&mut self, mut x: Input) -> Result<Self::Output, T::Error> {
        for i in 0..N {
            x = self.modules[i].try_forward_mut(x)?;
        }
        Ok(x)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{prelude::*, tests::*};

    #[test]
    fn test_default_and_reset() {
        let dev: TestDevice = Default::default();

        type Model = Repeated<(Linear<3, 3>, ReLU), 5>;
        let m = dev.build_module::<Model, TestDtype>();

        for i in 0..5 {
            assert_ne!(
                m.modules[i].0.weight.array(),
                [[TestDtype::default(); 3]; 3]
            );
            assert_ne!(m.modules[i].0.bias.array(), [TestDtype::default(); 3]);
        }
    }

    #[test]
    fn test_forward() {
        let dev: TestDevice = Default::default();

        type Model = Repeated<(Linear<3, 3>, ReLU), 5>;
        let mut m = dev.build_module::<Model, TestDtype>();

        let x = dev.zeros::<Rank1<3>>();
        let x = m.modules[0].forward(x);
        let x = m.modules[1].forward(x);
        let x = m.modules[2].forward(x);
        let x = m.modules[3].forward(x);
        let x = m.modules[4].forward(x);

        assert_eq!(x.array(), m.forward_mut(dev.zeros::<Rank1<3>>()).array());
    }
}