use ndarray::ShapeBuilder;
use crate::array::owned::Array;
use crate::array::view::ArrayView;
use crate::dimension::{Dimension, IxDyn};
use crate::dtype::Element;
use crate::error::{FerrayError, FerrayResult};
pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> FerrayResult<Vec<usize>> {
let ndim = a.len().max(b.len());
let mut result = vec![0usize; ndim];
for i in 0..ndim {
let da = if i < ndim - a.len() {
1
} else {
a[i - (ndim - a.len())]
};
let db = if i < ndim - b.len() {
1
} else {
b[i - (ndim - b.len())]
};
if da == db {
result[i] = da;
} else if da == 1 {
result[i] = db;
} else if db == 1 {
result[i] = da;
} else {
return Err(FerrayError::broadcast_failure(a, b));
}
}
Ok(result)
}
pub fn broadcast_shapes_multi(shapes: &[&[usize]]) -> FerrayResult<Vec<usize>> {
if shapes.is_empty() {
return Ok(vec![]);
}
let mut result = shapes[0].to_vec();
for &s in &shapes[1..] {
result = broadcast_shapes(&result, s)?;
}
Ok(result)
}
pub fn broadcast_strides(
src_shape: &[usize],
src_strides: &[isize],
target_shape: &[usize],
) -> FerrayResult<Vec<isize>> {
let tndim = target_shape.len();
let sndim = src_shape.len();
if tndim < sndim {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast shape {:?} to shape {:?}: target has fewer dimensions",
src_shape, target_shape
)));
}
let pad = tndim - sndim;
let mut out_strides = vec![0isize; tndim];
for i in 0..tndim {
if i < pad {
out_strides[i] = 0;
} else {
let si = i - pad;
let src_dim = src_shape[si];
let tgt_dim = target_shape[i];
if src_dim == tgt_dim {
out_strides[i] = src_strides[si];
} else if src_dim == 1 {
out_strides[i] = 0;
} else {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast dimension {} (size {}) to size {}",
si, src_dim, tgt_dim
)));
}
}
}
Ok(out_strides)
}
pub fn broadcast_to<'a, T: Element, D: Dimension>(
array: &'a Array<T, D>,
target_shape: &[usize],
) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
let src_shape = array.shape();
let src_strides = array.strides();
let result_shape = broadcast_shapes(src_shape, target_shape)?;
if result_shape != target_shape {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast shape {:?} to shape {:?}",
src_shape, target_shape
)));
}
let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
for (i, &s) in new_strides.iter().enumerate() {
if s < 0 {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast array with negative stride {} on axis {}; \
make the array contiguous first",
s, i
)));
}
}
let nd_shape = ndarray::IxDyn(target_shape);
let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
let ptr = array.as_ptr();
let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
Ok(ArrayView::from_ndarray(nd_view))
}
pub fn broadcast_view_to<'a, T: Element, D: Dimension>(
view: &ArrayView<'a, T, D>,
target_shape: &[usize],
) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
let src_shape = view.shape();
let src_strides = view.strides();
let result_shape = broadcast_shapes(src_shape, target_shape)?;
if result_shape != target_shape {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast shape {:?} to shape {:?}",
src_shape, target_shape
)));
}
let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
for (i, &s) in new_strides.iter().enumerate() {
if s < 0 {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast view with negative stride {} on axis {}; \
make the array contiguous first",
s, i
)));
}
}
let nd_shape = ndarray::IxDyn(target_shape);
let nd_strides = ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
let ptr = view.as_ptr();
let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
Ok(ArrayView::from_ndarray(nd_view))
}
pub fn broadcast_arrays<'a, T: Element, D: Dimension>(
arrays: &'a [Array<T, D>],
) -> FerrayResult<Vec<ArrayView<'a, T, IxDyn>>> {
if arrays.is_empty() {
return Ok(vec![]);
}
let shapes: Vec<&[usize]> = arrays.iter().map(|a| a.shape()).collect();
let target = broadcast_shapes_multi(&shapes)?;
let mut result = Vec::with_capacity(arrays.len());
for arr in arrays {
result.push(broadcast_to(arr, &target)?);
}
Ok(result)
}
impl<T: Element, D: Dimension> Array<T, D> {
pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
broadcast_to(self, target_shape)
}
}
impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
pub fn broadcast_to(&self, target_shape: &[usize]) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
let src_shape = self.shape();
let src_strides = self.strides();
let result_shape = broadcast_shapes(src_shape, target_shape)?;
if result_shape != target_shape {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast shape {:?} to shape {:?}",
src_shape, target_shape
)));
}
let new_strides = broadcast_strides(src_shape, src_strides, target_shape)?;
for (i, &s) in new_strides.iter().enumerate() {
if s < 0 {
return Err(FerrayError::shape_mismatch(format!(
"cannot broadcast view with negative stride {} on axis {}; \
make the array contiguous first",
s, i
)));
}
}
let nd_shape = ndarray::IxDyn(target_shape);
let nd_strides =
ndarray::IxDyn(&new_strides.iter().map(|&s| s as usize).collect::<Vec<_>>());
let ptr = self.as_ptr();
let nd_view =
unsafe { ndarray::ArrayView::from_shape_ptr(nd_shape.strides(nd_strides), ptr) };
Ok(ArrayView::from_ndarray(nd_view))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::{Ix1, Ix2, Ix3};
#[test]
fn broadcast_shapes_same() {
assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).unwrap(), vec![3, 4]);
}
#[test]
fn broadcast_shapes_scalar() {
assert_eq!(broadcast_shapes(&[3, 4], &[]).unwrap(), vec![3, 4]);
assert_eq!(broadcast_shapes(&[], &[5]).unwrap(), vec![5]);
}
#[test]
fn broadcast_shapes_prepend_ones() {
assert_eq!(broadcast_shapes(&[4, 3], &[3]).unwrap(), vec![4, 3]);
}
#[test]
fn broadcast_shapes_stretch_ones() {
assert_eq!(broadcast_shapes(&[4, 1], &[4, 3]).unwrap(), vec![4, 3]);
}
#[test]
fn broadcast_shapes_3d() {
assert_eq!(
broadcast_shapes(&[2, 1, 4], &[3, 4]).unwrap(),
vec![2, 3, 4]
);
}
#[test]
fn broadcast_shapes_both_ones() {
assert_eq!(broadcast_shapes(&[1, 3], &[2, 1]).unwrap(), vec![2, 3]);
}
#[test]
fn broadcast_shapes_incompatible() {
assert!(broadcast_shapes(&[3], &[4]).is_err());
assert!(broadcast_shapes(&[2, 3], &[4, 3]).is_err());
}
#[test]
fn broadcast_shapes_multi_test() {
let result = broadcast_shapes_multi(&[&[2, 1], &[3], &[1, 3]]).unwrap();
assert_eq!(result, vec![2, 3]);
}
#[test]
fn broadcast_shapes_multi_empty() {
assert_eq!(broadcast_shapes_multi(&[]).unwrap(), Vec::<usize>::new());
}
#[test]
fn broadcast_strides_identity() {
let strides = broadcast_strides(&[3, 4], &[3, 4], &[3, 4]).unwrap();
assert_eq!(strides, vec![3, 4]);
}
#[test]
fn broadcast_strides_expand_ones() {
let strides = broadcast_strides(&[1, 4], &[4, 1], &[3, 4]).unwrap();
assert_eq!(strides, vec![0, 1]);
}
#[test]
fn broadcast_strides_prepend() {
let strides = broadcast_strides(&[4], &[1], &[3, 4]).unwrap();
assert_eq!(strides, vec![0, 1]);
}
#[test]
fn broadcast_to_1d_to_2d() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let view = broadcast_to(&arr, &[4, 3]).unwrap();
assert_eq!(view.shape(), &[4, 3]);
assert_eq!(view.size(), 12);
let data: Vec<f64> = view.iter().copied().collect();
assert_eq!(
data,
vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
);
}
#[test]
fn broadcast_to_column_to_2d() {
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
let view = broadcast_to(&arr, &[3, 4]).unwrap();
assert_eq!(view.shape(), &[3, 4]);
let data: Vec<f64> = view.iter().copied().collect();
assert_eq!(
data,
vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
);
}
#[test]
fn broadcast_to_no_materialization() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let view = broadcast_to(&arr, &[1000, 3]).unwrap();
assert_eq!(view.shape(), &[1000, 3]);
assert_eq!(view.as_ptr(), arr.as_ptr());
}
#[test]
fn broadcast_to_incompatible() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
assert!(broadcast_to(&arr, &[4, 5]).is_err());
}
#[test]
fn broadcast_to_scalar() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![42.0]).unwrap();
let view = broadcast_to(&arr, &[5]).unwrap();
assert_eq!(view.shape(), &[5]);
let data: Vec<f64> = view.iter().copied().collect();
assert_eq!(data, vec![42.0; 5]);
}
#[test]
fn broadcast_arrays_test() {
let a = Array::<f64, Ix2>::from_vec(Ix2::new([4, 1]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
let arrays = [a, b];
let views = broadcast_arrays(&arrays).unwrap();
assert_eq!(views.len(), 2);
assert_eq!(views[0].shape(), &[4, 3]);
assert_eq!(views[1].shape(), &[4, 3]);
}
#[test]
fn array_broadcast_to_method() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let view = arr.broadcast_to(&[2, 3]).unwrap();
assert_eq!(view.shape(), &[2, 3]);
}
#[test]
fn broadcast_3d() {
let a =
Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
let view = a.broadcast_to(&[2, 3, 4]).unwrap();
assert_eq!(view.shape(), &[2, 3, 4]);
assert_eq!(view.size(), 24);
}
#[test]
fn broadcast_to_same_shape() {
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
let view = arr.broadcast_to(&[2, 3]).unwrap();
assert_eq!(view.shape(), &[2, 3]);
}
#[test]
fn broadcast_to_cannot_shrink() {
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![1.0; 12]).unwrap();
assert!(arr.broadcast_to(&[3]).is_err());
}
}