redstone_ml/ndarray/broadcast.rs
1use crate::dtype::RawDataType;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::util::functions::pad;
4use crate::{NdArray, Reshape};
5
6
7impl<'a, T: RawDataType> NdArray<'a, T> {
8 /// Broadcasts the `NdArray` to the specified shape.
9 ///
10 /// This method returns a *readonly* view of the ndarray with the desired shape.
11 /// Broadcasting is done by left-padding the ndarray's shape with ones until they reach the
12 /// desired dimension. Then, any axes with length 1 are repeated to match the target shape.
13 ///
14 /// For example, suppose the ndarray's shape is `[2, 3]` and the broadcast shape is `[3, 2, 3]`.
15 /// Then the ndarray's shape becomes `[1, 2, 3]` after padding and `[3, 2, 3]` after repeating
16 /// the first axis.
17 ///
18 /// # Panics
19 /// This method panics if the target shape is incompatible with the ndarray.
20 ///
21 /// - If `shape.len()` is less than the dimensionality of the ndarray.
22 /// - If a dimension in `shape` does not equal the corresponding dimension in the ndarray's `shape`
23 /// and cannot be broadcasted (i.e., it is not 1 or does not match).
24 ///
25 /// # Example
26 ///
27 /// ```
28 /// # use redstone_ml::*;
29 /// let ndarray = NdArray::new([1, 2, 3]); // shape is [3]
30 /// let broadcasted_array = ndarray.broadcast_to(&[2, 3]);
31 ///
32 /// assert_eq!(broadcasted_array.shape(), &[2, 3]);
33 /// ```
34 pub fn broadcast_to(&'a self, shape: &[usize]) -> NdArray<'a, T> {
35 let broadcast_shape = broadcast_shape(&self.shape, shape);
36 let broadcast_stride = broadcast_stride(&self.stride, &broadcast_shape, &self.shape);
37
38 let mut result = unsafe { self.reshaped_view(broadcast_shape, broadcast_stride) };
39 result.flags -= NdArrayFlags::Writeable;
40 result
41 }
42}
43
44/// Adjusts `shape` and `stride` to match an `ndims`-dimensional view of the ndarray
45///
46/// This is done by left-padding `shape` with ones and `stride` with zeros until they reach
47/// the desired dimension.
48///
49/// # Panics
50/// - If `shape.len() > ndims`
51///
52/// # Example
53/// ```ignore
54/// let shape = vec![2, 3];
55/// let stride = vec![3, 1];
56/// let ndims = 4;
57///
58/// let (padded_shape, padded_stride) = pad_dimensions(&shape, &stride, ndims);
59///
60/// assert_eq!(padded_shape, vec![1, 1, 2, 3]);
61/// assert_eq!(padded_stride, vec![0, 0, 3, 1]);
62/// ```
63fn pad_dimensions(shape: &[usize], stride: &[usize], ndims: usize) -> (Vec<usize>, Vec<usize>) {
64 let n = ndims - shape.len();
65 let shape = pad(shape, 1, n);
66 let stride = pad(stride, 0, n);
67
68 (shape, stride)
69}
70
71/// Checks if broadcasting a shape to another is possible. Panics otherwise.
72///
73/// Broadcasting is done by left-padding the ndarray's shape with ones until they reach the
74/// desired dimension. Then, any axes with length 1 are repeated to match the target shape.
75///
76/// For example, suppose `shape` is `[2, 3]` and `to` is `[3, 2, 3]`.
77/// Then `shape` becomes `[1, 2, 3]` after padding and `[3, 2, 3]` after repeating the first axis.
78///
79/// # Panics
80/// - If the number of dimensions in `to` is less than the number of dimensions in `shape`.
81/// - If a dimension in `shape` does not equal the corresponding dimension in `to`
82/// and cannot be broadcasted (i.e., it is not 1 or does not match).
83fn broadcast_shape(shape: &[usize], to: &[usize]) -> Vec<usize> {
84 let to = to.to_vec();
85
86 if to.len() < shape.len() {
87 panic!("cannot broadcast {shape:?} to shape {to:?} with fewer dimensions")
88 }
89
90 let last_ndims = &to[to.len() - shape.len()..];
91
92 for axis in 0..shape.len() {
93 if shape[axis] != 1 && shape[axis] != last_ndims[axis] {
94 panic!("broadcasting {shape:?} is not compatible with the desired shape {to:?}");
95 }
96 }
97
98 to
99}
100
101/// Calculates the broadcasted strides for an ndarray to match the specified broadcast shape.
102///
103/// This is done be left-padding the original stride with zeros until it matches the desired dimension.
104/// The stride is set to 0 for any axes that have been repeated and kept the same otherwise.
105///
106/// # Panics
107/// - If the number of dimensions in `broadcast_shape` is less than the number of dimensions in `original_shape`.
108/// - If a dimension in `original_shape` does not equal the corresponding dimension in `broadcast_shape`
109/// and cannot be broadcasted (i.e., it is not 1 or does not match).
110///
111/// # Examples
112///
113/// ```ignore
114/// let stride = vec![4, 1];
115/// let original_shape = vec![2, 3];
116/// let broadcast_shape = vec![3, 2, 3];
117///
118/// let result = broadcast_stride(&stride, &broadcast_shape, &original_shape);
119/// assert_eq!(result, vec![0, 4, 1]);
120/// ```
121pub(crate) fn broadcast_stride(stride: &[usize],
122 broadcast_shape: &[usize],
123 original_shape: &[usize]) -> Vec<usize> {
124 let ndims = broadcast_shape.len();
125
126 if ndims < original_shape.len() {
127 panic!("cannot broadcast {original_shape:?} to shape {broadcast_shape:?} with fewer dimensions");
128 }
129
130 let mut broadcast_stride = Vec::with_capacity(ndims);
131 let original_first_axis = ndims - original_shape.len();
132
133 broadcast_stride.resize(original_first_axis, 0); // new dimensions get a zero stride
134
135 for axis in original_first_axis..ndims {
136 let original_axis_length = original_shape[axis - original_first_axis];
137
138 if original_axis_length == 1 {
139 broadcast_stride.push(0);
140 } else if original_axis_length == broadcast_shape[axis] {
141 broadcast_stride.push(stride[axis - original_first_axis]);
142 } else {
143 panic!("broadcasting {original_shape:?} is not compatible with the desired shape {broadcast_shape:?}");
144 }
145 }
146
147 broadcast_stride
148}
149
150/// Broadcasts two compatible shapes together and returns the resulting shape.
151///
152/// Broadcasting follows the rules of NumPy-style broadcasting:
153/// - The smaller shape is left-padded with ones until it matches the length of the other shape
154/// - If one of the shapes is of length 1 at a particular axis, it can broadcast to the length of the other shape at that axis.
155/// - If both shapes have differing lengths at a certain axis and neither is 1, the two shapes are deemed incompatible for broadcasting.
156///
157/// For example, if `first` is `[8, 1, 6]` and `second` is `[7, 1]`, then `second` is left-padded
158/// to become `[1, 7, 1]`. The middle axis of `first` is repeated to have dimension 7 and the
159/// first and last axes of `second` are repeated to have dimensions 8 and 6 respectively.
160/// The resulting shape is `[8, 7, 6]`.
161///
162/// # Panics
163/// - If the two shapes are incompatible for broadcasting
164///
165/// # Examples
166/// ```ignore
167/// let shape1 = vec![8, 1, 6];
168/// let shape2 = vec![7, 1];
169/// let result = broadcast_shapes(&shape1, &shape2);
170/// assert_eq!(result, vec![8, 7, 6]);
171/// ```
172pub(crate) fn broadcast_shapes(first: &[usize], second: &[usize]) -> Vec<usize> {
173 let mut shape1;
174 let mut shape2;
175
176 // pad shapes with ones to match in length
177 if first.len() > second.len() {
178 shape1 = pad(second, 1, first.len());
179 shape2 = first.to_vec();
180 } else {
181 shape1 = pad(first, 1, second.len());
182 shape2 = second.to_vec();
183 }
184
185 for axis in 0..shape1.len() {
186 // If one of the shapes is 1 at a particular axis,
187 // it can be repeated to match the length of the other's shape at that axis
188 if shape1[axis] == 1 {
189 shape1[axis] = shape2[axis];
190 } else if shape2[axis] == 1 {
191 shape2[axis] = shape1[axis];
192 }
193
194 // if neither shape is 1 along axis, and they don't match, the shapes cannot be broadcast
195 else if shape1[axis] != shape2[axis] {
196 panic!("broadcasting {first:?} is not compatible with the desired shape {second:?}");
197 }
198 }
199
200 shape1
201}
202
203/// Determines the axes that are broadcasted when broadcasting from the `original_shape`
204/// to the `broadcast_shape`.
205///
206/// # Panics
207/// - If `broadcast_shape` has fewer dimensions than `original_shape`.
208///
209/// # Example
210///
211/// ```ignore
212/// let broadcast_shape = vec![4, 3, 2];
213/// let original_shape = vec![3, 1];
214/// let axes = get_broadcasted_axes(&broadcast_shape, &original_shape);
215/// assert_eq!(axes, vec![0, 2]);
216/// ```
217///
218/// In this example:
219/// - Dimension `0` in the `broadcast_shape` (size `4`) is broadcasted because `original_shape` is missing
220/// that dimension.
221/// - Dimension `2` in the `broadcast_shape` (size `2`) is broadcasted because `original_shape[1]` is `1`.
222pub(crate) fn get_broadcasted_axes(broadcast_shape: &[usize],
223 original_shape: &[usize]) -> Vec<isize> {
224
225 if broadcast_shape.len() < original_shape.len() {
226 panic!("cannot broadcast {original_shape:?} to shape {broadcast_shape:?} with fewer dimensions");
227 }
228
229 let ndims_diff = broadcast_shape.len() - original_shape.len();
230 let mut axes = Vec::new();
231
232 for i in 0..broadcast_shape.len() {
233 let to_dim = broadcast_shape[i];
234 let from_dim = if i < ndims_diff { 1 } else { original_shape[i - ndims_diff] };
235
236 if from_dim == 1 && to_dim > 1 || i < ndims_diff {
237 axes.push(i as isize);
238 }
239 }
240
241 axes
242}
243
244#[cfg(test)]
245mod tests {
246 use crate::broadcast::{broadcast_shapes, get_broadcasted_axes};
247
248 #[test]
249 fn test_broadcast_shapes() {
250 let shape1 = vec![5, 1];
251 let shape2 = vec![2, 1, 3];
252
253 let correct = vec![2, 5, 3];
254 let output = broadcast_shapes(&shape1, &shape2);
255
256 assert_eq!(output, correct);
257 }
258
259 #[test]
260 fn test_get_broadcasted_axes() {
261 // grad_shape: [3, 3]
262 // original_shape: [3, 1]
263 // axes to sum: [1]
264 assert_eq!(get_broadcasted_axes(&[3, 3], &[3, 1]), vec![1]);
265
266 // grad_shape: [2, 3]
267 // original_shape: [3]
268 // axes to sum: [0]
269 assert_eq!(get_broadcasted_axes(&[2, 3], &[3]), vec![0]);
270
271 // grad_shape: [8, 7, 6]
272 // original_shape: [7, 1]
273 // axes to sum: [0, 2]
274 assert_eq!(get_broadcasted_axes(&[8, 7, 6], &[7, 1]), vec![0, 2]);
275
276 // grad_shape: [4, 5, 6]
277 // original_shape: [1, 5, 1]
278 // axes to sum: [0, 2]
279 assert_eq!(get_broadcasted_axes(&[4, 5, 6], &[1, 5, 1]), vec![0, 2]);
280
281 // grad_shape: [5, 6]
282 // original_shape: [1, 6]
283 // axes to sum: [0]
284 assert_eq!(get_broadcasted_axes(&[5, 6], &[1, 6]), vec![0]);
285
286 // grad_shape: [5, 6]
287 // original_shape: [5, 1]
288 // axes to sum: [1]
289 assert_eq!(get_broadcasted_axes(&[5, 6], &[5, 1]), vec![1]);
290 }
291}