1use ndarray::ShapeBuilder;
12
13use crate::array::owned::Array;
14use crate::array::view::ArrayView;
15use crate::dimension::{Dimension, IxDyn};
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
39 let ndim = a.len().max(b.len());
40 let mut result = vec![0usize; ndim];
41
42 for i in 0..ndim {
43 let da = if i < ndim - a.len() {
44 1
45 } else {
46 a[i - (ndim - a.len())]
47 };
48 let db = if i < ndim - b.len() {
49 1
50 } else {
51 b[i - (ndim - b.len())]
52 };
53
54 if da == db {
55 result[i] = da;
56 } else if da == 1 {
57 result[i] = db;
58 } else if db == 1 {
59 result[i] = da;
60 } else {
61 return Err(FerrayError::broadcast_failure(a, b));
62 }
63 }
64 Ok(result)
65}
66
67pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
75 if shapes.is_empty() {
76 return Ok(vec![]);
77 }
78 let mut result = shapes[0].to_vec();
79 for &s in &shapes[1..] {
80 result = broadcast_shapes(&result, s)?;
81 }
82 Ok(result)
83}
84
85pub fn broadcast_strides(
96 src_shape: &[usize],
97 src_strides: &[isize],
98 target_shape: &[usize],
99) -> FerrayResult<Vec<isize>> {
100 let tndim = target_shape.len();
101 let sndim = src_shape.len();
102
103 if tndim < sndim {
104 return Err(FerrayError::shape_mismatch(format!(
105 "cannot broadcast shape {:?} to shape {:?}: target has fewer dimensions",
106 src_shape, target_shape
107 )));
108 }
109
110 let pad = tndim - sndim;
111 let mut out_strides = vec![0isize; tndim];
112
113 for i in 0..tndim {
114 if i < pad {
115 out_strides[i] = 0;
117 } else {
118 let si = i - pad;
119 let src_dim = src_shape[si];
120 let tgt_dim = target_shape[i];
121
122 if src_dim == tgt_dim {
123 out_strides[i] = src_strides[si];
124 } else if src_dim == 1 {
125 out_strides[i] = 0;
127 } else {
128 return Err(FerrayError::shape_mismatch(format!(
129 "cannot broadcast dimension {} (size {}) to size {}",
130 si, src_dim, tgt_dim
131 )));
132 }
133 }
134 }
135
136 Ok(out_strides)
137}
138
139pub fn broadcast_to<'a, T: Element, D: Dimension>(
148 array: &'a Array<T, D>,
149 target_shape: &[usize],
150) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
151 let src_shape = array.shape();
152 let src_strides = array.strides();
153
154 let result_shape = broadcast_shapes(src_shape, target_shape)?;
156 if result_shape != target_shape {
157 return Err(FerrayError::shape_mismatch(format!(
158 "cannot broadcast shape {:?} to shape {:?}",
159 src_shape, target_shape
160 )));
161 }
162
163 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
164
165 for (i, &s) in new_strides.iter().enumerate() {
174 if s < 0 {
175 return Err(FerrayError::shape_mismatch(format!(
176 "cannot broadcast with negative stride {s} on axis {i}; \
177 call .to_owned() on the reversed/transposed array first",
178 s = s,
179 i = i
180 )));
181 }
182 }
183
184 let nd_shape = ndarray::IxDyn(target_shape);
186 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
187
188 let ptr = array.as_ptr();
190 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
194
195 Ok(ArrayView::from_ndarray(nd_view))
196}
197
198pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
203 view: &ArrayView<'a, T, D>,
204 target_shape: &[usize],
205) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
206 let src_shape = view.shape();
207 let src_strides = view.strides();
208
209 let result_shape = broadcast_shapes(src_shape, target_shape)?;
210 if result_shape != target_shape {
211 return Err(FerrayError::shape_mismatch(format!(
212 "cannot broadcast shape {:?} to shape {:?}",
213 src_shape, target_shape
214 )));
215 }
216
217 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
218
219 for (i, &s) in new_strides.iter().enumerate() {
220 if s < 0 {
221 return Err(FerrayError::shape_mismatch(format!(
222 "cannot broadcast view with negative stride {} on axis {}; \
223 call .to_owned() on the reversed/transposed view first",
224 s, i
225 )));
226 }
227 }
228
229 let nd_shape = ndarray::IxDyn(target_shape);
230 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
231
232 let ptr = view.as_ptr();
233 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
236
237 Ok(ArrayView::from_ndarray(nd_view))
238}
239
240pub fn broadcast_arrays<'a, T: Element, D: Dimension>(
248 arrays: &'a [Array<T, D>],
249) -> FerrayResult<Vec<ArrayView<'a, T, IxDyn>>> {
250 if arrays.is_empty() {
251 return Ok(vec![]);
252 }
253
254 let shapes: Vec<&[usize]> = arrays.iter().map(|a| a.shape()).collect();
256 let target = broadcast_shapes_multi(&shapes)?;
257
258 let mut result = Vec::with_capacity(arrays.len());
260 for arr in arrays {
261 result.push(broadcast_to(arr, &target)?);
262 }
263 Ok(result)
264}
265
266impl<T: Element, D: Dimension> Array<T, D> {
271 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
279 broadcast_to(self, target_shape)
280 }
281}
282
283impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
284 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
289 let src_shape = self.shape();
290 let src_strides = self.strides();
291
292 let result_shape = broadcast_shapes(src_shape, target_shape)?;
293 if result_shape != target_shape {
294 return Err(FerrayError::shape_mismatch(format!(
295 "cannot broadcast shape {:?} to shape {:?}",
296 src_shape, target_shape
297 )));
298 }
299
300 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
301
302 for (i, &s) in new_strides.iter().enumerate() {
303 if s < 0 {
304 return Err(FerrayError::shape_mismatch(format!(
305 "cannot broadcast view with negative stride {} on axis {}; \
306 make the array contiguous first",
307 s, i
308 )));
309 }
310 }
311
312 let nd_shape = ndarray::IxDyn(target_shape);
313 let nd_strides =
314 ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
315
316 let ptr = self.as_ptr();
317 let nd_view =
320 unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
321
322 Ok(ArrayView::from_ndarray(nd_view))
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::dimension::{Ix1, Ix2, Ix3};
330
331 #[test]
336 fn broadcast_shapes_same() {
337 assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
338 }
339
340 #[test]
341 fn broadcast_shapes_scalar() {
342 assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
343 assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
344 }
345
346 #[test]
347 fn broadcast_shapes_prepend_ones() {
348 assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
350 }
351
352 #[test]
353 fn broadcast_shapes_stretch_ones() {
354 assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
356 }
357
358 #[test]
359 fn broadcast_shapes_3d() {
360 assert_eq!(
362 broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
363 vec![2, 3, 4]
364 );
365 }
366
367 #[test]
368 fn broadcast_shapes_both_ones() {
369 assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
371 }
372
373 #[test]
374 fn broadcast_shapes_incompatible() {
375 assert!(broadcast_shapes(&[3], &[4]).is_err());
376 assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
377 }
378
379 #[test]
380 fn broadcast_shapes_multi_test() {
381 let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
382 assert_eq!(result, vec![2, 3]);
383 }
384
385 #[test]
386 fn broadcast_shapes_multi_empty() {
387 assert_eq!(broadcast_shapes_multi(&[]).unwrap(), Vec::<usize>::new());
388 }
389
390 #[test]
395 fn broadcast_strides_identity() {
396 let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
397 assert_eq!(strides, vec![3, 4]);
398 }
399
400 #[test]
401 fn broadcast_strides_expand_ones() {
402 let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
404 assert_eq!(strides, vec![0, 1]);
405 }
406
407 #[test]
408 fn broadcast_strides_prepend() {
409 let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
411 assert_eq!(strides, vec![0, 1]);
412 }
413
414 #[test]
419 fn broadcast_to_1d_to_2d() {
420 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
421 let view = broadcast_to(&arr, &[4, 3]).unwrap();
422 assert_eq!(view.shape(), &[4, 3]);
423 assert_eq!(view.size(), 12);
424
425 let data: Vec<f64> = view.iter().copied().collect();
427 assert_eq!(
428 data,
429 vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
430 );
431 }
432
433 #[test]
434 fn broadcast_to_column_to_2d() {
435 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
437 let view = broadcast_to(&arr, &[3, 4]).unwrap();
438 assert_eq!(view.shape(), &[3, 4]);
439
440 let data: Vec<f64> = view.iter().copied().collect();
441 assert_eq!(
442 data,
443 vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
444 );
445 }
446
447 #[test]
448 fn broadcast_to_no_materialization() {
449 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
451 let view = broadcast_to(&arr, &[1000, 3]).unwrap();
452 assert_eq!(view.shape(), &[1000, 3]);
453 assert_eq!(view.as_ptr(), arr.as_ptr());
455 }
456
457 #[test]
458 fn broadcast_to_incompatible() {
459 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
460 assert!(broadcast_to(&arr, &[4, 5]).is_err());
461 }
462
463 #[test]
464 fn broadcast_to_scalar() {
465 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
467 let view = broadcast_to(&arr, &[5]).unwrap();
468 assert_eq!(view.shape(), &[5]);
469 let data: Vec<f64> = view.iter().copied().collect();
470 assert_eq!(data, vec![42.0; 5]);
471 }
472
473 #[test]
478 fn broadcast_arrays_test() {
479 let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
480 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
481 let arrays = [a, b];
482 let views = broadcast_arrays(&arrays).unwrap();
483 assert_eq!(views.len(), 2);
484 assert_eq!(views[0].shape(), &[4, 3]);
485 assert_eq!(views[1].shape(), &[4, 3]);
486 }
487
488 #[test]
493 fn array_broadcast_to_method() {
494 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
495 let view = arr.broadcast_to(&[2, 3]).unwrap();
496 assert_eq!(view.shape(), &[2, 3]);
497 }
498
499 #[test]
500 fn broadcast_3d() {
501 let a =
503 Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
504 let view = a.broadcast_to(&[2, 3, 4]).unwrap();
505 assert_eq!(view.shape(), &[2, 3, 4]);
506 assert_eq!(view.size(), 24);
507 }
508
509 #[test]
510 fn broadcast_to_same_shape() {
511 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
512 let view = arr.broadcast_to(&[2, 3]).unwrap();
513 assert_eq!(view.shape(), &[2, 3]);
514 }
515
516 #[test]
517 fn broadcast_to_cannot_shrink() {
518 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
519 assert!(arr.broadcast_to(&[3]).is_err());
520 }
521}