concision_traits/ops/
reshape.rs1pub trait Unsqueeze {
11 type Output;
12
13 fn unsqueeze(self, axis: usize) -> Self::Output;
14}
15
16pub trait DecrementAxis {
19 type Output;
20
21 fn dec_axis(&self) -> Self::Output;
22}
23
24pub trait IncrementAxis {
27 type Output;
28
29 fn inc_axis(self) -> Self::Output;
30}
31
32use ndarray::{ArrayBase, Axis, Dimension, RawData, RawDataClone, RemoveAxis};
36
37impl<D, E> DecrementAxis for D
38where
39 D: RemoveAxis<Smaller = E>,
40 E: Dimension,
41{
42 type Output = E;
43
44 fn dec_axis(&self) -> Self::Output {
45 self.remove_axis(Axis(self.ndim() - 1))
46 }
47}
48
49impl<D, E> IncrementAxis for D
50where
51 D: Dimension<Larger = E>,
52 E: Dimension,
53{
54 type Output = E;
55
56 fn inc_axis(self) -> Self::Output {
57 self.insert_axis(Axis(self.ndim()))
58 }
59}
60
61impl<S, D, A> Unsqueeze for ArrayBase<S, D, A>
62where
63 D: Dimension,
64 S: RawData<Elem = A>,
65{
66 type Output = ArrayBase<S, D::Larger>;
67
68 fn unsqueeze(self, axis: usize) -> Self::Output {
69 self.insert_axis(Axis(axis))
70 }
71}
72
73impl<S, D, A> Unsqueeze for &ArrayBase<S, D, A>
74where
75 D: Dimension,
76 S: RawDataClone<Elem = A>,
77{
78 type Output = ArrayBase<S, D::Larger>;
79
80 fn unsqueeze(self, axis: usize) -> Self::Output {
81 self.clone().insert_axis(Axis(axis))
82 }
83}