Skip to main content

ferray_core/dimension/
broadcast.rs

1// ferray-core: Broadcasting logic (REQ-9, REQ-10, REQ-11)
2//
3// Implements NumPy's full broadcasting rules:
4//   1. Prepend 1s to shape of lower-dim array
5//   2. Stretch size-1 dimensions
6//   3. Error on size mismatch where neither is 1
7//
8// Broadcasting NEVER materializes the expanded array — it uses virtual
9// expansion via strides (setting stride = 0 for broadcast dimensions).
10
11use ndarray::ShapeBuilder;
12
13use crate::array::owned::Array;
14use crate::array::view::ArrayView;
15use crate::dimension::{Dimension, IxDyn};
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19/// Compute the broadcast shape from two shapes, following NumPy rules.
20///
21/// The result shape has `max(a.len(), b.len())` dimensions. Shorter shapes
22/// are left-padded with 1s. For each axis, the result dimension is the
23/// larger of the two inputs; if neither is 1 and they differ, an error
24/// is returned.
25///
26/// # Examples
27/// ```
28/// # use ferray_core::dimension::broadcast::broadcast_shapes;
29/// let result = broadcast_shapes(&[4, 3], &[3]).unwrap();
30/// assert_eq!(result, vec![4, 3]);
31///
32/// let result = broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap();
33/// assert_eq!(result, vec![2, 3, 4]);
34/// ```
35///
36/// # Errors
37/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
38pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
39    let ndim = a.len().max(b.len());
40    let mut result = vec![0usize; ndim];
41
42    for i in 0..ndim {
43        let da = if i < ndim - a.len() {
44            1
45        } else {
46            a[i - (ndim - a.len())]
47        };
48        let db = if i < ndim - b.len() {
49            1
50        } else {
51            b[i - (ndim - b.len())]
52        };
53
54        if da == db {
55            result[i] = da;
56        } else if da == 1 {
57            result[i] = db;
58        } else if db == 1 {
59            result[i] = da;
60        } else {
61            return Err(FerrayError::broadcast_failure(a, b));
62        }
63    }
64    Ok(result)
65}
66
67/// Compute the broadcast shape from multiple shapes.
68///
69/// This is the N-ary version of [`broadcast_shapes`]. It folds pairwise
70/// over all input shapes.
71///
72/// # Errors
73/// Returns `FerrayError::BroadcastFailure` if any pair is incompatible.
74pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
75    if shapes.is_empty() {
76        return Ok(vec![]);
77    }
78    let mut result = shapes[0].to_vec();
79    for &s in &shapes[1..] {
80        result = broadcast_shapes(&result, s)?;
81    }
82    Ok(result)
83}
84
85/// Compute the strides for broadcasting a source shape to a target shape.
86///
87/// For dimensions where the source has size 1 but the target is larger,
88/// the stride is set to 0 (virtual expansion). For matching dimensions,
89/// the original stride is preserved. The source shape is left-padded with
90/// 1s (stride 0) as needed.
91///
92/// # Errors
93/// Returns `FerrayError::BroadcastFailure` if the source cannot be broadcast
94/// to the target (i.e., a source dimension is neither 1 nor equal to target).
95pub fn broadcast_strides(
96    src_shape: &[usize],
97    src_strides: &[isize],
98    target_shape: &[usize],
99) -> FerrayResult<Vec<isize>> {
100    let tndim = target_shape.len();
101    let sndim = src_shape.len();
102
103    if tndim < sndim {
104        return Err(FerrayError::shape_mismatch(format!(
105            "cannot broadcast shape {:?} to shape {:?}: target has fewer dimensions",
106            src_shape, target_shape
107        )));
108    }
109
110    let pad = tndim - sndim;
111    let mut out_strides = vec![0isize; tndim];
112
113    for i in 0..tndim {
114        if i < pad {
115            // Prepended dimension: virtual, stride = 0
116            out_strides[i] = 0;
117        } else {
118            let si = i - pad;
119            let src_dim = src_shape[si];
120            let tgt_dim = target_shape[i];
121
122            if src_dim == tgt_dim {
123                out_strides[i] = src_strides[si];
124            } else if src_dim == 1 {
125                // Broadcast: virtual expansion
126                out_strides[i] = 0;
127            } else {
128                return Err(FerrayError::shape_mismatch(format!(
129                    "cannot broadcast dimension {} (size {}) to size {}",
130                    si, src_dim, tgt_dim
131                )));
132            }
133        }
134    }
135
136    Ok(out_strides)
137}
138
139/// Broadcast an array to a target shape, returning a view.
140///
141/// The returned view uses stride-0 tricks to virtually expand size-1
142/// dimensions — no data is copied. The view borrows from the source array.
143///
144/// # Errors
145/// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
146/// to the given shape.
147pub fn broadcast_to<'a, T: Element, D: Dimension>(
148    array: &'a Array<T, D>,
149    target_shape: &[usize],
150) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
151    let src_shape = array.shape();
152    let src_strides = array.strides();
153
154    // Validate broadcast compatibility
155    let result_shape = broadcast_shapes(src_shape, target_shape)?;
156    if result_shape != target_shape {
157        return Err(FerrayError::shape_mismatch(format!(
158            "cannot broadcast shape {:?} to shape {:?}",
159            src_shape, target_shape
160        )));
161    }
162
163    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
164
165    // Validate all strides are non-negative before casting to usize.
166    // Negative strides (from reversed/transposed views) cannot be represented
167    // as usize and would wrap to huge values, causing out-of-bounds access.
168    for (i, &s) in new_strides.iter().enumerate() {
169        if s < 0 {
170            return Err(FerrayError::shape_mismatch(format!(
171                "cannot broadcast array with negative stride {} on axis {}; \
172                 make the array contiguous first",
173                s, i
174            )));
175        }
176    }
177
178    // Build ndarray view with computed strides
179    let nd_shape = ndarray::IxDyn(target_shape);
180    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
181
182    // Use from_shape_ptr with the broadcast strides
183    let ptr = array.as_ptr();
184    // SAFETY: broadcast strides are validated non-negative above and ensure
185    // we only access valid memory from the source array. Stride-0 dimensions
186    // repeat the same element.
187    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
188
189    Ok(ArrayView::from_ndarray(nd_view))
190}
191
192/// Broadcast an `ArrayView` to a target shape, returning a new view.
193///
194/// # Errors
195/// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
196pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
197    view: &ArrayView<'a, T, D>,
198    target_shape: &[usize],
199) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
200    let src_shape = view.shape();
201    let src_strides = view.strides();
202
203    let result_shape = broadcast_shapes(src_shape, target_shape)?;
204    if result_shape != target_shape {
205        return Err(FerrayError::shape_mismatch(format!(
206            "cannot broadcast shape {:?} to shape {:?}",
207            src_shape, target_shape
208        )));
209    }
210
211    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
212
213    for (i, &s) in new_strides.iter().enumerate() {
214        if s < 0 {
215            return Err(FerrayError::shape_mismatch(format!(
216                "cannot broadcast view with negative stride {} on axis {}; \
217                 make the array contiguous first",
218                s, i
219            )));
220        }
221    }
222
223    let nd_shape = ndarray::IxDyn(target_shape);
224    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
225
226    let ptr = view.as_ptr();
227    // SAFETY: strides validated non-negative above; broadcast strides ensure
228    // only valid source memory is accessed.
229    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
230
231    Ok(ArrayView::from_ndarray(nd_view))
232}
233
234/// Broadcast multiple arrays to a common shape.
235///
236/// Returns a vector of `ArrayView<IxDyn>` views, all sharing the same
237/// broadcast shape. No data is copied.
238///
239/// # Errors
240/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
241pub fn broadcast_arrays<'a, T: Element, D: Dimension>(
242    arrays: &'a [Array<T, D>],
243) -> FerrayResult<Vec<ArrayView<'a, T, IxDyn>>> {
244    if arrays.is_empty() {
245        return Ok(vec![]);
246    }
247
248    // Compute common broadcast shape
249    let shapes: Vec<&[usize]> = arrays.iter().map(|a| a.shape()).collect();
250    let target = broadcast_shapes_multi(&shapes)?;
251
252    // Broadcast each array to the common shape
253    let mut result = Vec::with_capacity(arrays.len());
254    for arr in arrays {
255        result.push(broadcast_to(arr, &target)?);
256    }
257    Ok(result)
258}
259
260// ---------------------------------------------------------------------------
261// Methods on Array for broadcasting
262// ---------------------------------------------------------------------------
263
264impl<T: Element, D: Dimension> Array<T, D> {
265    /// Broadcast this array to the given shape, returning a dynamic-rank view.
266    ///
267    /// Uses stride-0 tricks for virtual expansion — no data is copied.
268    ///
269    /// # Errors
270    /// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
271    /// to the target shape.
272    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
273        broadcast_to(self, target_shape)
274    }
275}
276
277impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
278    /// Broadcast this view to the given shape, returning a dynamic-rank view.
279    ///
280    /// # Errors
281    /// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
282    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
283        let src_shape = self.shape();
284        let src_strides = self.strides();
285
286        let result_shape = broadcast_shapes(src_shape, target_shape)?;
287        if result_shape != target_shape {
288            return Err(FerrayError::shape_mismatch(format!(
289                "cannot broadcast shape {:?} to shape {:?}",
290                src_shape, target_shape
291            )));
292        }
293
294        let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
295
296        for (i, &s) in new_strides.iter().enumerate() {
297            if s < 0 {
298                return Err(FerrayError::shape_mismatch(format!(
299                    "cannot broadcast view with negative stride {} on axis {}; \
300                     make the array contiguous first",
301                    s, i
302                )));
303            }
304        }
305
306        let nd_shape = ndarray::IxDyn(target_shape);
307        let nd_strides =
308            ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
309
310        let ptr = self.as_ptr();
311        // SAFETY: strides validated non-negative above; broadcast strides ensure
312        // only valid source memory is accessed.
313        let nd_view =
314            unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
315
316        Ok(ArrayView::from_ndarray(nd_view))
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::dimension::{Ix1, Ix2, Ix3};
324
325    // -----------------------------------------------------------------------
326    // broadcast_shapes tests
327    // -----------------------------------------------------------------------
328
329    #[test]
330    fn broadcast_shapes_same() {
331        assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
332    }
333
334    #[test]
335    fn broadcast_shapes_scalar() {
336        assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
337        assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
338    }
339
340    #[test]
341    fn broadcast_shapes_prepend_ones() {
342        // (4,3) + (3,) -> (4,3)
343        assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
344    }
345
346    #[test]
347    fn broadcast_shapes_stretch_ones() {
348        // (4,1) * (4,3) -> (4,3)
349        assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
350    }
351
352    #[test]
353    fn broadcast_shapes_3d() {
354        // (2,1,4) + (3,4) -> (2,3,4)
355        assert_eq!(
356            broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
357            vec![2, 3, 4]
358        );
359    }
360
361    #[test]
362    fn broadcast_shapes_both_ones() {
363        // (1,3) + (2,1) -> (2,3)
364        assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
365    }
366
367    #[test]
368    fn broadcast_shapes_incompatible() {
369        assert!(broadcast_shapes(&[3], &[4]).is_err());
370        assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
371    }
372
373    #[test]
374    fn broadcast_shapes_multi_test() {
375        let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
376        assert_eq!(result, vec![2, 3]);
377    }
378
379    #[test]
380    fn broadcast_shapes_multi_empty() {
381        assert_eq!(broadcast_shapes_multi(&[]).unwrap(), Vec::<usize>::new());
382    }
383
384    // -----------------------------------------------------------------------
385    // broadcast_strides tests
386    // -----------------------------------------------------------------------
387
388    #[test]
389    fn broadcast_strides_identity() {
390        let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
391        assert_eq!(strides, vec![3, 4]);
392    }
393
394    #[test]
395    fn broadcast_strides_expand_ones() {
396        // shape (1,4) with strides (4,1) -> target (3,4)
397        let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
398        assert_eq!(strides, vec![0, 1]);
399    }
400
401    #[test]
402    fn broadcast_strides_prepend() {
403        // shape (4,) with strides (1,) -> target (3, 4)
404        let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
405        assert_eq!(strides, vec![0, 1]);
406    }
407
408    // -----------------------------------------------------------------------
409    // broadcast_to tests
410    // -----------------------------------------------------------------------
411
412    #[test]
413    fn broadcast_to_1d_to_2d() {
414        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
415        let view = broadcast_to(&arr, &[4, 3]).unwrap();
416        assert_eq!(view.shape(), &[4, 3]);
417        assert_eq!(view.size(), 12);
418
419        // All rows should be the same
420        let data: Vec<f64> = view.iter().copied().collect();
421        assert_eq!(
422            data,
423            vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
424        );
425    }
426
427    #[test]
428    fn broadcast_to_column_to_2d() {
429        // (3,1) -> (3,4)
430        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
431        let view = broadcast_to(&arr, &[3, 4]).unwrap();
432        assert_eq!(view.shape(), &[3, 4]);
433
434        let data: Vec<f64> = view.iter().copied().collect();
435        assert_eq!(
436            data,
437            vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
438        );
439    }
440
441    #[test]
442    fn broadcast_to_no_materialization() {
443        // Verify that broadcast_to does NOT copy data
444        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
445        let view = broadcast_to(&arr, &[1000, 3]).unwrap();
446        assert_eq!(view.shape(), &[1000, 3]);
447        // The view shares the same base pointer
448        assert_eq!(view.as_ptr(), arr.as_ptr());
449    }
450
451    #[test]
452    fn broadcast_to_incompatible() {
453        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
454        assert!(broadcast_to(&arr, &[4, 5]).is_err());
455    }
456
457    #[test]
458    fn broadcast_to_scalar() {
459        // (1,) -> (5,)
460        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
461        let view = broadcast_to(&arr, &[5]).unwrap();
462        assert_eq!(view.shape(), &[5]);
463        let data: Vec<f64> = view.iter().copied().collect();
464        assert_eq!(data, vec![42.0; 5]);
465    }
466
467    // -----------------------------------------------------------------------
468    // broadcast_arrays tests
469    // -----------------------------------------------------------------------
470
471    #[test]
472    fn broadcast_arrays_test() {
473        let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
474        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
475        let arrays = [a, b];
476        let views = broadcast_arrays(&arrays).unwrap();
477        assert_eq!(views.len(), 2);
478        assert_eq!(views[0].shape(), &[4, 3]);
479        assert_eq!(views[1].shape(), &[4, 3]);
480    }
481
482    // -----------------------------------------------------------------------
483    // Method tests
484    // -----------------------------------------------------------------------
485
486    #[test]
487    fn array_broadcast_to_method() {
488        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
489        let view = arr.broadcast_to(&[2, 3]).unwrap();
490        assert_eq!(view.shape(), &[2, 3]);
491    }
492
493    #[test]
494    fn broadcast_3d() {
495        // (2,1,4) + (3,4) -> (2,3,4)
496        let a =
497            Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
498        let view = a.broadcast_to(&[2, 3, 4]).unwrap();
499        assert_eq!(view.shape(), &[2, 3, 4]);
500        assert_eq!(view.size(), 24);
501    }
502
503    #[test]
504    fn broadcast_to_same_shape() {
505        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
506        let view = arr.broadcast_to(&[2, 3]).unwrap();
507        assert_eq!(view.shape(), &[2, 3]);
508    }
509
510    #[test]
511    fn broadcast_to_cannot_shrink() {
512        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
513        assert!(arr.broadcast_to(&[3]).is_err());
514    }
515}