1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use crate::prelude::{Const, GenericUpscale2D, NearestNeighbor, UpscaleMethod};
use crate::prelude::{Dim, Dtype, HasErr, Tape, Tensor, Upscale2DKernel, ZerosTensor};

#[allow(unused)]
use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};

#[derive(Debug, Default, Clone)]
pub struct Upscale2D<const OH: usize, const OW: usize = OH, M: UpscaleMethod = NearestNeighbor>(M);

impl<const OH: usize, const OW: usize, M: UpscaleMethod> ZeroSizedModule for Upscale2D<OH, OW, M> {}
impl<const OH: usize, const OW: usize, M: UpscaleMethod> NonMutableModule for Upscale2D<OH, OW, M> {}

impl<const OH: usize, const OW: usize, M: UpscaleMethod, Img: GenericUpscale2D<M>> Module<Img>
    for Upscale2D<OH, OW, M>
{
    type Output = Img::Output<Const<OH>, Const<OW>>;
    type Error = Img::Err;

    fn try_forward(&self, x: Img) -> Result<Self::Output, Img::Err> {
        x.generic_upscale2d_like(M::default(), Const, Const)
    }
}

#[derive(Debug, Default, Clone)]
pub struct Upscale2DBy<const H: usize, const W: usize = H, M: UpscaleMethod = NearestNeighbor>(M);

impl<const H: usize, const W: usize, M: UpscaleMethod> ZeroSizedModule for Upscale2DBy<H, W, M> {}
impl<const H: usize, const W: usize, M: UpscaleMethod> NonMutableModule for Upscale2DBy<H, W, M> {}

#[cfg(feature = "nightly")]
impl<
        const H: usize,
        const W: usize,
        const IH: usize,
        const IW: usize,
        C: Dim,
        E: Dtype,
        M: UpscaleMethod,
        D: Upscale2DKernel<E, M> + ZerosTensor<E>,
        T: 'static + Tape<E, D>,
    > Module<Tensor<(C, Const<IH>, Const<IW>), E, D, T>> for Upscale2DBy<H, W, M>
where
    Tensor<(C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>: Sized,
{
    type Output = Tensor<(C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>;
    type Error = <Self::Output as HasErr>::Err;

    fn try_forward(
        &self,
        x: Tensor<(C, Const<IH>, Const<IW>), E, D, T>,
    ) -> Result<Self::Output, Self::Error> {
        x.generic_upscale2d_like(M::default(), Const, Const)
    }
}

#[cfg(feature = "nightly")]
impl<
        const H: usize,
        const W: usize,
        const IH: usize,
        const IW: usize,
        B: Dim,
        C: Dim,
        E: Dtype,
        M: UpscaleMethod,
        D: Upscale2DKernel<E, M> + ZerosTensor<E>,
        T: 'static + Tape<E, D>,
    > Module<Tensor<(B, C, Const<IH>, Const<IW>), E, D, T>> for Upscale2DBy<H, W, M>
where
    Tensor<(B, C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>: Sized,
{
    type Output = Tensor<(B, C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>;
    type Error = <Self::Output as HasErr>::Err;

    fn try_forward(
        &self,
        x: Tensor<(B, C, Const<IH>, Const<IW>), E, D, T>,
    ) -> Result<Self::Output, Self::Error> {
        x.generic_upscale2d_like(M::default(), Const, Const)
    }
}

impl<
        const H: usize,
        const W: usize,
        C: Dim,
        E: Dtype,
        M: UpscaleMethod,
        D: Upscale2DKernel<E, M> + ZerosTensor<E>,
        T: 'static + Tape<E, D>,
    > Module<Tensor<(C, usize, usize), E, D, T>> for Upscale2DBy<H, W, M>
{
    type Output = Tensor<(C, usize, usize), E, D, T>;
    type Error = <Self::Output as HasErr>::Err;

    fn try_forward(
        &self,
        x: Tensor<(C, usize, usize), E, D, T>,
    ) -> Result<Self::Output, Self::Error> {
        let shape = x.shape;
        x.generic_upscale2d_like(M::default(), shape.1 * H, shape.2 * W)
    }
}

impl<
        const H: usize,
        const W: usize,
        B: Dim,
        C: Dim,
        E: Dtype,
        M: UpscaleMethod,
        D: Upscale2DKernel<E, M> + ZerosTensor<E>,
        T: 'static + Tape<E, D>,
    > Module<Tensor<(B, C, usize, usize), E, D, T>> for Upscale2DBy<H, W, M>
where
    Tensor<(B, C, usize, usize), E, D, T>: Sized,
{
    type Output = Tensor<(B, C, usize, usize), E, D, T>;
    type Error = <Self::Output as HasErr>::Err;

    fn try_forward(
        &self,
        x: Tensor<(B, C, usize, usize), E, D, T>,
    ) -> Result<Self::Output, Self::Error> {
        let shape = x.shape;
        x.generic_upscale2d_like(M::default(), shape.2 * H, shape.3 * W)
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{shapes::*, tensor::*, tests::*};

    #[test]
    fn test_upscale2d() {
        let dev: TestDevice = Default::default();
        let x: Tensor<Rank3<3, 4, 4>, TestDtype, _> = dev.zeros();
        let _: Tensor<Rank3<3, 8, 8>, _, _> = Upscale2D::<8>::default().forward(x.clone());
        let _: Tensor<Rank3<3, 8, 12>, _, _> = Upscale2D::<8, 12>::default().forward(x.clone());
        let _: Tensor<Rank3<3, 9, 9>, _, _> =
            Upscale2D::<9, 9, NearestNeighbor>::default().forward(x);
    }

    #[cfg(feature = "nightly")]
    #[test]
    fn test_upscale2dby() {
        use crate::prelude::Bilinear;
        let dev: TestDevice = Default::default();
        let x: Tensor<Rank3<3, 4, 4>, TestDtype, _> = dev.zeros();
        let _: Tensor<Rank3<3, 8, 8>, _, _> = Upscale2DBy::<2>::default().forward(x.clone());
        let _: Tensor<Rank3<3, 8, 12>, _, _> = Upscale2DBy::<2, 3>::default().forward(x.clone());
        let _: Tensor<Rank3<3, 12, 12>, _, _> =
            Upscale2DBy::<3, 3, Bilinear>::default().forward(x.clone());
    }
}