1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
use crate::gradients::*;
use crate::prelude::*;
#[cfg(feature = "nightly")]
use crate::{Assert, ConstTrue};

/// **Requires Nightly** Flattens 3d tensors to 1d, and 4d tensors to 2d.
///
/// Specifically:
/// ```ignore
/// # use dfdx::prelude::*;
/// let _: Tensor1D<{3 * 5 * 7}> = Flatten2D.forward(Tensor3D::<3, 5, 7>::zeros());
/// let _: Tensor2D<8, {3 * 5 * 7}> = Flatten2D.forward(Tensor4D::<8, 3, 5, 7>::zeros());
/// ```
#[derive(Default, Clone, Copy)]
pub struct Flatten2D;

impl ResetParams for Flatten2D {
    fn reset_params<R: rand::Rng>(&mut self, _: &mut R) {}
}

impl CanUpdateWithGradients for Flatten2D {
    fn update<G: GradientProvider>(&mut self, _: &mut G, _: &mut UnusedTensors) {}
}

#[cfg(feature = "nightly")]
impl<const M: usize, const N: usize, const O: usize, H: Tape> Module<Tensor3D<M, N, O, H>>
    for Flatten2D
where
    Assert<{ M * N * O == (M * N * O) }>: ConstTrue,
{
    type Output = Tensor1D<{ M * N * O }, H>;
    fn forward(&self, input: Tensor3D<M, N, O, H>) -> Self::Output {
        Reshape::<Self::Output>::reshape(input)
    }
}

#[cfg(feature = "nightly")]
impl<const M: usize, const N: usize, const O: usize, const P: usize, H: Tape>
    Module<Tensor4D<M, N, O, P, H>> for Flatten2D
where
    Assert<{ M * N * O * P == M * (N * O * P) }>: ConstTrue,
{
    type Output = Tensor2D<M, { N * O * P }, H>;
    fn forward(&self, input: Tensor4D<M, N, O, P, H>) -> Self::Output {
        Reshape::<Self::Output>::reshape(input)
    }
}

impl<T> ModuleMut<T> for Flatten2D
where
    Self: Module<T>,
{
    type Output = <Self as Module<T>>::Output;
    fn forward_mut(&mut self, input: T) -> Self::Output {
        self.forward(input)
    }
}

#[cfg(feature = "nightly")]
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_flattens() {
        let _: Tensor1D<{ 15 * 10 * 5 }> = Flatten2D.forward_mut(Tensor3D::<15, 10, 5>::zeros());
        let _: Tensor2D<5, 24> = Flatten2D.forward_mut(Tensor4D::<5, 4, 3, 2>::zeros());
    }
}