concision_traits/ops/
reshape.rs

1/*
2    Appellation: reshape <module>
3    Created At: 2025.11.26:13:10:09
4    Contrib: @FL03
5*/
6
7/// The [`Unsqueeze`] trait establishes an interface for a routine that _unsqueezes_ an array,
8/// by inserting a new axis at a specified position. This is useful for reshaping arrays to
9/// meet specific dimensional requirements.
10pub trait Unsqueeze {
11    type Output;
12
13    fn unsqueeze(self, axis: usize) -> Self::Output;
14}
15
16/// The [`DecrementAxis`] is used as a unary operator for removing a single axis
17/// from a multidimensional array or tensor-like structure.
18pub trait DecrementAxis {
19    type Output;
20
21    fn dec_axis(&self) -> Self::Output;
22}
23
24/// The [`IncrementAxis`] trait defines a method enabling an axis to increment itself,
25/// effectively adding a new axis to the array.
26pub trait IncrementAxis {
27    type Output;
28
29    fn inc_axis(self) -> Self::Output;
30}
31
32/*
33 ************* Implementations *************
34*/
35use ndarray::{ArrayBase, Axis, Dimension, RawData, RawDataClone, RemoveAxis};
36
37impl<D, E> DecrementAxis for D
38where
39    D: RemoveAxis<Smaller = E>,
40    E: Dimension,
41{
42    type Output = E;
43
44    fn dec_axis(&self) -> Self::Output {
45        self.remove_axis(Axis(self.ndim() - 1))
46    }
47}
48
49impl<D, E> IncrementAxis for D
50where
51    D: Dimension<Larger = E>,
52    E: Dimension,
53{
54    type Output = E;
55
56    fn inc_axis(self) -> Self::Output {
57        self.insert_axis(Axis(self.ndim()))
58    }
59}
60
61impl<S, D, A> Unsqueeze for ArrayBase<S, D, A>
62where
63    D: Dimension,
64    S: RawData<Elem = A>,
65{
66    type Output = ArrayBase<S, D::Larger>;
67
68    fn unsqueeze(self, axis: usize) -> Self::Output {
69        self.insert_axis(Axis(axis))
70    }
71}
72
73impl<S, D, A> Unsqueeze for &ArrayBase<S, D, A>
74where
75    D: Dimension,
76    S: RawDataClone<Elem = A>,
77{
78    type Output = ArrayBase<S, D::Larger>;
79
80    fn unsqueeze(self, axis: usize) -> Self::Output {
81        self.clone().insert_axis(Axis(axis))
82    }
83}