redstone_ml/common/reshape.rs
1use crate::{AxisType, StridedMemory, RawDataType};
2use crate::util::to_vec::ToVec;
3
4
5pub trait Reshape<T: RawDataType>: StridedMemory {
6 type Output;
7
8 /// Provides a non-owning view of the ndarray with the specified shape and stride.
9 /// The data pointed to by the view is shared with the original ndarray.
10 ///
11 /// # Safety
12 /// - Ensure the memory layout referenced by `shape`, and `stride` is valid and owned
13 /// by the original ndarray.
14 unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output;
15
16 /// Provides a non-owning view of the ndarray that shares its data with the original ndarray.
17 ///
18 /// # Example
19 /// ```
20 /// # use redstone_ml::*;
21 ///
22 /// let ndarray = NdArray::new(vec![1, 2, 3, 4]);
23 /// let view = (&ndarray).view();
24 /// assert!(view.is_view())
25 /// ```
26 fn view(self) -> Self::Output {
27 let shape = self.shape().to_vec();
28 let stride = self.stride().to_vec();
29 unsafe { self.reshaped_view(shape, stride) }
30 }
31
32 /// Reshapes the ndarray into the specified shape.
33 ///
34 /// This method returns a view.
35 ///
36 /// # Panics
37 ///
38 /// * If the total number of elements in the current ndarray does not match the
39 /// total number of elements in `new_shape`.
40 ///
41 /// # Example
42 ///
43 /// ```
44 /// # use redstone_ml::*;
45 ///
46 /// let ndarray = NdArray::new([[4, 5], [6, 7], [8, 9]]); // shape is [3, 2]
47 /// let reshaped_array = ndarray.reshape([1, 2, 3]);
48 /// assert_eq!(reshaped_array, NdArray::new([[[4, 5, 6], [7, 8, 9]]]));
49 ///
50 /// let ndarray = NdArray::new([0, 1, 2, 3]);
51 /// let reshaped_array = (&ndarray ).reshape([2, 2]); // reshape without consuming ndarray
52 /// assert_eq!(ndarray.shape(), &[4]);
53 /// assert_eq!(reshaped_array, NdArray::new([[0, 1], [2, 3]]));
54 /// ```
55 fn reshape(self, new_shape: impl ToVec<usize>) -> Self::Output {
56 let new_shape = new_shape.to_vec();
57 if new_shape == self.shape() {
58 return self.view();
59 }
60
61 assert!(self.is_uniformly_strided(),
62 "reshape requires uniformly strided array. This array has shape {:?} and stride {:?}.\
63 Try `array.clone().reshape()` instead.", self.shape(), self.stride());
64
65 if self.size() != new_shape.iter().product() {
66 panic!("total number of elements must not change during reshape");
67 }
68
69 let mut new_stride = vec![0; new_shape.len()];
70 let mut acc = self.stride()[self.ndims() - 1];
71 for (i, &dim) in new_shape.iter().rev().enumerate() {
72 new_stride[new_shape.len() - 1 - i] = acc;
73 acc *= dim;
74 }
75
76 unsafe { self.reshaped_view(new_shape, new_stride) }
77 }
78
79 /// Removes all singleton dimensions (dimensions of size 1) from the ndarray's shape.
80 ///
81 /// This method returns a view.
82 ///
83 /// # Example
84 /// ```
85 /// # use redstone_ml::*;
86 ///
87 /// let ndarray = NdArray::new([[[[1], [3]], [[1], [4]]]]); // shape [1, 2, 2, 1]
88 /// let squeezed = ndarray.squeeze();
89 /// assert_eq!(squeezed, NdArray::new([[1, 3], [1, 4]])); // shape [2, 2]
90 ///
91 /// let ndarray = NdArray::new([[3], [5], [7], [9]]);
92 /// let squeezed = (&ndarray ).squeeze(); // squeeze without consuming ndarray
93 /// assert_eq!(ndarray.shape(), &[4, 1]);
94 /// assert_eq!(squeezed, NdArray::new([3, 5, 7, 9]));
95 /// ```
96 fn squeeze(self) -> Self::Output {
97 let mut shape = self.shape().to_vec();
98 let mut stride = self.stride().to_vec();
99
100 (shape, stride) = shape.iter()
101 .zip(stride.iter())
102 .filter(|&(&axis_length, _)| axis_length != 1)
103 .unzip();
104
105 unsafe { self.reshaped_view(shape, stride) }
106 }
107
108 /// Adds a singleton dimension (dimensions of size 1) to the ndarray at the specified axis.
109 ///
110 /// This method returns a view.
111 ///
112 /// # Example
113 ///
114 /// ```
115 /// # use redstone_ml::*;
116 ///
117 /// let ndarray = NdArray::new([2, 3]); // shape is [2]
118 /// let unsqueezed = ndarray.unsqueeze(-1); // add dimension after the last axis
119 /// assert_eq!(unsqueezed.shape(), &[2, 1]);
120 ///
121 /// let ndarray = NdArray::new([[1, 2, 3], [9, 8, 7]]); // shape is [2, 3]
122 /// let unsqueezed = (&ndarray ).unsqueeze(1); // unsqueeze without consuming ndarray
123 /// assert_eq!(ndarray.shape(), &[2, 3]);
124 /// assert_eq!(unsqueezed.shape(), &[2, 1, 3]);
125 /// ```
126 fn unsqueeze(self, axis: impl AxisType) -> Self::Output {
127 let axis = axis.as_absolute(self.ndims() + 1);
128
129 let mut shape = self.shape().to_vec();
130 let mut stride = self.stride().to_vec();
131
132 if axis == self.ndims() {
133 shape.push(1);
134 stride.push(1)
135 } else {
136 shape.insert(axis, 1);
137 stride.insert(axis, stride[axis] * shape[axis + 1]);
138 }
139
140 unsafe { self.reshaped_view(shape, stride) }
141 }
142
143 /// Transposes the array along the first 2 dimensions.
144 ///
145 /// # Panics
146 /// - If the array is 1-dimensional or a scalar.
147 ///
148 /// # Examples
149 /// ```
150 /// # use redstone_ml::*;
151 ///
152 /// let array = NdArray::new([[2, 3, 4], [10, 20, 30]]);
153 ///
154 /// let transposed = array.T();
155 /// assert_eq!(transposed, NdArray::new([[2, 10], [3, 20], [4, 30]]));
156 /// ```
157 #[allow(non_snake_case)]
158 fn T(self) -> Self::Output {
159 self.transpose(0, 1)
160 }
161
162 /// Returns a transposed version of the array, swapping the specified axes.
163 ///
164 /// # Panics
165 /// - If `axis1` or `axis2` are out of bounds
166 ///
167 /// # Examples
168 /// ```
169 /// # use redstone_ml::*;
170 ///
171 /// let array = NdArray::new([[2, 3, 4], [10, 20, 30]]);
172 ///
173 /// let transposed = array.transpose(0, 1);
174 /// assert_eq!(transposed, NdArray::new([[2, 10], [3, 20], [4, 30]]));
175 /// ```
176 fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output {
177 let axis1 = axis1.as_absolute(self.ndims());
178 let axis2 = axis2.as_absolute(self.ndims());
179
180 let mut shape = self.shape().to_vec();
181 let mut stride = self.stride().to_vec();
182
183 shape.swap(axis1, axis2);
184 stride.swap(axis1, axis2);
185
186 unsafe { self.reshaped_view(shape, stride) }
187 }
188}