1use ndarray::ShapeBuilder;
12
13use crate::array::owned::Array;
14use crate::array::view::ArrayView;
15use crate::dimension::{Dimension, IxDyn};
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
39 let ndim = a.len().max(b.len());
40 let mut result = vec![0usize; ndim];
41
42 for i in 0..ndim {
43 let da = if i < ndim - a.len() {
44 1
45 } else {
46 a[i - (ndim - a.len())]
47 };
48 let db = if i < ndim - b.len() {
49 1
50 } else {
51 b[i - (ndim - b.len())]
52 };
53
54 if da == db {
55 result[i] = da;
56 } else if da == 1 {
57 result[i] = db;
58 } else if db == 1 {
59 result[i] = da;
60 } else {
61 return Err(FerrayError::broadcast_failure(a, b));
62 }
63 }
64 Ok(result)
65}
66
67pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
75 if shapes.is_empty() {
76 return Ok(vec![]);
77 }
78 let mut result = shapes[0].to_vec();
79 for &s in &shapes[1..] {
80 result = broadcast_shapes(&result, s)?;
81 }
82 Ok(result)
83}
84
85pub fn broadcast_strides(
96 src_shape: &[usize],
97 src_strides: &[isize],
98 target_shape: &[usize],
99) -> FerrayResult<Vec<isize>> {
100 let tndim = target_shape.len();
101 let sndim = src_shape.len();
102
103 if tndim < sndim {
104 return Err(FerrayError::shape_mismatch(format!(
105 "cannot broadcast shape {:?} to shape {:?}: target has fewer dimensions",
106 src_shape, target_shape
107 )));
108 }
109
110 let pad = tndim - sndim;
111 let mut out_strides = vec![0isize; tndim];
112
113 for i in 0..tndim {
114 if i < pad {
115 out_strides[i] = 0;
117 } else {
118 let si = i - pad;
119 let src_dim = src_shape[si];
120 let tgt_dim = target_shape[i];
121
122 if src_dim == tgt_dim {
123 out_strides[i] = src_strides[si];
124 } else if src_dim == 1 {
125 out_strides[i] = 0;
127 } else {
128 return Err(FerrayError::shape_mismatch(format!(
129 "cannot broadcast dimension {} (size {}) to size {}",
130 si, src_dim, tgt_dim
131 )));
132 }
133 }
134 }
135
136 Ok(out_strides)
137}
138
139pub fn broadcast_to<'a, T: Element, D: Dimension>(
148 array: &'a Array<T, D>,
149 target_shape: &[usize],
150) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
151 let src_shape = array.shape();
152 let src_strides = array.strides();
153
154 let result_shape = broadcast_shapes(src_shape, target_shape)?;
156 if result_shape != target_shape {
157 return Err(FerrayError::shape_mismatch(format!(
158 "cannot broadcast shape {:?} to shape {:?}",
159 src_shape, target_shape
160 )));
161 }
162
163 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
164
165 for (i, &s) in new_strides.iter().enumerate() {
169 if s < 0 {
170 return Err(FerrayError::shape_mismatch(format!(
171 "cannot broadcast array with negative stride {} on axis {}; \
172 make the array contiguous first",
173 s, i
174 )));
175 }
176 }
177
178 let nd_shape = ndarray::IxDyn(target_shape);
180 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
181
182 let ptr = array.as_ptr();
184 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
188
189 Ok(ArrayView::from_ndarray(nd_view))
190}
191
192pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
197 view: &ArrayView<'a, T, D>,
198 target_shape: &[usize],
199) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
200 let src_shape = view.shape();
201 let src_strides = view.strides();
202
203 let result_shape = broadcast_shapes(src_shape, target_shape)?;
204 if result_shape != target_shape {
205 return Err(FerrayError::shape_mismatch(format!(
206 "cannot broadcast shape {:?} to shape {:?}",
207 src_shape, target_shape
208 )));
209 }
210
211 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
212
213 for (i, &s) in new_strides.iter().enumerate() {
214 if s < 0 {
215 return Err(FerrayError::shape_mismatch(format!(
216 "cannot broadcast view with negative stride {} on axis {}; \
217 make the array contiguous first",
218 s, i
219 )));
220 }
221 }
222
223 let nd_shape = ndarray::IxDyn(target_shape);
224 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
225
226 let ptr = view.as_ptr();
227 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
230
231 Ok(ArrayView::from_ndarray(nd_view))
232}
233
234pub fn broadcast_arrays<'a, T: Element, D: Dimension>(
242 arrays: &'a [Array<T, D>],
243) -> FerrayResult<Vec<ArrayView<'a, T, IxDyn>>> {
244 if arrays.is_empty() {
245 return Ok(vec![]);
246 }
247
248 let shapes: Vec<&[usize]> = arrays.iter().map(|a| a.shape()).collect();
250 let target = broadcast_shapes_multi(&shapes)?;
251
252 let mut result = Vec::with_capacity(arrays.len());
254 for arr in arrays {
255 result.push(broadcast_to(arr, &target)?);
256 }
257 Ok(result)
258}
259
260impl<T: Element, D: Dimension> Array<T, D> {
265 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
273 broadcast_to(self, target_shape)
274 }
275}
276
277impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
278 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
283 let src_shape = self.shape();
284 let src_strides = self.strides();
285
286 let result_shape = broadcast_shapes(src_shape, target_shape)?;
287 if result_shape != target_shape {
288 return Err(FerrayError::shape_mismatch(format!(
289 "cannot broadcast shape {:?} to shape {:?}",
290 src_shape, target_shape
291 )));
292 }
293
294 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
295
296 for (i, &s) in new_strides.iter().enumerate() {
297 if s < 0 {
298 return Err(FerrayError::shape_mismatch(format!(
299 "cannot broadcast view with negative stride {} on axis {}; \
300 make the array contiguous first",
301 s, i
302 )));
303 }
304 }
305
306 let nd_shape = ndarray::IxDyn(target_shape);
307 let nd_strides =
308 ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
309
310 let ptr = self.as_ptr();
311 let nd_view =
314 unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
315
316 Ok(ArrayView::from_ndarray(nd_view))
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::dimension::{Ix1, Ix2, Ix3};
324
325 #[test]
330 fn broadcast_shapes_same() {
331 assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
332 }
333
334 #[test]
335 fn broadcast_shapes_scalar() {
336 assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
337 assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
338 }
339
340 #[test]
341 fn broadcast_shapes_prepend_ones() {
342 assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
344 }
345
346 #[test]
347 fn broadcast_shapes_stretch_ones() {
348 assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
350 }
351
352 #[test]
353 fn broadcast_shapes_3d() {
354 assert_eq!(
356 broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
357 vec![2, 3, 4]
358 );
359 }
360
361 #[test]
362 fn broadcast_shapes_both_ones() {
363 assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
365 }
366
367 #[test]
368 fn broadcast_shapes_incompatible() {
369 assert!(broadcast_shapes(&[3], &[4]).is_err());
370 assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
371 }
372
373 #[test]
374 fn broadcast_shapes_multi_test() {
375 let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
376 assert_eq!(result, vec![2, 3]);
377 }
378
379 #[test]
380 fn broadcast_shapes_multi_empty() {
381 assert_eq!(broadcast_shapes_multi(&[]).unwrap(), vec![]);
382 }
383
384 #[test]
389 fn broadcast_strides_identity() {
390 let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
391 assert_eq!(strides, vec![3, 4]);
392 }
393
394 #[test]
395 fn broadcast_strides_expand_ones() {
396 let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
398 assert_eq!(strides, vec![0, 1]);
399 }
400
401 #[test]
402 fn broadcast_strides_prepend() {
403 let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
405 assert_eq!(strides, vec![0, 1]);
406 }
407
408 #[test]
413 fn broadcast_to_1d_to_2d() {
414 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
415 let view = broadcast_to(&arr, &[4, 3]).unwrap();
416 assert_eq!(view.shape(), &[4, 3]);
417 assert_eq!(view.size(), 12);
418
419 let data: Vec<f64> = view.iter().copied().collect();
421 assert_eq!(
422 data,
423 vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
424 );
425 }
426
427 #[test]
428 fn broadcast_to_column_to_2d() {
429 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
431 let view = broadcast_to(&arr, &[3, 4]).unwrap();
432 assert_eq!(view.shape(), &[3, 4]);
433
434 let data: Vec<f64> = view.iter().copied().collect();
435 assert_eq!(
436 data,
437 vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
438 );
439 }
440
441 #[test]
442 fn broadcast_to_no_materialization() {
443 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
445 let view = broadcast_to(&arr, &[1000, 3]).unwrap();
446 assert_eq!(view.shape(), &[1000, 3]);
447 assert_eq!(view.as_ptr(), arr.as_ptr());
449 }
450
451 #[test]
452 fn broadcast_to_incompatible() {
453 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
454 assert!(broadcast_to(&arr, &[4, 5]).is_err());
455 }
456
457 #[test]
458 fn broadcast_to_scalar() {
459 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
461 let view = broadcast_to(&arr, &[5]).unwrap();
462 assert_eq!(view.shape(), &[5]);
463 let data: Vec<f64> = view.iter().copied().collect();
464 assert_eq!(data, vec![42.0; 5]);
465 }
466
467 #[test]
472 fn broadcast_arrays_test() {
473 let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
474 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
475 let arrays = [a, b];
476 let views = broadcast_arrays(&arrays).unwrap();
477 assert_eq!(views.len(), 2);
478 assert_eq!(views[0].shape(), &[4, 3]);
479 assert_eq!(views[1].shape(), &[4, 3]);
480 }
481
482 #[test]
487 fn array_broadcast_to_method() {
488 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
489 let view = arr.broadcast_to(&[2, 3]).unwrap();
490 assert_eq!(view.shape(), &[2, 3]);
491 }
492
493 #[test]
494 fn broadcast_3d() {
495 let a =
497 Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
498 let view = a.broadcast_to(&[2, 3, 4]).unwrap();
499 assert_eq!(view.shape(), &[2, 3, 4]);
500 assert_eq!(view.size(), 24);
501 }
502
503 #[test]
504 fn broadcast_to_same_shape() {
505 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
506 let view = arr.broadcast_to(&[2, 3]).unwrap();
507 assert_eq!(view.shape(), &[2, 3]);
508 }
509
510 #[test]
511 fn broadcast_to_cannot_shrink() {
512 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
513 assert!(arr.broadcast_to(&[3]).is_err());
514 }
515}