1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
use crate::gradients::*;
use crate::prelude::*;
#[cfg(feature = "nightly")]
use crate::{Assert, ConstTrue};
#[derive(Default, Clone, Copy)]
pub struct Flatten2D;
impl ResetParams for Flatten2D {
fn reset_params<R: rand::Rng>(&mut self, _: &mut R) {}
}
impl CanUpdateWithGradients for Flatten2D {
fn update<G: GradientProvider>(&mut self, _: &mut G, _: &mut UnusedTensors) {}
}
#[cfg(feature = "nightly")]
impl<const M: usize, const N: usize, const O: usize, H: Tape> Module<Tensor3D<M, N, O, H>>
for Flatten2D
where
Assert<{ M * N * O == (M * N * O) }>: ConstTrue,
{
type Output = Tensor1D<{ M * N * O }, H>;
fn forward(&self, input: Tensor3D<M, N, O, H>) -> Self::Output {
Reshape::<Self::Output>::reshape(input)
}
}
#[cfg(feature = "nightly")]
impl<const M: usize, const N: usize, const O: usize, const P: usize, H: Tape>
Module<Tensor4D<M, N, O, P, H>> for Flatten2D
where
Assert<{ M * N * O * P == M * (N * O * P) }>: ConstTrue,
{
type Output = Tensor2D<M, { N * O * P }, H>;
fn forward(&self, input: Tensor4D<M, N, O, P, H>) -> Self::Output {
Reshape::<Self::Output>::reshape(input)
}
}
impl<T> ModuleMut<T> for Flatten2D
where
Self: Module<T>,
{
type Output = <Self as Module<T>>::Output;
fn forward_mut(&mut self, input: T) -> Self::Output {
self.forward(input)
}
}
#[cfg(feature = "nightly")]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flattens() {
let _: Tensor1D<{ 15 * 10 * 5 }> = Flatten2D.forward_mut(Tensor3D::<15, 10, 5>::zeros());
let _: Tensor2D<5, 24> = Flatten2D.forward_mut(Tensor4D::<5, 4, 3, 2>::zeros());
}
}