redstone_ml/common/
reshape.rs

1use crate::{AxisType, StridedMemory, RawDataType};
2use crate::util::to_vec::ToVec;
3
4
5pub trait Reshape<T: RawDataType>: StridedMemory {
6    type Output;
7
8    /// Provides a non-owning view of the ndarray with the specified shape and stride.
9    /// The data pointed to by the view is shared with the original ndarray.
10    ///
11    /// # Safety
12    /// - Ensure the memory layout referenced by `shape`, and `stride` is valid and owned
13    ///   by the original ndarray.
14    unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output;
15
16    /// Provides a non-owning view of the ndarray that shares its data with the original ndarray.
17    ///
18    /// # Example
19    /// ```
20    /// # use redstone_ml::*;
21    ///
22    /// let ndarray = NdArray::new(vec![1, 2, 3, 4]);
23    /// let view = (&ndarray).view();
24    /// assert!(view.is_view())
25    /// ```
26    fn view(self) -> Self::Output {
27        let shape = self.shape().to_vec();
28        let stride = self.stride().to_vec();
29        unsafe { self.reshaped_view(shape, stride) }
30    }
31
32    /// Reshapes the ndarray into the specified shape.
33    ///
34    /// This method returns a view.
35    ///
36    /// # Panics
37    ///
38    /// * If the total number of elements in the current ndarray does not match the
39    ///   total number of elements in `new_shape`.
40    ///
41    /// # Example
42    ///
43    /// ```
44    /// # use redstone_ml::*;
45    ///
46    /// let ndarray = NdArray::new([[4, 5], [6, 7], [8, 9]]);  // shape is [3, 2]
47    /// let reshaped_array = ndarray.reshape([1, 2, 3]);
48    /// assert_eq!(reshaped_array, NdArray::new([[[4, 5, 6], [7, 8, 9]]]));
49    ///
50    /// let ndarray = NdArray::new([0, 1, 2, 3]);
51    /// let reshaped_array = (&ndarray ).reshape([2, 2]);  // reshape without consuming ndarray
52    /// assert_eq!(ndarray.shape(), &[4]);
53    /// assert_eq!(reshaped_array, NdArray::new([[0, 1], [2, 3]]));
54    /// ```
55    fn reshape(self, new_shape: impl ToVec<usize>) -> Self::Output {
56        let new_shape = new_shape.to_vec();
57        if new_shape == self.shape() {
58            return self.view();
59        }
60        
61        assert!(self.is_uniformly_strided(), 
62                "reshape requires uniformly strided array. This array has shape {:?} and stride {:?}.\
63         Try `array.clone().reshape()` instead.", self.shape(), self.stride());
64        
65        if self.size() != new_shape.iter().product() {
66            panic!("total number of elements must not change during reshape");
67        }
68
69        let mut new_stride = vec![0; new_shape.len()];
70        let mut acc = self.stride()[self.ndims() - 1];
71        for (i, &dim) in new_shape.iter().rev().enumerate() {
72            new_stride[new_shape.len() - 1 - i] = acc;
73            acc *= dim;
74        }
75
76        unsafe { self.reshaped_view(new_shape, new_stride) }
77    }
78
79    /// Removes all singleton dimensions (dimensions of size 1) from the ndarray's shape.
80    ///
81    /// This method returns a view.
82    ///
83    /// # Example
84    /// ```
85    /// # use redstone_ml::*;
86    ///
87    /// let ndarray = NdArray::new([[[[1], [3]], [[1], [4]]]]);  // shape [1, 2, 2, 1]
88    /// let squeezed = ndarray.squeeze();
89    /// assert_eq!(squeezed, NdArray::new([[1, 3], [1, 4]]));  // shape [2, 2]
90    ///
91    /// let ndarray = NdArray::new([[3], [5], [7], [9]]);
92    /// let squeezed = (&ndarray ).squeeze();  // squeeze without consuming ndarray
93    /// assert_eq!(ndarray.shape(), &[4, 1]);
94    /// assert_eq!(squeezed, NdArray::new([3, 5, 7, 9]));
95    /// ```
96    fn squeeze(self) -> Self::Output {
97        let mut shape = self.shape().to_vec();
98        let mut stride = self.stride().to_vec();
99
100        (shape, stride) = shape.iter()
101                               .zip(stride.iter())
102                               .filter(|&(&axis_length, _)| axis_length != 1)
103                               .unzip();
104
105        unsafe { self.reshaped_view(shape, stride) }
106    }
107
108    /// Adds a singleton dimension (dimensions of size 1) to the ndarray at the specified axis.
109    ///
110    /// This method returns a view.
111    ///
112    /// # Example
113    ///
114    /// ```
115    /// # use redstone_ml::*;
116    ///
117    /// let ndarray = NdArray::new([2, 3]);  // shape is [2]
118    /// let unsqueezed = ndarray.unsqueeze(-1);  // add dimension after the last axis
119    /// assert_eq!(unsqueezed.shape(), &[2, 1]);
120    ///
121    /// let ndarray = NdArray::new([[1, 2, 3], [9, 8, 7]]);  // shape is [2, 3]
122    /// let unsqueezed = (&ndarray ).unsqueeze(1);  // unsqueeze without consuming ndarray
123    /// assert_eq!(ndarray.shape(), &[2, 3]);
124    /// assert_eq!(unsqueezed.shape(), &[2, 1, 3]);
125    /// ```
126    fn unsqueeze(self, axis: impl AxisType) -> Self::Output {
127        let axis = axis.as_absolute(self.ndims() + 1);
128
129        let mut shape = self.shape().to_vec();
130        let mut stride = self.stride().to_vec();
131
132        if axis == self.ndims() {
133            shape.push(1);
134            stride.push(1)
135        } else {
136            shape.insert(axis, 1);
137            stride.insert(axis, stride[axis] * shape[axis + 1]);
138        }
139
140        unsafe { self.reshaped_view(shape, stride) }
141    }
142
143    /// Transposes the array along the first 2 dimensions.
144    ///
145    /// # Panics
146    /// - If the array is 1-dimensional or a scalar.
147    ///
148    /// # Examples
149    /// ```
150    /// # use redstone_ml::*;
151    ///
152    /// let array = NdArray::new([[2, 3, 4], [10, 20, 30]]);
153    ///
154    /// let transposed = array.T();
155    /// assert_eq!(transposed, NdArray::new([[2, 10], [3, 20], [4, 30]]));
156    /// ```
157    #[allow(non_snake_case)]
158    fn T(self) -> Self::Output {
159        self.transpose(0, 1)
160    }
161
162    /// Returns a transposed version of the array, swapping the specified axes.
163    ///
164    /// # Panics
165    /// - If `axis1` or `axis2` are out of bounds
166    ///
167    /// # Examples
168    /// ```
169    /// # use redstone_ml::*;
170    ///
171    /// let array = NdArray::new([[2, 3, 4], [10, 20, 30]]);
172    ///
173    /// let transposed = array.transpose(0, 1);
174    /// assert_eq!(transposed, NdArray::new([[2, 10], [3, 20], [4, 30]]));
175    /// ```
176    fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output {
177        let axis1 = axis1.as_absolute(self.ndims());
178        let axis2 = axis2.as_absolute(self.ndims());
179
180        let mut shape = self.shape().to_vec();
181        let mut stride = self.stride().to_vec();
182
183        shape.swap(axis1, axis2);
184        stride.swap(axis1, axis2);
185
186        unsafe { self.reshaped_view(shape, stride) }
187    }
188}