#[allow(unused)]
use crate::{
    shapes::*,
    tensor::{Tape, Tensor},
    tensor_ops::*,
};
use super::*;
#[derive(Default, Clone, Copy)]
pub struct Reshape<S: ConstShape>(S);
impl<S: ConstShape> ZeroSizedModule for Reshape<S> {}
impl<S: ConstShape> NonMutableModule for Reshape<S> {}
impl<Src: Shape, Dst: ConstShape, D: Device<E>, E: Dtype, T: Tape<E, D>>
    Module<Tensor<Src, E, D, T>> for Reshape<Dst>
{
    type Output = Tensor<Dst, E, D, T>;
    type Error = D::Err;
    fn try_forward(&self, input: Tensor<Src, E, D, T>) -> Result<Self::Output, D::Err> {
        input.try_reshape_like(&self.0)
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{tensor::ZerosTensor, tests::*};
    #[test]
    fn test_flattens() {
        let dev: TestDevice = Default::default();
        let _: Tensor<Rank1<100>, TestDtype, _> =
            Reshape::<Rank1<100>>::default().forward_mut(dev.zeros::<Rank3<10, 5, 2>>());
        let _: Tensor<Rank2<5, 24>, TestDtype, _> =
            Reshape::<Rank2<5, 24>>::default().forward_mut(dev.zeros::<Rank4<5, 4, 3, 2>>());
        let _: Tensor<Rank3<10, 5, 2>, TestDtype, _> =
            Reshape::<Rank3<10, 5, 2>>::default().forward_mut(dev.zeros::<Rank1<100>>());
        let _: Tensor<Rank4<5, 4, 3, 2>, TestDtype, _> =
            Reshape::<Rank4<5, 4, 3, 2>>::default().forward_mut(dev.zeros::<Rank2<5, 24>>());
    }
}