1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#[allow(unused)]
use crate::{
    shapes::*,
    tensor::{Tape, Tensor},
    tensor_ops::*,
};

use super::*;

/// **Requires Nightly** Flattens 3d tensors to 1d, and 4d tensors to 2d.
#[derive(Default, Clone, Copy)]
pub struct Flatten2D;

impl ZeroSizedModule for Flatten2D {}
impl NonMutableModule for Flatten2D {}

#[cfg(feature = "nightly")]
impl<const C: usize, const H: usize, const W: usize, D: Device<E>, E: Dtype, T: Tape<E, D>>
    Module<Tensor<Rank3<C, H, W>, E, D, T>> for Flatten2D
where
    Rank1<{ C * H * W }>: Sized,
{
    type Output = Tensor<Rank1<{ C * H * W }>, E, D, T>;
    type Error = D::Err;

    fn try_forward(&self, input: Tensor<Rank3<C, H, W>, E, D, T>) -> Result<Self::Output, D::Err> {
        input.try_reshape()
    }
}

#[cfg(feature = "nightly")]
impl<const B: usize, const C: usize, const H: usize, const W: usize, D, E: Dtype, T>
    Module<Tensor<Rank4<B, C, H, W>, E, D, T>> for Flatten2D
where
    D: Device<E>,
    T: Tape<E, D>,
    Rank2<B, { C * H * W }>: Sized,
{
    type Output = Tensor<Rank2<B, { C * H * W }>, E, D, T>;
    type Error = D::Err;

    fn try_forward(
        &self,
        input: Tensor<Rank4<B, C, H, W>, E, D, T>,
    ) -> Result<Self::Output, D::Err> {
        input.try_reshape()
    }
}

#[cfg(feature = "nightly")]
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{tensor::ZerosTensor, tests::*};

    #[test]
    fn test_flattens() {
        let dev: TestDevice = Default::default();
        let _: Tensor<Rank1<100>, TestDtype, _> =
            Flatten2D.forward_mut(dev.zeros::<Rank3<10, 5, 2>>());
        let _: Tensor<Rank2<5, 24>, TestDtype, _> =
            Flatten2D.forward_mut(dev.zeros::<Rank4<5, 4, 3, 2>>());
    }
}