1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#[allow(unused)]
use crate::{
shapes::*,
tensor::{Tape, Tensor},
tensor_ops::*,
};
use super::*;
#[derive(Default, Clone, Copy)]
pub struct Flatten2D;
impl ZeroSizedModule for Flatten2D {}
impl NonMutableModule for Flatten2D {}
#[cfg(feature = "nightly")]
impl<const C: usize, const H: usize, const W: usize, D: Device<E>, E: Dtype, T: Tape<E, D>>
Module<Tensor<Rank3<C, H, W>, E, D, T>> for Flatten2D
where
Rank1<{ C * H * W }>: Sized,
{
type Output = Tensor<Rank1<{ C * H * W }>, E, D, T>;
type Error = D::Err;
fn try_forward(&self, input: Tensor<Rank3<C, H, W>, E, D, T>) -> Result<Self::Output, D::Err> {
input.try_reshape()
}
}
#[cfg(feature = "nightly")]
impl<const B: usize, const C: usize, const H: usize, const W: usize, D, E: Dtype, T>
Module<Tensor<Rank4<B, C, H, W>, E, D, T>> for Flatten2D
where
D: Device<E>,
T: Tape<E, D>,
Rank2<B, { C * H * W }>: Sized,
{
type Output = Tensor<Rank2<B, { C * H * W }>, E, D, T>;
type Error = D::Err;
fn try_forward(
&self,
input: Tensor<Rank4<B, C, H, W>, E, D, T>,
) -> Result<Self::Output, D::Err> {
input.try_reshape()
}
}
#[cfg(feature = "nightly")]
#[cfg(test)]
mod tests {
use super::*;
use crate::{tensor::ZerosTensor, tests::*};
#[test]
fn test_flattens() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank1<100>, TestDtype, _> =
Flatten2D.forward_mut(dev.zeros::<Rank3<10, 5, 2>>());
let _: Tensor<Rank2<5, 24>, TestDtype, _> =
Flatten2D.forward_mut(dev.zeros::<Rank4<5, 4, 3, 2>>());
}
}