1use ndarray::ShapeBuilder;
12
13use crate::array::owned::Array;
14use crate::array::view::ArrayView;
15use crate::dimension::{Dimension, IxDyn};
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
39 let ndim = a.len().max(b.len());
40 let mut result = vec![0usize; ndim];
41
42 for i in 0..ndim {
43 let da = if i < ndim - a.len() {
44 1
45 } else {
46 a[i - (ndim - a.len())]
47 };
48 let db = if i < ndim - b.len() {
49 1
50 } else {
51 b[i - (ndim - b.len())]
52 };
53
54 if da == db {
55 result[i] = da;
56 } else if da == 1 {
57 result[i] = db;
58 } else if db == 1 {
59 result[i] = da;
60 } else {
61 return Err(FerrayError::broadcast_failure(a, b));
62 }
63 }
64 Ok(result)
65}
66
67pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
75 if shapes.is_empty() {
76 return Ok(vec![]);
77 }
78 let mut result = shapes[0].to_vec();
79 for &s in &shapes[1..] {
80 result = broadcast_shapes(&result, s)?;
81 }
82 Ok(result)
83}
84
85pub fn broadcast_strides(
96 src_shape: &[usize],
97 src_strides: &[isize],
98 target_shape: &[usize],
99) -> FerrayResult<Vec<isize>> {
100 let tndim = target_shape.len();
101 let sndim = src_shape.len();
102
103 if tndim < sndim {
104 return Err(FerrayError::shape_mismatch(format!(
105 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}: target has fewer dimensions"
106 )));
107 }
108
109 let pad = tndim - sndim;
110 let mut out_strides = vec![0isize; tndim];
111
112 for i in 0..tndim {
113 if i < pad {
114 out_strides[i] = 0;
116 } else {
117 let si = i - pad;
118 let src_dim = src_shape[si];
119 let tgt_dim = target_shape[i];
120
121 if src_dim == tgt_dim {
122 out_strides[i] = src_strides[si];
123 } else if src_dim == 1 {
124 out_strides[i] = 0;
126 } else {
127 return Err(FerrayError::shape_mismatch(format!(
128 "cannot broadcast dimension {si} (size {src_dim}) to size {tgt_dim}"
129 )));
130 }
131 }
132 }
133
134 Ok(out_strides)
135}
136
137pub fn broadcast_to<'a, T: Element, D: Dimension>(
146 array: &'a Array<T, D>,
147 target_shape: &[usize],
148) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
149 let src_shape = array.shape();
150 let src_strides = array.strides();
151
152 let result_shape = broadcast_shapes(src_shape, target_shape)?;
154 if result_shape != target_shape {
155 return Err(FerrayError::shape_mismatch(format!(
156 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
157 )));
158 }
159
160 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
161
162 for (i, &s) in new_strides.iter().enumerate() {
171 if s < 0 {
172 return Err(FerrayError::shape_mismatch(format!(
173 "cannot broadcast with negative stride {s} on axis {i}; \
174 call .to_owned() on the reversed/transposed array first"
175 )));
176 }
177 }
178
179 let nd_shape = ndarray::IxDyn(target_shape);
181 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
182
183 let ptr = array.as_ptr();
185 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
189
190 Ok(ArrayView::from_ndarray(nd_view))
191}
192
193pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
198 view: &ArrayView<'a, T, D>,
199 target_shape: &[usize],
200) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
201 let src_shape = view.shape();
202 let src_strides = view.strides();
203
204 let result_shape = broadcast_shapes(src_shape, target_shape)?;
205 if result_shape != target_shape {
206 return Err(FerrayError::shape_mismatch(format!(
207 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
208 )));
209 }
210
211 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
212
213 for (i, &s) in new_strides.iter().enumerate() {
214 if s < 0 {
215 return Err(FerrayError::shape_mismatch(format!(
216 "cannot broadcast view with negative stride {s} on axis {i}; \
217 call .to_owned() on the reversed/transposed view first"
218 )));
219 }
220 }
221
222 let nd_shape = ndarray::IxDyn(target_shape);
223 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
224
225 let ptr = view.as_ptr();
226 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
229
230 Ok(ArrayView::from_ndarray(nd_view))
231}
232
233pub fn broadcast_arrays<T: Element, D: Dimension>(
241 arrays: &[Array<T, D>],
242) -> FerrayResult<Vec<ArrayView<'_, T, IxDyn>>> {
243 if arrays.is_empty() {
244 return Ok(vec![]);
245 }
246
247 let shapes: Vec<&[usize]> = arrays
249 .iter()
250 .map(super::super::array::owned::Array::shape)
251 .collect();
252 let target = broadcast_shapes_multi(&shapes)?;
253
254 let mut result = Vec::with_capacity(arrays.len());
256 for arr in arrays {
257 result.push(broadcast_to(arr, &target)?);
258 }
259 Ok(result)
260}
261
262impl<T: Element, D: Dimension> Array<T, D> {
267 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
275 broadcast_to(self, target_shape)
276 }
277}
278
279impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
280 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
285 let src_shape = self.shape();
286 let src_strides = self.strides();
287
288 let result_shape = broadcast_shapes(src_shape, target_shape)?;
289 if result_shape != target_shape {
290 return Err(FerrayError::shape_mismatch(format!(
291 "cannot broadcast shape {src_shape:?} to shape {target_shape:?}"
292 )));
293 }
294
295 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
296
297 for (i, &s) in new_strides.iter().enumerate() {
298 if s < 0 {
299 return Err(FerrayError::shape_mismatch(format!(
300 "cannot broadcast view with negative stride {s} on axis {i}; \
301 make the array contiguous first"
302 )));
303 }
304 }
305
306 let nd_shape = ndarray::IxDyn(target_shape);
307 let nd_strides =
308 ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
309
310 let ptr = self.as_ptr();
311 let nd_view =
314 unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
315
316 Ok(ArrayView::from_ndarray(nd_view))
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::dimension::{Ix1, Ix2, Ix3};
324
325 #[test]
330 fn broadcast_shapes_same() {
331 assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
332 }
333
334 #[test]
335 fn broadcast_shapes_scalar() {
336 assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
337 assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
338 }
339
340 #[test]
341 fn broadcast_shapes_prepend_ones() {
342 assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
344 }
345
346 #[test]
347 fn broadcast_shapes_stretch_ones() {
348 assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
350 }
351
352 #[test]
353 fn broadcast_shapes_3d() {
354 assert_eq!(
356 broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
357 vec![2, 3, 4]
358 );
359 }
360
361 #[test]
362 fn broadcast_shapes_both_ones() {
363 assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
365 }
366
367 #[test]
368 fn broadcast_shapes_incompatible() {
369 assert!(broadcast_shapes(&[3], &[4]).is_err());
370 assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
371 }
372
373 #[test]
374 fn broadcast_shapes_multi_test() {
375 let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
376 assert_eq!(result, vec![2, 3]);
377 }
378
379 #[test]
380 fn broadcast_shapes_multi_empty() {
381 assert_eq!(broadcast_shapes_multi(&[]).unwrap(), Vec::<usize>::new());
382 }
383
384 #[test]
389 fn broadcast_strides_identity() {
390 let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
391 assert_eq!(strides, vec![3, 4]);
392 }
393
394 #[test]
395 fn broadcast_strides_expand_ones() {
396 let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
398 assert_eq!(strides, vec![0, 1]);
399 }
400
401 #[test]
402 fn broadcast_strides_prepend() {
403 let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
405 assert_eq!(strides, vec![0, 1]);
406 }
407
408 #[test]
413 fn broadcast_to_1d_to_2d() {
414 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
415 let view = broadcast_to(&arr, &[4, 3]).unwrap();
416 assert_eq!(view.shape(), &[4, 3]);
417 assert_eq!(view.size(), 12);
418
419 let data: Vec<f64> = view.iter().copied().collect();
421 assert_eq!(
422 data,
423 vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
424 );
425 }
426
427 #[test]
428 fn broadcast_to_column_to_2d() {
429 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
431 let view = broadcast_to(&arr, &[3, 4]).unwrap();
432 assert_eq!(view.shape(), &[3, 4]);
433
434 let data: Vec<f64> = view.iter().copied().collect();
435 assert_eq!(
436 data,
437 vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
438 );
439 }
440
441 #[test]
442 fn broadcast_to_no_materialization() {
443 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
445 let view = broadcast_to(&arr, &[1000, 3]).unwrap();
446 assert_eq!(view.shape(), &[1000, 3]);
447 assert_eq!(view.as_ptr(), arr.as_ptr());
449 }
450
451 #[test]
452 fn broadcast_to_incompatible() {
453 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
454 assert!(broadcast_to(&arr, &[4, 5]).is_err());
455 }
456
457 #[test]
458 fn broadcast_to_scalar() {
459 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
461 let view = broadcast_to(&arr, &[5]).unwrap();
462 assert_eq!(view.shape(), &[5]);
463 let data: Vec<f64> = view.iter().copied().collect();
464 assert_eq!(data, vec![42.0; 5]);
465 }
466
467 #[test]
472 fn broadcast_arrays_test() {
473 let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
474 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
475 let arrays = [a, b];
476 let views = broadcast_arrays(&arrays).unwrap();
477 assert_eq!(views.len(), 2);
478 assert_eq!(views[0].shape(), &[4, 3]);
479 assert_eq!(views[1].shape(), &[4, 3]);
480 }
481
482 #[test]
487 fn array_broadcast_to_method() {
488 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
489 let view = arr.broadcast_to(&[2, 3]).unwrap();
490 assert_eq!(view.shape(), &[2, 3]);
491 }
492
493 #[test]
494 fn broadcast_3d() {
495 let a =
497 Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
498 let view = a.broadcast_to(&[2, 3, 4]).unwrap();
499 assert_eq!(view.shape(), &[2, 3, 4]);
500 assert_eq!(view.size(), 24);
501 }
502
503 #[test]
504 fn broadcast_to_same_shape() {
505 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
506 let view = arr.broadcast_to(&[2, 3]).unwrap();
507 assert_eq!(view.shape(), &[2, 3]);
508 }
509
510 #[test]
511 fn broadcast_to_cannot_shrink() {
512 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
513 assert!(arr.broadcast_to(&[3]).is_err());
514 }
515}