Skip to main content

ferray_core/dimension/
broadcast.rs

1// ferray-core: Broadcasting logic (REQ-9, REQ-10, REQ-11)
2//
3// Implements NumPy's full broadcasting rules:
4//   1. Prepend 1s to shape of lower-dim array
5//   2. Stretch size-1 dimensions
6//   3. Error on size mismatch where neither is 1
7//
8// Broadcasting NEVER materializes the expanded array — it uses virtual
9// expansion via strides (setting stride = 0 for broadcast dimensions).
10
11use ndarray::ShapeBuilder;
12
13use crate::array::owned::Array;
14use crate::array::view::ArrayView;
15use crate::dimension::{Dimension, IxDyn};
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19/// Compute the broadcast shape from two shapes, following NumPy rules.
20///
21/// The result shape has `max(a.len(), b.len())` dimensions. Shorter shapes
22/// are left-padded with 1s. For each axis, the result dimension is the
23/// larger of the two inputs; if neither is 1 and they differ, an error
24/// is returned.
25///
26/// # Examples
27/// ```
28/// # use ferray_core::dimension::broadcast::broadcast_shapes;
29/// let result = broadcast_shapes(&[4, 3], &[3]).unwrap();
30/// assert_eq!(result, vec![4, 3]);
31///
32/// let result = broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap();
33/// assert_eq!(result, vec![2, 3, 4]);
34/// ```
35///
36/// # Errors
37/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
38pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
39    let ndim = a.len().max(b.len());
40    let mut result = vec![0usize; ndim];
41
42    for i in 0..ndim {
43        let da = if i < ndim - a.len() {
44            1
45        } else {
46            a[i - (ndim - a.len())]
47        };
48        let db = if i < ndim - b.len() {
49            1
50        } else {
51            b[i - (ndim - b.len())]
52        };
53
54        if da == db {
55            result[i] = da;
56        } else if da == 1 {
57            result[i] = db;
58        } else if db == 1 {
59            result[i] = da;
60        } else {
61            return Err(FerrayError::broadcast_failure(a, b));
62        }
63    }
64    Ok(result)
65}
66
67/// Compute the broadcast shape from multiple shapes.
68///
69/// This is the N-ary version of [`broadcast_shapes`]. It folds pairwise
70/// over all input shapes.
71///
72/// # Errors
73/// Returns `FerrayError::BroadcastFailure` if any pair is incompatible.
74pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
75    if shapes.is_empty() {
76        return Ok(vec![]);
77    }
78    let mut result = shapes[0].to_vec();
79    for &s in &shapes[1..] {
80        result = broadcast_shapes(&result, s)?;
81    }
82    Ok(result)
83}
84
85/// Compute the strides for broadcasting a source shape to a target shape.
86///
87/// For dimensions where the source has size 1 but the target is larger,
88/// the stride is set to 0 (virtual expansion). For matching dimensions,
89/// the original stride is preserved. The source shape is left-padded with
90/// 1s (stride 0) as needed.
91///
92/// # Errors
93/// Returns `FerrayError::BroadcastFailure` if the source cannot be broadcast
94/// to the target (i.e., a source dimension is neither 1 nor equal to target).
95pub fn broadcast_strides(
96    src_shape: &[usize],
97    src_strides: &[isize],
98    target_shape: &[usize],
99) -> FerrayResult<Vec<isize>> {
100    let tndim = target_shape.len();
101    let sndim = src_shape.len();
102
103    if tndim < sndim {
104        return Err(FerrayError::shape_mismatch(format!(
105            "cannot broadcast shape {:?} to shape {:?}: target has fewer dimensions",
106            src_shape, target_shape
107        )));
108    }
109
110    let pad = tndim - sndim;
111    let mut out_strides = vec![0isize; tndim];
112
113    for i in 0..tndim {
114        if i < pad {
115            // Prepended dimension: virtual, stride = 0
116            out_strides[i] = 0;
117        } else {
118            let si = i - pad;
119            let src_dim = src_shape[si];
120            let tgt_dim = target_shape[i];
121
122            if src_dim == tgt_dim {
123                out_strides[i] = src_strides[si];
124            } else if src_dim == 1 {
125                // Broadcast: virtual expansion
126                out_strides[i] = 0;
127            } else {
128                return Err(FerrayError::shape_mismatch(format!(
129                    "cannot broadcast dimension {} (size {}) to size {}",
130                    si, src_dim, tgt_dim
131                )));
132            }
133        }
134    }
135
136    Ok(out_strides)
137}
138
139/// Broadcast an array to a target shape, returning a view.
140///
141/// The returned view uses stride-0 tricks to virtually expand size-1
142/// dimensions — no data is copied. The view borrows from the source array.
143///
144/// # Errors
145/// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
146/// to the given shape.
147pub fn broadcast_to<'a, T: Element, D: Dimension>(
148    array: &'a Array<T, D>,
149    target_shape: &[usize],
150) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
151    let src_shape = array.shape();
152    let src_strides = array.strides();
153
154    // Validate broadcast compatibility
155    let result_shape = broadcast_shapes(src_shape, target_shape)?;
156    if result_shape != target_shape {
157        return Err(FerrayError::shape_mismatch(format!(
158            "cannot broadcast shape {:?} to shape {:?}",
159            src_shape, target_shape
160        )));
161    }
162
163    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
164
165    // Build ndarray view with computed strides
166    // We need to create an IxDyn view with the broadcast strides
167    let nd_shape = ndarray::IxDyn(target_shape);
168    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
169
170    // Use from_shape_ptr with the broadcast strides
171    let ptr = array.as_ptr();
172    // SAFETY: the broadcast strides ensure we only access valid memory
173    // from the source array. Stride-0 dimensions repeat the same element.
174    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
175
176    Ok(ArrayView::from_ndarray(nd_view))
177}
178
179/// Broadcast an `ArrayView` to a target shape, returning a new view.
180///
181/// # Errors
182/// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
183pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
184    view: &ArrayView<'a, T, D>,
185    target_shape: &[usize],
186) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
187    let src_shape = view.shape();
188    let src_strides = view.strides();
189
190    let result_shape = broadcast_shapes(src_shape, target_shape)?;
191    if result_shape != target_shape {
192        return Err(FerrayError::shape_mismatch(format!(
193            "cannot broadcast shape {:?} to shape {:?}",
194            src_shape, target_shape
195        )));
196    }
197
198    let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
199
200    let nd_shape = ndarray::IxDyn(target_shape);
201    let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
202
203    let ptr = view.as_ptr();
204    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
205
206    Ok(ArrayView::from_ndarray(nd_view))
207}
208
209/// Broadcast multiple arrays to a common shape.
210///
211/// Returns a vector of `ArrayView<IxDyn>` views, all sharing the same
212/// broadcast shape. No data is copied.
213///
214/// # Errors
215/// Returns `FerrayError::BroadcastFailure` if shapes are incompatible.
216pub fn broadcast_arrays<'a, T: Element, D: Dimension>(
217    arrays: &'a [Array<T, D>],
218) -> FerrayResult<Vec<ArrayView<'a, T, IxDyn>>> {
219    if arrays.is_empty() {
220        return Ok(vec![]);
221    }
222
223    // Compute common broadcast shape
224    let shapes: Vec<&[usize]> = arrays.iter().map(|a| a.shape()).collect();
225    let target = broadcast_shapes_multi(&shapes)?;
226
227    // Broadcast each array to the common shape
228    let mut result = Vec::with_capacity(arrays.len());
229    for arr in arrays {
230        result.push(broadcast_to(arr, &target)?);
231    }
232    Ok(result)
233}
234
235// ---------------------------------------------------------------------------
236// Methods on Array for broadcasting
237// ---------------------------------------------------------------------------
238
239impl<T: Element, D: Dimension> Array<T, D> {
240    /// Broadcast this array to the given shape, returning a dynamic-rank view.
241    ///
242    /// Uses stride-0 tricks for virtual expansion — no data is copied.
243    ///
244    /// # Errors
245    /// Returns `FerrayError::BroadcastFailure` if the array cannot be broadcast
246    /// to the target shape.
247    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
248        broadcast_to(self, target_shape)
249    }
250}
251
252impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
253    /// Broadcast this view to the given shape, returning a dynamic-rank view.
254    ///
255    /// # Errors
256    /// Returns `FerrayError::BroadcastFailure` if the view cannot be broadcast.
257    pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
258        let src_shape = self.shape();
259        let src_strides = self.strides();
260
261        let result_shape = broadcast_shapes(src_shape, target_shape)?;
262        if result_shape != target_shape {
263            return Err(FerrayError::shape_mismatch(format!(
264                "cannot broadcast shape {:?} to shape {:?}",
265                src_shape, target_shape
266            )));
267        }
268
269        let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
270
271        let nd_shape = ndarray::IxDyn(target_shape);
272        let nd_strides =
273            ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
274
275        let ptr = self.as_ptr();
276        let nd_view =
277            unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
278
279        Ok(ArrayView::from_ndarray(nd_view))
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::dimension::{Ix1, Ix2, Ix3};
287
288    // -----------------------------------------------------------------------
289    // broadcast_shapes tests
290    // -----------------------------------------------------------------------
291
292    #[test]
293    fn broadcast_shapes_same() {
294        assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
295    }
296
297    #[test]
298    fn broadcast_shapes_scalar() {
299        assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
300        assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
301    }
302
303    #[test]
304    fn broadcast_shapes_prepend_ones() {
305        // (4,3) + (3,) -> (4,3)
306        assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
307    }
308
309    #[test]
310    fn broadcast_shapes_stretch_ones() {
311        // (4,1) * (4,3) -> (4,3)
312        assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
313    }
314
315    #[test]
316    fn broadcast_shapes_3d() {
317        // (2,1,4) + (3,4) -> (2,3,4)
318        assert_eq!(
319            broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
320            vec![2, 3, 4]
321        );
322    }
323
324    #[test]
325    fn broadcast_shapes_both_ones() {
326        // (1,3) + (2,1) -> (2,3)
327        assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
328    }
329
330    #[test]
331    fn broadcast_shapes_incompatible() {
332        assert!(broadcast_shapes(&[3], &[4]).is_err());
333        assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
334    }
335
336    #[test]
337    fn broadcast_shapes_multi_test() {
338        let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
339        assert_eq!(result, vec![2, 3]);
340    }
341
342    #[test]
343    fn broadcast_shapes_multi_empty() {
344        assert_eq!(broadcast_shapes_multi(&[]).unwrap(), vec![]);
345    }
346
347    // -----------------------------------------------------------------------
348    // broadcast_strides tests
349    // -----------------------------------------------------------------------
350
351    #[test]
352    fn broadcast_strides_identity() {
353        let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
354        assert_eq!(strides, vec![3, 4]);
355    }
356
357    #[test]
358    fn broadcast_strides_expand_ones() {
359        // shape (1,4) with strides (4,1) -> target (3,4)
360        let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
361        assert_eq!(strides, vec![0, 1]);
362    }
363
364    #[test]
365    fn broadcast_strides_prepend() {
366        // shape (4,) with strides (1,) -> target (3, 4)
367        let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
368        assert_eq!(strides, vec![0, 1]);
369    }
370
371    // -----------------------------------------------------------------------
372    // broadcast_to tests
373    // -----------------------------------------------------------------------
374
375    #[test]
376    fn broadcast_to_1d_to_2d() {
377        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
378        let view = broadcast_to(&arr, &[4, 3]).unwrap();
379        assert_eq!(view.shape(), &[4, 3]);
380        assert_eq!(view.size(), 12);
381
382        // All rows should be the same
383        let data: Vec<f64> = view.iter().copied().collect();
384        assert_eq!(
385            data,
386            vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
387        );
388    }
389
390    #[test]
391    fn broadcast_to_column_to_2d() {
392        // (3,1) -> (3,4)
393        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
394        let view = broadcast_to(&arr, &[3, 4]).unwrap();
395        assert_eq!(view.shape(), &[3, 4]);
396
397        let data: Vec<f64> = view.iter().copied().collect();
398        assert_eq!(
399            data,
400            vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
401        );
402    }
403
404    #[test]
405    fn broadcast_to_no_materialization() {
406        // Verify that broadcast_to does NOT copy data
407        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
408        let view = broadcast_to(&arr, &[1000, 3]).unwrap();
409        assert_eq!(view.shape(), &[1000, 3]);
410        // The view shares the same base pointer
411        assert_eq!(view.as_ptr(), arr.as_ptr());
412    }
413
414    #[test]
415    fn broadcast_to_incompatible() {
416        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
417        assert!(broadcast_to(&arr, &[4, 5]).is_err());
418    }
419
420    #[test]
421    fn broadcast_to_scalar() {
422        // (1,) -> (5,)
423        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
424        let view = broadcast_to(&arr, &[5]).unwrap();
425        assert_eq!(view.shape(), &[5]);
426        let data: Vec<f64> = view.iter().copied().collect();
427        assert_eq!(data, vec![42.0; 5]);
428    }
429
430    // -----------------------------------------------------------------------
431    // broadcast_arrays tests
432    // -----------------------------------------------------------------------
433
434    #[test]
435    fn broadcast_arrays_test() {
436        let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
437        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
438        let arrays = [a, b];
439        let views = broadcast_arrays(&arrays).unwrap();
440        assert_eq!(views.len(), 2);
441        assert_eq!(views[0].shape(), &[4, 3]);
442        assert_eq!(views[1].shape(), &[4, 3]);
443    }
444
445    // -----------------------------------------------------------------------
446    // Method tests
447    // -----------------------------------------------------------------------
448
449    #[test]
450    fn array_broadcast_to_method() {
451        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
452        let view = arr.broadcast_to(&[2, 3]).unwrap();
453        assert_eq!(view.shape(), &[2, 3]);
454    }
455
456    #[test]
457    fn broadcast_3d() {
458        // (2,1,4) + (3,4) -> (2,3,4)
459        let a =
460            Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
461        let view = a.broadcast_to(&[2, 3, 4]).unwrap();
462        assert_eq!(view.shape(), &[2, 3, 4]);
463        assert_eq!(view.size(), 24);
464    }
465
466    #[test]
467    fn broadcast_to_same_shape() {
468        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
469        let view = arr.broadcast_to(&[2, 3]).unwrap();
470        assert_eq!(view.shape(), &[2, 3]);
471    }
472
473    #[test]
474    fn broadcast_to_cannot_shrink() {
475        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
476        assert!(arr.broadcast_to(&[3]).is_err());
477    }
478}