1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
use crate::{AxisType, StridedMemory, RawDataType};
use crate::util::to_vec::ToVec;
pub trait Reshape<T: RawDataType>: StridedMemory {
type Output;
/// Provides a non-owning view of the ndarray with the specified shape and stride.
/// The data pointed to by the view is shared with the original ndarray.
///
/// # Safety
/// - Ensure the memory layout referenced by `shape`, and `stride` is valid and owned
/// by the original ndarray.
unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output;
/// Provides a non-owning view of the ndarray that shares its data with the original ndarray.
///
/// # Example
/// ```
/// # use redstone_ml::*;
///
/// let ndarray = NdArray::new(vec![1, 2, 3, 4]);
/// let view = (&ndarray).view();
/// assert!(view.is_view())
/// ```
fn view(self) -> Self::Output {
let shape = self.shape().to_vec();
let stride = self.stride().to_vec();
unsafe { self.reshaped_view(shape, stride) }
}
/// Reshapes the ndarray into the specified shape.
///
/// This method returns a view.
///
/// # Panics
///
/// * If the total number of elements in the current ndarray does not match the
/// total number of elements in `new_shape`.
///
/// # Example
///
/// ```
/// # use redstone_ml::*;
///
/// let ndarray = NdArray::new([[4, 5], [6, 7], [8, 9]]); // shape is [3, 2]
/// let reshaped_array = ndarray.reshape([1, 2, 3]);
/// assert_eq!(reshaped_array, NdArray::new([[[4, 5, 6], [7, 8, 9]]]));
///
/// let ndarray = NdArray::new([0, 1, 2, 3]);
/// let reshaped_array = (&ndarray ).reshape([2, 2]); // reshape without consuming ndarray
/// assert_eq!(ndarray.shape(), &[4]);
/// assert_eq!(reshaped_array, NdArray::new([[0, 1], [2, 3]]));
/// ```
fn reshape(self, new_shape: impl ToVec<usize>) -> Self::Output {
let new_shape = new_shape.to_vec();
if new_shape == self.shape() {
return self.view();
}
assert!(self.is_uniformly_strided(),
"reshape requires uniformly strided array. This array has shape {:?} and stride {:?}.\
Try `array.clone().reshape()` instead.", self.shape(), self.stride());
if self.size() != new_shape.iter().product() {
panic!("total number of elements must not change during reshape");
}
let mut new_stride = vec![0; new_shape.len()];
let mut acc = self.stride()[self.ndims() - 1];
for (i, &dim) in new_shape.iter().rev().enumerate() {
new_stride[new_shape.len() - 1 - i] = acc;
acc *= dim;
}
unsafe { self.reshaped_view(new_shape, new_stride) }
}
/// Removes all singleton dimensions (dimensions of size 1) from the ndarray's shape.
///
/// This method returns a view.
///
/// # Example
/// ```
/// # use redstone_ml::*;
///
/// let ndarray = NdArray::new([[[[1], [3]], [[1], [4]]]]); // shape [1, 2, 2, 1]
/// let squeezed = ndarray.squeeze();
/// assert_eq!(squeezed, NdArray::new([[1, 3], [1, 4]])); // shape [2, 2]
///
/// let ndarray = NdArray::new([[3], [5], [7], [9]]);
/// let squeezed = (&ndarray ).squeeze(); // squeeze without consuming ndarray
/// assert_eq!(ndarray.shape(), &[4, 1]);
/// assert_eq!(squeezed, NdArray::new([3, 5, 7, 9]));
/// ```
fn squeeze(self) -> Self::Output {
let mut shape = self.shape().to_vec();
let mut stride = self.stride().to_vec();
(shape, stride) = shape.iter()
.zip(stride.iter())
.filter(|&(&axis_length, _)| axis_length != 1)
.unzip();
unsafe { self.reshaped_view(shape, stride) }
}
/// Adds a singleton dimension (dimensions of size 1) to the ndarray at the specified axis.
///
/// This method returns a view.
///
/// # Example
///
/// ```
/// # use redstone_ml::*;
///
/// let ndarray = NdArray::new([2, 3]); // shape is [2]
/// let unsqueezed = ndarray.unsqueeze(-1); // add dimension after the last axis
/// assert_eq!(unsqueezed.shape(), &[2, 1]);
///
/// let ndarray = NdArray::new([[1, 2, 3], [9, 8, 7]]); // shape is [2, 3]
/// let unsqueezed = (&ndarray ).unsqueeze(1); // unsqueeze without consuming ndarray
/// assert_eq!(ndarray.shape(), &[2, 3]);
/// assert_eq!(unsqueezed.shape(), &[2, 1, 3]);
/// ```
fn unsqueeze(self, axis: impl AxisType) -> Self::Output {
let axis = axis.as_absolute(self.ndims() + 1);
let mut shape = self.shape().to_vec();
let mut stride = self.stride().to_vec();
if axis == self.ndims() {
shape.push(1);
stride.push(1)
} else {
shape.insert(axis, 1);
stride.insert(axis, stride[axis] * shape[axis + 1]);
}
unsafe { self.reshaped_view(shape, stride) }
}
/// Transposes the array along the first 2 dimensions.
///
/// # Panics
/// - If the array is 1-dimensional or a scalar.
///
/// # Examples
/// ```
/// # use redstone_ml::*;
///
/// let array = NdArray::new([[2, 3, 4], [10, 20, 30]]);
///
/// let transposed = array.T();
/// assert_eq!(transposed, NdArray::new([[2, 10], [3, 20], [4, 30]]));
/// ```
#[allow(non_snake_case)]
fn T(self) -> Self::Output {
self.transpose(0, 1)
}
/// Returns a transposed version of the array, swapping the specified axes.
///
/// # Panics
/// - If `axis1` or `axis2` are out of bounds
///
/// # Examples
/// ```
/// # use redstone_ml::*;
///
/// let array = NdArray::new([[2, 3, 4], [10, 20, 30]]);
///
/// let transposed = array.transpose(0, 1);
/// assert_eq!(transposed, NdArray::new([[2, 10], [3, 20], [4, 30]]));
/// ```
fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output {
let axis1 = axis1.as_absolute(self.ndims());
let axis2 = axis2.as_absolute(self.ndims());
let mut shape = self.shape().to_vec();
let mut stride = self.stride().to_vec();
shape.swap(axis1, axis2);
stride.swap(axis1, axis2);
unsafe { self.reshaped_view(shape, stride) }
}
}