1use ndarray::ShapeBuilder;
27
28use crate::array::owned::Array;
29use crate::array::view::ArrayView;
30use crate::dimension::{Dimension, IxDyn};
31use crate::dtype::Element;
32use crate::error::{FerrayError, FerrayResult};
33
34pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
54 let ndim = a.len().max(b.len());
55 let mut result = vec![0usize; ndim];
56
57 for i in 0..ndim {
58 let da = if i < ndim - a.len() {
59 1
60 } else {
61 a[i - (ndim - a.len())]
62 };
63 let db = if i < ndim - b.len() {
64 1
65 } else {
66 b[i - (ndim - b.len())]
67 };
68
69 if da == db {
70 result[i] = da;
71 } else if da == 1 {
72 result[i] = db;
73 } else if db == 1 {
74 result[i] = da;
75 } else {
76 return Err(FerrayError::broadcast_failure(a, b));
77 }
78 }
79 Ok(result)
80}
81
82pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
90 if shapes.is_empty() {
91 return Ok(vec![]);
92 }
93 let mut result = shapes[0].to_vec();
94 for &s in &shapes[1..] {
95 result = broadcast_shapes(&result, s)?;
96 }
97 Ok(result)
98}
99
100pub fn broadcast_strides(
111 src_shape: &[usize],
112 src_strides: &[isize],
113 target_shape: &[usize],
114) -> FerrayResult<Vec<isize>> {
115 let tndim = target_shape.len();
116 let sndim = src_shape.len();
117
118 if tndim < sndim {
119 return Err(FerrayError::shape_mismatch(format!(
120 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}: target has fewer dimensions"
121 )));
122 }
123
124 let pad = tndim - sndim;
125 let mut out_strides = vec![0isize; tndim];
126
127 for i in 0..tndim {
128 if i < pad {
129 out_strides[i] = 0;
131 } else {
132 let si = i - pad;
133 let src_dim = src_shape[si];
134 let tgt_dim = target_shape[i];
135
136 if src_dim == tgt_dim {
137 out_strides[i] = src_strides[si];
138 } else if src_dim == 1 {
139 out_strides[i] = 0;
141 } else {
142 return Err(FerrayError::shape_mismatch(format!(
143 "cannot broadcast dimension {si} (size {src_dim}) to size {tgt_dim}"
144 )));
145 }
146 }
147 }
148
149 Ok(out_strides)
150}
151
152pub fn broadcast_to<'a, T: Element, D: Dimension>(
161 array: &'a Array<T, D>,
162 target_shape: &[usize],
163) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
164 let src_shape = array.shape();
165 let src_strides = array.strides();
166
167 let result_shape = broadcast_shapes(src_shape, target_shape)?;
169 if result_shape != target_shape {
170 return Err(FerrayError::shape_mismatch(format!(
171 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
172 )));
173 }
174
175 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
176
177 for (i, &s) in new_strides.iter().enumerate() {
186 if s < 0 {
187 return Err(FerrayError::shape_mismatch(format!(
188 "cannot broadcast with negative stride {s} on axis {i}; \
189 call .to_owned() on the reversed/transposed array first"
190 )));
191 }
192 }
193
194 let nd_shape = ndarray::IxDyn(target_shape);
196 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
197
198 let ptr = array.as_ptr();
200 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
204
205 Ok(ArrayView::from_ndarray(nd_view))
206}
207
208pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
213 view: &ArrayView<'a, T, D>,
214 target_shape: &[usize],
215) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
216 let src_shape = view.shape();
217 let src_strides = view.strides();
218
219 let result_shape = broadcast_shapes(src_shape, target_shape)?;
220 if result_shape != target_shape {
221 return Err(FerrayError::shape_mismatch(format!(
222 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
223 )));
224 }
225
226 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
227
228 for (i, &s) in new_strides.iter().enumerate() {
229 if s < 0 {
230 return Err(FerrayError::shape_mismatch(format!(
231 "cannot broadcast view with negative stride {s} on axis {i}; \
232 call .to_owned() on the reversed/transposed view first"
233 )));
234 }
235 }
236
237 let nd_shape = ndarray::IxDyn(target_shape);
238 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
239
240 let ptr = view.as_ptr();
241 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
244
245 Ok(ArrayView::from_ndarray(nd_view))
246}
247
248pub fn broadcast_arrays<T: Element, D: Dimension>(
256 arrays: &[Array<T, D>],
257) -> FerrayResult<Vec<ArrayView<'_, T, IxDyn>>> {
258 if arrays.is_empty() {
259 return Ok(vec![]);
260 }
261
262 let shapes: Vec<&[usize]> = arrays
264 .iter()
265 .map(super::super::array::owned::Array::shape)
266 .collect();
267 let target = broadcast_shapes_multi(&shapes)?;
268
269 let mut result = Vec::with_capacity(arrays.len());
271 for arr in arrays {
272 result.push(broadcast_to(arr, &target)?);
273 }
274 Ok(result)
275}
276
277impl<T: Element, D: Dimension> Array<T, D> {
282 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
290 broadcast_to(self, target_shape)
291 }
292}
293
294impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
295 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
300 let src_shape = self.shape();
301 let src_strides = self.strides();
302
303 let result_shape = broadcast_shapes(src_shape, target_shape)?;
304 if result_shape != target_shape {
305 return Err(FerrayError::shape_mismatch(format!(
306 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
307 )));
308 }
309
310 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
311
312 for (i, &s) in new_strides.iter().enumerate() {
313 if s < 0 {
314 return Err(FerrayError::shape_mismatch(format!(
315 "cannot broadcast view with negative stride {s} on axis {i}; \
316 make the array contiguous first"
317 )));
318 }
319 }
320
321 let nd_shape = ndarray::IxDyn(target_shape);
322 let nd_strides =
323 ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
324
325 let ptr = self.as_ptr();
326 let nd_view =
329 unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
330
331 Ok(ArrayView::from_ndarray(nd_view))
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::dimension::{Ix1, Ix2, Ix3};
339
340 #[test]
345 fn broadcast_shapes_same() {
346 assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
347 }
348
349 #[test]
350 fn broadcast_shapes_scalar() {
351 assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
352 assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
353 }
354
355 #[test]
356 fn broadcast_shapes_prepend_ones() {
357 assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
359 }
360
361 #[test]
362 fn broadcast_shapes_stretch_ones() {
363 assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
365 }
366
367 #[test]
368 fn broadcast_shapes_3d() {
369 assert_eq!(
371 broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
372 vec![2, 3, 4]
373 );
374 }
375
376 #[test]
377 fn broadcast_shapes_both_ones() {
378 assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
380 }
381
382 #[test]
383 fn broadcast_shapes_incompatible() {
384 assert!(broadcast_shapes(&[3], &[4]).is_err());
385 assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
386 }
387
388 #[test]
389 fn broadcast_shapes_multi_test() {
390 let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
391 assert_eq!(result, vec![2, 3]);
392 }
393
394 #[test]
395 fn broadcast_shapes_multi_empty() {
396 assert_eq!(broadcast_shapes_multi(&[]).unwrap(), Vec::<usize>::new());
397 }
398
399 #[test]
404 fn broadcast_strides_identity() {
405 let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
406 assert_eq!(strides, vec![3, 4]);
407 }
408
409 #[test]
410 fn broadcast_strides_expand_ones() {
411 let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
413 assert_eq!(strides, vec![0, 1]);
414 }
415
416 #[test]
417 fn broadcast_strides_prepend() {
418 let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
420 assert_eq!(strides, vec![0, 1]);
421 }
422
423 #[test]
428 fn broadcast_to_1d_to_2d() {
429 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
430 let view = broadcast_to(&arr, &[4, 3]).unwrap();
431 assert_eq!(view.shape(), &[4, 3]);
432 assert_eq!(view.size(), 12);
433
434 let data: Vec<f64> = view.iter().copied().collect();
436 assert_eq!(
437 data,
438 vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
439 );
440 }
441
442 #[test]
443 fn broadcast_to_column_to_2d() {
444 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
446 let view = broadcast_to(&arr, &[3, 4]).unwrap();
447 assert_eq!(view.shape(), &[3, 4]);
448
449 let data: Vec<f64> = view.iter().copied().collect();
450 assert_eq!(
451 data,
452 vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
453 );
454 }
455
456 #[test]
457 fn broadcast_to_no_materialization() {
458 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
460 let view = broadcast_to(&arr, &[1000, 3]).unwrap();
461 assert_eq!(view.shape(), &[1000, 3]);
462 assert_eq!(view.as_ptr(), arr.as_ptr());
464 }
465
466 #[test]
467 fn broadcast_to_incompatible() {
468 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
469 assert!(broadcast_to(&arr, &[4, 5]).is_err());
470 }
471
472 #[test]
473 fn broadcast_to_scalar() {
474 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
476 let view = broadcast_to(&arr, &[5]).unwrap();
477 assert_eq!(view.shape(), &[5]);
478 let data: Vec<f64> = view.iter().copied().collect();
479 assert_eq!(data, vec![42.0; 5]);
480 }
481
482 #[test]
487 fn broadcast_arrays_test() {
488 let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
489 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
490 let arrays = [a, b];
491 let views = broadcast_arrays(&arrays).unwrap();
492 assert_eq!(views.len(), 2);
493 assert_eq!(views[0].shape(), &[4, 3]);
494 assert_eq!(views[1].shape(), &[4, 3]);
495 }
496
497 #[test]
502 fn array_broadcast_to_method() {
503 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
504 let view = arr.broadcast_to(&[2, 3]).unwrap();
505 assert_eq!(view.shape(), &[2, 3]);
506 }
507
508 #[test]
509 fn broadcast_3d() {
510 let a =
512 Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
513 let view = a.broadcast_to(&[2, 3, 4]).unwrap();
514 assert_eq!(view.shape(), &[2, 3, 4]);
515 assert_eq!(view.size(), 24);
516 }
517
518 #[test]
519 fn broadcast_to_same_shape() {
520 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
521 let view = arr.broadcast_to(&[2, 3]).unwrap();
522 assert_eq!(view.shape(), &[2, 3]);
523 }
524
525 #[test]
526 fn broadcast_to_cannot_shrink() {
527 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
528 assert!(arr.broadcast_to(&[3]).is_err());
529 }
530}