1use ndarray::ShapeBuilder;
12
13use crate::array::owned::Array;
14use crate::array::view::ArrayView;
15use crate::dimension::{Dimension, IxDyn};
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
39 let ndim = a.len().max(b.len());
40 let mut result = vec![0usize; ndim];
41
42 for i in 0..ndim {
43 let da = if i < ndim - a.len() {
44 1
45 } else {
46 a[i - (ndim - a.len())]
47 };
48 let db = if i < ndim - b.len() {
49 1
50 } else {
51 b[i - (ndim - b.len())]
52 };
53
54 if da == db {
55 result[i] = da;
56 } else if da == 1 {
57 result[i] = db;
58 } else if db == 1 {
59 result[i] = da;
60 } else {
61 return Err(FerrayError::broadcast_failure(a, b));
62 }
63 }
64 Ok(result)
65}
66
67pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
75 if shapes.is_empty() {
76 return Ok(vec![]);
77 }
78 let mut result = shapes[0].to_vec();
79 for &s in &shapes[1..] {
80 result = broadcast_shapes(&result, s)?;
81 }
82 Ok(result)
83}
84
85pub fn broadcast_strides(
96 src_shape: &[usize],
97 src_strides: &[isize],
98 target_shape: &[usize],
99) -> FerrayResult<Vec<isize>> {
100 let tndim = target_shape.len();
101 let sndim = src_shape.len();
102
103 if tndim < sndim {
104 return Err(FerrayError::shape_mismatch(format!(
105 "cannot broadcast shape {:?} to shape {:?}: target has fewer dimensions",
106 src_shape, target_shape
107 )));
108 }
109
110 let pad = tndim - sndim;
111 let mut out_strides = vec![0isize; tndim];
112
113 for i in 0..tndim {
114 if i < pad {
115 out_strides[i] = 0;
117 } else {
118 let si = i - pad;
119 let src_dim = src_shape[si];
120 let tgt_dim = target_shape[i];
121
122 if src_dim == tgt_dim {
123 out_strides[i] = src_strides[si];
124 } else if src_dim == 1 {
125 out_strides[i] = 0;
127 } else {
128 return Err(FerrayError::shape_mismatch(format!(
129 "cannot broadcast dimension {} (size {}) to size {}",
130 si, src_dim, tgt_dim
131 )));
132 }
133 }
134 }
135
136 Ok(out_strides)
137}
138
139pub fn broadcast_to<'a, T: Element, D: Dimension>(
148 array: &'a Array<T, D>,
149 target_shape: &[usize],
150) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
151 let src_shape = array.shape();
152 let src_strides = array.strides();
153
154 let result_shape = broadcast_shapes(src_shape, target_shape)?;
156 if result_shape != target_shape {
157 return Err(FerrayError::shape_mismatch(format!(
158 "cannot broadcast shape {:?} to shape {:?}",
159 src_shape, target_shape
160 )));
161 }
162
163 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
164
165 let nd_shape = ndarray::IxDyn(target_shape);
168 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
169
170 let ptr = array.as_ptr();
172 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
175
176 Ok(ArrayView::from_ndarray(nd_view))
177}
178
179pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
184 view: &ArrayView<'a, T, D>,
185 target_shape: &[usize],
186) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
187 let src_shape = view.shape();
188 let src_strides = view.strides();
189
190 let result_shape = broadcast_shapes(src_shape, target_shape)?;
191 if result_shape != target_shape {
192 return Err(FerrayError::shape_mismatch(format!(
193 "cannot broadcast shape {:?} to shape {:?}",
194 src_shape, target_shape
195 )));
196 }
197
198 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
199
200 let nd_shape = ndarray::IxDyn(target_shape);
201 let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
202
203 let ptr = view.as_ptr();
204 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
205
206 Ok(ArrayView::from_ndarray(nd_view))
207}
208
209pub fn broadcast_arrays<'a, T: Element, D: Dimension>(
217 arrays: &'a [Array<T, D>],
218) -> FerrayResult<Vec<ArrayView<'a, T, IxDyn>>> {
219 if arrays.is_empty() {
220 return Ok(vec![]);
221 }
222
223 let shapes: Vec<&[usize]> = arrays.iter().map(|a| a.shape()).collect();
225 let target = broadcast_shapes_multi(&shapes)?;
226
227 let mut result = Vec::with_capacity(arrays.len());
229 for arr in arrays {
230 result.push(broadcast_to(arr, &target)?);
231 }
232 Ok(result)
233}
234
235impl<T: Element, D: Dimension> Array<T, D> {
240 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
248 broadcast_to(self, target_shape)
249 }
250}
251
252impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
253 pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
258 let src_shape = self.shape();
259 let src_strides = self.strides();
260
261 let result_shape = broadcast_shapes(src_shape, target_shape)?;
262 if result_shape != target_shape {
263 return Err(FerrayError::shape_mismatch(format!(
264 "cannot broadcast shape {:?} to shape {:?}",
265 src_shape, target_shape
266 )));
267 }
268
269 let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
270
271 let nd_shape = ndarray::IxDyn(target_shape);
272 let nd_strides =
273 ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
274
275 let ptr = self.as_ptr();
276 let nd_view =
277 unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
278
279 Ok(ArrayView::from_ndarray(nd_view))
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::dimension::{Ix1, Ix2, Ix3};
287
288 #[test]
293 fn broadcast_shapes_same() {
294 assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
295 }
296
297 #[test]
298 fn broadcast_shapes_scalar() {
299 assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
300 assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
301 }
302
303 #[test]
304 fn broadcast_shapes_prepend_ones() {
305 assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
307 }
308
309 #[test]
310 fn broadcast_shapes_stretch_ones() {
311 assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
313 }
314
315 #[test]
316 fn broadcast_shapes_3d() {
317 assert_eq!(
319 broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
320 vec![2, 3, 4]
321 );
322 }
323
324 #[test]
325 fn broadcast_shapes_both_ones() {
326 assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
328 }
329
330 #[test]
331 fn broadcast_shapes_incompatible() {
332 assert!(broadcast_shapes(&[3], &[4]).is_err());
333 assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
334 }
335
336 #[test]
337 fn broadcast_shapes_multi_test() {
338 let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
339 assert_eq!(result, vec![2, 3]);
340 }
341
342 #[test]
343 fn broadcast_shapes_multi_empty() {
344 assert_eq!(broadcast_shapes_multi(&[]).unwrap(), vec![]);
345 }
346
347 #[test]
352 fn broadcast_strides_identity() {
353 let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
354 assert_eq!(strides, vec![3, 4]);
355 }
356
357 #[test]
358 fn broadcast_strides_expand_ones() {
359 let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
361 assert_eq!(strides, vec![0, 1]);
362 }
363
364 #[test]
365 fn broadcast_strides_prepend() {
366 let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
368 assert_eq!(strides, vec![0, 1]);
369 }
370
371 #[test]
376 fn broadcast_to_1d_to_2d() {
377 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
378 let view = broadcast_to(&arr, &[4, 3]).unwrap();
379 assert_eq!(view.shape(), &[4, 3]);
380 assert_eq!(view.size(), 12);
381
382 let data: Vec<f64> = view.iter().copied().collect();
384 assert_eq!(
385 data,
386 vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
387 );
388 }
389
390 #[test]
391 fn broadcast_to_column_to_2d() {
392 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
394 let view = broadcast_to(&arr, &[3, 4]).unwrap();
395 assert_eq!(view.shape(), &[3, 4]);
396
397 let data: Vec<f64> = view.iter().copied().collect();
398 assert_eq!(
399 data,
400 vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
401 );
402 }
403
404 #[test]
405 fn broadcast_to_no_materialization() {
406 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
408 let view = broadcast_to(&arr, &[1000, 3]).unwrap();
409 assert_eq!(view.shape(), &[1000, 3]);
410 assert_eq!(view.as_ptr(), arr.as_ptr());
412 }
413
414 #[test]
415 fn broadcast_to_incompatible() {
416 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
417 assert!(broadcast_to(&arr, &[4, 5]).is_err());
418 }
419
420 #[test]
421 fn broadcast_to_scalar() {
422 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
424 let view = broadcast_to(&arr, &[5]).unwrap();
425 assert_eq!(view.shape(), &[5]);
426 let data: Vec<f64> = view.iter().copied().collect();
427 assert_eq!(data, vec![42.0; 5]);
428 }
429
430 #[test]
435 fn broadcast_arrays_test() {
436 let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
437 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
438 let arrays = [a, b];
439 let views = broadcast_arrays(&arrays).unwrap();
440 assert_eq!(views.len(), 2);
441 assert_eq!(views[0].shape(), &[4, 3]);
442 assert_eq!(views[1].shape(), &[4, 3]);
443 }
444
445 #[test]
450 fn array_broadcast_to_method() {
451 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
452 let view = arr.broadcast_to(&[2, 3]).unwrap();
453 assert_eq!(view.shape(), &[2, 3]);
454 }
455
456 #[test]
457 fn broadcast_3d() {
458 let a =
460 Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
461 let view = a.broadcast_to(&[2, 3, 4]).unwrap();
462 assert_eq!(view.shape(), &[2, 3, 4]);
463 assert_eq!(view.size(), 24);
464 }
465
466 #[test]
467 fn broadcast_to_same_shape() {
468 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
469 let view = arr.broadcast_to(&[2, 3]).unwrap();
470 assert_eq!(view.shape(), &[2, 3]);
471 }
472
473 #[test]
474 fn broadcast_to_cannot_shrink() {
475 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
476 assert!(arr.broadcast_to(&[3]).is_err());
477 }
478}