1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use crate::prelude::*;
#[derive(Default, Clone, Copy)]
pub struct FlattenImage;
impl ResetParams for FlattenImage {
fn reset_params<R: rand::Rng>(&mut self, _: &mut R) {}
}
impl CanUpdateWithGradients for FlattenImage {
fn update<G: GradientProvider>(&mut self, _: &mut G, _: &mut UnusedTensors) {}
}
impl SaveToNpz for FlattenImage {}
impl LoadFromNpz for FlattenImage {}
impl<const M: usize, const N: usize, const O: usize, H: Tape> Module<Tensor3D<M, N, O, H>>
for FlattenImage
where
Assert<{ M * N * O == (M * N * O) }>: ConstTrue,
{
type Output = Tensor1D<{ M * N * O }, H>;
fn forward(&self, input: Tensor3D<M, N, O, H>) -> Self::Output {
Reshape::<Self::Output>::reshape(input)
}
}
impl<const M: usize, const N: usize, const O: usize, const P: usize, H: Tape>
Module<Tensor4D<M, N, O, P, H>> for FlattenImage
where
Assert<{ M * N * O * P == M * (N * O * P) }>: ConstTrue,
{
type Output = Tensor2D<M, { N * O * P }, H>;
fn forward(&self, input: Tensor4D<M, N, O, P, H>) -> Self::Output {
Reshape::<Self::Output>::reshape(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flattens() {
let _: Tensor1D<{ 15 * 10 * 5 }> = FlattenImage.forward(Tensor3D::<15, 10, 5>::zeros());
let _: Tensor2D<5, 24> = FlattenImage.forward(Tensor4D::<5, 4, 3, 2>::zeros());
}
}