redstone_ml/ndarray/
broadcast.rs

1use crate::dtype::RawDataType;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::util::functions::pad;
4use crate::{NdArray, Reshape};
5
6
7impl<'a, T: RawDataType> NdArray<'a, T> {
8    /// Broadcasts the `NdArray` to the specified shape.
9    ///
10    /// This method returns a *readonly* view of the ndarray with the desired shape.
11    /// Broadcasting is done by left-padding the ndarray's shape with ones until they reach the
12    /// desired dimension. Then, any axes with length 1 are repeated to match the target shape.
13    ///
14    /// For example, suppose the ndarray's shape is `[2, 3]` and the broadcast shape is `[3, 2, 3]`.
15    /// Then the ndarray's shape becomes `[1, 2, 3]` after padding and `[3, 2, 3]` after repeating
16    /// the first axis.
17    ///
18    /// # Panics
19    /// This method panics if the target shape is incompatible with the ndarray.
20    ///
21    /// - If `shape.len()` is less than the dimensionality of the ndarray.
22    /// - If a dimension in `shape` does not equal the corresponding dimension in the ndarray's `shape`
23    ///   and cannot be broadcasted (i.e., it is not 1 or does not match).
24    ///
25    /// # Example
26    ///
27    /// ```
28    /// # use redstone_ml::*;
29    /// let ndarray = NdArray::new([1, 2, 3]);  // shape is [3]
30    /// let broadcasted_array = ndarray.broadcast_to(&[2, 3]);
31    ///
32    /// assert_eq!(broadcasted_array.shape(), &[2, 3]);
33    /// ```
34    pub fn broadcast_to(&'a self, shape: &[usize]) -> NdArray<'a, T> {
35        let broadcast_shape = broadcast_shape(&self.shape, shape);
36        let broadcast_stride = broadcast_stride(&self.stride, &broadcast_shape, &self.shape);
37
38        let mut result = unsafe { self.reshaped_view(broadcast_shape, broadcast_stride) };
39        result.flags -= NdArrayFlags::Writeable;
40        result
41    }
42}
43
44/// Adjusts `shape` and `stride` to match an `ndims`-dimensional view of the ndarray
45///
46/// This is done by left-padding `shape` with ones and `stride` with zeros until they reach
47/// the desired dimension.
48///
49/// # Panics
50/// - If `shape.len() > ndims`
51///
52/// # Example
53/// ```ignore
54/// let shape = vec![2, 3];
55/// let stride = vec![3, 1];
56/// let ndims = 4;
57///
58/// let (padded_shape, padded_stride) = pad_dimensions(&shape, &stride, ndims);
59///
60/// assert_eq!(padded_shape, vec![1, 1, 2, 3]);
61/// assert_eq!(padded_stride, vec![0, 0, 3, 1]);
62/// ```
63fn pad_dimensions(shape: &[usize], stride: &[usize], ndims: usize) -> (Vec<usize>, Vec<usize>) {
64    let n = ndims - shape.len();
65    let shape = pad(shape, 1, n);
66    let stride = pad(stride, 0, n);
67
68    (shape, stride)
69}
70
71/// Checks if broadcasting a shape to another is possible. Panics otherwise.
72///
73/// Broadcasting is done by left-padding the ndarray's shape with ones until they reach the
74/// desired dimension. Then, any axes with length 1 are repeated to match the target shape.
75///
76/// For example, suppose `shape` is `[2, 3]` and `to` is `[3, 2, 3]`.
77/// Then `shape` becomes `[1, 2, 3]` after padding and `[3, 2, 3]` after repeating the first axis.
78///
79/// # Panics
80/// - If the number of dimensions in `to` is less than the number of dimensions in `shape`.
81/// - If a dimension in `shape` does not equal the corresponding dimension in `to`
82///   and cannot be broadcasted (i.e., it is not 1 or does not match).
83fn broadcast_shape(shape: &[usize], to: &[usize]) -> Vec<usize> {
84    let to = to.to_vec();
85
86    if to.len() < shape.len() {
87        panic!("cannot broadcast {shape:?} to shape {to:?} with fewer dimensions")
88    }
89
90    let last_ndims = &to[to.len() - shape.len()..];
91
92    for axis in 0..shape.len() {
93        if shape[axis] != 1 && shape[axis] != last_ndims[axis] {
94            panic!("broadcasting {shape:?} is not compatible with the desired shape {to:?}");
95        }
96    }
97
98    to
99}
100
101/// Calculates the broadcasted strides for an ndarray to match the specified broadcast shape.
102///
103/// This is done be left-padding the original stride with zeros until it matches the desired dimension.
104/// The stride is set to 0 for any axes that have been repeated and kept the same otherwise.
105///
106/// # Panics
107/// - If the number of dimensions in `broadcast_shape` is less than the number of dimensions in `original_shape`.
108/// - If a dimension in `original_shape` does not equal the corresponding dimension in `broadcast_shape`
109///   and cannot be broadcasted (i.e., it is not 1 or does not match).
110///
111/// # Examples
112///
113/// ```ignore
114/// let stride = vec![4, 1];
115/// let original_shape = vec![2, 3];
116/// let broadcast_shape = vec![3, 2, 3];
117///
118/// let result = broadcast_stride(&stride, &broadcast_shape, &original_shape);
119/// assert_eq!(result, vec![0, 4, 1]);
120/// ```
121pub(crate) fn broadcast_stride(stride: &[usize],
122                    broadcast_shape: &[usize],
123                    original_shape: &[usize]) -> Vec<usize> {
124    let ndims = broadcast_shape.len();
125
126    if ndims < original_shape.len() {
127        panic!("cannot broadcast {original_shape:?} to shape {broadcast_shape:?} with fewer dimensions");
128    }
129
130    let mut broadcast_stride = Vec::with_capacity(ndims);
131    let original_first_axis = ndims - original_shape.len();
132
133    broadcast_stride.resize(original_first_axis, 0);  // new dimensions get a zero stride
134
135    for axis in original_first_axis..ndims {
136        let original_axis_length = original_shape[axis - original_first_axis];
137
138        if original_axis_length == 1 {
139            broadcast_stride.push(0);
140        } else if original_axis_length == broadcast_shape[axis] {
141            broadcast_stride.push(stride[axis - original_first_axis]);
142        } else {
143            panic!("broadcasting {original_shape:?} is not compatible with the desired shape {broadcast_shape:?}");
144        }
145    }
146
147    broadcast_stride
148}
149
150/// Broadcasts two compatible shapes together and returns the resulting shape.
151///
152/// Broadcasting follows the rules of NumPy-style broadcasting:
153/// - The smaller shape is left-padded with ones until it matches the length of the other shape
154/// - If one of the shapes is of length 1 at a particular axis, it can broadcast to the length of the other shape at that axis.
155/// - If both shapes have differing lengths at a certain axis and neither is 1, the two shapes are deemed incompatible for broadcasting.
156///
157/// For example, if `first` is `[8, 1, 6]` and `second` is `[7, 1]`, then `second` is left-padded
158/// to become `[1, 7, 1]`. The middle axis of `first` is repeated to have dimension 7 and the
159/// first and last axes of `second` are repeated to have dimensions 8 and 6 respectively.
160/// The resulting shape is `[8, 7, 6]`.
161///
162/// # Panics
163/// - If the two shapes are incompatible for broadcasting
164///
165/// # Examples
166/// ```ignore
167/// let shape1 = vec![8, 1, 6];
168/// let shape2 = vec![7, 1];
169/// let result = broadcast_shapes(&shape1, &shape2);
170/// assert_eq!(result, vec![8, 7, 6]);
171/// ```
172pub(crate) fn broadcast_shapes(first: &[usize], second: &[usize]) -> Vec<usize> {
173    let mut shape1;
174    let mut shape2;
175
176    // pad shapes with ones to match in length
177    if first.len() > second.len() {
178        shape1 = pad(second, 1, first.len());
179        shape2 = first.to_vec();
180    } else {
181        shape1 = pad(first, 1, second.len());
182        shape2 = second.to_vec();
183    }
184
185    for axis in 0..shape1.len() {
186        // If one of the shapes is 1 at a particular axis,
187        // it can be repeated to match the length of the other's shape at that axis   
188        if shape1[axis] == 1 {
189            shape1[axis] = shape2[axis];
190        } else if shape2[axis] == 1 {
191            shape2[axis] = shape1[axis];
192        }
193
194        // if neither shape is 1 along axis, and they don't match, the shapes cannot be broadcast
195        else if shape1[axis] != shape2[axis] {
196            panic!("broadcasting {first:?} is not compatible with the desired shape {second:?}");
197        }
198    }
199
200    shape1
201}
202
203/// Determines the axes that are broadcasted when broadcasting from the `original_shape` 
204/// to the `broadcast_shape`.
205///
206/// # Panics
207/// - If `broadcast_shape` has fewer dimensions than `original_shape`.
208///
209/// # Example
210///
211/// ```ignore
212/// let broadcast_shape = vec![4, 3, 2];
213/// let original_shape = vec![3, 1];
214/// let axes = get_broadcasted_axes(&broadcast_shape, &original_shape);
215/// assert_eq!(axes, vec![0, 2]);
216/// ```
217///
218/// In this example:
219/// - Dimension `0` in the `broadcast_shape` (size `4`) is broadcasted because `original_shape` is missing
220///   that dimension.
221/// - Dimension `2` in the `broadcast_shape` (size `2`) is broadcasted because `original_shape[1]` is `1`.
222pub(crate) fn get_broadcasted_axes(broadcast_shape: &[usize],
223                                   original_shape: &[usize]) -> Vec<isize> {
224
225    if broadcast_shape.len() < original_shape.len() {
226        panic!("cannot broadcast {original_shape:?} to shape {broadcast_shape:?} with fewer dimensions");
227    }
228    
229    let ndims_diff = broadcast_shape.len() - original_shape.len();
230    let mut axes = Vec::new();
231
232    for i in 0..broadcast_shape.len() {
233        let to_dim = broadcast_shape[i];
234        let from_dim = if i < ndims_diff { 1 } else { original_shape[i - ndims_diff] };
235
236        if from_dim == 1 && to_dim > 1 || i < ndims_diff {
237            axes.push(i as isize);
238        }
239    }
240
241    axes
242}
243
244#[cfg(test)]
245mod tests {
246    use crate::broadcast::{broadcast_shapes, get_broadcasted_axes};
247
248    #[test]
249    fn test_broadcast_shapes() {
250        let shape1 = vec![5, 1];
251        let shape2 = vec![2, 1, 3];
252
253        let correct = vec![2, 5, 3];
254        let output = broadcast_shapes(&shape1, &shape2);
255
256        assert_eq!(output, correct);
257    }
258
259    #[test]
260    fn test_get_broadcasted_axes() {
261        // grad_shape: [3, 3]
262        // original_shape: [3, 1]
263        // axes to sum: [1]
264        assert_eq!(get_broadcasted_axes(&[3, 3], &[3, 1]), vec![1]);
265
266        // grad_shape: [2, 3]
267        // original_shape: [3]
268        // axes to sum: [0]
269        assert_eq!(get_broadcasted_axes(&[2, 3], &[3]), vec![0]);
270
271        // grad_shape: [8, 7, 6]
272        // original_shape: [7, 1]
273        // axes to sum: [0, 2]
274        assert_eq!(get_broadcasted_axes(&[8, 7, 6], &[7, 1]), vec![0, 2]);
275        
276        // grad_shape: [4, 5, 6]
277        // original_shape: [1, 5, 1]
278        // axes to sum: [0, 2]
279        assert_eq!(get_broadcasted_axes(&[4, 5, 6], &[1, 5, 1]), vec![0, 2]);
280
281        // grad_shape: [5, 6]
282        // original_shape: [1, 6]
283        // axes to sum: [0]
284        assert_eq!(get_broadcasted_axes(&[5, 6], &[1, 6]), vec![0]);
285
286        // grad_shape: [5, 6]
287        // original_shape: [5, 1]
288        // axes to sum: [1]
289        assert_eq!(get_broadcasted_axes(&[5, 6], &[5, 1]), vec![1]);
290    }
291}