Skip to main content

ferray_core/dimension/
broadcast.rs

1// ferray-core: Broadcasting logic (REQ-9, REQ-10, REQ-11)
2//
3// Implements NumPy's full broadcasting rules:
4//   1. Prepend 1s to shape of lower-dim array
5//   2. Stretch size-1 dimensions
6//   3. Error on size mismatch where neither is 1
7//
8// Broadcasting NEVER materializes the expanded array — it uses virtual
9// expansion via strides (setting stride = 0 for broadcast dimensions).
10//
11// ## REQ status (broadcasting, NumPy parity)
12//  - REQ-9 (full broadcasting rules: prepend 1s, stretch size-1 dims, error on
13//    incompatible mismatch) — SHIPPED: `broadcast_shapes` (this file)
14//    left-pads the shorter shape with 1s, takes the per-axis max, and returns
15//    `FerrayError::BroadcastFailure` when neither dim is 1 and they differ;
16//    `broadcast_shapes_multi` folds it over N shapes.
17//  - REQ-10 (never materialize the broadcast array — virtual stride expansion) —
18//    SHIPPED: `broadcast_strides` (this file) sets stride = 0 for broadcast
19//    axes, and `broadcast_to`/`broadcast_view_to` produce views without copying.
20//    Non-test consumer: `nditer.rs` zips two operands via `broadcast_shapes`
21//    + `broadcast_to` with no intermediate allocation.
22//  - REQ-11 (provide `broadcast_to`/`broadcast_arrays`/`broadcast_shapes`) —
23//    SHIPPED: `broadcast_to`, `broadcast_view_to`, `broadcast_arrays`, and
24//    `broadcast_shapes` (this file).
25
26use ndarray::ShapeBuilder;
27
28use crate::array::owned::Array;
29use crate::array::view::ArrayView;
30use crate::dimension::{Dimension, IxDyn};
31use crate::dtype::Element;
32use crate::error::{FerrayError, FerrayResult};
33
34/// Compute the broadcast shape from two shapes, following `NumPy` rules.
35///
36/// The result shape has `max(a.len(), b.len())` dimensions. Shorter shapes
37/// are left-padded with 1s. For each axis, the result dimension is the
38/// larger of the two inputs; if neither is 1 and they differ, an error
39/// is returned.
40///
41/// # Examples
42/// ```
43/// # use ferray_core::dimension::broadcast::broadcast_shapes;
44/// let result = broadcast_shapes(&[4, 3], &[3]).unwrap();
45/// assert_eq!(result, vec![4, 3]);
46///
47/// let result = broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap();
48/// assert_eq!(result, vec![2, 3, 4]);
49/// ```
50///
51/// # Errors
52/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
53pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
54    let ndim = a.len().max(b.len());
55    let mut result = vec![0usize; ndim];
56
57    for i in 0..ndim {
58        let da = if i < ndim - a.len() {
59            1
60        } else {
61            a[i - (ndim - a.len())]
62        };
63        let db = if i < ndim - b.len() {
64            1
65        } else {
66            b[i - (ndim - b.len())]
67        };
68
69        if da == db {
70            result[i] = da;
71        } else if da == 1 {
72            result[i] = db;
73        } else if db == 1 {
74            result[i] = da;
75        } else {
76            return Err(FerrayError::broadcast_failure(a, b));
77        }
78    }
79    Ok(result)
80}
81
82/// Compute the broadcast shape from multiple shapes.
83///
84/// This is the N-ary version of [`broadcast_shapes`]. It folds pairwise
85/// over all input shapes.
86///
87/// # Errors
88/// Returns `FerrayError::BroadcastFailure` if any pair is incompatible.
89pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
90    if shapes.is_empty() {
91        return Ok(vec![]);
92    }
93    let mut result = shapes[0].to_vec();
94    for &s in &shapes[1..] {
95        result = broadcast_shapes(&result, s)?;
96    }
97    Ok(result)
98}
99
100/// Compute the strides for broadcasting a source shape to a target shape.
101///
102/// For dimensions where the source has size 1 but the target is larger,
103/// the stride is set to 0 (virtual expansion). For matching dimensions,
104/// the original stride is preserved. The source shape is left-padded with
105/// 1s (stride 0) as needed.
106///
107/// # Errors
108/// Returns `FerrayError::BroadcastFailure` if the source cannot be broadcast
109/// to the target (i.e., a source dimension is neither 1 nor equal to target).
110pub fn broadcast_strides(
111    src_shape: &[usize],
112    src_strides: &[isize],
113    target_shape: &[usize],
114) -> FerrayResult<Vec<isize>> {
115    let tndim = target_shape.len();
116    let sndim = src_shape.len();
117
118    if tndim < sndim {
119        return Err(FerrayError::shape_mismatch(format!(
120            "cannot broadcast shape {src_shape:?} to shape {target_shape:?}: target has fewer dimensions"
121        )));
122    }
123
124    let pad = tndim - sndim;
125    let mut out_strides = vec![0isize; tndim];
126
127    for i in 0..tndim {
128        if i < pad {
129            // Prepended dimension: virtual, stride = 0
130            out_strides[i] = 0;
131        } else {
132            let si = i - pad;
133            let src_dim = src_shape[si];
134            let tgt_dim = target_shape[i];
135
136            if src_dim == tgt_dim {
137                out_strides[i] = src_strides[si];
138            } else if src_dim == 1 {
139                // Broadcast: virtual expansion
140                out_strides[i] = 0;
141            } else {
142                return Err(FerrayError::shape_mismatch(format!(
143                    "cannot broadcast dimension {si} (size {src_dim}) to size {tgt_dim}"
144                )));
145            }
146        }
147    }
148
149    Ok(out_strides)
150}
151
152/// Broadcast an array to a target shape, returning a view.
153///
154/// The returned view uses stride-0 tricks to virtually expand size-1
155/// dimensions — no data is copied. The view borrows from the source array.
156///
157/// # Errors
158/// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
159/// to the given shape.
160pub fn broadcast_to<'a, T: Element, D: Dimension>(
161    array: &'a Array<T, D>,
162    target_shape: &[usize],
163) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
164    let src_shape = array.shape();
165    let src_strides = array.strides();
166
167    // Validate broadcast compatibility
168    let result_shape = broadcast_shapes(src_shape, target_shape)?;
169    if result_shape != target_shape {
170        return Err(FerrayError::shape_mismatch(format!(
171            "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
172        )));
173    }
174
175    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
176
177    // Validate all strides are non-negative before casting to usize.
178    // Negative strides (from reversed/transposed views) cannot be represented
179    // as usize and would wrap to huge values, causing out-of-bounds access.
180    //
181    // Note: owned Array<T, D> is always C-contiguous and never has negative
182    // strides, so this check only triggers for unusual ndarray internals.
183    // For ArrayViews with negative strides, use broadcast_view_to instead
184    // (which has the same limitation — call .to_owned() on the view first).
185    for (i, &s) in new_strides.iter().enumerate() {
186        if s < 0 {
187            return Err(FerrayError::shape_mismatch(format!(
188                "cannot broadcast with negative stride {s} on axis {i}; \
189                 call .to_owned() on the reversed/transposed array first"
190            )));
191        }
192    }
193
194    // Build ndarray view with computed strides
195    let nd_shape = ndarray::IxDyn(target_shape);
196    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
197
198    // Use from_shape_ptr with the broadcast strides
199    let ptr = array.as_ptr();
200    // SAFETY: broadcast strides are validated non-negative above and ensure
201    // we only access valid memory from the source array. Stride-0 dimensions
202    // repeat the same element.
203    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
204
205    Ok(ArrayView::from_ndarray(nd_view))
206}
207
208/// Broadcast an `ArrayView` to a target shape, returning a new view.
209///
210/// # Errors
211/// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
212pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
213    view: &ArrayView<'a, T, D>,
214    target_shape: &[usize],
215) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
216    let src_shape = view.shape();
217    let src_strides = view.strides();
218
219    let result_shape = broadcast_shapes(src_shape, target_shape)?;
220    if result_shape != target_shape {
221        return Err(FerrayError::shape_mismatch(format!(
222            "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
223        )));
224    }
225
226    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
227
228    for (i, &s) in new_strides.iter().enumerate() {
229        if s < 0 {
230            return Err(FerrayError::shape_mismatch(format!(
231                "cannot broadcast view with negative stride {s} on axis {i}; \
232                 call .to_owned() on the reversed/transposed view first"
233            )));
234        }
235    }
236
237    let nd_shape = ndarray::IxDyn(target_shape);
238    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
239
240    let ptr = view.as_ptr();
241    // SAFETY: strides validated non-negative above; broadcast strides ensure
242    // only valid source memory is accessed.
243    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
244
245    Ok(ArrayView::from_ndarray(nd_view))
246}
247
248/// Broadcast multiple arrays to a common shape.
249///
250/// Returns a vector of `ArrayView<IxDyn>` views, all sharing the same
251/// broadcast shape. No data is copied.
252///
253/// # Errors
254/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
255pub fn broadcast_arrays<T: Element, D: Dimension>(
256    arrays: &[Array<T, D>],
257) -> FerrayResult<Vec<ArrayView<'_, T, IxDyn>>> {
258    if arrays.is_empty() {
259        return Ok(vec![]);
260    }
261
262    // Compute common broadcast shape
263    let shapes: Vec<&[usize]> = arrays
264        .iter()
265        .map(super::super::array::owned::Array::shape)
266        .collect();
267    let target = broadcast_shapes_multi(&shapes)?;
268
269    // Broadcast each array to the common shape
270    let mut result = Vec::with_capacity(arrays.len());
271    for arr in arrays {
272        result.push(broadcast_to(arr, &target)?);
273    }
274    Ok(result)
275}
276
277// ---------------------------------------------------------------------------
278// Methods on Array for broadcasting
279// ---------------------------------------------------------------------------
280
281impl<T: Element, D: Dimension> Array<T, D> {
282    /// Broadcast this array to the given shape, returning a dynamic-rank view.
283    ///
284    /// Uses stride-0 tricks for virtual expansion — no data is copied.
285    ///
286    /// # Errors
287    /// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
288    /// to the target shape.
289    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
290        broadcast_to(self, target_shape)
291    }
292}
293
294impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
295    /// Broadcast this view to the given shape, returning a dynamic-rank view.
296    ///
297    /// # Errors
298    /// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
299    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
300        let src_shape = self.shape();
301        let src_strides = self.strides();
302
303        let result_shape = broadcast_shapes(src_shape, target_shape)?;
304        if result_shape != target_shape {
305            return Err(FerrayError::shape_mismatch(format!(
306                "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
307            )));
308        }
309
310        let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
311
312        for (i, &s) in new_strides.iter().enumerate() {
313            if s < 0 {
314                return Err(FerrayError::shape_mismatch(format!(
315                    "cannot broadcast view with negative stride {s} on axis {i}; \
316                     make the array contiguous first"
317                )));
318            }
319        }
320
321        let nd_shape = ndarray::IxDyn(target_shape);
322        let nd_strides =
323            ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
324
325        let ptr = self.as_ptr();
326        // SAFETY: strides validated non-negative above; broadcast strides ensure
327        // only valid source memory is accessed.
328        let nd_view =
329            unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
330
331        Ok(ArrayView::from_ndarray(nd_view))
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::dimension::{Ix1, Ix2, Ix3};
339
340    // -----------------------------------------------------------------------
341    // broadcast_shapes tests
342    // -----------------------------------------------------------------------
343
344    #[test]
345    fn broadcast_shapes_same() {
346        assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
347    }
348
349    #[test]
350    fn broadcast_shapes_scalar() {
351        assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
352        assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
353    }
354
355    #[test]
356    fn broadcast_shapes_prepend_ones() {
357        // (4,3) + (3,) -> (4,3)
358        assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
359    }
360
361    #[test]
362    fn broadcast_shapes_stretch_ones() {
363        // (4,1) * (4,3) -> (4,3)
364        assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
365    }
366
367    #[test]
368    fn broadcast_shapes_3d() {
369        // (2,1,4) + (3,4) -> (2,3,4)
370        assert_eq!(
371            broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
372            vec![2, 3, 4]
373        );
374    }
375
376    #[test]
377    fn broadcast_shapes_both_ones() {
378        // (1,3) + (2,1) -> (2,3)
379        assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
380    }
381
382    #[test]
383    fn broadcast_shapes_incompatible() {
384        assert!(broadcast_shapes(&[3], &[4]).is_err());
385        assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
386    }
387
388    #[test]
389    fn broadcast_shapes_multi_test() {
390        let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
391        assert_eq!(result, vec![2, 3]);
392    }
393
394    #[test]
395    fn broadcast_shapes_multi_empty() {
396        assert_eq!(broadcast_shapes_multi(&[]).unwrap(), Vec::<usize>::new());
397    }
398
399    // -----------------------------------------------------------------------
400    // broadcast_strides tests
401    // -----------------------------------------------------------------------
402
403    #[test]
404    fn broadcast_strides_identity() {
405        let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
406        assert_eq!(strides, vec![3, 4]);
407    }
408
409    #[test]
410    fn broadcast_strides_expand_ones() {
411        // shape (1,4) with strides (4,1) -> target (3,4)
412        let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
413        assert_eq!(strides, vec![0, 1]);
414    }
415
416    #[test]
417    fn broadcast_strides_prepend() {
418        // shape (4,) with strides (1,) -> target (3, 4)
419        let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
420        assert_eq!(strides, vec![0, 1]);
421    }
422
423    // -----------------------------------------------------------------------
424    // broadcast_to tests
425    // -----------------------------------------------------------------------
426
427    #[test]
428    fn broadcast_to_1d_to_2d() {
429        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
430        let view = broadcast_to(&arr, &[4, 3]).unwrap();
431        assert_eq!(view.shape(), &[4, 3]);
432        assert_eq!(view.size(), 12);
433
434        // All rows should be the same
435        let data: Vec<f64> = view.iter().copied().collect();
436        assert_eq!(
437            data,
438            vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
439        );
440    }
441
442    #[test]
443    fn broadcast_to_column_to_2d() {
444        // (3,1) -> (3,4)
445        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
446        let view = broadcast_to(&arr, &[3, 4]).unwrap();
447        assert_eq!(view.shape(), &[3, 4]);
448
449        let data: Vec<f64> = view.iter().copied().collect();
450        assert_eq!(
451            data,
452            vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
453        );
454    }
455
456    #[test]
457    fn broadcast_to_no_materialization() {
458        // Verify that broadcast_to does NOT copy data
459        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
460        let view = broadcast_to(&arr, &[1000, 3]).unwrap();
461        assert_eq!(view.shape(), &[1000, 3]);
462        // The view shares the same base pointer
463        assert_eq!(view.as_ptr(), arr.as_ptr());
464    }
465
466    #[test]
467    fn broadcast_to_incompatible() {
468        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
469        assert!(broadcast_to(&arr, &[4, 5]).is_err());
470    }
471
472    #[test]
473    fn broadcast_to_scalar() {
474        // (1,) -> (5,)
475        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
476        let view = broadcast_to(&arr, &[5]).unwrap();
477        assert_eq!(view.shape(), &[5]);
478        let data: Vec<f64> = view.iter().copied().collect();
479        assert_eq!(data, vec![42.0; 5]);
480    }
481
482    // -----------------------------------------------------------------------
483    // broadcast_arrays tests
484    // -----------------------------------------------------------------------
485
486    #[test]
487    fn broadcast_arrays_test() {
488        let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
489        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
490        let arrays = [a, b];
491        let views = broadcast_arrays(&arrays).unwrap();
492        assert_eq!(views.len(), 2);
493        assert_eq!(views[0].shape(), &[4, 3]);
494        assert_eq!(views[1].shape(), &[4, 3]);
495    }
496
497    // -----------------------------------------------------------------------
498    // Method tests
499    // -----------------------------------------------------------------------
500
501    #[test]
502    fn array_broadcast_to_method() {
503        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
504        let view = arr.broadcast_to(&[2, 3]).unwrap();
505        assert_eq!(view.shape(), &[2, 3]);
506    }
507
508    #[test]
509    fn broadcast_3d() {
510        // (2,1,4) + (3,4) -> (2,3,4)
511        let a =
512            Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
513        let view = a.broadcast_to(&[2, 3, 4]).unwrap();
514        assert_eq!(view.shape(), &[2, 3, 4]);
515        assert_eq!(view.size(), 24);
516    }
517
518    #[test]
519    fn broadcast_to_same_shape() {
520        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
521        let view = arr.broadcast_to(&[2, 3]).unwrap();
522        assert_eq!(view.shape(), &[2, 3]);
523    }
524
525    #[test]
526    fn broadcast_to_cannot_shrink() {
527        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
528        assert!(arr.broadcast_to(&[3]).is_err());
529    }
530}