Skip to main content

ferray_core/dimension/
broadcast.rs

1// ferray-core: Broadcasting logic (REQ-9, REQ-10, REQ-11)
2//
3// Implements NumPy's full broadcasting rules:
4//   1. Prepend 1s to shape of lower-dim array
5//   2. Stretch size-1 dimensions
6//   3. Error on size mismatch where neither is 1
7//
8// Broadcasting NEVER materializes the expanded array — it uses virtual
9// expansion via strides (setting stride = 0 for broadcast dimensions).
10
11use ndarray::ShapeBuilder;
12
13use crate::array::owned::Array;
14use crate::array::view::ArrayView;
15use crate::dimension::{Dimension, IxDyn};
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19/// Compute the broadcast shape from two shapes, following NumPy rules.
20///
21/// The result shape has `max(a.len(), b.len())` dimensions. Shorter shapes
22/// are left-padded with 1s. For each axis, the result dimension is the
23/// larger of the two inputs; if neither is 1 and they differ, an error
24/// is returned.
25///
26/// # Examples
27/// ```
28/// # use ferray_core::dimension::broadcast::broadcast_shapes;
29/// let result = broadcast_shapes(&[4, 3], &[3]).unwrap();
30/// assert_eq!(result, vec![4, 3]);
31///
32/// let result = broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap();
33/// assert_eq!(result, vec![2, 3, 4]);
34/// ```
35///
36/// # Errors
37/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
38pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
39    let ndim = a.len().max(b.len());
40    let mut result = vec![0usize; ndim];
41
42    for i in 0..ndim {
43        let da = if i < ndim - a.len() {
44            1
45        } else {
46            a[i - (ndim - a.len())]
47        };
48        let db = if i < ndim - b.len() {
49            1
50        } else {
51            b[i - (ndim - b.len())]
52        };
53
54        if da == db {
55            result[i] = da;
56        } else if da == 1 {
57            result[i] = db;
58        } else if db == 1 {
59            result[i] = da;
60        } else {
61            return Err(FerrayError::broadcast_failure(a, b));
62        }
63    }
64    Ok(result)
65}
66
67/// Compute the broadcast shape from multiple shapes.
68///
69/// This is the N-ary version of [`broadcast_shapes`]. It folds pairwise
70/// over all input shapes.
71///
72/// # Errors
73/// Returns `FerrayError::BroadcastFailure` if any pair is incompatible.
74pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
75    if shapes.is_empty() {
76        return Ok(vec![]);
77    }
78    let mut result = shapes[0].to_vec();
79    for &s in &shapes[1..] {
80        result = broadcast_shapes(&result, s)?;
81    }
82    Ok(result)
83}
84
85/// Compute the strides for broadcasting a source shape to a target shape.
86///
87/// For dimensions where the source has size 1 but the target is larger,
88/// the stride is set to 0 (virtual expansion). For matching dimensions,
89/// the original stride is preserved. The source shape is left-padded with
90/// 1s (stride 0) as needed.
91///
92/// # Errors
93/// Returns `FerrayError::BroadcastFailure` if the source cannot be broadcast
94/// to the target (i.e., a source dimension is neither 1 nor equal to target).
95pub fn broadcast_strides(
96    src_shape: &[usize],
97    src_strides: &[isize],
98    target_shape: &[usize],
99) -> FerrayResult<Vec<isize>> {
100    let tndim = target_shape.len();
101    let sndim = src_shape.len();
102
103    if tndim < sndim {
104        return Err(FerrayError::shape_mismatch(format!(
105            "cannot broadcast shape {:?} to shape {:?}: target has fewer dimensions",
106            src_shape, target_shape
107        )));
108    }
109
110    let pad = tndim - sndim;
111    let mut out_strides = vec![0isize; tndim];
112
113    for i in 0..tndim {
114        if i < pad {
115            // Prepended dimension: virtual, stride = 0
116            out_strides[i] = 0;
117        } else {
118            let si = i - pad;
119            let src_dim = src_shape[si];
120            let tgt_dim = target_shape[i];
121
122            if src_dim == tgt_dim {
123                out_strides[i] = src_strides[si];
124            } else if src_dim == 1 {
125                // Broadcast: virtual expansion
126                out_strides[i] = 0;
127            } else {
128                return Err(FerrayError::shape_mismatch(format!(
129                    "cannot broadcast dimension {} (size {}) to size {}",
130                    si, src_dim, tgt_dim
131                )));
132            }
133        }
134    }
135
136    Ok(out_strides)
137}
138
139/// Broadcast an array to a target shape, returning a view.
140///
141/// The returned view uses stride-0 tricks to virtually expand size-1
142/// dimensions — no data is copied. The view borrows from the source array.
143///
144/// # Errors
145/// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
146/// to the given shape.
147pub fn broadcast_to<'a, T: Element, D: Dimension>(
148    array: &'a Array<T, D>,
149    target_shape: &[usize],
150) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
151    let src_shape = array.shape();
152    let src_strides = array.strides();
153
154    // Validate broadcast compatibility
155    let result_shape = broadcast_shapes(src_shape, target_shape)?;
156    if result_shape != target_shape {
157        return Err(FerrayError::shape_mismatch(format!(
158            "cannot broadcast shape {:?} to shape {:?}",
159            src_shape, target_shape
160        )));
161    }
162
163    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
164
165    // Validate all strides are non-negative before casting to usize.
166    // Negative strides (from reversed/transposed views) cannot be represented
167    // as usize and would wrap to huge values, causing out-of-bounds access.
168    //
169    // Note: owned Array<T, D> is always C-contiguous and never has negative
170    // strides, so this check only triggers for unusual ndarray internals.
171    // For ArrayViews with negative strides, use broadcast_view_to instead
172    // (which has the same limitation — call .to_owned() on the view first).
173    for (i, &s) in new_strides.iter().enumerate() {
174        if s < 0 {
175            return Err(FerrayError::shape_mismatch(format!(
176                "cannot broadcast with negative stride {s} on axis {i}; \
177                 call .to_owned() on the reversed/transposed array first",
178                s = s,
179                i = i
180            )));
181        }
182    }
183
184    // Build ndarray view with computed strides
185    let nd_shape = ndarray::IxDyn(target_shape);
186    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
187
188    // Use from_shape_ptr with the broadcast strides
189    let ptr = array.as_ptr();
190    // SAFETY: broadcast strides are validated non-negative above and ensure
191    // we only access valid memory from the source array. Stride-0 dimensions
192    // repeat the same element.
193    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
194
195    Ok(ArrayView::from_ndarray(nd_view))
196}
197
198/// Broadcast an `ArrayView` to a target shape, returning a new view.
199///
200/// # Errors
201/// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
202pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
203    view: &ArrayView<'a, T, D>,
204    target_shape: &[usize],
205) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
206    let src_shape = view.shape();
207    let src_strides = view.strides();
208
209    let result_shape = broadcast_shapes(src_shape, target_shape)?;
210    if result_shape != target_shape {
211        return Err(FerrayError::shape_mismatch(format!(
212            "cannot broadcast shape {:?} to shape {:?}",
213            src_shape, target_shape
214        )));
215    }
216
217    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
218
219    for (i, &s) in new_strides.iter().enumerate() {
220        if s < 0 {
221            return Err(FerrayError::shape_mismatch(format!(
222                "cannot broadcast view with negative stride {} on axis {}; \
223                 call .to_owned() on the reversed/transposed view first",
224                s, i
225            )));
226        }
227    }
228
229    let nd_shape = ndarray::IxDyn(target_shape);
230    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
231
232    let ptr = view.as_ptr();
233    // SAFETY: strides validated non-negative above; broadcast strides ensure
234    // only valid source memory is accessed.
235    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
236
237    Ok(ArrayView::from_ndarray(nd_view))
238}
239
240/// Broadcast multiple arrays to a common shape.
241///
242/// Returns a vector of `ArrayView<IxDyn>` views, all sharing the same
243/// broadcast shape. No data is copied.
244///
245/// # Errors
246/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
247pub fn broadcast_arrays<'a, T: Element, D: Dimension>(
248    arrays: &'a [Array<T, D>],
249) -> FerrayResult<Vec<ArrayView<'a, T, IxDyn>>> {
250    if arrays.is_empty() {
251        return Ok(vec![]);
252    }
253
254    // Compute common broadcast shape
255    let shapes: Vec<&[usize]> = arrays.iter().map(|a| a.shape()).collect();
256    let target = broadcast_shapes_multi(&shapes)?;
257
258    // Broadcast each array to the common shape
259    let mut result = Vec::with_capacity(arrays.len());
260    for arr in arrays {
261        result.push(broadcast_to(arr, &target)?);
262    }
263    Ok(result)
264}
265
266// ---------------------------------------------------------------------------
267// Methods on Array for broadcasting
268// ---------------------------------------------------------------------------
269
270impl<T: Element, D: Dimension> Array<T, D> {
271    /// Broadcast this array to the given shape, returning a dynamic-rank view.
272    ///
273    /// Uses stride-0 tricks for virtual expansion — no data is copied.
274    ///
275    /// # Errors
276    /// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
277    /// to the target shape.
278    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
279        broadcast_to(self, target_shape)
280    }
281}
282
283impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
284    /// Broadcast this view to the given shape, returning a dynamic-rank view.
285    ///
286    /// # Errors
287    /// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
288    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
289        let src_shape = self.shape();
290        let src_strides = self.strides();
291
292        let result_shape = broadcast_shapes(src_shape, target_shape)?;
293        if result_shape != target_shape {
294            return Err(FerrayError::shape_mismatch(format!(
295                "cannot broadcast shape {:?} to shape {:?}",
296                src_shape, target_shape
297            )));
298        }
299
300        let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
301
302        for (i, &s) in new_strides.iter().enumerate() {
303            if s < 0 {
304                return Err(FerrayError::shape_mismatch(format!(
305                    "cannot broadcast view with negative stride {} on axis {}; \
306                     make the array contiguous first",
307                    s, i
308                )));
309            }
310        }
311
312        let nd_shape = ndarray::IxDyn(target_shape);
313        let nd_strides =
314            ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
315
316        let ptr = self.as_ptr();
317        // SAFETY: strides validated non-negative above; broadcast strides ensure
318        // only valid source memory is accessed.
319        let nd_view =
320            unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
321
322        Ok(ArrayView::from_ndarray(nd_view))
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::dimension::{Ix1, Ix2, Ix3};
330
331    // -----------------------------------------------------------------------
332    // broadcast_shapes tests
333    // -----------------------------------------------------------------------
334
335    #[test]
336    fn broadcast_shapes_same() {
337        assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
338    }
339
340    #[test]
341    fn broadcast_shapes_scalar() {
342        assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
343        assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
344    }
345
346    #[test]
347    fn broadcast_shapes_prepend_ones() {
348        // (4,3) + (3,) -> (4,3)
349        assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
350    }
351
352    #[test]
353    fn broadcast_shapes_stretch_ones() {
354        // (4,1) * (4,3) -> (4,3)
355        assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
356    }
357
358    #[test]
359    fn broadcast_shapes_3d() {
360        // (2,1,4) + (3,4) -> (2,3,4)
361        assert_eq!(
362            broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
363            vec![2, 3, 4]
364        );
365    }
366
367    #[test]
368    fn broadcast_shapes_both_ones() {
369        // (1,3) + (2,1) -> (2,3)
370        assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
371    }
372
373    #[test]
374    fn broadcast_shapes_incompatible() {
375        assert!(broadcast_shapes(&[3], &[4]).is_err());
376        assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
377    }
378
379    #[test]
380    fn broadcast_shapes_multi_test() {
381        let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
382        assert_eq!(result, vec![2, 3]);
383    }
384
385    #[test]
386    fn broadcast_shapes_multi_empty() {
387        assert_eq!(broadcast_shapes_multi(&[]).unwrap(), Vec::<usize>::new());
388    }
389
390    // -----------------------------------------------------------------------
391    // broadcast_strides tests
392    // -----------------------------------------------------------------------
393
394    #[test]
395    fn broadcast_strides_identity() {
396        let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
397        assert_eq!(strides, vec![3, 4]);
398    }
399
400    #[test]
401    fn broadcast_strides_expand_ones() {
402        // shape (1,4) with strides (4,1) -> target (3,4)
403        let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
404        assert_eq!(strides, vec![0, 1]);
405    }
406
407    #[test]
408    fn broadcast_strides_prepend() {
409        // shape (4,) with strides (1,) -> target (3, 4)
410        let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
411        assert_eq!(strides, vec![0, 1]);
412    }
413
414    // -----------------------------------------------------------------------
415    // broadcast_to tests
416    // -----------------------------------------------------------------------
417
418    #[test]
419    fn broadcast_to_1d_to_2d() {
420        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
421        let view = broadcast_to(&arr, &[4, 3]).unwrap();
422        assert_eq!(view.shape(), &[4, 3]);
423        assert_eq!(view.size(), 12);
424
425        // All rows should be the same
426        let data: Vec<f64> = view.iter().copied().collect();
427        assert_eq!(
428            data,
429            vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
430        );
431    }
432
433    #[test]
434    fn broadcast_to_column_to_2d() {
435        // (3,1) -> (3,4)
436        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
437        let view = broadcast_to(&arr, &[3, 4]).unwrap();
438        assert_eq!(view.shape(), &[3, 4]);
439
440        let data: Vec<f64> = view.iter().copied().collect();
441        assert_eq!(
442            data,
443            vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
444        );
445    }
446
447    #[test]
448    fn broadcast_to_no_materialization() {
449        // Verify that broadcast_to does NOT copy data
450        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
451        let view = broadcast_to(&arr, &[1000, 3]).unwrap();
452        assert_eq!(view.shape(), &[1000, 3]);
453        // The view shares the same base pointer
454        assert_eq!(view.as_ptr(), arr.as_ptr());
455    }
456
457    #[test]
458    fn broadcast_to_incompatible() {
459        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
460        assert!(broadcast_to(&arr, &[4, 5]).is_err());
461    }
462
463    #[test]
464    fn broadcast_to_scalar() {
465        // (1,) -> (5,)
466        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
467        let view = broadcast_to(&arr, &[5]).unwrap();
468        assert_eq!(view.shape(), &[5]);
469        let data: Vec<f64> = view.iter().copied().collect();
470        assert_eq!(data, vec![42.0; 5]);
471    }
472
473    // -----------------------------------------------------------------------
474    // broadcast_arrays tests
475    // -----------------------------------------------------------------------
476
477    #[test]
478    fn broadcast_arrays_test() {
479        let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
480        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
481        let arrays = [a, b];
482        let views = broadcast_arrays(&arrays).unwrap();
483        assert_eq!(views.len(), 2);
484        assert_eq!(views[0].shape(), &[4, 3]);
485        assert_eq!(views[1].shape(), &[4, 3]);
486    }
487
488    // -----------------------------------------------------------------------
489    // Method tests
490    // -----------------------------------------------------------------------
491
492    #[test]
493    fn array_broadcast_to_method() {
494        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
495        let view = arr.broadcast_to(&[2, 3]).unwrap();
496        assert_eq!(view.shape(), &[2, 3]);
497    }
498
499    #[test]
500    fn broadcast_3d() {
501        // (2,1,4) + (3,4) -> (2,3,4)
502        let a =
503            Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
504        let view = a.broadcast_to(&[2, 3, 4]).unwrap();
505        assert_eq!(view.shape(), &[2, 3, 4]);
506        assert_eq!(view.size(), 24);
507    }
508
509    #[test]
510    fn broadcast_to_same_shape() {
511        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
512        let view = arr.broadcast_to(&[2, 3]).unwrap();
513        assert_eq!(view.shape(), &[2, 3]);
514    }
515
516    #[test]
517    fn broadcast_to_cannot_shrink() {
518        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
519        assert!(arr.broadcast_to(&[3]).is_err());
520    }
521}